mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
Compare commits
1 Commits
feat/launc
...
feat/migra
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e9d846eebb |
@@ -30,12 +30,6 @@ poetry run test
|
||||
# Run specific test
|
||||
poetry run pytest path/to/test_file.py::test_function_name
|
||||
|
||||
# Run block tests (tests that validate all blocks work correctly)
|
||||
poetry run pytest backend/blocks/test/test_block.py -xvs
|
||||
|
||||
# Run tests for a specific block (e.g., GetCurrentTimeBlock)
|
||||
poetry run pytest 'backend/blocks/test/test_block.py::test_available_blocks[GetCurrentTimeBlock]' -xvs
|
||||
|
||||
# Lint and format
|
||||
# prefer format if you want to just "fix" it and only get the errors that can't be autofixed
|
||||
poetry run format # Black + isort
|
||||
@@ -128,9 +122,6 @@ Key models (defined in `/backend/schema.prisma`):
|
||||
5. Register in block registry
|
||||
6. Generate the block uuid using `uuid.uuid4()`
|
||||
|
||||
Note: when making many new blocks analyze the interfaces for each of these blcoks and picture if they would go well together in a graph based editor or would they struggle to connect productively?
|
||||
ex: do the inputs and outputs tie well together?
|
||||
|
||||
**Modifying the API:**
|
||||
1. Update route in `/backend/backend/server/routers/`
|
||||
2. Add/update Pydantic models in same directory
|
||||
|
||||
@@ -7,5 +7,9 @@ class Settings:
|
||||
self.ENABLE_AUTH: bool = os.getenv("ENABLE_AUTH", "false").lower() == "true"
|
||||
self.JWT_ALGORITHM: str = "HS256"
|
||||
|
||||
@property
|
||||
def is_configured(self) -> bool:
|
||||
return bool(self.JWT_SECRET_KEY)
|
||||
|
||||
|
||||
settings = Settings()
|
||||
|
||||
@@ -1,14 +1,10 @@
|
||||
import asyncio
|
||||
import contextlib
|
||||
import logging
|
||||
from functools import wraps
|
||||
from json import JSONDecodeError
|
||||
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Optional, TypeVar
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.data.model import User
|
||||
from typing import Any, Awaitable, Callable, Dict, Optional, TypeVar, Union, cast
|
||||
|
||||
import ldclient
|
||||
from backend.util.json import loads as json_loads
|
||||
from fastapi import HTTPException
|
||||
from ldclient import Context, LDClient
|
||||
from ldclient.config import Config
|
||||
@@ -60,202 +56,32 @@ def shutdown_launchdarkly() -> None:
|
||||
|
||||
|
||||
def create_context(
|
||||
user_id: str, additional_attributes: Optional[dict[str, Any]] = None
|
||||
user_id: str, additional_attributes: Optional[Dict[str, Any]] = None
|
||||
) -> Context:
|
||||
"""Create LaunchDarkly context with optional additional attributes."""
|
||||
# Use the key from attributes if provided, otherwise use user_id
|
||||
context_key = user_id
|
||||
if additional_attributes and "key" in additional_attributes:
|
||||
context_key = additional_attributes["key"]
|
||||
|
||||
builder = Context.builder(str(context_key)).kind("user")
|
||||
|
||||
builder = Context.builder(str(user_id)).kind("user")
|
||||
if additional_attributes:
|
||||
for key, value in additional_attributes.items():
|
||||
# Skip kind and key as they're already set
|
||||
if key in ["kind", "key"]:
|
||||
continue
|
||||
elif key == "custom" and isinstance(value, dict):
|
||||
# Handle custom attributes object - these go as individual attributes
|
||||
for custom_key, custom_value in value.items():
|
||||
builder.set(custom_key, custom_value)
|
||||
else:
|
||||
builder.set(key, value)
|
||||
builder.set(key, value)
|
||||
return builder.build()
|
||||
|
||||
|
||||
async def _fetch_user_context_data(user_id: str) -> dict[str, Any]:
|
||||
def is_feature_enabled(flag_key: str, user_id: str, default: bool = False) -> bool:
|
||||
"""
|
||||
Fetch user data and build LaunchDarkly context.
|
||||
|
||||
Args:
|
||||
user_id: The user ID to fetch data for
|
||||
|
||||
Returns:
|
||||
Dictionary with user context data including role
|
||||
"""
|
||||
# Use the unified database access approach
|
||||
from backend.util.clients import get_database_manager_async_client
|
||||
|
||||
db_client = get_database_manager_async_client()
|
||||
user = await db_client.get_user_by_id(user_id)
|
||||
|
||||
# Build LaunchDarkly context from user data
|
||||
return _build_launchdarkly_context(user)
|
||||
|
||||
|
||||
def _build_launchdarkly_context(user: "User") -> dict[str, Any]:
|
||||
"""
|
||||
Build LaunchDarkly context data matching frontend format.
|
||||
|
||||
Returns a context like:
|
||||
{
|
||||
"kind": "user",
|
||||
"key": "user-id",
|
||||
"email": "user@example.com", # Optional
|
||||
"anonymous": false,
|
||||
"custom": {
|
||||
"role": "admin" # Optional
|
||||
}
|
||||
}
|
||||
|
||||
Args:
|
||||
user: User object from database
|
||||
|
||||
Returns:
|
||||
Dictionary with user context data
|
||||
"""
|
||||
from autogpt_libs.auth.models import DEFAULT_USER_ID
|
||||
|
||||
# Build basic context - always include kind, key, and anonymous
|
||||
context_data: dict[str, Any] = {
|
||||
"kind": "user",
|
||||
"key": user.id,
|
||||
"anonymous": False,
|
||||
}
|
||||
|
||||
# Add email if present
|
||||
if user.email:
|
||||
context_data["email"] = user.email
|
||||
|
||||
# Initialize custom attributes
|
||||
custom: dict[str, Any] = {}
|
||||
|
||||
# Determine user role from metadata
|
||||
role = None
|
||||
|
||||
# Check if user is default/system user
|
||||
if user.id == DEFAULT_USER_ID:
|
||||
role = "admin" # Default user has admin privileges when auth is disabled
|
||||
elif user.metadata:
|
||||
# Check for role in metadata
|
||||
try:
|
||||
# Handle both string (direct DB) and dict (RPC) formats
|
||||
if isinstance(user.metadata, str):
|
||||
metadata = json_loads(user.metadata)
|
||||
elif isinstance(user.metadata, dict):
|
||||
metadata = user.metadata
|
||||
else:
|
||||
metadata = {}
|
||||
|
||||
# Extract role from metadata if present
|
||||
if metadata.get("role"):
|
||||
role = metadata["role"]
|
||||
|
||||
except (JSONDecodeError, TypeError) as e:
|
||||
logger.debug(f"Failed to parse user metadata for context: {e}")
|
||||
|
||||
# Add role to custom attributes if present
|
||||
if role:
|
||||
custom["role"] = role
|
||||
|
||||
# Only add custom object if it has content
|
||||
if custom:
|
||||
context_data["custom"] = custom
|
||||
|
||||
return context_data
|
||||
|
||||
|
||||
async def is_feature_enabled(
|
||||
flag_key: str,
|
||||
user_id: str,
|
||||
default: bool = False,
|
||||
use_user_id_only: bool = False,
|
||||
additional_attributes: Optional[dict[str, Any]] = None,
|
||||
user_role: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Check if a feature flag is enabled for a user with full LaunchDarkly context support.
|
||||
Simple helper to check if a feature flag is enabled for a user.
|
||||
|
||||
Args:
|
||||
flag_key: The LaunchDarkly feature flag key
|
||||
user_id: The user ID to evaluate the flag for
|
||||
default: Default value if LaunchDarkly is unavailable or flag evaluation fails
|
||||
use_user_id_only: If True, only use user_id without fetching database context
|
||||
additional_attributes: Additional attributes to include in the context
|
||||
user_role: Optional user role (e.g., "admin", "user") to add to segments
|
||||
|
||||
Returns:
|
||||
True if feature is enabled, False otherwise
|
||||
"""
|
||||
try:
|
||||
client = get_client()
|
||||
|
||||
if use_user_id_only:
|
||||
# Simple context with just user ID (for backward compatibility)
|
||||
attrs = additional_attributes or {}
|
||||
if user_role:
|
||||
# Add role to custom attributes for consistency
|
||||
if "custom" not in attrs:
|
||||
attrs["custom"] = {}
|
||||
if isinstance(attrs["custom"], dict):
|
||||
attrs["custom"]["role"] = user_role
|
||||
context = create_context(str(user_id), attrs)
|
||||
else:
|
||||
# Full context with user segments and metadata from database
|
||||
try:
|
||||
user_data = await _fetch_user_context_data(user_id)
|
||||
except ImportError as e:
|
||||
# Database modules not available - fallback to simple context
|
||||
logger.debug(f"Database modules not available: {e}")
|
||||
user_data = {}
|
||||
except Exception as e:
|
||||
# Database error - log and fallback to simple context
|
||||
logger.warning(f"Failed to fetch user context for {user_id}: {e}")
|
||||
user_data = {}
|
||||
|
||||
# Merge additional attributes and role
|
||||
attrs = additional_attributes or {}
|
||||
|
||||
# If user_role is provided, add it to custom attributes
|
||||
if user_role:
|
||||
if "custom" not in user_data:
|
||||
user_data["custom"] = {}
|
||||
user_data["custom"]["role"] = user_role
|
||||
|
||||
# Merge additional attributes with user data
|
||||
# Handle custom attributes specially
|
||||
if "custom" in attrs and isinstance(attrs["custom"], dict):
|
||||
if "custom" not in user_data:
|
||||
user_data["custom"] = {}
|
||||
user_data["custom"].update(attrs["custom"])
|
||||
# Remove custom from attrs to avoid duplication
|
||||
attrs = {k: v for k, v in attrs.items() if k != "custom"}
|
||||
|
||||
# Merge remaining attributes
|
||||
final_attrs = {**user_data, **attrs}
|
||||
|
||||
context = create_context(str(user_id), final_attrs)
|
||||
|
||||
# Evaluate the flag
|
||||
result = client.variation(flag_key, context, default)
|
||||
|
||||
logger.debug(
|
||||
f"Feature flag {flag_key} for user {user_id}: {result} "
|
||||
f"(use_user_id_only: {use_user_id_only})"
|
||||
)
|
||||
|
||||
return result
|
||||
context = create_context(str(user_id))
|
||||
return client.variation(flag_key, context, default)
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(
|
||||
@@ -267,19 +93,16 @@ async def is_feature_enabled(
|
||||
def feature_flag(
|
||||
flag_key: str,
|
||||
default: bool = False,
|
||||
) -> Callable[[Callable[P, Awaitable[T]]], Callable[P, Awaitable[T]]]:
|
||||
) -> Callable[
|
||||
[Callable[P, Union[T, Awaitable[T]]]], Callable[P, Union[T, Awaitable[T]]]
|
||||
]:
|
||||
"""
|
||||
Decorator for async feature flag protected endpoints.
|
||||
|
||||
Args:
|
||||
flag_key: The LaunchDarkly feature flag key
|
||||
default: Default value if flag evaluation fails
|
||||
|
||||
Returns:
|
||||
Decorator that only works with async functions
|
||||
Decorator for feature flag protected endpoints.
|
||||
"""
|
||||
|
||||
def decorator(func: Callable[P, Awaitable[T]]) -> Callable[P, Awaitable[T]]:
|
||||
def decorator(
|
||||
func: Callable[P, Union[T, Awaitable[T]]],
|
||||
) -> Callable[P, Union[T, Awaitable[T]]]:
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
|
||||
try:
|
||||
@@ -293,24 +116,73 @@ def feature_flag(
|
||||
)
|
||||
is_enabled = default
|
||||
else:
|
||||
# Use the unified function with full context support
|
||||
is_enabled = await is_feature_enabled(
|
||||
flag_key, str(user_id), default, use_user_id_only=False
|
||||
)
|
||||
context = create_context(str(user_id))
|
||||
is_enabled = get_client().variation(flag_key, context, default)
|
||||
|
||||
if not is_enabled:
|
||||
raise HTTPException(status_code=404, detail="Feature not available")
|
||||
|
||||
return await func(*args, **kwargs)
|
||||
result = func(*args, **kwargs)
|
||||
if asyncio.iscoroutine(result):
|
||||
return await result
|
||||
return cast(T, result)
|
||||
except Exception as e:
|
||||
logger.error(f"Error evaluating feature flag {flag_key}: {e}")
|
||||
raise
|
||||
|
||||
return async_wrapper
|
||||
@wraps(func)
|
||||
def sync_wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
|
||||
try:
|
||||
user_id = kwargs.get("user_id")
|
||||
if not user_id:
|
||||
raise ValueError("user_id is required")
|
||||
|
||||
if not get_client().is_initialized():
|
||||
logger.warning(
|
||||
f"LaunchDarkly not initialized, using default={default}"
|
||||
)
|
||||
is_enabled = default
|
||||
else:
|
||||
context = create_context(str(user_id))
|
||||
is_enabled = get_client().variation(flag_key, context, default)
|
||||
|
||||
if not is_enabled:
|
||||
raise HTTPException(status_code=404, detail="Feature not available")
|
||||
|
||||
return cast(T, func(*args, **kwargs))
|
||||
except Exception as e:
|
||||
logger.error(f"Error evaluating feature flag {flag_key}: {e}")
|
||||
raise
|
||||
|
||||
return cast(
|
||||
Callable[P, Union[T, Awaitable[T]]],
|
||||
async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper,
|
||||
)
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def percentage_rollout(
|
||||
flag_key: str,
|
||||
default: bool = False,
|
||||
) -> Callable[
|
||||
[Callable[P, Union[T, Awaitable[T]]]], Callable[P, Union[T, Awaitable[T]]]
|
||||
]:
|
||||
"""Decorator for percentage-based rollouts."""
|
||||
return feature_flag(flag_key, default)
|
||||
|
||||
|
||||
def beta_feature(
|
||||
flag_key: Optional[str] = None,
|
||||
unauthorized_response: Any = {"message": "Not available in beta"},
|
||||
) -> Callable[
|
||||
[Callable[P, Union[T, Awaitable[T]]]], Callable[P, Union[T, Awaitable[T]]]
|
||||
]:
|
||||
"""Decorator for beta features."""
|
||||
actual_key = f"beta-{flag_key}" if flag_key else "beta"
|
||||
return feature_flag(actual_key, False)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def mock_flag_variation(flag_key: str, return_value: Any):
|
||||
"""Context manager for testing feature flags."""
|
||||
|
||||
@@ -1,26 +1,10 @@
|
||||
import inspect
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from functools import wraps
|
||||
from typing import (
|
||||
Any,
|
||||
Awaitable,
|
||||
Callable,
|
||||
ParamSpec,
|
||||
Protocol,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
cast,
|
||||
overload,
|
||||
runtime_checkable,
|
||||
)
|
||||
from typing import Awaitable, Callable, ParamSpec, TypeVar, cast, overload
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@overload
|
||||
def thread_cached(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]: ...
|
||||
@@ -73,153 +57,3 @@ def thread_cached(
|
||||
def clear_thread_cache(func: Callable) -> None:
|
||||
if clear := getattr(func, "clear_cache", None):
|
||||
clear()
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class AsyncCachedFunction(Protocol):
|
||||
"""Protocol for async functions with cache management methods."""
|
||||
|
||||
def cache_clear(self) -> None:
|
||||
"""Clear all cached entries."""
|
||||
return None
|
||||
|
||||
def cache_info(self) -> dict[str, Any]:
|
||||
"""Get cache statistics."""
|
||||
return {}
|
||||
|
||||
async def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
||||
"""Call the cached function."""
|
||||
return None
|
||||
|
||||
|
||||
def async_ttl_cache(
|
||||
maxsize: int = 128, ttl_seconds: int | None = None
|
||||
) -> Callable[[Callable[..., Awaitable[Any]]], AsyncCachedFunction]:
|
||||
"""
|
||||
TTL (Time To Live) cache decorator for async functions.
|
||||
|
||||
Similar to functools.lru_cache but works with async functions and includes optional TTL.
|
||||
|
||||
Args:
|
||||
maxsize: Maximum number of cached entries
|
||||
ttl_seconds: Time to live in seconds. If None, entries never expire (like lru_cache)
|
||||
|
||||
Returns:
|
||||
Decorator function
|
||||
|
||||
Example:
|
||||
# With TTL
|
||||
@async_ttl_cache(maxsize=1000, ttl_seconds=300)
|
||||
async def api_call(param: str) -> dict:
|
||||
return {"result": param}
|
||||
|
||||
# Without TTL (permanent cache like lru_cache)
|
||||
@async_ttl_cache(maxsize=1000)
|
||||
async def expensive_computation(param: str) -> dict:
|
||||
return {"result": param}
|
||||
"""
|
||||
|
||||
def decorator(
|
||||
async_func: Callable[..., Awaitable[Any]],
|
||||
) -> AsyncCachedFunction:
|
||||
# Cache storage - use union type to handle both cases
|
||||
cache_storage: dict[Any, Any | Tuple[Any, float]] = {}
|
||||
|
||||
@wraps(async_func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
# Create cache key from arguments
|
||||
key = (args, tuple(sorted(kwargs.items())))
|
||||
current_time = time.time()
|
||||
|
||||
# Check if we have a valid cached entry
|
||||
if key in cache_storage:
|
||||
if ttl_seconds is None:
|
||||
# No TTL - return cached result directly
|
||||
logger.debug(
|
||||
f"Cache hit for {async_func.__name__} with key: {str(key)[:50]}"
|
||||
)
|
||||
return cache_storage[key]
|
||||
else:
|
||||
# With TTL - check expiration
|
||||
cached_data = cache_storage[key]
|
||||
if isinstance(cached_data, tuple):
|
||||
result, timestamp = cached_data
|
||||
if current_time - timestamp < ttl_seconds:
|
||||
logger.debug(
|
||||
f"Cache hit for {async_func.__name__} with key: {str(key)[:50]}"
|
||||
)
|
||||
return result
|
||||
else:
|
||||
# Expired entry
|
||||
del cache_storage[key]
|
||||
logger.debug(
|
||||
f"Cache entry expired for {async_func.__name__}"
|
||||
)
|
||||
|
||||
# Cache miss or expired - fetch fresh data
|
||||
logger.debug(
|
||||
f"Cache miss for {async_func.__name__} with key: {str(key)[:50]}"
|
||||
)
|
||||
result = await async_func(*args, **kwargs)
|
||||
|
||||
# Store in cache
|
||||
if ttl_seconds is None:
|
||||
cache_storage[key] = result
|
||||
else:
|
||||
cache_storage[key] = (result, current_time)
|
||||
|
||||
# Simple cleanup when cache gets too large
|
||||
if len(cache_storage) > maxsize:
|
||||
# Remove oldest entries (simple FIFO cleanup)
|
||||
cutoff = maxsize // 2
|
||||
oldest_keys = list(cache_storage.keys())[:-cutoff] if cutoff > 0 else []
|
||||
for old_key in oldest_keys:
|
||||
cache_storage.pop(old_key, None)
|
||||
logger.debug(
|
||||
f"Cache cleanup: removed {len(oldest_keys)} entries for {async_func.__name__}"
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
# Add cache management methods (similar to functools.lru_cache)
|
||||
def cache_clear() -> None:
|
||||
cache_storage.clear()
|
||||
|
||||
def cache_info() -> dict[str, Any]:
|
||||
return {
|
||||
"size": len(cache_storage),
|
||||
"maxsize": maxsize,
|
||||
"ttl_seconds": ttl_seconds,
|
||||
}
|
||||
|
||||
# Attach methods to wrapper
|
||||
setattr(wrapper, "cache_clear", cache_clear)
|
||||
setattr(wrapper, "cache_info", cache_info)
|
||||
|
||||
return cast(AsyncCachedFunction, wrapper)
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def async_cache(
|
||||
maxsize: int = 128,
|
||||
) -> Callable[[Callable[..., Awaitable[Any]]], AsyncCachedFunction]:
|
||||
"""
|
||||
Process-level cache decorator for async functions (no TTL).
|
||||
|
||||
Similar to functools.lru_cache but works with async functions.
|
||||
This is a convenience wrapper around async_ttl_cache with ttl_seconds=None.
|
||||
|
||||
Args:
|
||||
maxsize: Maximum number of cached entries
|
||||
|
||||
Returns:
|
||||
Decorator function
|
||||
|
||||
Example:
|
||||
@async_cache(maxsize=1000)
|
||||
async def expensive_computation(param: str) -> dict:
|
||||
# Expensive computation here
|
||||
return {"result": param}
|
||||
"""
|
||||
return async_ttl_cache(maxsize=maxsize, ttl_seconds=None)
|
||||
|
||||
@@ -16,12 +16,7 @@ from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from autogpt_libs.utils.cache import (
|
||||
async_cache,
|
||||
async_ttl_cache,
|
||||
clear_thread_cache,
|
||||
thread_cached,
|
||||
)
|
||||
from autogpt_libs.utils.cache import clear_thread_cache, thread_cached
|
||||
|
||||
|
||||
class TestThreadCached:
|
||||
@@ -328,378 +323,3 @@ class TestThreadCached:
|
||||
|
||||
assert function_using_mock(2) == 42
|
||||
assert mock.call_count == 2
|
||||
|
||||
|
||||
class TestAsyncTTLCache:
|
||||
"""Tests for the @async_ttl_cache decorator."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_caching(self):
|
||||
"""Test basic caching functionality."""
|
||||
call_count = 0
|
||||
|
||||
@async_ttl_cache(maxsize=10, ttl_seconds=60)
|
||||
async def cached_function(x: int, y: int = 0) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
await asyncio.sleep(0.01) # Simulate async work
|
||||
return x + y
|
||||
|
||||
# First call
|
||||
result1 = await cached_function(1, 2)
|
||||
assert result1 == 3
|
||||
assert call_count == 1
|
||||
|
||||
# Second call with same args - should use cache
|
||||
result2 = await cached_function(1, 2)
|
||||
assert result2 == 3
|
||||
assert call_count == 1 # No additional call
|
||||
|
||||
# Different args - should call function again
|
||||
result3 = await cached_function(2, 3)
|
||||
assert result3 == 5
|
||||
assert call_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ttl_expiration(self):
|
||||
"""Test that cache entries expire after TTL."""
|
||||
call_count = 0
|
||||
|
||||
@async_ttl_cache(maxsize=10, ttl_seconds=1) # Short TTL
|
||||
async def short_lived_cache(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return x * 2
|
||||
|
||||
# First call
|
||||
result1 = await short_lived_cache(5)
|
||||
assert result1 == 10
|
||||
assert call_count == 1
|
||||
|
||||
# Second call immediately - should use cache
|
||||
result2 = await short_lived_cache(5)
|
||||
assert result2 == 10
|
||||
assert call_count == 1
|
||||
|
||||
# Wait for TTL to expire
|
||||
await asyncio.sleep(1.1)
|
||||
|
||||
# Third call after expiration - should call function again
|
||||
result3 = await short_lived_cache(5)
|
||||
assert result3 == 10
|
||||
assert call_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_info(self):
|
||||
"""Test cache info functionality."""
|
||||
|
||||
@async_ttl_cache(maxsize=5, ttl_seconds=300)
|
||||
async def info_test_function(x: int) -> int:
|
||||
return x * 3
|
||||
|
||||
# Check initial cache info
|
||||
info = info_test_function.cache_info()
|
||||
assert info["size"] == 0
|
||||
assert info["maxsize"] == 5
|
||||
assert info["ttl_seconds"] == 300
|
||||
|
||||
# Add an entry
|
||||
await info_test_function(1)
|
||||
info = info_test_function.cache_info()
|
||||
assert info["size"] == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_clear(self):
|
||||
"""Test cache clearing functionality."""
|
||||
call_count = 0
|
||||
|
||||
@async_ttl_cache(maxsize=10, ttl_seconds=60)
|
||||
async def clearable_function(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return x * 4
|
||||
|
||||
# First call
|
||||
result1 = await clearable_function(2)
|
||||
assert result1 == 8
|
||||
assert call_count == 1
|
||||
|
||||
# Second call - should use cache
|
||||
result2 = await clearable_function(2)
|
||||
assert result2 == 8
|
||||
assert call_count == 1
|
||||
|
||||
# Clear cache
|
||||
clearable_function.cache_clear()
|
||||
|
||||
# Third call after clear - should call function again
|
||||
result3 = await clearable_function(2)
|
||||
assert result3 == 8
|
||||
assert call_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_maxsize_cleanup(self):
|
||||
"""Test that cache cleans up when maxsize is exceeded."""
|
||||
call_count = 0
|
||||
|
||||
@async_ttl_cache(maxsize=3, ttl_seconds=60)
|
||||
async def size_limited_function(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return x**2
|
||||
|
||||
# Fill cache to maxsize
|
||||
await size_limited_function(1) # call_count: 1
|
||||
await size_limited_function(2) # call_count: 2
|
||||
await size_limited_function(3) # call_count: 3
|
||||
|
||||
info = size_limited_function.cache_info()
|
||||
assert info["size"] == 3
|
||||
|
||||
# Add one more entry - should trigger cleanup
|
||||
await size_limited_function(4) # call_count: 4
|
||||
|
||||
# Cache size should be reduced (cleanup removes oldest entries)
|
||||
info = size_limited_function.cache_info()
|
||||
assert info["size"] <= 3 # Should be cleaned up
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_argument_variations(self):
|
||||
"""Test caching with different argument patterns."""
|
||||
call_count = 0
|
||||
|
||||
@async_ttl_cache(maxsize=10, ttl_seconds=60)
|
||||
async def arg_test_function(a: int, b: str = "default", *, c: int = 100) -> str:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return f"{a}-{b}-{c}"
|
||||
|
||||
# Different ways to call with same logical arguments
|
||||
result1 = await arg_test_function(1, "test", c=200)
|
||||
assert call_count == 1
|
||||
|
||||
# Same arguments, same order - should use cache
|
||||
result2 = await arg_test_function(1, "test", c=200)
|
||||
assert call_count == 1
|
||||
assert result1 == result2
|
||||
|
||||
# Different arguments - should call function
|
||||
result3 = await arg_test_function(2, "test", c=200)
|
||||
assert call_count == 2
|
||||
assert result1 != result3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exception_handling(self):
|
||||
"""Test that exceptions are not cached."""
|
||||
call_count = 0
|
||||
|
||||
@async_ttl_cache(maxsize=10, ttl_seconds=60)
|
||||
async def exception_function(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if x < 0:
|
||||
raise ValueError("Negative value not allowed")
|
||||
return x * 2
|
||||
|
||||
# Successful call - should be cached
|
||||
result1 = await exception_function(5)
|
||||
assert result1 == 10
|
||||
assert call_count == 1
|
||||
|
||||
# Same successful call - should use cache
|
||||
result2 = await exception_function(5)
|
||||
assert result2 == 10
|
||||
assert call_count == 1
|
||||
|
||||
# Exception call - should not be cached
|
||||
with pytest.raises(ValueError):
|
||||
await exception_function(-1)
|
||||
assert call_count == 2
|
||||
|
||||
# Same exception call - should call again (not cached)
|
||||
with pytest.raises(ValueError):
|
||||
await exception_function(-1)
|
||||
assert call_count == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_calls(self):
|
||||
"""Test caching behavior with concurrent calls."""
|
||||
call_count = 0
|
||||
|
||||
@async_ttl_cache(maxsize=10, ttl_seconds=60)
|
||||
async def concurrent_function(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
await asyncio.sleep(0.05) # Simulate work
|
||||
return x * x
|
||||
|
||||
# Launch concurrent calls with same arguments
|
||||
tasks = [concurrent_function(3) for _ in range(5)]
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
# All results should be the same
|
||||
assert all(result == 9 for result in results)
|
||||
|
||||
# Note: Due to race conditions, call_count might be up to 5 for concurrent calls
|
||||
# This tests that the cache doesn't break under concurrent access
|
||||
assert 1 <= call_count <= 5
|
||||
|
||||
|
||||
class TestAsyncCache:
|
||||
"""Tests for the @async_cache decorator (no TTL)."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_caching_no_ttl(self):
|
||||
"""Test basic caching functionality without TTL."""
|
||||
call_count = 0
|
||||
|
||||
@async_cache(maxsize=10)
|
||||
async def cached_function(x: int, y: int = 0) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
await asyncio.sleep(0.01) # Simulate async work
|
||||
return x + y
|
||||
|
||||
# First call
|
||||
result1 = await cached_function(1, 2)
|
||||
assert result1 == 3
|
||||
assert call_count == 1
|
||||
|
||||
# Second call with same args - should use cache
|
||||
result2 = await cached_function(1, 2)
|
||||
assert result2 == 3
|
||||
assert call_count == 1 # No additional call
|
||||
|
||||
# Third call after some time - should still use cache (no TTL)
|
||||
await asyncio.sleep(0.05)
|
||||
result3 = await cached_function(1, 2)
|
||||
assert result3 == 3
|
||||
assert call_count == 1 # Still no additional call
|
||||
|
||||
# Different args - should call function again
|
||||
result4 = await cached_function(2, 3)
|
||||
assert result4 == 5
|
||||
assert call_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_ttl_vs_ttl_behavior(self):
|
||||
"""Test the difference between TTL and no-TTL caching."""
|
||||
ttl_call_count = 0
|
||||
no_ttl_call_count = 0
|
||||
|
||||
@async_ttl_cache(maxsize=10, ttl_seconds=1) # Short TTL
|
||||
async def ttl_function(x: int) -> int:
|
||||
nonlocal ttl_call_count
|
||||
ttl_call_count += 1
|
||||
return x * 2
|
||||
|
||||
@async_cache(maxsize=10) # No TTL
|
||||
async def no_ttl_function(x: int) -> int:
|
||||
nonlocal no_ttl_call_count
|
||||
no_ttl_call_count += 1
|
||||
return x * 2
|
||||
|
||||
# First calls
|
||||
await ttl_function(5)
|
||||
await no_ttl_function(5)
|
||||
assert ttl_call_count == 1
|
||||
assert no_ttl_call_count == 1
|
||||
|
||||
# Wait for TTL to expire
|
||||
await asyncio.sleep(1.1)
|
||||
|
||||
# Second calls after TTL expiry
|
||||
await ttl_function(5) # Should call function again (TTL expired)
|
||||
await no_ttl_function(5) # Should use cache (no TTL)
|
||||
assert ttl_call_count == 2 # TTL function called again
|
||||
assert no_ttl_call_count == 1 # No-TTL function still cached
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_cache_info(self):
|
||||
"""Test cache info for no-TTL cache."""
|
||||
|
||||
@async_cache(maxsize=5)
|
||||
async def info_test_function(x: int) -> int:
|
||||
return x * 3
|
||||
|
||||
# Check initial cache info
|
||||
info = info_test_function.cache_info()
|
||||
assert info["size"] == 0
|
||||
assert info["maxsize"] == 5
|
||||
assert info["ttl_seconds"] is None # No TTL
|
||||
|
||||
# Add an entry
|
||||
await info_test_function(1)
|
||||
info = info_test_function.cache_info()
|
||||
assert info["size"] == 1
|
||||
|
||||
|
||||
class TestTTLOptional:
|
||||
"""Tests for optional TTL functionality."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ttl_none_behavior(self):
|
||||
"""Test that ttl_seconds=None works like no TTL."""
|
||||
call_count = 0
|
||||
|
||||
@async_ttl_cache(maxsize=10, ttl_seconds=None)
|
||||
async def no_ttl_via_none(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return x**2
|
||||
|
||||
# First call
|
||||
result1 = await no_ttl_via_none(3)
|
||||
assert result1 == 9
|
||||
assert call_count == 1
|
||||
|
||||
# Wait (would expire if there was TTL)
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Second call - should still use cache
|
||||
result2 = await no_ttl_via_none(3)
|
||||
assert result2 == 9
|
||||
assert call_count == 1 # No additional call
|
||||
|
||||
# Check cache info
|
||||
info = no_ttl_via_none.cache_info()
|
||||
assert info["ttl_seconds"] is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_options_comparison(self):
|
||||
"""Test different cache options work as expected."""
|
||||
ttl_calls = 0
|
||||
no_ttl_calls = 0
|
||||
|
||||
@async_ttl_cache(maxsize=10, ttl_seconds=1) # With TTL
|
||||
async def ttl_function(x: int) -> int:
|
||||
nonlocal ttl_calls
|
||||
ttl_calls += 1
|
||||
return x * 10
|
||||
|
||||
@async_cache(maxsize=10) # Process-level cache (no TTL)
|
||||
async def process_function(x: int) -> int:
|
||||
nonlocal no_ttl_calls
|
||||
no_ttl_calls += 1
|
||||
return x * 10
|
||||
|
||||
# Both should cache initially
|
||||
await ttl_function(3)
|
||||
await process_function(3)
|
||||
assert ttl_calls == 1
|
||||
assert no_ttl_calls == 1
|
||||
|
||||
# Immediate second calls - both should use cache
|
||||
await ttl_function(3)
|
||||
await process_function(3)
|
||||
assert ttl_calls == 1
|
||||
assert no_ttl_calls == 1
|
||||
|
||||
# Wait for TTL to expire
|
||||
await asyncio.sleep(1.1)
|
||||
|
||||
# After TTL expiry
|
||||
await ttl_function(3) # Should call function again
|
||||
await process_function(3) # Should still use cache
|
||||
assert ttl_calls == 2 # TTL cache expired, called again
|
||||
assert no_ttl_calls == 1 # Process cache never expires
|
||||
|
||||
@@ -159,7 +159,6 @@ class AirtableOAuthHandler(BaseOAuthHandler):
|
||||
logger.info("Successfully refreshed tokens")
|
||||
|
||||
new_credentials = OAuth2Credentials(
|
||||
id=credentials.id,
|
||||
access_token=SecretStr(response.access_token),
|
||||
refresh_token=SecretStr(response.refresh_token),
|
||||
access_token_expires_at=int(time.time()) + response.expires_in,
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from enum import Enum
|
||||
|
||||
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
|
||||
from backend.sdk import (
|
||||
Block,
|
||||
@@ -13,12 +11,6 @@ from backend.sdk import (
|
||||
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
|
||||
|
||||
|
||||
class TikTokVisibility(str, Enum):
|
||||
PUBLIC = "public"
|
||||
PRIVATE = "private"
|
||||
FOLLOWERS = "followers"
|
||||
|
||||
|
||||
class PostToTikTokBlock(Block):
|
||||
"""Block for posting to TikTok with TikTok-specific options."""
|
||||
|
||||
@@ -28,6 +20,7 @@ class PostToTikTokBlock(Block):
|
||||
# Override post field to include TikTok-specific information
|
||||
post: str = SchemaField(
|
||||
description="The post text (max 2,200 chars, empty string allowed). Use @handle to mention users. Line breaks will be ignored.",
|
||||
default="",
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
@@ -40,7 +33,7 @@ class PostToTikTokBlock(Block):
|
||||
|
||||
# TikTok-specific options
|
||||
auto_add_music: bool = SchemaField(
|
||||
description="Whether to automatically add recommended music to the post. If you set this field to true, you can change the music later in the TikTok app.",
|
||||
description="Automatically add recommended music to image posts",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
@@ -60,17 +53,17 @@ class PostToTikTokBlock(Block):
|
||||
advanced=True,
|
||||
)
|
||||
is_ai_generated: bool = SchemaField(
|
||||
description="If you enable the toggle, your video will be labeled as “Creator labeled as AI-generated” once posted and can’t be changed. The “Creator labeled as AI-generated” label indicates that the content was completely AI-generated or significantly edited with AI.",
|
||||
description="Label content as AI-generated (video only)",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
is_branded_content: bool = SchemaField(
|
||||
description="Whether to enable the Branded Content toggle. If this field is set to true, the video will be labeled as Branded Content, indicating you are in a paid partnership with a brand. A “Paid partnership” label will be attached to the video.",
|
||||
description="Label as branded content (paid partnership)",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
is_brand_organic: bool = SchemaField(
|
||||
description="Whether to enable the Brand Organic Content toggle. If this field is set to true, the video will be labeled as Brand Organic Content, indicating you are promoting yourself or your own business. A “Promotional content” label will be attached to the video.",
|
||||
description="Label as brand organic content (promotional)",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
@@ -87,9 +80,9 @@ class PostToTikTokBlock(Block):
|
||||
default=0,
|
||||
advanced=True,
|
||||
)
|
||||
visibility: TikTokVisibility = SchemaField(
|
||||
visibility: str = SchemaField(
|
||||
description="Post visibility: 'public', 'private', 'followers', or 'friends'",
|
||||
default=TikTokVisibility.PUBLIC,
|
||||
default="public",
|
||||
advanced=True,
|
||||
)
|
||||
draft: bool = SchemaField(
|
||||
@@ -104,6 +97,7 @@ class PostToTikTokBlock(Block):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
disabled=True,
|
||||
id="7faf4b27-96b0-4f05-bf64-e0de54ae74e1",
|
||||
description="Post to TikTok using Ayrshare",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
@@ -166,6 +160,12 @@ class PostToTikTokBlock(Block):
|
||||
yield "error", f"Image cover index {input_data.image_cover_index} is out of range (max: {len(input_data.media_urls) - 1})"
|
||||
return
|
||||
|
||||
# Validate visibility option
|
||||
valid_visibility = ["public", "private", "followers", "friends"]
|
||||
if input_data.visibility not in valid_visibility:
|
||||
yield "error", f"TikTok visibility must be one of: {', '.join(valid_visibility)}"
|
||||
return
|
||||
|
||||
# Check for PNG files (not supported)
|
||||
has_png = any(url.lower().endswith(".png") for url in input_data.media_urls)
|
||||
if has_png:
|
||||
@@ -218,8 +218,8 @@ class PostToTikTokBlock(Block):
|
||||
if input_data.title:
|
||||
tiktok_options["title"] = input_data.title
|
||||
|
||||
if input_data.visibility != TikTokVisibility.PUBLIC:
|
||||
tiktok_options["visibility"] = input_data.visibility.value
|
||||
if input_data.visibility != "public":
|
||||
tiktok_options["visibility"] = input_data.visibility
|
||||
|
||||
response = await client.create_post(
|
||||
post=input_data.post,
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,247 +0,0 @@
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
# Enum definitions based on available options
|
||||
class WebsetStatus(str, Enum):
|
||||
IDLE = "idle"
|
||||
PENDING = "pending"
|
||||
RUNNING = "running"
|
||||
PAUSED = "paused"
|
||||
|
||||
|
||||
class WebsetSearchStatus(str, Enum):
|
||||
CREATED = "created"
|
||||
# Add more if known, based on example it's "created"
|
||||
|
||||
|
||||
class ImportStatus(str, Enum):
|
||||
PENDING = "pending"
|
||||
# Add more if known
|
||||
|
||||
|
||||
class ImportFormat(str, Enum):
|
||||
CSV = "csv"
|
||||
# Add more if known
|
||||
|
||||
|
||||
class EnrichmentStatus(str, Enum):
|
||||
PENDING = "pending"
|
||||
# Add more if known
|
||||
|
||||
|
||||
class EnrichmentFormat(str, Enum):
|
||||
TEXT = "text"
|
||||
# Add more if known
|
||||
|
||||
|
||||
class MonitorStatus(str, Enum):
|
||||
ENABLED = "enabled"
|
||||
# Add more if known
|
||||
|
||||
|
||||
class MonitorBehaviorType(str, Enum):
|
||||
SEARCH = "search"
|
||||
# Add more if known
|
||||
|
||||
|
||||
class MonitorRunStatus(str, Enum):
|
||||
CREATED = "created"
|
||||
# Add more if known
|
||||
|
||||
|
||||
class CanceledReason(str, Enum):
|
||||
WEBSET_DELETED = "webset_deleted"
|
||||
# Add more if known
|
||||
|
||||
|
||||
class FailedReason(str, Enum):
|
||||
INVALID_FORMAT = "invalid_format"
|
||||
# Add more if known
|
||||
|
||||
|
||||
class Confidence(str, Enum):
|
||||
HIGH = "high"
|
||||
# Add more if known
|
||||
|
||||
|
||||
# Nested models
|
||||
|
||||
|
||||
class Entity(BaseModel):
|
||||
type: str
|
||||
|
||||
|
||||
class Criterion(BaseModel):
|
||||
description: str
|
||||
successRate: Optional[int] = None
|
||||
|
||||
|
||||
class ExcludeItem(BaseModel):
|
||||
source: str = Field(default="import")
|
||||
id: str
|
||||
|
||||
|
||||
class Relationship(BaseModel):
|
||||
definition: str
|
||||
limit: Optional[float] = None
|
||||
|
||||
|
||||
class ScopeItem(BaseModel):
|
||||
source: str = Field(default="import")
|
||||
id: str
|
||||
relationship: Optional[Relationship] = None
|
||||
|
||||
|
||||
class Progress(BaseModel):
|
||||
found: int
|
||||
analyzed: int
|
||||
completion: int
|
||||
timeLeft: int
|
||||
|
||||
|
||||
class Bounds(BaseModel):
|
||||
min: int
|
||||
max: int
|
||||
|
||||
|
||||
class Expected(BaseModel):
|
||||
total: int
|
||||
confidence: str = Field(default="high") # Use str or Confidence enum
|
||||
bounds: Bounds
|
||||
|
||||
|
||||
class Recall(BaseModel):
|
||||
expected: Expected
|
||||
reasoning: str
|
||||
|
||||
|
||||
class WebsetSearch(BaseModel):
|
||||
id: str
|
||||
object: str = Field(default="webset_search")
|
||||
status: str = Field(default="created") # Or use WebsetSearchStatus
|
||||
websetId: str
|
||||
query: str
|
||||
entity: Entity
|
||||
criteria: List[Criterion]
|
||||
count: int
|
||||
behavior: str = Field(default="override")
|
||||
exclude: List[ExcludeItem]
|
||||
scope: List[ScopeItem]
|
||||
progress: Progress
|
||||
recall: Recall
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
canceledAt: Optional[datetime] = None
|
||||
canceledReason: Optional[str] = Field(default=None) # Or use CanceledReason
|
||||
createdAt: datetime
|
||||
updatedAt: datetime
|
||||
|
||||
|
||||
class ImportEntity(BaseModel):
|
||||
type: str
|
||||
|
||||
|
||||
class Import(BaseModel):
|
||||
id: str
|
||||
object: str = Field(default="import")
|
||||
status: str = Field(default="pending") # Or use ImportStatus
|
||||
format: str = Field(default="csv") # Or use ImportFormat
|
||||
entity: ImportEntity
|
||||
title: str
|
||||
count: int
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
failedReason: Optional[str] = Field(default=None) # Or use FailedReason
|
||||
failedAt: Optional[datetime] = None
|
||||
failedMessage: Optional[str] = None
|
||||
createdAt: datetime
|
||||
updatedAt: datetime
|
||||
|
||||
|
||||
class Option(BaseModel):
|
||||
label: str
|
||||
|
||||
|
||||
class WebsetEnrichment(BaseModel):
|
||||
id: str
|
||||
object: str = Field(default="webset_enrichment")
|
||||
status: str = Field(default="pending") # Or use EnrichmentStatus
|
||||
websetId: str
|
||||
title: str
|
||||
description: str
|
||||
format: str = Field(default="text") # Or use EnrichmentFormat
|
||||
options: List[Option]
|
||||
instructions: str
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
createdAt: datetime
|
||||
updatedAt: datetime
|
||||
|
||||
|
||||
class Cadence(BaseModel):
|
||||
cron: str
|
||||
timezone: str = Field(default="Etc/UTC")
|
||||
|
||||
|
||||
class BehaviorConfig(BaseModel):
|
||||
query: Optional[str] = None
|
||||
criteria: Optional[List[Criterion]] = None
|
||||
entity: Optional[Entity] = None
|
||||
count: Optional[int] = None
|
||||
behavior: Optional[str] = Field(default=None)
|
||||
|
||||
|
||||
class Behavior(BaseModel):
|
||||
type: str = Field(default="search") # Or use MonitorBehaviorType
|
||||
config: BehaviorConfig
|
||||
|
||||
|
||||
class MonitorRun(BaseModel):
|
||||
id: str
|
||||
object: str = Field(default="monitor_run")
|
||||
status: str = Field(default="created") # Or use MonitorRunStatus
|
||||
monitorId: str
|
||||
type: str = Field(default="search")
|
||||
completedAt: Optional[datetime] = None
|
||||
failedAt: Optional[datetime] = None
|
||||
failedReason: Optional[str] = None
|
||||
canceledAt: Optional[datetime] = None
|
||||
createdAt: datetime
|
||||
updatedAt: datetime
|
||||
|
||||
|
||||
class Monitor(BaseModel):
|
||||
id: str
|
||||
object: str = Field(default="monitor")
|
||||
status: str = Field(default="enabled") # Or use MonitorStatus
|
||||
websetId: str
|
||||
cadence: Cadence
|
||||
behavior: Behavior
|
||||
lastRun: Optional[MonitorRun] = None
|
||||
nextRunAt: Optional[datetime] = None
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
createdAt: datetime
|
||||
updatedAt: datetime
|
||||
|
||||
|
||||
class Webset(BaseModel):
|
||||
id: str
|
||||
object: str = Field(default="webset")
|
||||
status: WebsetStatus
|
||||
externalId: Optional[str] = None
|
||||
title: Optional[str] = None
|
||||
searches: List[WebsetSearch]
|
||||
imports: List[Import]
|
||||
enrichments: List[WebsetEnrichment]
|
||||
monitors: List[Monitor]
|
||||
streams: List[Any]
|
||||
createdAt: datetime
|
||||
updatedAt: datetime
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class ListWebsets(BaseModel):
|
||||
data: List[Webset]
|
||||
hasMore: bool
|
||||
nextCursor: Optional[str] = None
|
||||
@@ -114,7 +114,6 @@ class ExaWebsetWebhookBlock(Block):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
disabled=True,
|
||||
id="d0204ed8-8b81-408d-8b8d-ed087a546228",
|
||||
description="Receive webhook notifications for Exa webset events",
|
||||
categories={BlockCategory.INPUT},
|
||||
|
||||
@@ -1,33 +1,7 @@
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Annotated, Any, Dict, List, Optional
|
||||
|
||||
from exa_py import Exa
|
||||
from exa_py.websets.types import (
|
||||
CreateCriterionParameters,
|
||||
CreateEnrichmentParameters,
|
||||
CreateWebsetParameters,
|
||||
CreateWebsetParametersSearch,
|
||||
ExcludeItem,
|
||||
Format,
|
||||
ImportItem,
|
||||
ImportSource,
|
||||
Option,
|
||||
ScopeItem,
|
||||
ScopeRelationship,
|
||||
ScopeSourceType,
|
||||
WebsetArticleEntity,
|
||||
WebsetCompanyEntity,
|
||||
WebsetCustomEntity,
|
||||
WebsetPersonEntity,
|
||||
WebsetResearchPaperEntity,
|
||||
WebsetStatus,
|
||||
)
|
||||
from pydantic import Field
|
||||
from typing import Any, Optional
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
BaseModel,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
@@ -38,69 +12,7 @@ from backend.sdk import (
|
||||
)
|
||||
|
||||
from ._config import exa
|
||||
|
||||
|
||||
class SearchEntityType(str, Enum):
|
||||
COMPANY = "company"
|
||||
PERSON = "person"
|
||||
ARTICLE = "article"
|
||||
RESEARCH_PAPER = "research_paper"
|
||||
CUSTOM = "custom"
|
||||
AUTO = "auto"
|
||||
|
||||
|
||||
class SearchType(str, Enum):
|
||||
IMPORT = "import"
|
||||
WEBSET = "webset"
|
||||
|
||||
|
||||
class EnrichmentFormat(str, Enum):
|
||||
TEXT = "text"
|
||||
DATE = "date"
|
||||
NUMBER = "number"
|
||||
OPTIONS = "options"
|
||||
EMAIL = "email"
|
||||
PHONE = "phone"
|
||||
|
||||
|
||||
class Webset(BaseModel):
|
||||
id: str
|
||||
status: WebsetStatus | None = Field(..., title="WebsetStatus")
|
||||
"""
|
||||
The status of the webset
|
||||
"""
|
||||
external_id: Annotated[Optional[str], Field(alias="externalId")] = None
|
||||
"""
|
||||
The external identifier for the webset
|
||||
NOTE: Returning dict to avoid ui crashing due to nested objects
|
||||
"""
|
||||
searches: List[dict[str, Any]] | None = None
|
||||
"""
|
||||
The searches that have been performed on the webset.
|
||||
NOTE: Returning dict to avoid ui crashing due to nested objects
|
||||
"""
|
||||
enrichments: List[dict[str, Any]] | None = None
|
||||
"""
|
||||
The Enrichments to apply to the Webset Items.
|
||||
NOTE: Returning dict to avoid ui crashing due to nested objects
|
||||
"""
|
||||
monitors: List[dict[str, Any]] | None = None
|
||||
"""
|
||||
The Monitors for the Webset.
|
||||
NOTE: Returning dict to avoid ui crashing due to nested objects
|
||||
"""
|
||||
metadata: Optional[Dict[str, Any]] = {}
|
||||
"""
|
||||
Set of key-value pairs you want to associate with this object.
|
||||
"""
|
||||
created_at: Annotated[datetime, Field(alias="createdAt")] | None = None
|
||||
"""
|
||||
The date and time the webset was created
|
||||
"""
|
||||
updated_at: Annotated[datetime, Field(alias="updatedAt")] | None = None
|
||||
"""
|
||||
The date and time the webset was last updated
|
||||
"""
|
||||
from .helpers import WebsetEnrichmentConfig, WebsetSearchConfig
|
||||
|
||||
|
||||
class ExaCreateWebsetBlock(Block):
|
||||
@@ -108,121 +20,40 @@ class ExaCreateWebsetBlock(Block):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
|
||||
# Search parameters (flattened)
|
||||
search_query: str = SchemaField(
|
||||
description="Your search query. Use this to describe what you are looking for. Any URL provided will be crawled and used as context for the search.",
|
||||
placeholder="Marketing agencies based in the US, that focus on consumer products",
|
||||
search: WebsetSearchConfig = SchemaField(
|
||||
description="Initial search configuration for the Webset"
|
||||
)
|
||||
search_count: Optional[int] = SchemaField(
|
||||
default=10,
|
||||
description="Number of items the search will attempt to find. The actual number of items found may be less than this number depending on the search complexity.",
|
||||
ge=1,
|
||||
le=1000,
|
||||
)
|
||||
search_entity_type: SearchEntityType = SchemaField(
|
||||
default=SearchEntityType.AUTO,
|
||||
description="Entity type: 'company', 'person', 'article', 'research_paper', or 'custom'. If not provided, we automatically detect the entity from the query.",
|
||||
advanced=True,
|
||||
)
|
||||
search_entity_description: Optional[str] = SchemaField(
|
||||
enrichments: Optional[list[WebsetEnrichmentConfig]] = SchemaField(
|
||||
default=None,
|
||||
description="Description for custom entity type (required when search_entity_type is 'custom')",
|
||||
description="Enrichments to apply to Webset items",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
# Search criteria (flattened)
|
||||
search_criteria: list[str] = SchemaField(
|
||||
default_factory=list,
|
||||
description="List of criteria descriptions that every item will be evaluated against. If not provided, we automatically detect the criteria from the query.",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
# Search exclude sources (flattened)
|
||||
search_exclude_sources: list[str] = SchemaField(
|
||||
default_factory=list,
|
||||
description="List of source IDs (imports or websets) to exclude from search results",
|
||||
advanced=True,
|
||||
)
|
||||
search_exclude_types: list[SearchType] = SchemaField(
|
||||
default_factory=list,
|
||||
description="List of source types corresponding to exclude sources ('import' or 'webset')",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
# Search scope sources (flattened)
|
||||
search_scope_sources: list[str] = SchemaField(
|
||||
default_factory=list,
|
||||
description="List of source IDs (imports or websets) to limit search scope to",
|
||||
advanced=True,
|
||||
)
|
||||
search_scope_types: list[SearchType] = SchemaField(
|
||||
default_factory=list,
|
||||
description="List of source types corresponding to scope sources ('import' or 'webset')",
|
||||
advanced=True,
|
||||
)
|
||||
search_scope_relationships: list[str] = SchemaField(
|
||||
default_factory=list,
|
||||
description="List of relationship definitions for hop searches (optional, one per scope source)",
|
||||
advanced=True,
|
||||
)
|
||||
search_scope_relationship_limits: list[int] = SchemaField(
|
||||
default_factory=list,
|
||||
description="List of limits on the number of related entities to find (optional, one per scope relationship)",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
# Import parameters (flattened)
|
||||
import_sources: list[str] = SchemaField(
|
||||
default_factory=list,
|
||||
description="List of source IDs to import from",
|
||||
advanced=True,
|
||||
)
|
||||
import_types: list[SearchType] = SchemaField(
|
||||
default_factory=list,
|
||||
description="List of source types corresponding to import sources ('import' or 'webset')",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
# Enrichment parameters (flattened)
|
||||
enrichment_descriptions: list[str] = SchemaField(
|
||||
default_factory=list,
|
||||
description="List of enrichment task descriptions to perform on each webset item",
|
||||
advanced=True,
|
||||
)
|
||||
enrichment_formats: list[EnrichmentFormat] = SchemaField(
|
||||
default_factory=list,
|
||||
description="List of formats for enrichment responses ('text', 'date', 'number', 'options', 'email', 'phone'). If not specified, we automatically select the best format.",
|
||||
advanced=True,
|
||||
)
|
||||
enrichment_options: list[list[str]] = SchemaField(
|
||||
default_factory=list,
|
||||
description="List of option lists for enrichments with 'options' format. Each inner list contains the option labels.",
|
||||
advanced=True,
|
||||
)
|
||||
enrichment_metadata: list[dict] = SchemaField(
|
||||
default_factory=list,
|
||||
description="List of metadata dictionaries for enrichments",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
# Webset metadata
|
||||
external_id: Optional[str] = SchemaField(
|
||||
default=None,
|
||||
description="External identifier for the webset. You can use this to reference the webset by your own internal identifiers.",
|
||||
description="External identifier for the webset",
|
||||
placeholder="my-webset-123",
|
||||
advanced=True,
|
||||
)
|
||||
metadata: Optional[dict] = SchemaField(
|
||||
default_factory=dict,
|
||||
default=None,
|
||||
description="Key-value pairs to associate with this webset",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
webset: Webset = SchemaField(
|
||||
webset_id: str = SchemaField(
|
||||
description="The unique identifier for the created webset"
|
||||
)
|
||||
status: str = SchemaField(description="The status of the webset")
|
||||
external_id: Optional[str] = SchemaField(
|
||||
description="The external identifier for the webset", default=None
|
||||
)
|
||||
created_at: str = SchemaField(
|
||||
description="The date and time the webset was created"
|
||||
)
|
||||
error: str = SchemaField(
|
||||
description="Error message if the request failed", default=""
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -236,171 +67,44 @@ class ExaCreateWebsetBlock(Block):
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
url = "https://api.exa.ai/websets/v0/websets"
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"x-api-key": credentials.api_key.get_secret_value(),
|
||||
}
|
||||
|
||||
exa = Exa(credentials.api_key.get_secret_value())
|
||||
# Build the payload
|
||||
payload: dict[str, Any] = {
|
||||
"search": input_data.search.model_dump(exclude_none=True),
|
||||
}
|
||||
|
||||
# ------------------------------------------------------------
|
||||
# Build entity (if explicitly provided)
|
||||
# ------------------------------------------------------------
|
||||
entity = None
|
||||
if input_data.search_entity_type == SearchEntityType.COMPANY:
|
||||
entity = WebsetCompanyEntity(type="company")
|
||||
elif input_data.search_entity_type == SearchEntityType.PERSON:
|
||||
entity = WebsetPersonEntity(type="person")
|
||||
elif input_data.search_entity_type == SearchEntityType.ARTICLE:
|
||||
entity = WebsetArticleEntity(type="article")
|
||||
elif input_data.search_entity_type == SearchEntityType.RESEARCH_PAPER:
|
||||
entity = WebsetResearchPaperEntity(type="research_paper")
|
||||
elif (
|
||||
input_data.search_entity_type == SearchEntityType.CUSTOM
|
||||
and input_data.search_entity_description
|
||||
):
|
||||
entity = WebsetCustomEntity(
|
||||
type="custom", description=input_data.search_entity_description
|
||||
)
|
||||
# Convert enrichments to API format
|
||||
if input_data.enrichments:
|
||||
enrichments_data = []
|
||||
for enrichment in input_data.enrichments:
|
||||
enrichments_data.append(enrichment.model_dump(exclude_none=True))
|
||||
payload["enrichments"] = enrichments_data
|
||||
|
||||
# ------------------------------------------------------------
|
||||
# Build criteria list
|
||||
# ------------------------------------------------------------
|
||||
criteria = None
|
||||
if input_data.search_criteria:
|
||||
criteria = [
|
||||
CreateCriterionParameters(description=item)
|
||||
for item in input_data.search_criteria
|
||||
]
|
||||
if input_data.external_id:
|
||||
payload["externalId"] = input_data.external_id
|
||||
|
||||
# ------------------------------------------------------------
|
||||
# Build exclude sources list
|
||||
# ------------------------------------------------------------
|
||||
exclude_items = None
|
||||
if input_data.search_exclude_sources:
|
||||
exclude_items = []
|
||||
for idx, src_id in enumerate(input_data.search_exclude_sources):
|
||||
src_type = None
|
||||
if input_data.search_exclude_types and idx < len(
|
||||
input_data.search_exclude_types
|
||||
):
|
||||
src_type = input_data.search_exclude_types[idx]
|
||||
# Default to IMPORT if type missing
|
||||
if src_type == SearchType.WEBSET:
|
||||
source_enum = ImportSource.webset
|
||||
else:
|
||||
source_enum = ImportSource.import_
|
||||
exclude_items.append(ExcludeItem(source=source_enum, id=src_id))
|
||||
if input_data.metadata:
|
||||
payload["metadata"] = input_data.metadata
|
||||
|
||||
# ------------------------------------------------------------
|
||||
# Build scope list
|
||||
# ------------------------------------------------------------
|
||||
scope_items = None
|
||||
if input_data.search_scope_sources:
|
||||
scope_items = []
|
||||
for idx, src_id in enumerate(input_data.search_scope_sources):
|
||||
src_type = None
|
||||
if input_data.search_scope_types and idx < len(
|
||||
input_data.search_scope_types
|
||||
):
|
||||
src_type = input_data.search_scope_types[idx]
|
||||
relationship = None
|
||||
if input_data.search_scope_relationships and idx < len(
|
||||
input_data.search_scope_relationships
|
||||
):
|
||||
rel_def = input_data.search_scope_relationships[idx]
|
||||
lim = None
|
||||
if input_data.search_scope_relationship_limits and idx < len(
|
||||
input_data.search_scope_relationship_limits
|
||||
):
|
||||
lim = input_data.search_scope_relationship_limits[idx]
|
||||
relationship = ScopeRelationship(definition=rel_def, limit=lim)
|
||||
if src_type == SearchType.WEBSET:
|
||||
src_enum = ScopeSourceType.webset
|
||||
else:
|
||||
src_enum = ScopeSourceType.import_
|
||||
scope_items.append(
|
||||
ScopeItem(source=src_enum, id=src_id, relationship=relationship)
|
||||
)
|
||||
try:
|
||||
response = await Requests().post(url, headers=headers, json=payload)
|
||||
data = response.json()
|
||||
|
||||
# ------------------------------------------------------------
|
||||
# Assemble search parameters (only if a query is provided)
|
||||
# ------------------------------------------------------------
|
||||
search_params = None
|
||||
if input_data.search_query:
|
||||
search_params = CreateWebsetParametersSearch(
|
||||
query=input_data.search_query,
|
||||
count=input_data.search_count,
|
||||
entity=entity,
|
||||
criteria=criteria,
|
||||
exclude=exclude_items,
|
||||
scope=scope_items,
|
||||
)
|
||||
yield "webset_id", data.get("id", "")
|
||||
yield "status", data.get("status", "")
|
||||
yield "external_id", data.get("externalId")
|
||||
yield "created_at", data.get("createdAt", "")
|
||||
|
||||
# ------------------------------------------------------------
|
||||
# Build imports list
|
||||
# ------------------------------------------------------------
|
||||
imports_params = None
|
||||
if input_data.import_sources:
|
||||
imports_params = []
|
||||
for idx, src_id in enumerate(input_data.import_sources):
|
||||
src_type = None
|
||||
if input_data.import_types and idx < len(input_data.import_types):
|
||||
src_type = input_data.import_types[idx]
|
||||
if src_type == SearchType.WEBSET:
|
||||
source_enum = ImportSource.webset
|
||||
else:
|
||||
source_enum = ImportSource.import_
|
||||
imports_params.append(ImportItem(source=source_enum, id=src_id))
|
||||
|
||||
# ------------------------------------------------------------
|
||||
# Build enrichment list
|
||||
# ------------------------------------------------------------
|
||||
enrichments_params = None
|
||||
if input_data.enrichment_descriptions:
|
||||
enrichments_params = []
|
||||
for idx, desc in enumerate(input_data.enrichment_descriptions):
|
||||
fmt = None
|
||||
if input_data.enrichment_formats and idx < len(
|
||||
input_data.enrichment_formats
|
||||
):
|
||||
fmt_enum = input_data.enrichment_formats[idx]
|
||||
if fmt_enum is not None:
|
||||
fmt = Format(
|
||||
fmt_enum.value if isinstance(fmt_enum, Enum) else fmt_enum
|
||||
)
|
||||
options_list = None
|
||||
if input_data.enrichment_options and idx < len(
|
||||
input_data.enrichment_options
|
||||
):
|
||||
raw_opts = input_data.enrichment_options[idx]
|
||||
if raw_opts:
|
||||
options_list = [Option(label=o) for o in raw_opts]
|
||||
metadata_obj = None
|
||||
if input_data.enrichment_metadata and idx < len(
|
||||
input_data.enrichment_metadata
|
||||
):
|
||||
metadata_obj = input_data.enrichment_metadata[idx]
|
||||
enrichments_params.append(
|
||||
CreateEnrichmentParameters(
|
||||
description=desc,
|
||||
format=fmt,
|
||||
options=options_list,
|
||||
metadata=metadata_obj,
|
||||
)
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------
|
||||
# Create the webset
|
||||
# ------------------------------------------------------------
|
||||
webset = exa.websets.create(
|
||||
params=CreateWebsetParameters(
|
||||
search=search_params,
|
||||
imports=imports_params,
|
||||
enrichments=enrichments_params,
|
||||
external_id=input_data.external_id,
|
||||
metadata=input_data.metadata,
|
||||
)
|
||||
)
|
||||
|
||||
# Use alias field names returned from Exa SDK so that nested models validate correctly
|
||||
yield "webset", Webset.model_validate(webset.model_dump(by_alias=True))
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
yield "webset_id", ""
|
||||
yield "status", ""
|
||||
yield "created_at", ""
|
||||
|
||||
|
||||
class ExaUpdateWebsetBlock(Block):
|
||||
@@ -479,11 +183,6 @@ class ExaListWebsetsBlock(Block):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
trigger: Any | None = SchemaField(
|
||||
default=None,
|
||||
description="Trigger for the webset, value is ignored!",
|
||||
advanced=False,
|
||||
)
|
||||
cursor: Optional[str] = SchemaField(
|
||||
default=None,
|
||||
description="Cursor for pagination through results",
|
||||
@@ -498,9 +197,7 @@ class ExaListWebsetsBlock(Block):
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
websets: list[Webset] = SchemaField(
|
||||
description="List of websets", default_factory=list
|
||||
)
|
||||
websets: list = SchemaField(description="List of websets", default_factory=list)
|
||||
has_more: bool = SchemaField(
|
||||
description="Whether there are more results to paginate through",
|
||||
default=False,
|
||||
@@ -558,6 +255,9 @@ class ExaGetWebsetBlock(Block):
|
||||
description="The ID or external ID of the Webset to retrieve",
|
||||
placeholder="webset-id-or-external-id",
|
||||
)
|
||||
expand_items: bool = SchemaField(
|
||||
default=False, description="Include items in the response", advanced=True
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
webset_id: str = SchemaField(description="The unique identifier for the webset")
|
||||
@@ -609,8 +309,12 @@ class ExaGetWebsetBlock(Block):
|
||||
"x-api-key": credentials.api_key.get_secret_value(),
|
||||
}
|
||||
|
||||
params = {}
|
||||
if input_data.expand_items:
|
||||
params["expand[]"] = "items"
|
||||
|
||||
try:
|
||||
response = await Requests().get(url, headers=headers)
|
||||
response = await Requests().get(url, headers=headers, params=params)
|
||||
data = response.json()
|
||||
|
||||
yield "webset_id", data.get("id", "")
|
||||
|
||||
@@ -1,388 +0,0 @@
|
||||
import logging
|
||||
import re
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
from ._api import get_api
|
||||
from ._auth import (
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
GithubCredentials,
|
||||
GithubCredentialsField,
|
||||
GithubCredentialsInput,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CheckRunStatus(Enum):
|
||||
QUEUED = "queued"
|
||||
IN_PROGRESS = "in_progress"
|
||||
COMPLETED = "completed"
|
||||
|
||||
|
||||
class CheckRunConclusion(Enum):
|
||||
SUCCESS = "success"
|
||||
FAILURE = "failure"
|
||||
NEUTRAL = "neutral"
|
||||
CANCELLED = "cancelled"
|
||||
SKIPPED = "skipped"
|
||||
TIMED_OUT = "timed_out"
|
||||
ACTION_REQUIRED = "action_required"
|
||||
|
||||
|
||||
class GithubGetCIResultsBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo: str = SchemaField(
|
||||
description="GitHub repository",
|
||||
placeholder="owner/repo",
|
||||
)
|
||||
target: str | int = SchemaField(
|
||||
description="Commit SHA or PR number to get CI results for",
|
||||
placeholder="abc123def or 123",
|
||||
)
|
||||
search_pattern: Optional[str] = SchemaField(
|
||||
description="Optional regex pattern to search for in CI logs (e.g., error messages, file names)",
|
||||
placeholder=".*error.*|.*warning.*",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
check_name_filter: Optional[str] = SchemaField(
|
||||
description="Optional filter for specific check names (supports wildcards)",
|
||||
placeholder="*lint* or build-*",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class CheckRunItem(TypedDict, total=False):
|
||||
id: int
|
||||
name: str
|
||||
status: str
|
||||
conclusion: Optional[str]
|
||||
started_at: Optional[str]
|
||||
completed_at: Optional[str]
|
||||
html_url: str
|
||||
details_url: Optional[str]
|
||||
output_title: Optional[str]
|
||||
output_summary: Optional[str]
|
||||
output_text: Optional[str]
|
||||
annotations: list[dict]
|
||||
|
||||
class MatchedLine(TypedDict):
|
||||
check_name: str
|
||||
line_number: int
|
||||
line: str
|
||||
context: list[str]
|
||||
|
||||
check_run: CheckRunItem = SchemaField(
|
||||
title="Check Run",
|
||||
description="Individual CI check run with details",
|
||||
)
|
||||
check_runs: list[CheckRunItem] = SchemaField(
|
||||
description="List of all CI check runs"
|
||||
)
|
||||
matched_line: MatchedLine = SchemaField(
|
||||
title="Matched Line",
|
||||
description="Line matching the search pattern with context",
|
||||
)
|
||||
matched_lines: list[MatchedLine] = SchemaField(
|
||||
description="All lines matching the search pattern across all checks"
|
||||
)
|
||||
overall_status: str = SchemaField(
|
||||
description="Overall CI status (pending, success, failure)"
|
||||
)
|
||||
overall_conclusion: str = SchemaField(
|
||||
description="Overall CI conclusion if completed"
|
||||
)
|
||||
total_checks: int = SchemaField(description="Total number of CI checks")
|
||||
passed_checks: int = SchemaField(description="Number of passed checks")
|
||||
failed_checks: int = SchemaField(description="Number of failed checks")
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="8ad9e103-78f2-4fdb-ba12-3571f2c95e98",
|
||||
description="This block gets CI results for a commit or PR, with optional search for specific errors/warnings in logs.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubGetCIResultsBlock.Input,
|
||||
output_schema=GithubGetCIResultsBlock.Output,
|
||||
test_input={
|
||||
"repo": "owner/repo",
|
||||
"target": "abc123def456",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("overall_status", "completed"),
|
||||
("overall_conclusion", "success"),
|
||||
("total_checks", 1),
|
||||
("passed_checks", 1),
|
||||
("failed_checks", 0),
|
||||
(
|
||||
"check_runs",
|
||||
[
|
||||
{
|
||||
"id": 123456,
|
||||
"name": "build",
|
||||
"status": "completed",
|
||||
"conclusion": "success",
|
||||
"started_at": "2024-01-01T00:00:00Z",
|
||||
"completed_at": "2024-01-01T00:05:00Z",
|
||||
"html_url": "https://github.com/owner/repo/runs/123456",
|
||||
"details_url": None,
|
||||
"output_title": "Build passed",
|
||||
"output_summary": "All tests passed",
|
||||
"output_text": "Build log output...",
|
||||
"annotations": [],
|
||||
}
|
||||
],
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"get_ci_results": lambda *args, **kwargs: {
|
||||
"check_runs": [
|
||||
{
|
||||
"id": 123456,
|
||||
"name": "build",
|
||||
"status": "completed",
|
||||
"conclusion": "success",
|
||||
"started_at": "2024-01-01T00:00:00Z",
|
||||
"completed_at": "2024-01-01T00:05:00Z",
|
||||
"html_url": "https://github.com/owner/repo/runs/123456",
|
||||
"details_url": None,
|
||||
"output_title": "Build passed",
|
||||
"output_summary": "All tests passed",
|
||||
"output_text": "Build log output...",
|
||||
"annotations": [],
|
||||
}
|
||||
],
|
||||
"total_count": 1,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def get_commit_sha(api, repo: str, target: str | int) -> str:
|
||||
"""Get commit SHA from either a commit SHA or PR URL."""
|
||||
# If it's already a SHA, return it
|
||||
|
||||
if isinstance(target, str):
|
||||
if re.match(r"^[0-9a-f]{6,40}$", target, re.IGNORECASE):
|
||||
return target
|
||||
|
||||
# If it's a PR URL, get the head SHA
|
||||
if isinstance(target, int):
|
||||
pr_url = f"https://api.github.com/repos/{repo}/pulls/{target}"
|
||||
response = await api.get(pr_url)
|
||||
pr_data = response.json()
|
||||
return pr_data["head"]["sha"]
|
||||
|
||||
raise ValueError("Target must be a commit SHA or PR URL")
|
||||
|
||||
@staticmethod
|
||||
async def search_in_logs(
|
||||
check_runs: list,
|
||||
pattern: str,
|
||||
) -> list[Output.MatchedLine]:
|
||||
"""Search for pattern in check run logs."""
|
||||
if not pattern:
|
||||
return []
|
||||
|
||||
matched_lines = []
|
||||
regex = re.compile(pattern, re.IGNORECASE | re.MULTILINE)
|
||||
|
||||
for check in check_runs:
|
||||
output_text = check.get("output_text", "") or ""
|
||||
if not output_text:
|
||||
continue
|
||||
|
||||
lines = output_text.split("\n")
|
||||
for i, line in enumerate(lines):
|
||||
if regex.search(line):
|
||||
# Get context (2 lines before and after)
|
||||
start = max(0, i - 2)
|
||||
end = min(len(lines), i + 3)
|
||||
context = lines[start:end]
|
||||
|
||||
matched_lines.append(
|
||||
{
|
||||
"check_name": check["name"],
|
||||
"line_number": i + 1,
|
||||
"line": line,
|
||||
"context": context,
|
||||
}
|
||||
)
|
||||
|
||||
return matched_lines
|
||||
|
||||
@staticmethod
|
||||
async def get_ci_results(
|
||||
credentials: GithubCredentials,
|
||||
repo: str,
|
||||
target: str | int,
|
||||
search_pattern: Optional[str] = None,
|
||||
check_name_filter: Optional[str] = None,
|
||||
) -> dict:
|
||||
api = get_api(credentials, convert_urls=False)
|
||||
|
||||
# Get the commit SHA
|
||||
commit_sha = await GithubGetCIResultsBlock.get_commit_sha(api, repo, target)
|
||||
|
||||
# Get check runs for the commit
|
||||
check_runs_url = (
|
||||
f"https://api.github.com/repos/{repo}/commits/{commit_sha}/check-runs"
|
||||
)
|
||||
|
||||
# Get all pages of check runs
|
||||
all_check_runs = []
|
||||
page = 1
|
||||
per_page = 100
|
||||
|
||||
while True:
|
||||
response = await api.get(
|
||||
check_runs_url, params={"per_page": per_page, "page": page}
|
||||
)
|
||||
data = response.json()
|
||||
|
||||
check_runs = data.get("check_runs", [])
|
||||
all_check_runs.extend(check_runs)
|
||||
|
||||
if len(check_runs) < per_page:
|
||||
break
|
||||
page += 1
|
||||
|
||||
# Filter by check name if specified
|
||||
if check_name_filter:
|
||||
import fnmatch
|
||||
|
||||
filtered_runs = []
|
||||
for run in all_check_runs:
|
||||
if fnmatch.fnmatch(run["name"].lower(), check_name_filter.lower()):
|
||||
filtered_runs.append(run)
|
||||
all_check_runs = filtered_runs
|
||||
|
||||
# Get check run details with logs
|
||||
detailed_runs = []
|
||||
for run in all_check_runs:
|
||||
# Get detailed output including logs
|
||||
if run.get("output", {}).get("text"):
|
||||
# Already has output
|
||||
detailed_run = {
|
||||
"id": run["id"],
|
||||
"name": run["name"],
|
||||
"status": run["status"],
|
||||
"conclusion": run.get("conclusion"),
|
||||
"started_at": run.get("started_at"),
|
||||
"completed_at": run.get("completed_at"),
|
||||
"html_url": run["html_url"],
|
||||
"details_url": run.get("details_url"),
|
||||
"output_title": run.get("output", {}).get("title"),
|
||||
"output_summary": run.get("output", {}).get("summary"),
|
||||
"output_text": run.get("output", {}).get("text"),
|
||||
"annotations": [],
|
||||
}
|
||||
else:
|
||||
# Try to get logs from the check run
|
||||
detailed_run = {
|
||||
"id": run["id"],
|
||||
"name": run["name"],
|
||||
"status": run["status"],
|
||||
"conclusion": run.get("conclusion"),
|
||||
"started_at": run.get("started_at"),
|
||||
"completed_at": run.get("completed_at"),
|
||||
"html_url": run["html_url"],
|
||||
"details_url": run.get("details_url"),
|
||||
"output_title": run.get("output", {}).get("title"),
|
||||
"output_summary": run.get("output", {}).get("summary"),
|
||||
"output_text": None,
|
||||
"annotations": [],
|
||||
}
|
||||
|
||||
# Get annotations if available
|
||||
if run.get("output", {}).get("annotations_count", 0) > 0:
|
||||
annotations_url = f"https://api.github.com/repos/{repo}/check-runs/{run['id']}/annotations"
|
||||
try:
|
||||
ann_response = await api.get(annotations_url)
|
||||
detailed_run["annotations"] = ann_response.json()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
detailed_runs.append(detailed_run)
|
||||
|
||||
return {
|
||||
"check_runs": detailed_runs,
|
||||
"total_count": len(detailed_runs),
|
||||
}
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
|
||||
try:
|
||||
target = int(input_data.target)
|
||||
except ValueError:
|
||||
target = input_data.target
|
||||
|
||||
result = await self.get_ci_results(
|
||||
credentials,
|
||||
input_data.repo,
|
||||
target,
|
||||
input_data.search_pattern,
|
||||
input_data.check_name_filter,
|
||||
)
|
||||
|
||||
check_runs = result["check_runs"]
|
||||
|
||||
# Calculate overall status
|
||||
if not check_runs:
|
||||
yield "overall_status", "no_checks"
|
||||
yield "overall_conclusion", "no_checks"
|
||||
else:
|
||||
all_completed = all(run["status"] == "completed" for run in check_runs)
|
||||
if all_completed:
|
||||
yield "overall_status", "completed"
|
||||
# Determine overall conclusion
|
||||
has_failure = any(
|
||||
run["conclusion"] in ["failure", "timed_out", "action_required"]
|
||||
for run in check_runs
|
||||
)
|
||||
if has_failure:
|
||||
yield "overall_conclusion", "failure"
|
||||
else:
|
||||
yield "overall_conclusion", "success"
|
||||
else:
|
||||
yield "overall_status", "pending"
|
||||
yield "overall_conclusion", "pending"
|
||||
|
||||
# Count checks
|
||||
total = len(check_runs)
|
||||
passed = sum(1 for run in check_runs if run.get("conclusion") == "success")
|
||||
failed = sum(
|
||||
1 for run in check_runs if run.get("conclusion") in ["failure", "timed_out"]
|
||||
)
|
||||
|
||||
yield "total_checks", total
|
||||
yield "passed_checks", passed
|
||||
yield "failed_checks", failed
|
||||
|
||||
# Output check runs
|
||||
yield "check_runs", check_runs
|
||||
|
||||
# Search for patterns if specified
|
||||
if input_data.search_pattern:
|
||||
matched_lines = await self.search_in_logs(
|
||||
check_runs, input_data.search_pattern
|
||||
)
|
||||
if matched_lines:
|
||||
yield "matched_lines", matched_lines
|
||||
@@ -1,840 +0,0 @@
|
||||
import logging
|
||||
from enum import Enum
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
from ._api import get_api
|
||||
from ._auth import (
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
GithubCredentials,
|
||||
GithubCredentialsField,
|
||||
GithubCredentialsInput,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ReviewEvent(Enum):
|
||||
COMMENT = "COMMENT"
|
||||
APPROVE = "APPROVE"
|
||||
REQUEST_CHANGES = "REQUEST_CHANGES"
|
||||
|
||||
|
||||
class GithubCreatePRReviewBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
class ReviewComment(TypedDict, total=False):
|
||||
path: str
|
||||
position: Optional[int]
|
||||
body: str
|
||||
line: Optional[int] # Will be used as position if position not provided
|
||||
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo: str = SchemaField(
|
||||
description="GitHub repository",
|
||||
placeholder="owner/repo",
|
||||
)
|
||||
pr_number: int = SchemaField(
|
||||
description="Pull request number",
|
||||
placeholder="123",
|
||||
)
|
||||
body: str = SchemaField(
|
||||
description="Body of the review comment",
|
||||
placeholder="Enter your review comment",
|
||||
)
|
||||
event: ReviewEvent = SchemaField(
|
||||
description="The review action to perform",
|
||||
default=ReviewEvent.COMMENT,
|
||||
)
|
||||
create_as_draft: bool = SchemaField(
|
||||
description="Create the review as a draft (pending) or post it immediately",
|
||||
default=False,
|
||||
advanced=False,
|
||||
)
|
||||
comments: Optional[List[ReviewComment]] = SchemaField(
|
||||
description="Optional inline comments to add to specific files/lines. Note: Only path, body, and position are supported. Position is line number in diff from first @@ hunk.",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
review_id: int = SchemaField(description="ID of the created review")
|
||||
state: str = SchemaField(
|
||||
description="State of the review (e.g., PENDING, COMMENTED, APPROVED, CHANGES_REQUESTED)"
|
||||
)
|
||||
html_url: str = SchemaField(description="URL of the created review")
|
||||
error: str = SchemaField(
|
||||
description="Error message if the review creation failed"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="84754b30-97d2-4c37-a3b8-eb39f268275b",
|
||||
description="This block creates a review on a GitHub pull request with optional inline comments. You can create it as a draft or post immediately. Note: For inline comments, 'position' should be the line number in the diff (starting from the first @@ hunk header).",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubCreatePRReviewBlock.Input,
|
||||
output_schema=GithubCreatePRReviewBlock.Output,
|
||||
test_input={
|
||||
"repo": "owner/repo",
|
||||
"pr_number": 1,
|
||||
"body": "This looks good to me!",
|
||||
"event": "APPROVE",
|
||||
"create_as_draft": False,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("review_id", 123456),
|
||||
("state", "APPROVED"),
|
||||
(
|
||||
"html_url",
|
||||
"https://github.com/owner/repo/pull/1#pullrequestreview-123456",
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"create_review": lambda *args, **kwargs: (
|
||||
123456,
|
||||
"APPROVED",
|
||||
"https://github.com/owner/repo/pull/1#pullrequestreview-123456",
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def create_review(
|
||||
credentials: GithubCredentials,
|
||||
repo: str,
|
||||
pr_number: int,
|
||||
body: str,
|
||||
event: ReviewEvent,
|
||||
create_as_draft: bool,
|
||||
comments: Optional[List[Input.ReviewComment]] = None,
|
||||
) -> tuple[int, str, str]:
|
||||
api = get_api(credentials, convert_urls=False)
|
||||
|
||||
# GitHub API endpoint for creating reviews
|
||||
reviews_url = f"https://api.github.com/repos/{repo}/pulls/{pr_number}/reviews"
|
||||
|
||||
# Get commit_id if we have comments
|
||||
commit_id = None
|
||||
if comments:
|
||||
# Get PR details to get the head commit for inline comments
|
||||
pr_url = f"https://api.github.com/repos/{repo}/pulls/{pr_number}"
|
||||
pr_response = await api.get(pr_url)
|
||||
pr_data = pr_response.json()
|
||||
commit_id = pr_data["head"]["sha"]
|
||||
|
||||
# Prepare the request data
|
||||
# If create_as_draft is True, omit the event field (creates a PENDING review)
|
||||
# Otherwise, use the actual event value which will auto-submit the review
|
||||
data: dict[str, Any] = {"body": body}
|
||||
|
||||
# Add commit_id if we have it
|
||||
if commit_id:
|
||||
data["commit_id"] = commit_id
|
||||
|
||||
# Add comments if provided
|
||||
if comments:
|
||||
# Process comments to ensure they have the required fields
|
||||
processed_comments = []
|
||||
for comment in comments:
|
||||
comment_data: dict = {
|
||||
"path": comment.get("path", ""),
|
||||
"body": comment.get("body", ""),
|
||||
}
|
||||
# Add position or line
|
||||
# Note: For review comments, only position is supported (not line/side)
|
||||
if "position" in comment and comment.get("position") is not None:
|
||||
comment_data["position"] = comment.get("position")
|
||||
elif "line" in comment and comment.get("line") is not None:
|
||||
# Note: Using line as position - may not work correctly
|
||||
# Position should be calculated from the diff
|
||||
comment_data["position"] = comment.get("line")
|
||||
|
||||
# Note: side, start_line, and start_side are NOT supported for review comments
|
||||
# They are only for standalone PR comments
|
||||
|
||||
processed_comments.append(comment_data)
|
||||
|
||||
data["comments"] = processed_comments
|
||||
|
||||
if not create_as_draft:
|
||||
# Only add event field if not creating a draft
|
||||
data["event"] = event.value
|
||||
|
||||
# Create the review
|
||||
response = await api.post(reviews_url, json=data)
|
||||
review_data = response.json()
|
||||
|
||||
return review_data["id"], review_data["state"], review_data["html_url"]
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
review_id, state, html_url = await self.create_review(
|
||||
credentials,
|
||||
input_data.repo,
|
||||
input_data.pr_number,
|
||||
input_data.body,
|
||||
input_data.event,
|
||||
input_data.create_as_draft,
|
||||
input_data.comments,
|
||||
)
|
||||
yield "review_id", review_id
|
||||
yield "state", state
|
||||
yield "html_url", html_url
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class GithubListPRReviewsBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo: str = SchemaField(
|
||||
description="GitHub repository",
|
||||
placeholder="owner/repo",
|
||||
)
|
||||
pr_number: int = SchemaField(
|
||||
description="Pull request number",
|
||||
placeholder="123",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class ReviewItem(TypedDict):
|
||||
id: int
|
||||
user: str
|
||||
state: str
|
||||
body: str
|
||||
html_url: str
|
||||
|
||||
review: ReviewItem = SchemaField(
|
||||
title="Review",
|
||||
description="Individual review with details",
|
||||
)
|
||||
reviews: list[ReviewItem] = SchemaField(
|
||||
description="List of all reviews on the pull request"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if listing reviews failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="f79bc6eb-33c0-4099-9c0f-d664ae1ba4d0",
|
||||
description="This block lists all reviews for a specified GitHub pull request.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubListPRReviewsBlock.Input,
|
||||
output_schema=GithubListPRReviewsBlock.Output,
|
||||
test_input={
|
||||
"repo": "owner/repo",
|
||||
"pr_number": 1,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
(
|
||||
"reviews",
|
||||
[
|
||||
{
|
||||
"id": 123456,
|
||||
"user": "reviewer1",
|
||||
"state": "APPROVED",
|
||||
"body": "Looks good!",
|
||||
"html_url": "https://github.com/owner/repo/pull/1#pullrequestreview-123456",
|
||||
}
|
||||
],
|
||||
),
|
||||
(
|
||||
"review",
|
||||
{
|
||||
"id": 123456,
|
||||
"user": "reviewer1",
|
||||
"state": "APPROVED",
|
||||
"body": "Looks good!",
|
||||
"html_url": "https://github.com/owner/repo/pull/1#pullrequestreview-123456",
|
||||
},
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"list_reviews": lambda *args, **kwargs: [
|
||||
{
|
||||
"id": 123456,
|
||||
"user": "reviewer1",
|
||||
"state": "APPROVED",
|
||||
"body": "Looks good!",
|
||||
"html_url": "https://github.com/owner/repo/pull/1#pullrequestreview-123456",
|
||||
}
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def list_reviews(
|
||||
credentials: GithubCredentials, repo: str, pr_number: int
|
||||
) -> list[Output.ReviewItem]:
|
||||
api = get_api(credentials, convert_urls=False)
|
||||
|
||||
# GitHub API endpoint for listing reviews
|
||||
reviews_url = f"https://api.github.com/repos/{repo}/pulls/{pr_number}/reviews"
|
||||
|
||||
response = await api.get(reviews_url)
|
||||
data = response.json()
|
||||
|
||||
reviews: list[GithubListPRReviewsBlock.Output.ReviewItem] = [
|
||||
{
|
||||
"id": review["id"],
|
||||
"user": review["user"]["login"],
|
||||
"state": review["state"],
|
||||
"body": review.get("body", ""),
|
||||
"html_url": review["html_url"],
|
||||
}
|
||||
for review in data
|
||||
]
|
||||
return reviews
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
reviews = await self.list_reviews(
|
||||
credentials,
|
||||
input_data.repo,
|
||||
input_data.pr_number,
|
||||
)
|
||||
yield "reviews", reviews
|
||||
for review in reviews:
|
||||
yield "review", review
|
||||
|
||||
|
||||
class GithubSubmitPendingReviewBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo: str = SchemaField(
|
||||
description="GitHub repository",
|
||||
placeholder="owner/repo",
|
||||
)
|
||||
pr_number: int = SchemaField(
|
||||
description="Pull request number",
|
||||
placeholder="123",
|
||||
)
|
||||
review_id: int = SchemaField(
|
||||
description="ID of the pending review to submit",
|
||||
placeholder="123456",
|
||||
)
|
||||
event: ReviewEvent = SchemaField(
|
||||
description="The review action to perform when submitting",
|
||||
default=ReviewEvent.COMMENT,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
state: str = SchemaField(description="State of the submitted review")
|
||||
html_url: str = SchemaField(description="URL of the submitted review")
|
||||
error: str = SchemaField(
|
||||
description="Error message if the review submission failed"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="2e468217-7ca0-4201-9553-36e93eb9357a",
|
||||
description="This block submits a pending (draft) review on a GitHub pull request.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubSubmitPendingReviewBlock.Input,
|
||||
output_schema=GithubSubmitPendingReviewBlock.Output,
|
||||
test_input={
|
||||
"repo": "owner/repo",
|
||||
"pr_number": 1,
|
||||
"review_id": 123456,
|
||||
"event": "APPROVE",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("state", "APPROVED"),
|
||||
(
|
||||
"html_url",
|
||||
"https://github.com/owner/repo/pull/1#pullrequestreview-123456",
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"submit_review": lambda *args, **kwargs: (
|
||||
"APPROVED",
|
||||
"https://github.com/owner/repo/pull/1#pullrequestreview-123456",
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def submit_review(
|
||||
credentials: GithubCredentials,
|
||||
repo: str,
|
||||
pr_number: int,
|
||||
review_id: int,
|
||||
event: ReviewEvent,
|
||||
) -> tuple[str, str]:
|
||||
api = get_api(credentials, convert_urls=False)
|
||||
|
||||
# GitHub API endpoint for submitting a review
|
||||
submit_url = f"https://api.github.com/repos/{repo}/pulls/{pr_number}/reviews/{review_id}/events"
|
||||
|
||||
data = {"event": event.value}
|
||||
|
||||
response = await api.post(submit_url, json=data)
|
||||
review_data = response.json()
|
||||
|
||||
return review_data["state"], review_data["html_url"]
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
state, html_url = await self.submit_review(
|
||||
credentials,
|
||||
input_data.repo,
|
||||
input_data.pr_number,
|
||||
input_data.review_id,
|
||||
input_data.event,
|
||||
)
|
||||
yield "state", state
|
||||
yield "html_url", html_url
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class GithubResolveReviewDiscussionBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo: str = SchemaField(
|
||||
description="GitHub repository",
|
||||
placeholder="owner/repo",
|
||||
)
|
||||
pr_number: int = SchemaField(
|
||||
description="Pull request number",
|
||||
placeholder="123",
|
||||
)
|
||||
comment_id: int = SchemaField(
|
||||
description="ID of the review comment to resolve/unresolve",
|
||||
placeholder="123456",
|
||||
)
|
||||
resolve: bool = SchemaField(
|
||||
description="Whether to resolve (true) or unresolve (false) the discussion",
|
||||
default=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
success: bool = SchemaField(description="Whether the operation was successful")
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="b4b8a38c-95ae-4c91-9ef8-c2cffaf2b5d1",
|
||||
description="This block resolves or unresolves a review discussion thread on a GitHub pull request.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubResolveReviewDiscussionBlock.Input,
|
||||
output_schema=GithubResolveReviewDiscussionBlock.Output,
|
||||
test_input={
|
||||
"repo": "owner/repo",
|
||||
"pr_number": 1,
|
||||
"comment_id": 123456,
|
||||
"resolve": True,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("success", True),
|
||||
],
|
||||
test_mock={"resolve_discussion": lambda *args, **kwargs: True},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def resolve_discussion(
|
||||
credentials: GithubCredentials,
|
||||
repo: str,
|
||||
pr_number: int,
|
||||
comment_id: int,
|
||||
resolve: bool,
|
||||
) -> bool:
|
||||
api = get_api(credentials, convert_urls=False)
|
||||
|
||||
# Extract owner and repo name
|
||||
parts = repo.split("/")
|
||||
owner = parts[0]
|
||||
repo_name = parts[1]
|
||||
|
||||
# GitHub GraphQL API is needed for resolving/unresolving discussions
|
||||
# First, we need to get the node ID of the comment
|
||||
graphql_url = "https://api.github.com/graphql"
|
||||
|
||||
# Query to get the review comment node ID
|
||||
query = """
|
||||
query($owner: String!, $repo: String!, $number: Int!) {
|
||||
repository(owner: $owner, name: $repo) {
|
||||
pullRequest(number: $number) {
|
||||
reviewThreads(first: 100) {
|
||||
nodes {
|
||||
comments(first: 100) {
|
||||
nodes {
|
||||
databaseId
|
||||
id
|
||||
}
|
||||
}
|
||||
id
|
||||
isResolved
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
variables = {"owner": owner, "repo": repo_name, "number": pr_number}
|
||||
|
||||
response = await api.post(
|
||||
graphql_url, json={"query": query, "variables": variables}
|
||||
)
|
||||
data = response.json()
|
||||
|
||||
# Find the thread containing our comment
|
||||
thread_id = None
|
||||
for thread in data["data"]["repository"]["pullRequest"]["reviewThreads"][
|
||||
"nodes"
|
||||
]:
|
||||
for comment in thread["comments"]["nodes"]:
|
||||
if comment["databaseId"] == comment_id:
|
||||
thread_id = thread["id"]
|
||||
break
|
||||
if thread_id:
|
||||
break
|
||||
|
||||
if not thread_id:
|
||||
raise ValueError(f"Comment {comment_id} not found in pull request")
|
||||
|
||||
# Now resolve or unresolve the thread
|
||||
# GitHub's GraphQL API has separate mutations for resolve and unresolve
|
||||
if resolve:
|
||||
mutation = """
|
||||
mutation($threadId: ID!) {
|
||||
resolveReviewThread(input: {threadId: $threadId}) {
|
||||
thread {
|
||||
isResolved
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
else:
|
||||
mutation = """
|
||||
mutation($threadId: ID!) {
|
||||
unresolveReviewThread(input: {threadId: $threadId}) {
|
||||
thread {
|
||||
isResolved
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
mutation_variables = {"threadId": thread_id}
|
||||
|
||||
response = await api.post(
|
||||
graphql_url, json={"query": mutation, "variables": mutation_variables}
|
||||
)
|
||||
result = response.json()
|
||||
|
||||
if "errors" in result:
|
||||
raise Exception(f"GraphQL error: {result['errors']}")
|
||||
|
||||
return True
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
success = await self.resolve_discussion(
|
||||
credentials,
|
||||
input_data.repo,
|
||||
input_data.pr_number,
|
||||
input_data.comment_id,
|
||||
input_data.resolve,
|
||||
)
|
||||
yield "success", success
|
||||
except Exception as e:
|
||||
yield "success", False
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class GithubGetPRReviewCommentsBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo: str = SchemaField(
|
||||
description="GitHub repository",
|
||||
placeholder="owner/repo",
|
||||
)
|
||||
pr_number: int = SchemaField(
|
||||
description="Pull request number",
|
||||
placeholder="123",
|
||||
)
|
||||
review_id: Optional[int] = SchemaField(
|
||||
description="ID of a specific review to get comments from (optional)",
|
||||
placeholder="123456",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class CommentItem(TypedDict):
|
||||
id: int
|
||||
user: str
|
||||
body: str
|
||||
path: str
|
||||
line: int
|
||||
side: str
|
||||
created_at: str
|
||||
updated_at: str
|
||||
in_reply_to_id: Optional[int]
|
||||
html_url: str
|
||||
|
||||
comment: CommentItem = SchemaField(
|
||||
title="Comment",
|
||||
description="Individual review comment with details",
|
||||
)
|
||||
comments: list[CommentItem] = SchemaField(
|
||||
description="List of all review comments on the pull request"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if getting comments failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="1d34db7f-10c1-45c1-9d43-749f743c8bd4",
|
||||
description="This block gets all review comments from a GitHub pull request or from a specific review.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubGetPRReviewCommentsBlock.Input,
|
||||
output_schema=GithubGetPRReviewCommentsBlock.Output,
|
||||
test_input={
|
||||
"repo": "owner/repo",
|
||||
"pr_number": 1,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
(
|
||||
"comments",
|
||||
[
|
||||
{
|
||||
"id": 123456,
|
||||
"user": "reviewer1",
|
||||
"body": "This needs improvement",
|
||||
"path": "src/main.py",
|
||||
"line": 42,
|
||||
"side": "RIGHT",
|
||||
"created_at": "2024-01-01T00:00:00Z",
|
||||
"updated_at": "2024-01-01T00:00:00Z",
|
||||
"in_reply_to_id": None,
|
||||
"html_url": "https://github.com/owner/repo/pull/1#discussion_r123456",
|
||||
}
|
||||
],
|
||||
),
|
||||
(
|
||||
"comment",
|
||||
{
|
||||
"id": 123456,
|
||||
"user": "reviewer1",
|
||||
"body": "This needs improvement",
|
||||
"path": "src/main.py",
|
||||
"line": 42,
|
||||
"side": "RIGHT",
|
||||
"created_at": "2024-01-01T00:00:00Z",
|
||||
"updated_at": "2024-01-01T00:00:00Z",
|
||||
"in_reply_to_id": None,
|
||||
"html_url": "https://github.com/owner/repo/pull/1#discussion_r123456",
|
||||
},
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"get_comments": lambda *args, **kwargs: [
|
||||
{
|
||||
"id": 123456,
|
||||
"user": "reviewer1",
|
||||
"body": "This needs improvement",
|
||||
"path": "src/main.py",
|
||||
"line": 42,
|
||||
"side": "RIGHT",
|
||||
"created_at": "2024-01-01T00:00:00Z",
|
||||
"updated_at": "2024-01-01T00:00:00Z",
|
||||
"in_reply_to_id": None,
|
||||
"html_url": "https://github.com/owner/repo/pull/1#discussion_r123456",
|
||||
}
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def get_comments(
|
||||
credentials: GithubCredentials,
|
||||
repo: str,
|
||||
pr_number: int,
|
||||
review_id: Optional[int] = None,
|
||||
) -> list[Output.CommentItem]:
|
||||
api = get_api(credentials, convert_urls=False)
|
||||
|
||||
# Determine the endpoint based on whether we want comments from a specific review
|
||||
if review_id:
|
||||
# Get comments from a specific review
|
||||
comments_url = f"https://api.github.com/repos/{repo}/pulls/{pr_number}/reviews/{review_id}/comments"
|
||||
else:
|
||||
# Get all review comments on the PR
|
||||
comments_url = (
|
||||
f"https://api.github.com/repos/{repo}/pulls/{pr_number}/comments"
|
||||
)
|
||||
|
||||
response = await api.get(comments_url)
|
||||
data = response.json()
|
||||
|
||||
comments: list[GithubGetPRReviewCommentsBlock.Output.CommentItem] = [
|
||||
{
|
||||
"id": comment["id"],
|
||||
"user": comment["user"]["login"],
|
||||
"body": comment["body"],
|
||||
"path": comment.get("path", ""),
|
||||
"line": comment.get("line", 0),
|
||||
"side": comment.get("side", ""),
|
||||
"created_at": comment["created_at"],
|
||||
"updated_at": comment["updated_at"],
|
||||
"in_reply_to_id": comment.get("in_reply_to_id"),
|
||||
"html_url": comment["html_url"],
|
||||
}
|
||||
for comment in data
|
||||
]
|
||||
return comments
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
comments = await self.get_comments(
|
||||
credentials,
|
||||
input_data.repo,
|
||||
input_data.pr_number,
|
||||
input_data.review_id,
|
||||
)
|
||||
yield "comments", comments
|
||||
for comment in comments:
|
||||
yield "comment", comment
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class GithubCreateCommentObjectBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
path: str = SchemaField(
|
||||
description="The file path to comment on",
|
||||
placeholder="src/main.py",
|
||||
)
|
||||
body: str = SchemaField(
|
||||
description="The comment text",
|
||||
placeholder="Please fix this issue",
|
||||
)
|
||||
position: Optional[int] = SchemaField(
|
||||
description="Position in the diff (line number from first @@ hunk). Use this OR line.",
|
||||
placeholder="6",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
line: Optional[int] = SchemaField(
|
||||
description="Line number in the file (will be used as position if position not provided)",
|
||||
placeholder="42",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
side: Optional[str] = SchemaField(
|
||||
description="Side of the diff to comment on (NOTE: Only for standalone comments, not review comments)",
|
||||
default="RIGHT",
|
||||
advanced=True,
|
||||
)
|
||||
start_line: Optional[int] = SchemaField(
|
||||
description="Start line for multi-line comments (NOTE: Only for standalone comments, not review comments)",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
start_side: Optional[str] = SchemaField(
|
||||
description="Side for the start of multi-line comments (NOTE: Only for standalone comments, not review comments)",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
comment_object: dict = SchemaField(
|
||||
description="The comment object formatted for GitHub API"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="b7d5e4f2-8c3a-4e6b-9f1d-7a8b9c5e4d3f",
|
||||
description="Creates a comment object for use with GitHub blocks. Note: For review comments, only path, body, and position are used. Side fields are only for standalone PR comments.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubCreateCommentObjectBlock.Input,
|
||||
output_schema=GithubCreateCommentObjectBlock.Output,
|
||||
test_input={
|
||||
"path": "src/main.py",
|
||||
"body": "Please fix this issue",
|
||||
"position": 6,
|
||||
},
|
||||
test_output=[
|
||||
(
|
||||
"comment_object",
|
||||
{
|
||||
"path": "src/main.py",
|
||||
"body": "Please fix this issue",
|
||||
"position": 6,
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
# Build the comment object
|
||||
comment_obj: dict = {
|
||||
"path": input_data.path,
|
||||
"body": input_data.body,
|
||||
}
|
||||
|
||||
# Add position or line
|
||||
if input_data.position is not None:
|
||||
comment_obj["position"] = input_data.position
|
||||
elif input_data.line is not None:
|
||||
# Note: line will be used as position, which may not be accurate
|
||||
# Position should be calculated from the diff
|
||||
comment_obj["position"] = input_data.line
|
||||
|
||||
# Add optional fields only if they differ from defaults or are explicitly provided
|
||||
if input_data.side and input_data.side != "RIGHT":
|
||||
comment_obj["side"] = input_data.side
|
||||
if input_data.start_line is not None:
|
||||
comment_obj["start_line"] = input_data.start_line
|
||||
if input_data.start_side:
|
||||
comment_obj["start_side"] = input_data.start_side
|
||||
|
||||
yield "comment_object", comment_obj
|
||||
@@ -21,8 +21,6 @@ from ._auth import (
|
||||
GoogleCredentialsInput,
|
||||
)
|
||||
|
||||
settings = Settings()
|
||||
|
||||
|
||||
class CalendarEvent(BaseModel):
|
||||
"""Structured representation of a Google Calendar event."""
|
||||
@@ -223,8 +221,8 @@ class GoogleCalendarReadEventsBlock(Block):
|
||||
else None
|
||||
),
|
||||
token_uri="https://oauth2.googleapis.com/token",
|
||||
client_id=settings.secrets.google_client_id,
|
||||
client_secret=settings.secrets.google_client_secret,
|
||||
client_id=Settings().secrets.google_client_id,
|
||||
client_secret=Settings().secrets.google_client_secret,
|
||||
scopes=credentials.scopes,
|
||||
)
|
||||
return build("calendar", "v3", credentials=creds)
|
||||
@@ -571,8 +569,8 @@ class GoogleCalendarCreateEventBlock(Block):
|
||||
else None
|
||||
),
|
||||
token_uri="https://oauth2.googleapis.com/token",
|
||||
client_id=settings.secrets.google_client_id,
|
||||
client_secret=settings.secrets.google_client_secret,
|
||||
client_id=Settings().secrets.google_client_id,
|
||||
client_secret=Settings().secrets.google_client_secret,
|
||||
scopes=credentials.scopes,
|
||||
)
|
||||
return build("calendar", "v3", credentials=creds)
|
||||
|
||||
@@ -21,8 +21,6 @@ from ._auth import (
|
||||
GoogleCredentialsInput,
|
||||
)
|
||||
|
||||
settings = Settings()
|
||||
|
||||
|
||||
def serialize_email_recipients(recipients: list[str]) -> str:
|
||||
"""Serialize recipients list to comma-separated string."""
|
||||
@@ -257,8 +255,8 @@ class GmailReadBlock(Block):
|
||||
else None
|
||||
),
|
||||
token_uri="https://oauth2.googleapis.com/token",
|
||||
client_id=settings.secrets.google_client_id,
|
||||
client_secret=settings.secrets.google_client_secret,
|
||||
client_id=Settings().secrets.google_client_id,
|
||||
client_secret=Settings().secrets.google_client_secret,
|
||||
scopes=credentials.scopes,
|
||||
)
|
||||
return build("gmail", "v1", credentials=creds)
|
||||
|
||||
@@ -81,11 +81,6 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
|
||||
O3 = "o3-2025-04-16"
|
||||
O1 = "o1"
|
||||
O1_MINI = "o1-mini"
|
||||
# GPT-5 models
|
||||
GPT5 = "gpt-5-2025-08-07"
|
||||
GPT5_MINI = "gpt-5-mini-2025-08-07"
|
||||
GPT5_NANO = "gpt-5-nano-2025-08-07"
|
||||
GPT5_CHAT = "gpt-5-chat-latest"
|
||||
GPT41 = "gpt-4.1-2025-04-14"
|
||||
GPT41_MINI = "gpt-4.1-mini-2025-04-14"
|
||||
GPT4O_MINI = "gpt-4o-mini"
|
||||
@@ -93,7 +88,6 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
|
||||
GPT4_TURBO = "gpt-4-turbo"
|
||||
GPT3_5_TURBO = "gpt-3.5-turbo"
|
||||
# Anthropic models
|
||||
CLAUDE_4_1_OPUS = "claude-opus-4-1-20250805"
|
||||
CLAUDE_4_OPUS = "claude-opus-4-20250514"
|
||||
CLAUDE_4_SONNET = "claude-sonnet-4-20250514"
|
||||
CLAUDE_3_7_SONNET = "claude-3-7-sonnet-20250219"
|
||||
@@ -179,11 +173,6 @@ MODEL_METADATA = {
|
||||
LlmModel.O3_MINI: ModelMetadata("openai", 200000, 100000), # o3-mini-2025-01-31
|
||||
LlmModel.O1: ModelMetadata("openai", 200000, 100000), # o1-2024-12-17
|
||||
LlmModel.O1_MINI: ModelMetadata("openai", 128000, 65536), # o1-mini-2024-09-12
|
||||
# GPT-5 models
|
||||
LlmModel.GPT5: ModelMetadata("openai", 400000, 128000),
|
||||
LlmModel.GPT5_MINI: ModelMetadata("openai", 400000, 128000),
|
||||
LlmModel.GPT5_NANO: ModelMetadata("openai", 400000, 128000),
|
||||
LlmModel.GPT5_CHAT: ModelMetadata("openai", 400000, 16384),
|
||||
LlmModel.GPT41: ModelMetadata("openai", 1047576, 32768),
|
||||
LlmModel.GPT41_MINI: ModelMetadata("openai", 1047576, 32768),
|
||||
LlmModel.GPT4O_MINI: ModelMetadata(
|
||||
@@ -195,9 +184,6 @@ MODEL_METADATA = {
|
||||
), # gpt-4-turbo-2024-04-09
|
||||
LlmModel.GPT3_5_TURBO: ModelMetadata("openai", 16385, 4096), # gpt-3.5-turbo-0125
|
||||
# https://docs.anthropic.com/en/docs/about-claude/models
|
||||
LlmModel.CLAUDE_4_1_OPUS: ModelMetadata(
|
||||
"anthropic", 200000, 32000
|
||||
), # claude-opus-4-1-20250805
|
||||
LlmModel.CLAUDE_4_OPUS: ModelMetadata(
|
||||
"anthropic", 200000, 8192
|
||||
), # claude-4-opus-20250514
|
||||
@@ -493,7 +479,6 @@ async def llm_call(
|
||||
messages=messages,
|
||||
max_tokens=max_tokens,
|
||||
tools=an_tools,
|
||||
timeout=600,
|
||||
)
|
||||
|
||||
if not resp.content:
|
||||
|
||||
@@ -3,7 +3,8 @@ from typing import List
|
||||
|
||||
from backend.data.block import BlockOutput, BlockSchema
|
||||
from backend.data.model import APIKeyCredentials, SchemaField
|
||||
from backend.util.settings import BehaveAs, Settings
|
||||
from backend.util import settings
|
||||
from backend.util.settings import BehaveAs
|
||||
|
||||
from ._api import (
|
||||
TEST_CREDENTIALS,
|
||||
@@ -15,8 +16,6 @@ from ._api import (
|
||||
)
|
||||
from .base import Slant3DBlockBase
|
||||
|
||||
settings = Settings()
|
||||
|
||||
|
||||
class Slant3DCreateOrderBlock(Slant3DBlockBase):
|
||||
"""Block for creating new orders"""
|
||||
@@ -281,7 +280,7 @@ class Slant3DGetOrdersBlock(Slant3DBlockBase):
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
# This block is disabled for cloud hosted because it allows access to all orders for the account
|
||||
disabled=settings.config.behave_as == BehaveAs.CLOUD,
|
||||
disabled=settings.Settings().config.behave_as == BehaveAs.CLOUD,
|
||||
test_input={"credentials": TEST_CREDENTIALS_INPUT},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
|
||||
@@ -9,7 +9,8 @@ from backend.data.block import (
|
||||
)
|
||||
from backend.data.model import SchemaField
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.settings import AppEnvironment, BehaveAs, Settings
|
||||
from backend.util import settings
|
||||
from backend.util.settings import AppEnvironment, BehaveAs
|
||||
|
||||
from ._api import (
|
||||
TEST_CREDENTIALS,
|
||||
@@ -18,8 +19,6 @@ from ._api import (
|
||||
Slant3DCredentialsInput,
|
||||
)
|
||||
|
||||
settings = Settings()
|
||||
|
||||
|
||||
class Slant3DTriggerBase:
|
||||
"""Base class for Slant3D webhook triggers"""
|
||||
@@ -77,8 +76,8 @@ class Slant3DOrderWebhookBlock(Slant3DTriggerBase, Block):
|
||||
),
|
||||
# All webhooks are currently subscribed to for all orders. This works for self hosted, but not for cloud hosted prod
|
||||
disabled=(
|
||||
settings.config.behave_as == BehaveAs.CLOUD
|
||||
and settings.config.app_env != AppEnvironment.LOCAL
|
||||
settings.Settings().config.behave_as == BehaveAs.CLOUD
|
||||
and settings.Settings().config.app_env != AppEnvironment.LOCAL
|
||||
),
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=self.Input,
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import logging
|
||||
|
||||
import pytest
|
||||
from prisma.models import User
|
||||
|
||||
from backend.data.model import ProviderName, User
|
||||
from backend.data.model import ProviderName
|
||||
from backend.server.model import CreateGraph
|
||||
from backend.server.rest_api import AgentServer
|
||||
from backend.usecases.sample import create_test_graph, create_test_user
|
||||
|
||||
@@ -1,78 +1,19 @@
|
||||
import asyncio
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Literal, Union
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing import Any, Union
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
# Shared timezone literal type for all time/date blocks
|
||||
TimezoneLiteral = Literal[
|
||||
"UTC", # UTC±00:00
|
||||
"Pacific/Honolulu", # UTC-10:00
|
||||
"America/Anchorage", # UTC-09:00 (Alaska)
|
||||
"America/Los_Angeles", # UTC-08:00 (Pacific)
|
||||
"America/Denver", # UTC-07:00 (Mountain)
|
||||
"America/Chicago", # UTC-06:00 (Central)
|
||||
"America/New_York", # UTC-05:00 (Eastern)
|
||||
"America/Caracas", # UTC-04:00
|
||||
"America/Sao_Paulo", # UTC-03:00
|
||||
"America/St_Johns", # UTC-02:30 (Newfoundland)
|
||||
"Atlantic/South_Georgia", # UTC-02:00
|
||||
"Atlantic/Azores", # UTC-01:00
|
||||
"Europe/London", # UTC+00:00 (GMT/BST)
|
||||
"Europe/Paris", # UTC+01:00 (CET)
|
||||
"Europe/Athens", # UTC+02:00 (EET)
|
||||
"Europe/Moscow", # UTC+03:00
|
||||
"Asia/Tehran", # UTC+03:30 (Iran)
|
||||
"Asia/Dubai", # UTC+04:00
|
||||
"Asia/Kabul", # UTC+04:30 (Afghanistan)
|
||||
"Asia/Karachi", # UTC+05:00 (Pakistan)
|
||||
"Asia/Kolkata", # UTC+05:30 (India)
|
||||
"Asia/Kathmandu", # UTC+05:45 (Nepal)
|
||||
"Asia/Dhaka", # UTC+06:00 (Bangladesh)
|
||||
"Asia/Yangon", # UTC+06:30 (Myanmar)
|
||||
"Asia/Bangkok", # UTC+07:00
|
||||
"Asia/Shanghai", # UTC+08:00 (China)
|
||||
"Australia/Eucla", # UTC+08:45
|
||||
"Asia/Tokyo", # UTC+09:00 (Japan)
|
||||
"Australia/Adelaide", # UTC+09:30
|
||||
"Australia/Sydney", # UTC+10:00
|
||||
"Australia/Lord_Howe", # UTC+10:30
|
||||
"Pacific/Noumea", # UTC+11:00
|
||||
"Pacific/Auckland", # UTC+12:00 (New Zealand)
|
||||
"Pacific/Chatham", # UTC+12:45
|
||||
"Pacific/Tongatapu", # UTC+13:00
|
||||
"Pacific/Kiritimati", # UTC+14:00
|
||||
"Etc/GMT-12", # UTC+12:00
|
||||
"Etc/GMT+12", # UTC-12:00
|
||||
]
|
||||
|
||||
|
||||
class TimeStrftimeFormat(BaseModel):
|
||||
discriminator: Literal["strftime"]
|
||||
format: str = "%H:%M:%S"
|
||||
timezone: TimezoneLiteral = "UTC"
|
||||
|
||||
|
||||
class TimeISO8601Format(BaseModel):
|
||||
discriminator: Literal["iso8601"]
|
||||
timezone: TimezoneLiteral = "UTC"
|
||||
include_microseconds: bool = False
|
||||
|
||||
|
||||
class GetCurrentTimeBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
trigger: str = SchemaField(
|
||||
description="Trigger any data to output the current time"
|
||||
)
|
||||
format_type: Union[TimeStrftimeFormat, TimeISO8601Format] = SchemaField(
|
||||
discriminator="discriminator",
|
||||
description="Format type for time output (strftime with custom format or ISO 8601)",
|
||||
default=TimeStrftimeFormat(discriminator="strftime"),
|
||||
format: str = SchemaField(
|
||||
description="Format of the time to output", default="%H:%M:%S"
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
@@ -89,65 +30,19 @@ class GetCurrentTimeBlock(Block):
|
||||
output_schema=GetCurrentTimeBlock.Output,
|
||||
test_input=[
|
||||
{"trigger": "Hello"},
|
||||
{
|
||||
"trigger": "Hello",
|
||||
"format_type": {
|
||||
"discriminator": "strftime",
|
||||
"format": "%H:%M",
|
||||
},
|
||||
},
|
||||
{
|
||||
"trigger": "Hello",
|
||||
"format_type": {
|
||||
"discriminator": "iso8601",
|
||||
"timezone": "UTC",
|
||||
"include_microseconds": False,
|
||||
},
|
||||
},
|
||||
{"trigger": "Hello", "format": "%H:%M"},
|
||||
],
|
||||
test_output=[
|
||||
("time", lambda _: time.strftime("%H:%M:%S")),
|
||||
("time", lambda _: time.strftime("%H:%M")),
|
||||
(
|
||||
"time",
|
||||
lambda t: "T" in t and ("+" in t or "Z" in t),
|
||||
), # Check for ISO format with timezone
|
||||
],
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
if isinstance(input_data.format_type, TimeISO8601Format):
|
||||
# ISO 8601 format for time only (extract time portion from full ISO datetime)
|
||||
tz = ZoneInfo(input_data.format_type.timezone)
|
||||
dt = datetime.now(tz=tz)
|
||||
|
||||
# Get the full ISO format and extract just the time portion with timezone
|
||||
if input_data.format_type.include_microseconds:
|
||||
full_iso = dt.isoformat()
|
||||
else:
|
||||
full_iso = dt.isoformat(timespec="seconds")
|
||||
|
||||
# Extract time portion (everything after 'T')
|
||||
current_time = full_iso.split("T")[1] if "T" in full_iso else full_iso
|
||||
current_time = f"T{current_time}" # Add T prefix for ISO 8601 time format
|
||||
else: # TimeStrftimeFormat
|
||||
tz = ZoneInfo(input_data.format_type.timezone)
|
||||
dt = datetime.now(tz=tz)
|
||||
current_time = dt.strftime(input_data.format_type.format)
|
||||
current_time = time.strftime(input_data.format)
|
||||
yield "time", current_time
|
||||
|
||||
|
||||
class DateStrftimeFormat(BaseModel):
|
||||
discriminator: Literal["strftime"]
|
||||
format: str = "%Y-%m-%d"
|
||||
timezone: TimezoneLiteral = "UTC"
|
||||
|
||||
|
||||
class DateISO8601Format(BaseModel):
|
||||
discriminator: Literal["iso8601"]
|
||||
timezone: TimezoneLiteral = "UTC"
|
||||
|
||||
|
||||
class GetCurrentDateBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
trigger: str = SchemaField(
|
||||
@@ -158,10 +53,8 @@ class GetCurrentDateBlock(Block):
|
||||
description="Offset in days from the current date",
|
||||
default=0,
|
||||
)
|
||||
format_type: Union[DateStrftimeFormat, DateISO8601Format] = SchemaField(
|
||||
discriminator="discriminator",
|
||||
description="Format type for date output (strftime with custom format or ISO 8601)",
|
||||
default=DateStrftimeFormat(discriminator="strftime"),
|
||||
format: str = SchemaField(
|
||||
description="Format of the date to output", default="%Y-%m-%d"
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
@@ -178,22 +71,7 @@ class GetCurrentDateBlock(Block):
|
||||
output_schema=GetCurrentDateBlock.Output,
|
||||
test_input=[
|
||||
{"trigger": "Hello", "offset": "7"},
|
||||
{
|
||||
"trigger": "Hello",
|
||||
"offset": "7",
|
||||
"format_type": {
|
||||
"discriminator": "strftime",
|
||||
"format": "%m/%d/%Y",
|
||||
},
|
||||
},
|
||||
{
|
||||
"trigger": "Hello",
|
||||
"offset": "0",
|
||||
"format_type": {
|
||||
"discriminator": "iso8601",
|
||||
"timezone": "UTC",
|
||||
},
|
||||
},
|
||||
{"trigger": "Hello", "offset": "7", "format": "%m/%d/%Y"},
|
||||
],
|
||||
test_output=[
|
||||
(
|
||||
@@ -207,12 +85,6 @@ class GetCurrentDateBlock(Block):
|
||||
< timedelta(days=8),
|
||||
# 7 days difference + 1 day error margin.
|
||||
),
|
||||
(
|
||||
"date",
|
||||
lambda t: len(t) == 10
|
||||
and t[4] == "-"
|
||||
and t[7] == "-", # ISO date format YYYY-MM-DD
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
@@ -221,31 +93,8 @@ class GetCurrentDateBlock(Block):
|
||||
offset = int(input_data.offset)
|
||||
except ValueError:
|
||||
offset = 0
|
||||
|
||||
if isinstance(input_data.format_type, DateISO8601Format):
|
||||
# ISO 8601 format for date only (YYYY-MM-DD)
|
||||
tz = ZoneInfo(input_data.format_type.timezone)
|
||||
current_date = datetime.now(tz=tz) - timedelta(days=offset)
|
||||
# ISO 8601 date format is YYYY-MM-DD
|
||||
date_str = current_date.date().isoformat()
|
||||
else: # DateStrftimeFormat
|
||||
tz = ZoneInfo(input_data.format_type.timezone)
|
||||
current_date = datetime.now(tz=tz) - timedelta(days=offset)
|
||||
date_str = current_date.strftime(input_data.format_type.format)
|
||||
|
||||
yield "date", date_str
|
||||
|
||||
|
||||
class StrftimeFormat(BaseModel):
|
||||
discriminator: Literal["strftime"]
|
||||
format: str = "%Y-%m-%d %H:%M:%S"
|
||||
timezone: TimezoneLiteral = "UTC"
|
||||
|
||||
|
||||
class ISO8601Format(BaseModel):
|
||||
discriminator: Literal["iso8601"]
|
||||
timezone: TimezoneLiteral = "UTC"
|
||||
include_microseconds: bool = False
|
||||
current_date = datetime.now() - timedelta(days=offset)
|
||||
yield "date", current_date.strftime(input_data.format)
|
||||
|
||||
|
||||
class GetCurrentDateAndTimeBlock(Block):
|
||||
@@ -253,10 +102,9 @@ class GetCurrentDateAndTimeBlock(Block):
|
||||
trigger: str = SchemaField(
|
||||
description="Trigger any data to output the current date and time"
|
||||
)
|
||||
format_type: Union[StrftimeFormat, ISO8601Format] = SchemaField(
|
||||
discriminator="discriminator",
|
||||
description="Format type for date and time output (strftime with custom format or ISO 8601/RFC 3339)",
|
||||
default=StrftimeFormat(discriminator="strftime"),
|
||||
format: str = SchemaField(
|
||||
description="Format of the date and time to output",
|
||||
default="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
@@ -273,63 +121,20 @@ class GetCurrentDateAndTimeBlock(Block):
|
||||
output_schema=GetCurrentDateAndTimeBlock.Output,
|
||||
test_input=[
|
||||
{"trigger": "Hello"},
|
||||
{
|
||||
"trigger": "Hello",
|
||||
"format_type": {
|
||||
"discriminator": "strftime",
|
||||
"format": "%Y/%m/%d",
|
||||
},
|
||||
},
|
||||
{
|
||||
"trigger": "Hello",
|
||||
"format_type": {
|
||||
"discriminator": "iso8601",
|
||||
"timezone": "UTC",
|
||||
"include_microseconds": False,
|
||||
},
|
||||
},
|
||||
],
|
||||
test_output=[
|
||||
(
|
||||
"date_time",
|
||||
lambda t: abs(
|
||||
datetime.now(tz=ZoneInfo("UTC"))
|
||||
- datetime.strptime(t + "+00:00", "%Y-%m-%d %H:%M:%S%z")
|
||||
datetime.now() - datetime.strptime(t, "%Y-%m-%d %H:%M:%S")
|
||||
)
|
||||
< timedelta(seconds=10), # 10 seconds error margin.
|
||||
),
|
||||
(
|
||||
"date_time",
|
||||
lambda t: abs(
|
||||
datetime.now().date() - datetime.strptime(t, "%Y/%m/%d").date()
|
||||
)
|
||||
< timedelta(days=1), # Date format only, no time component
|
||||
),
|
||||
(
|
||||
"date_time",
|
||||
lambda t: abs(
|
||||
datetime.now(tz=ZoneInfo("UTC")) - datetime.fromisoformat(t)
|
||||
)
|
||||
< timedelta(seconds=10), # 10 seconds error margin for ISO format.
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
if isinstance(input_data.format_type, ISO8601Format):
|
||||
# ISO 8601 format with specified timezone (also RFC3339-compliant)
|
||||
tz = ZoneInfo(input_data.format_type.timezone)
|
||||
dt = datetime.now(tz=tz)
|
||||
|
||||
# Format with or without microseconds
|
||||
if input_data.format_type.include_microseconds:
|
||||
current_date_time = dt.isoformat()
|
||||
else:
|
||||
current_date_time = dt.isoformat(timespec="seconds")
|
||||
else: # StrftimeFormat
|
||||
tz = ZoneInfo(input_data.format_type.timezone)
|
||||
dt = datetime.now(tz=tz)
|
||||
current_date_time = dt.strftime(input_data.format_type.format)
|
||||
current_date_time = time.strftime(input_data.format)
|
||||
yield "date_time", current_date_time
|
||||
|
||||
|
||||
|
||||
@@ -48,18 +48,12 @@ MODEL_COST: dict[LlmModel, int] = {
|
||||
LlmModel.O3_MINI: 2, # $1.10 / $4.40
|
||||
LlmModel.O1: 16, # $15 / $60
|
||||
LlmModel.O1_MINI: 4,
|
||||
# GPT-5 models
|
||||
LlmModel.GPT5: 2,
|
||||
LlmModel.GPT5_MINI: 1,
|
||||
LlmModel.GPT5_NANO: 1,
|
||||
LlmModel.GPT5_CHAT: 2,
|
||||
LlmModel.GPT41: 2,
|
||||
LlmModel.GPT41_MINI: 1,
|
||||
LlmModel.GPT4O_MINI: 1,
|
||||
LlmModel.GPT4O: 3,
|
||||
LlmModel.GPT4_TURBO: 10,
|
||||
LlmModel.GPT3_5_TURBO: 1,
|
||||
LlmModel.CLAUDE_4_1_OPUS: 21,
|
||||
LlmModel.CLAUDE_4_OPUS: 21,
|
||||
LlmModel.CLAUDE_4_SONNET: 5,
|
||||
LlmModel.CLAUDE_3_7_SONNET: 5,
|
||||
|
||||
@@ -998,8 +998,8 @@ def get_block_costs() -> dict[str, list[BlockCost]]:
|
||||
async def get_stripe_customer_id(user_id: str) -> str:
|
||||
user = await get_user_by_id(user_id)
|
||||
|
||||
if user.stripe_customer_id:
|
||||
return user.stripe_customer_id
|
||||
if user.stripeCustomerId:
|
||||
return user.stripeCustomerId
|
||||
|
||||
customer = stripe.Customer.create(
|
||||
name=user.name or "",
|
||||
@@ -1022,10 +1022,10 @@ async def set_auto_top_up(user_id: str, config: AutoTopUpConfig):
|
||||
async def get_auto_top_up(user_id: str) -> AutoTopUpConfig:
|
||||
user = await get_user_by_id(user_id)
|
||||
|
||||
if not user.top_up_config:
|
||||
if not user.topUpConfig:
|
||||
return AutoTopUpConfig(threshold=0, amount=0)
|
||||
|
||||
return AutoTopUpConfig.model_validate(user.top_up_config)
|
||||
return AutoTopUpConfig.model_validate(user.topUpConfig)
|
||||
|
||||
|
||||
async def admin_get_user_history(
|
||||
|
||||
@@ -5,7 +5,6 @@ import enum
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timezone
|
||||
from json import JSONDecodeError
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Annotated,
|
||||
@@ -41,120 +40,12 @@ from pydantic_core import (
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.json import loads as json_loads
|
||||
from backend.util.settings import Secrets
|
||||
|
||||
# Type alias for any provider name (including custom ones)
|
||||
AnyProviderName = str # Will be validated as ProviderName at runtime
|
||||
|
||||
|
||||
class User(BaseModel):
|
||||
"""Application-layer User model with snake_case convention."""
|
||||
|
||||
model_config = ConfigDict(
|
||||
extra="forbid",
|
||||
str_strip_whitespace=True,
|
||||
)
|
||||
|
||||
id: str = Field(..., description="User ID")
|
||||
email: str = Field(..., description="User email address")
|
||||
email_verified: bool = Field(default=True, description="Whether email is verified")
|
||||
name: Optional[str] = Field(None, description="User display name")
|
||||
created_at: datetime = Field(..., description="When user was created")
|
||||
updated_at: datetime = Field(..., description="When user was last updated")
|
||||
metadata: dict[str, Any] = Field(
|
||||
default_factory=dict, description="User metadata as dict"
|
||||
)
|
||||
integrations: str = Field(default="", description="Encrypted integrations data")
|
||||
stripe_customer_id: Optional[str] = Field(None, description="Stripe customer ID")
|
||||
top_up_config: Optional["AutoTopUpConfig"] = Field(
|
||||
None, description="Top up configuration"
|
||||
)
|
||||
|
||||
# Notification preferences
|
||||
max_emails_per_day: int = Field(default=3, description="Maximum emails per day")
|
||||
notify_on_agent_run: bool = Field(default=True, description="Notify on agent run")
|
||||
notify_on_zero_balance: bool = Field(
|
||||
default=True, description="Notify on zero balance"
|
||||
)
|
||||
notify_on_low_balance: bool = Field(
|
||||
default=True, description="Notify on low balance"
|
||||
)
|
||||
notify_on_block_execution_failed: bool = Field(
|
||||
default=True, description="Notify on block execution failure"
|
||||
)
|
||||
notify_on_continuous_agent_error: bool = Field(
|
||||
default=True, description="Notify on continuous agent error"
|
||||
)
|
||||
notify_on_daily_summary: bool = Field(
|
||||
default=True, description="Notify on daily summary"
|
||||
)
|
||||
notify_on_weekly_summary: bool = Field(
|
||||
default=True, description="Notify on weekly summary"
|
||||
)
|
||||
notify_on_monthly_summary: bool = Field(
|
||||
default=True, description="Notify on monthly summary"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, prisma_user: "PrismaUser") -> "User":
|
||||
"""Convert a database User object to application User model."""
|
||||
# Handle metadata field - convert from JSON string or dict to dict
|
||||
metadata = {}
|
||||
if prisma_user.metadata:
|
||||
if isinstance(prisma_user.metadata, str):
|
||||
try:
|
||||
metadata = json_loads(prisma_user.metadata)
|
||||
except (JSONDecodeError, TypeError):
|
||||
metadata = {}
|
||||
elif isinstance(prisma_user.metadata, dict):
|
||||
metadata = prisma_user.metadata
|
||||
|
||||
# Handle topUpConfig field
|
||||
top_up_config = None
|
||||
if prisma_user.topUpConfig:
|
||||
if isinstance(prisma_user.topUpConfig, str):
|
||||
try:
|
||||
config_dict = json_loads(prisma_user.topUpConfig)
|
||||
top_up_config = AutoTopUpConfig.model_validate(config_dict)
|
||||
except (JSONDecodeError, TypeError, ValueError):
|
||||
top_up_config = None
|
||||
elif isinstance(prisma_user.topUpConfig, dict):
|
||||
try:
|
||||
top_up_config = AutoTopUpConfig.model_validate(
|
||||
prisma_user.topUpConfig
|
||||
)
|
||||
except ValueError:
|
||||
top_up_config = None
|
||||
|
||||
return cls(
|
||||
id=prisma_user.id,
|
||||
email=prisma_user.email,
|
||||
email_verified=prisma_user.emailVerified or True,
|
||||
name=prisma_user.name,
|
||||
created_at=prisma_user.createdAt,
|
||||
updated_at=prisma_user.updatedAt,
|
||||
metadata=metadata,
|
||||
integrations=prisma_user.integrations or "",
|
||||
stripe_customer_id=prisma_user.stripeCustomerId,
|
||||
top_up_config=top_up_config,
|
||||
max_emails_per_day=prisma_user.maxEmailsPerDay or 3,
|
||||
notify_on_agent_run=prisma_user.notifyOnAgentRun or True,
|
||||
notify_on_zero_balance=prisma_user.notifyOnZeroBalance or True,
|
||||
notify_on_low_balance=prisma_user.notifyOnLowBalance or True,
|
||||
notify_on_block_execution_failed=prisma_user.notifyOnBlockExecutionFailed
|
||||
or True,
|
||||
notify_on_continuous_agent_error=prisma_user.notifyOnContinuousAgentError
|
||||
or True,
|
||||
notify_on_daily_summary=prisma_user.notifyOnDailySummary or True,
|
||||
notify_on_weekly_summary=prisma_user.notifyOnWeeklySummary or True,
|
||||
notify_on_monthly_summary=prisma_user.notifyOnMonthlySummary or True,
|
||||
)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from prisma.models import User as PrismaUser
|
||||
|
||||
from backend.data.block import BlockSchema
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
@@ -140,7 +140,6 @@ class SyncRabbitMQ(RabbitMQBase):
|
||||
socket_timeout=SOCKET_TIMEOUT,
|
||||
connection_attempts=CONNECTION_ATTEMPTS,
|
||||
retry_delay=RETRY_DELAY,
|
||||
heartbeat=300, # 5 minute timeout (heartbeats sent every 2.5 min)
|
||||
)
|
||||
|
||||
self._connection = pika.BlockingConnection(parameters)
|
||||
@@ -245,7 +244,6 @@ class AsyncRabbitMQ(RabbitMQBase):
|
||||
password=self.password,
|
||||
virtualhost=self.config.vhost.lstrip("/"),
|
||||
blocked_connection_timeout=BLOCKED_CONNECTION_TIMEOUT,
|
||||
heartbeat=300, # 5 minute timeout (heartbeats sent every 2.5 min)
|
||||
)
|
||||
self._channel = await self._connection.channel()
|
||||
await self._channel.set_qos(prefetch_count=1)
|
||||
|
||||
@@ -9,11 +9,11 @@ from urllib.parse import quote_plus
|
||||
from autogpt_libs.auth.models import DEFAULT_USER_ID
|
||||
from fastapi import HTTPException
|
||||
from prisma.enums import NotificationType
|
||||
from prisma.models import User as PrismaUser
|
||||
from prisma.models import User
|
||||
from prisma.types import JsonFilter, UserCreateInput, UserUpdateInput
|
||||
|
||||
from backend.data.db import prisma
|
||||
from backend.data.model import User, UserIntegrations, UserMetadata
|
||||
from backend.data.model import UserIntegrations, UserMetadata
|
||||
from backend.data.notifications import NotificationPreference, NotificationPreferenceDTO
|
||||
from backend.server.v2.store.exceptions import DatabaseError
|
||||
from backend.util.encryption import JSONCryptor
|
||||
@@ -21,7 +21,6 @@ from backend.util.json import SafeJson
|
||||
from backend.util.settings import Settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
settings = Settings()
|
||||
|
||||
|
||||
async def get_or_create_user(user_data: dict) -> User:
|
||||
@@ -44,7 +43,7 @@ async def get_or_create_user(user_data: dict) -> User:
|
||||
)
|
||||
)
|
||||
|
||||
return User.from_db(user)
|
||||
return User.model_validate(user)
|
||||
except Exception as e:
|
||||
raise DatabaseError(f"Failed to get or create user {user_data}: {e}") from e
|
||||
|
||||
@@ -53,7 +52,7 @@ async def get_user_by_id(user_id: str) -> User:
|
||||
user = await prisma.user.find_unique(where={"id": user_id})
|
||||
if not user:
|
||||
raise ValueError(f"User not found with ID: {user_id}")
|
||||
return User.from_db(user)
|
||||
return User.model_validate(user)
|
||||
|
||||
|
||||
async def get_user_email_by_id(user_id: str) -> Optional[str]:
|
||||
@@ -67,7 +66,7 @@ async def get_user_email_by_id(user_id: str) -> Optional[str]:
|
||||
async def get_user_by_email(email: str) -> Optional[User]:
|
||||
try:
|
||||
user = await prisma.user.find_unique(where={"email": email})
|
||||
return User.from_db(user) if user else None
|
||||
return User.model_validate(user) if user else None
|
||||
except Exception as e:
|
||||
raise DatabaseError(f"Failed to get user by email {email}: {e}") from e
|
||||
|
||||
@@ -91,11 +90,11 @@ async def create_default_user() -> Optional[User]:
|
||||
name="Default User",
|
||||
)
|
||||
)
|
||||
return User.from_db(user)
|
||||
return User.model_validate(user)
|
||||
|
||||
|
||||
async def get_user_integrations(user_id: str) -> UserIntegrations:
|
||||
user = await PrismaUser.prisma().find_unique_or_raise(
|
||||
user = await User.prisma().find_unique_or_raise(
|
||||
where={"id": user_id},
|
||||
)
|
||||
|
||||
@@ -110,7 +109,7 @@ async def get_user_integrations(user_id: str) -> UserIntegrations:
|
||||
|
||||
async def update_user_integrations(user_id: str, data: UserIntegrations):
|
||||
encrypted_data = JSONCryptor().encrypt(data.model_dump(exclude_none=True))
|
||||
await PrismaUser.prisma().update(
|
||||
await User.prisma().update(
|
||||
where={"id": user_id},
|
||||
data={"integrations": encrypted_data},
|
||||
)
|
||||
@@ -118,7 +117,7 @@ async def update_user_integrations(user_id: str, data: UserIntegrations):
|
||||
|
||||
async def migrate_and_encrypt_user_integrations():
|
||||
"""Migrate integration credentials and OAuth states from metadata to integrations column."""
|
||||
users = await PrismaUser.prisma().find_many(
|
||||
users = await User.prisma().find_many(
|
||||
where={
|
||||
"metadata": cast(
|
||||
JsonFilter,
|
||||
@@ -154,7 +153,7 @@ async def migrate_and_encrypt_user_integrations():
|
||||
raw_metadata.pop("integration_oauth_states", None)
|
||||
|
||||
# Update metadata without integration data
|
||||
await PrismaUser.prisma().update(
|
||||
await User.prisma().update(
|
||||
where={"id": user.id},
|
||||
data={"metadata": SafeJson(raw_metadata)},
|
||||
)
|
||||
@@ -162,7 +161,7 @@ async def migrate_and_encrypt_user_integrations():
|
||||
|
||||
async def get_active_user_ids_in_timerange(start_time: str, end_time: str) -> list[str]:
|
||||
try:
|
||||
users = await PrismaUser.prisma().find_many(
|
||||
users = await User.prisma().find_many(
|
||||
where={
|
||||
"AgentGraphExecutions": {
|
||||
"some": {
|
||||
@@ -192,7 +191,7 @@ async def get_active_users_ids() -> list[str]:
|
||||
|
||||
async def get_user_notification_preference(user_id: str) -> NotificationPreference:
|
||||
try:
|
||||
user = await PrismaUser.prisma().find_unique_or_raise(
|
||||
user = await User.prisma().find_unique_or_raise(
|
||||
where={"id": user_id},
|
||||
)
|
||||
|
||||
@@ -269,7 +268,7 @@ async def update_user_notification_preference(
|
||||
if data.daily_limit:
|
||||
update_data["maxEmailsPerDay"] = data.daily_limit
|
||||
|
||||
user = await PrismaUser.prisma().update(
|
||||
user = await User.prisma().update(
|
||||
where={"id": user_id},
|
||||
data=update_data,
|
||||
)
|
||||
@@ -307,7 +306,7 @@ async def update_user_notification_preference(
|
||||
async def set_user_email_verification(user_id: str, verified: bool) -> None:
|
||||
"""Set the email verification status for a user."""
|
||||
try:
|
||||
await PrismaUser.prisma().update(
|
||||
await User.prisma().update(
|
||||
where={"id": user_id},
|
||||
data={"emailVerified": verified},
|
||||
)
|
||||
@@ -320,7 +319,7 @@ async def set_user_email_verification(user_id: str, verified: bool) -> None:
|
||||
async def get_user_email_verification(user_id: str) -> bool:
|
||||
"""Get the email verification status for a user."""
|
||||
try:
|
||||
user = await PrismaUser.prisma().find_unique_or_raise(
|
||||
user = await User.prisma().find_unique_or_raise(
|
||||
where={"id": user_id},
|
||||
)
|
||||
return user.emailVerified
|
||||
@@ -333,7 +332,7 @@ async def get_user_email_verification(user_id: str) -> bool:
|
||||
def generate_unsubscribe_link(user_id: str) -> str:
|
||||
"""Generate a link to unsubscribe from all notifications"""
|
||||
# Create an HMAC using a secret key
|
||||
secret_key = settings.secrets.unsubscribe_secret_key
|
||||
secret_key = Settings().secrets.unsubscribe_secret_key
|
||||
signature = hmac.new(
|
||||
secret_key.encode("utf-8"), user_id.encode("utf-8"), hashlib.sha256
|
||||
).digest()
|
||||
@@ -344,7 +343,7 @@ def generate_unsubscribe_link(user_id: str) -> str:
|
||||
).decode("utf-8")
|
||||
logger.info(f"Generating unsubscribe link for user {user_id}")
|
||||
|
||||
base_url = settings.config.platform_base_url
|
||||
base_url = Settings().config.platform_base_url
|
||||
return f"{base_url}/api/email/unsubscribe?token={quote_plus(token)}"
|
||||
|
||||
|
||||
@@ -356,7 +355,7 @@ async def unsubscribe_user_by_token(token: str) -> None:
|
||||
user_id, received_signature_hex = decoded.split(":", 1)
|
||||
|
||||
# Verify the signature
|
||||
secret_key = settings.secrets.unsubscribe_secret_key
|
||||
secret_key = Settings().secrets.unsubscribe_secret_key
|
||||
expected_signature = hmac.new(
|
||||
secret_key.encode("utf-8"), user_id.encode("utf-8"), hashlib.sha256
|
||||
).digest()
|
||||
|
||||
@@ -102,10 +102,8 @@ async def generate_activity_status_for_execution(
|
||||
Returns:
|
||||
AI-generated activity status string, or None if feature is disabled
|
||||
"""
|
||||
# Check LaunchDarkly feature flag for AI activity status generation with full context support
|
||||
if not await is_feature_enabled(
|
||||
AI_ACTIVITY_STATUS_FLAG_KEY, user_id, default=False
|
||||
):
|
||||
# Check LaunchDarkly feature flag for AI activity status generation
|
||||
if not is_feature_enabled(AI_ACTIVITY_STATUS_FLAG_KEY, user_id, default=False):
|
||||
logger.debug("AI activity status generation is disabled via LaunchDarkly")
|
||||
return None
|
||||
|
||||
|
||||
@@ -35,20 +35,13 @@ from backend.data.notifications import (
|
||||
)
|
||||
from backend.data.user import (
|
||||
get_active_user_ids_in_timerange,
|
||||
get_user_by_id,
|
||||
get_user_email_by_id,
|
||||
get_user_email_verification,
|
||||
get_user_integrations,
|
||||
get_user_notification_preference,
|
||||
update_user_integrations,
|
||||
)
|
||||
from backend.util.service import (
|
||||
AppService,
|
||||
AppServiceClient,
|
||||
UnhealthyServiceError,
|
||||
endpoint_to_sync,
|
||||
expose,
|
||||
)
|
||||
from backend.util.service import AppService, AppServiceClient, endpoint_to_sync, expose
|
||||
from backend.util.settings import Config
|
||||
|
||||
config = Config()
|
||||
@@ -80,10 +73,10 @@ class DatabaseManager(AppService):
|
||||
logger.info(f"[{self.service_name}] ⏳ Disconnecting Database...")
|
||||
self.run_and_wait(db.disconnect())
|
||||
|
||||
async def health_check(self) -> str:
|
||||
def health_check(self) -> str:
|
||||
if not db.is_connected():
|
||||
raise UnhealthyServiceError("Database is not connected")
|
||||
return await super().health_check()
|
||||
raise RuntimeError("Database is not connected")
|
||||
return super().health_check()
|
||||
|
||||
@classmethod
|
||||
def get_port(cls) -> int:
|
||||
@@ -130,7 +123,6 @@ class DatabaseManager(AppService):
|
||||
|
||||
# User Comms - async
|
||||
get_active_user_ids_in_timerange = _(get_active_user_ids_in_timerange)
|
||||
get_user_by_id = _(get_user_by_id)
|
||||
get_user_email_by_id = _(get_user_email_by_id)
|
||||
get_user_email_verification = _(get_user_email_verification)
|
||||
get_user_notification_preference = _(get_user_notification_preference)
|
||||
@@ -171,6 +163,23 @@ class DatabaseManagerClient(AppServiceClient):
|
||||
spend_credits = _(d.spend_credits)
|
||||
get_credits = _(d.get_credits)
|
||||
|
||||
# User Comms - async
|
||||
get_active_user_ids_in_timerange = _(d.get_active_user_ids_in_timerange)
|
||||
get_user_email_by_id = _(d.get_user_email_by_id)
|
||||
get_user_email_verification = _(d.get_user_email_verification)
|
||||
get_user_notification_preference = _(d.get_user_notification_preference)
|
||||
|
||||
# Notifications - async
|
||||
create_or_add_to_user_notification_batch = _(
|
||||
d.create_or_add_to_user_notification_batch
|
||||
)
|
||||
empty_user_notification_batch = _(d.empty_user_notification_batch)
|
||||
get_all_batches_by_type = _(d.get_all_batches_by_type)
|
||||
get_user_notification_batch = _(d.get_user_notification_batch)
|
||||
get_user_notification_oldest_message_in_batch = _(
|
||||
d.get_user_notification_oldest_message_in_batch
|
||||
)
|
||||
|
||||
# Block error monitoring
|
||||
get_block_error_stats = _(d.get_block_error_stats)
|
||||
|
||||
@@ -200,21 +209,3 @@ class DatabaseManagerAsyncClient(AppServiceClient):
|
||||
update_user_integrations = d.update_user_integrations
|
||||
get_execution_kv_data = d.get_execution_kv_data
|
||||
set_execution_kv_data = d.set_execution_kv_data
|
||||
|
||||
# User Comms
|
||||
get_active_user_ids_in_timerange = d.get_active_user_ids_in_timerange
|
||||
get_user_by_id = d.get_user_by_id
|
||||
get_user_email_by_id = d.get_user_email_by_id
|
||||
get_user_email_verification = d.get_user_email_verification
|
||||
get_user_notification_preference = d.get_user_notification_preference
|
||||
|
||||
# Notifications
|
||||
create_or_add_to_user_notification_batch = (
|
||||
d.create_or_add_to_user_notification_batch
|
||||
)
|
||||
empty_user_notification_batch = d.empty_user_notification_batch
|
||||
get_all_batches_by_type = d.get_all_batches_by_type
|
||||
get_user_notification_batch = d.get_user_notification_batch
|
||||
get_user_notification_oldest_message_in_batch = (
|
||||
d.get_user_notification_oldest_message_in_batch
|
||||
)
|
||||
|
||||
@@ -732,6 +732,7 @@ class ExecutionProcessor:
|
||||
log_metadata: LogMetadata,
|
||||
execution_stats: GraphExecutionStats,
|
||||
) -> ExecutionStatus:
|
||||
|
||||
"""
|
||||
Returns:
|
||||
dict: The execution statistics of the graph execution.
|
||||
@@ -1257,7 +1258,7 @@ class ExecutionManager(AppProcess):
|
||||
|
||||
def _handle_run_message(
|
||||
self,
|
||||
_channel: BlockingChannel,
|
||||
channel: BlockingChannel,
|
||||
method: Basic.Deliver,
|
||||
_properties: BasicProperties,
|
||||
body: bytes,
|
||||
@@ -1267,9 +1268,6 @@ class ExecutionManager(AppProcess):
|
||||
@func_retry
|
||||
def _ack_message(reject: bool = False):
|
||||
"""Acknowledge or reject the message based on execution status."""
|
||||
|
||||
# Connection can be lost, so always get a fresh channel
|
||||
channel = self.run_client.get_channel()
|
||||
if reject:
|
||||
channel.connection.add_callback_threadsafe(
|
||||
lambda: channel.basic_nack(delivery_tag, requeue=True)
|
||||
@@ -1361,25 +1359,6 @@ class ExecutionManager(AppProcess):
|
||||
else:
|
||||
utilization_gauge.set(active_count / self.pool_size)
|
||||
|
||||
def _stop_message_consumers(
|
||||
self, thread: threading.Thread, client: SyncRabbitMQ, prefix: str
|
||||
):
|
||||
try:
|
||||
channel = client.get_channel()
|
||||
channel.connection.add_callback_threadsafe(lambda: channel.stop_consuming())
|
||||
|
||||
try:
|
||||
thread.join(timeout=300)
|
||||
except TimeoutError:
|
||||
logger.error(
|
||||
f"{prefix} ⚠️ Run thread did not finish in time, forcing disconnect"
|
||||
)
|
||||
|
||||
client.disconnect()
|
||||
logger.info(f"{prefix} ✅ Run client disconnected")
|
||||
except Exception as e:
|
||||
logger.error(f"{prefix} ⚠️ Error disconnecting run client: {type(e)} {e}")
|
||||
|
||||
def cleanup(self):
|
||||
"""Override cleanup to implement graceful shutdown with active execution waiting."""
|
||||
prefix = f"[{self.service_name}][on_graph_executor_stop {os.getpid()}]"
|
||||
@@ -1435,16 +1414,26 @@ class ExecutionManager(AppProcess):
|
||||
logger.error(f"{prefix} ⚠️ Error during executor shutdown: {type(e)} {e}")
|
||||
|
||||
# Disconnect the run execution consumer
|
||||
self._stop_message_consumers(
|
||||
self.run_thread,
|
||||
self.run_client,
|
||||
prefix + " [run-consumer]",
|
||||
)
|
||||
self._stop_message_consumers(
|
||||
self.cancel_thread,
|
||||
self.cancel_client,
|
||||
prefix + " [cancel-consumer]",
|
||||
)
|
||||
try:
|
||||
run_channel = self.run_client.get_channel()
|
||||
run_channel.connection.add_callback_threadsafe(
|
||||
lambda: self.run_client.disconnect()
|
||||
)
|
||||
self.run_thread.join()
|
||||
logger.info(f"{prefix} ✅ Run client disconnected")
|
||||
except Exception as e:
|
||||
logger.error(f"{prefix} ⚠️ Error disconnecting run client: {type(e)} {e}")
|
||||
|
||||
# Disconnect the cancel execution consumer
|
||||
try:
|
||||
cancel_channel = self.cancel_client.get_channel()
|
||||
cancel_channel.connection.add_callback_threadsafe(
|
||||
lambda: self.cancel_client.disconnect()
|
||||
)
|
||||
self.cancel_thread.join()
|
||||
logger.info(f"{prefix} ✅ Cancel client disconnected")
|
||||
except Exception as e:
|
||||
logger.error(f"{prefix} ⚠️ Error disconnecting cancel client: {type(e)} {e}")
|
||||
|
||||
logger.info(f"{prefix} ✅ Finished GraphExec cleanup")
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ import logging
|
||||
import autogpt_libs.auth.models
|
||||
import fastapi.responses
|
||||
import pytest
|
||||
from prisma.models import User
|
||||
|
||||
import backend.server.v2.library.model
|
||||
import backend.server.v2.store.model
|
||||
@@ -11,7 +12,6 @@ from backend.blocks.data_manipulation import FindInDictionaryBlock
|
||||
from backend.blocks.io import AgentInputBlock
|
||||
from backend.blocks.maths import CalculatorBlock, Operation
|
||||
from backend.data import execution, graph
|
||||
from backend.data.model import User
|
||||
from backend.server.model import CreateGraph
|
||||
from backend.server.rest_api import AgentServer
|
||||
from backend.usecases.sample import create_test_graph, create_test_user
|
||||
|
||||
@@ -1,22 +1,17 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
from urllib.parse import parse_qs, urlencode, urlparse, urlunparse
|
||||
|
||||
from apscheduler.events import (
|
||||
EVENT_JOB_ERROR,
|
||||
EVENT_JOB_EXECUTED,
|
||||
EVENT_JOB_MAX_INSTANCES,
|
||||
EVENT_JOB_MISSED,
|
||||
)
|
||||
from apscheduler.events import EVENT_JOB_ERROR, EVENT_JOB_EXECUTED
|
||||
from apscheduler.job import Job as JobObj
|
||||
from apscheduler.jobstores.memory import MemoryJobStore
|
||||
from apscheduler.jobstores.sqlalchemy import SQLAlchemyJobStore
|
||||
from apscheduler.schedulers.background import BackgroundScheduler
|
||||
from apscheduler.schedulers.blocking import BlockingScheduler
|
||||
from apscheduler.triggers.cron import CronTrigger
|
||||
from autogpt_libs.utils.cache import thread_cached
|
||||
from dotenv import load_dotenv
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
from sqlalchemy import MetaData, create_engine
|
||||
@@ -35,14 +30,7 @@ from backend.monitoring import (
|
||||
from backend.util.cloud_storage import cleanup_expired_files_async
|
||||
from backend.util.exceptions import NotAuthorizedError, NotFoundError
|
||||
from backend.util.logging import PrefixFilter
|
||||
from backend.util.retry import func_retry
|
||||
from backend.util.service import (
|
||||
AppService,
|
||||
AppServiceClient,
|
||||
UnhealthyServiceError,
|
||||
endpoint_to_async,
|
||||
expose,
|
||||
)
|
||||
from backend.util.service import AppService, AppServiceClient, endpoint_to_async, expose
|
||||
from backend.util.settings import Config
|
||||
|
||||
|
||||
@@ -72,69 +60,26 @@ apscheduler_logger.addFilter(PrefixFilter("[Scheduler] [APScheduler]"))
|
||||
|
||||
config = Config()
|
||||
|
||||
# Timeout constants
|
||||
SCHEDULER_OPERATION_TIMEOUT_SECONDS = 300 # 5 minutes for scheduler operations
|
||||
|
||||
|
||||
def job_listener(event):
|
||||
"""Logs job execution outcomes for better monitoring."""
|
||||
if event.exception:
|
||||
logger.error(
|
||||
f"Job {event.job_id} failed: {type(event.exception).__name__}: {event.exception}"
|
||||
)
|
||||
logger.error(f"Job {event.job_id} failed.")
|
||||
else:
|
||||
logger.info(f"Job {event.job_id} completed successfully.")
|
||||
|
||||
|
||||
def job_missed_listener(event):
|
||||
"""Logs when jobs are missed due to scheduling issues."""
|
||||
logger.warning(
|
||||
f"Job {event.job_id} was missed at scheduled time {event.scheduled_run_time}. "
|
||||
f"This can happen if the scheduler is overloaded or if previous executions are still running."
|
||||
)
|
||||
|
||||
|
||||
def job_max_instances_listener(event):
|
||||
"""Logs when jobs hit max instances limit."""
|
||||
logger.warning(
|
||||
f"Job {event.job_id} execution was SKIPPED - max instances limit reached. "
|
||||
f"Previous execution(s) are still running. "
|
||||
f"Consider increasing max_instances or check why previous executions are taking too long."
|
||||
)
|
||||
|
||||
|
||||
_event_loop: asyncio.AbstractEventLoop | None = None
|
||||
_event_loop_thread: threading.Thread | None = None
|
||||
|
||||
|
||||
@func_retry
|
||||
@thread_cached
|
||||
def get_event_loop():
|
||||
"""Get the shared event loop."""
|
||||
if _event_loop is None:
|
||||
raise RuntimeError("Event loop not initialized. Scheduler not started.")
|
||||
return _event_loop
|
||||
|
||||
|
||||
def run_async(coro, timeout: float = SCHEDULER_OPERATION_TIMEOUT_SECONDS):
|
||||
"""Run a coroutine in the shared event loop and wait for completion."""
|
||||
loop = get_event_loop()
|
||||
future = asyncio.run_coroutine_threadsafe(coro, loop)
|
||||
try:
|
||||
return future.result(timeout=timeout)
|
||||
except Exception as e:
|
||||
logger.error(f"Async operation failed: {type(e).__name__}: {e}")
|
||||
raise
|
||||
return asyncio.new_event_loop()
|
||||
|
||||
|
||||
def execute_graph(**kwargs):
|
||||
"""Execute graph in the shared event loop and wait for completion."""
|
||||
# Wait for completion to ensure job doesn't exit prematurely
|
||||
run_async(_execute_graph(**kwargs))
|
||||
get_event_loop().run_until_complete(_execute_graph(**kwargs))
|
||||
|
||||
|
||||
async def _execute_graph(**kwargs):
|
||||
args = GraphExecutionJobArgs(**kwargs)
|
||||
start_time = asyncio.get_event_loop().time()
|
||||
try:
|
||||
logger.info(f"Executing recurring job for graph #{args.graph_id}")
|
||||
graph_exec: GraphExecutionWithNodes = await execution_utils.add_graph_execution(
|
||||
@@ -144,28 +89,17 @@ async def _execute_graph(**kwargs):
|
||||
inputs=args.input_data,
|
||||
graph_credentials_inputs=args.input_credentials,
|
||||
)
|
||||
elapsed = asyncio.get_event_loop().time() - start_time
|
||||
logger.info(
|
||||
f"Graph execution started with ID {graph_exec.id} for graph {args.graph_id} "
|
||||
f"(took {elapsed:.2f}s to create and publish)"
|
||||
f"Graph execution started with ID {graph_exec.id} for graph {args.graph_id}"
|
||||
)
|
||||
if elapsed > 10:
|
||||
logger.warning(
|
||||
f"Graph execution {graph_exec.id} took {elapsed:.2f}s to create/publish - "
|
||||
f"this is unusually slow and may indicate resource contention"
|
||||
)
|
||||
except Exception as e:
|
||||
elapsed = asyncio.get_event_loop().time() - start_time
|
||||
logger.error(
|
||||
f"Error executing graph {args.graph_id} after {elapsed:.2f}s: "
|
||||
f"{type(e).__name__}: {e}"
|
||||
)
|
||||
# TODO: We need to communicate this error to the user somehow.
|
||||
logger.error(f"Error executing graph {args.graph_id}: {e}")
|
||||
|
||||
|
||||
def cleanup_expired_files():
|
||||
"""Clean up expired files from cloud storage."""
|
||||
# Wait for completion
|
||||
run_async(cleanup_expired_files_async())
|
||||
get_event_loop().run_until_complete(cleanup_expired_files_async())
|
||||
|
||||
|
||||
# Monitoring functions are now imported from monitoring module
|
||||
@@ -221,7 +155,7 @@ class NotificationJobInfo(NotificationJobArgs):
|
||||
|
||||
|
||||
class Scheduler(AppService):
|
||||
scheduler: BackgroundScheduler
|
||||
scheduler: BlockingScheduler
|
||||
|
||||
def __init__(self, register_system_tasks: bool = True):
|
||||
self.register_system_tasks = register_system_tasks
|
||||
@@ -234,48 +168,15 @@ class Scheduler(AppService):
|
||||
def db_pool_size(cls) -> int:
|
||||
return config.scheduler_db_pool_size
|
||||
|
||||
async def health_check(self) -> str:
|
||||
# Thread-safe health check with proper initialization handling
|
||||
if not hasattr(self, "scheduler"):
|
||||
raise UnhealthyServiceError("Scheduler is still initializing")
|
||||
|
||||
# Check if we're in the middle of cleanup
|
||||
if self.cleaned_up:
|
||||
return await super().health_check()
|
||||
|
||||
# Normal operation - check if scheduler is running
|
||||
def health_check(self) -> str:
|
||||
if not self.scheduler.running:
|
||||
raise UnhealthyServiceError("Scheduler is not running")
|
||||
|
||||
return await super().health_check()
|
||||
raise RuntimeError("Scheduler is not running")
|
||||
return super().health_check()
|
||||
|
||||
def run_service(self):
|
||||
load_dotenv()
|
||||
|
||||
# Initialize the event loop for async jobs
|
||||
global _event_loop
|
||||
_event_loop = asyncio.new_event_loop()
|
||||
|
||||
# Use daemon thread since it should die with the main service
|
||||
global _event_loop_thread
|
||||
_event_loop_thread = threading.Thread(
|
||||
target=_event_loop.run_forever, daemon=True, name="SchedulerEventLoop"
|
||||
)
|
||||
_event_loop_thread.start()
|
||||
|
||||
db_schema, db_url = _extract_schema_from_url(os.getenv("DIRECT_URL"))
|
||||
# Configure executors to limit concurrency without skipping jobs
|
||||
from apscheduler.executors.pool import ThreadPoolExecutor
|
||||
|
||||
self.scheduler = BackgroundScheduler(
|
||||
executors={
|
||||
"default": ThreadPoolExecutor(max_workers=10), # Max 10 concurrent jobs
|
||||
},
|
||||
job_defaults={
|
||||
"coalesce": True, # Skip redundant missed jobs - just run the latest
|
||||
"max_instances": 1000, # Effectively unlimited - never drop executions
|
||||
"misfire_grace_time": None, # No time limit for missed jobs
|
||||
},
|
||||
self.scheduler = BlockingScheduler(
|
||||
jobstores={
|
||||
Jobstores.EXECUTION.value: SQLAlchemyJobStore(
|
||||
engine=create_engine(
|
||||
@@ -355,30 +256,13 @@ class Scheduler(AppService):
|
||||
)
|
||||
|
||||
self.scheduler.add_listener(job_listener, EVENT_JOB_EXECUTED | EVENT_JOB_ERROR)
|
||||
self.scheduler.add_listener(job_missed_listener, EVENT_JOB_MISSED)
|
||||
self.scheduler.add_listener(job_max_instances_listener, EVENT_JOB_MAX_INSTANCES)
|
||||
self.scheduler.start()
|
||||
|
||||
# Keep the service running since BackgroundScheduler doesn't block
|
||||
super().run_service()
|
||||
|
||||
def cleanup(self):
|
||||
super().cleanup()
|
||||
if self.scheduler:
|
||||
logger.info("⏳ Shutting down scheduler...")
|
||||
self.scheduler.shutdown(wait=True)
|
||||
|
||||
global _event_loop
|
||||
if _event_loop:
|
||||
logger.info("⏳ Closing event loop...")
|
||||
_event_loop.call_soon_threadsafe(_event_loop.stop)
|
||||
|
||||
global _event_loop_thread
|
||||
if _event_loop_thread:
|
||||
logger.info("⏳ Waiting for event loop thread to finish...")
|
||||
_event_loop_thread.join(timeout=SCHEDULER_OPERATION_TIMEOUT_SECONDS)
|
||||
|
||||
logger.info("Scheduler cleanup complete.")
|
||||
self.scheduler.shutdown(wait=False)
|
||||
|
||||
@expose
|
||||
def add_graph_execution_schedule(
|
||||
@@ -391,18 +275,6 @@ class Scheduler(AppService):
|
||||
input_credentials: dict[str, CredentialsMetaInput],
|
||||
name: Optional[str] = None,
|
||||
) -> GraphExecutionJobInfo:
|
||||
# Validate the graph before scheduling to prevent runtime failures
|
||||
# We don't need the return value, just want the validation to run
|
||||
run_async(
|
||||
execution_utils.validate_and_construct_node_execution_input(
|
||||
graph_id=graph_id,
|
||||
user_id=user_id,
|
||||
graph_inputs=input_data,
|
||||
graph_version=graph_version,
|
||||
graph_credentials_inputs=input_credentials,
|
||||
)
|
||||
)
|
||||
|
||||
job_args = GraphExecutionJobArgs(
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
|
||||
@@ -548,7 +548,7 @@ async def validate_graph_with_credentials(
|
||||
return node_input_errors
|
||||
|
||||
|
||||
async def _construct_node_execution_input(
|
||||
async def construct_node_execution_input(
|
||||
graph: GraphModel,
|
||||
user_id: str,
|
||||
graph_inputs: BlockInput,
|
||||
@@ -615,67 +615,6 @@ async def _construct_node_execution_input(
|
||||
return nodes_input
|
||||
|
||||
|
||||
async def validate_and_construct_node_execution_input(
|
||||
graph_id: str,
|
||||
user_id: str,
|
||||
graph_inputs: BlockInput,
|
||||
graph_version: Optional[int] = None,
|
||||
graph_credentials_inputs: Optional[dict[str, CredentialsMetaInput]] = None,
|
||||
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None,
|
||||
) -> tuple[GraphModel, list[tuple[str, BlockInput]]]:
|
||||
"""
|
||||
Public wrapper that handles graph fetching, credential mapping, and validation+construction.
|
||||
This centralizes the logic used by both scheduler validation and actual execution.
|
||||
|
||||
Args:
|
||||
graph_id: The ID of the graph to validate/construct.
|
||||
user_id: The ID of the user.
|
||||
graph_inputs: The input data for the graph execution.
|
||||
graph_version: The version of the graph to use.
|
||||
graph_credentials_inputs: Credentials inputs to use.
|
||||
nodes_input_masks: Node inputs to use.
|
||||
|
||||
Returns:
|
||||
tuple[GraphModel, list[tuple[str, BlockInput]]]: Graph model and list of tuples for node execution input.
|
||||
|
||||
Raises:
|
||||
NotFoundError: If the graph is not found.
|
||||
GraphValidationError: If the graph has validation issues.
|
||||
ValueError: If there are other validation errors.
|
||||
"""
|
||||
if prisma.is_connected():
|
||||
gdb = graph_db
|
||||
else:
|
||||
gdb = get_database_manager_async_client()
|
||||
|
||||
graph: GraphModel | None = await gdb.get_graph(
|
||||
graph_id=graph_id,
|
||||
user_id=user_id,
|
||||
version=graph_version,
|
||||
include_subgraphs=True,
|
||||
)
|
||||
if not graph:
|
||||
raise NotFoundError(f"Graph #{graph_id} not found.")
|
||||
|
||||
nodes_input_masks = _merge_nodes_input_masks(
|
||||
(
|
||||
make_node_credentials_input_map(graph, graph_credentials_inputs)
|
||||
if graph_credentials_inputs
|
||||
else {}
|
||||
),
|
||||
nodes_input_masks or {},
|
||||
)
|
||||
|
||||
starting_nodes_input = await _construct_node_execution_input(
|
||||
graph=graph,
|
||||
user_id=user_id,
|
||||
graph_inputs=graph_inputs,
|
||||
nodes_input_masks=nodes_input_masks,
|
||||
)
|
||||
|
||||
return graph, starting_nodes_input
|
||||
|
||||
|
||||
def _merge_nodes_input_masks(
|
||||
overrides_map_1: dict[str, dict[str, JsonValue]],
|
||||
overrides_map_2: dict[str, dict[str, JsonValue]],
|
||||
@@ -852,16 +791,33 @@ async def add_graph_execution(
|
||||
ValueError: If the graph is not found or if there are validation errors.
|
||||
"""
|
||||
if prisma.is_connected():
|
||||
gdb = graph_db
|
||||
edb = execution_db
|
||||
else:
|
||||
gdb = get_database_manager_async_client()
|
||||
edb = get_database_manager_async_client()
|
||||
|
||||
graph, starting_nodes_input = await validate_and_construct_node_execution_input(
|
||||
graph: GraphModel | None = await gdb.get_graph(
|
||||
graph_id=graph_id,
|
||||
user_id=user_id,
|
||||
version=graph_version,
|
||||
include_subgraphs=True,
|
||||
)
|
||||
if not graph:
|
||||
raise NotFoundError(f"Graph #{graph_id} not found.")
|
||||
|
||||
nodes_input_masks = _merge_nodes_input_masks(
|
||||
(
|
||||
make_node_credentials_input_map(graph, graph_credentials_inputs)
|
||||
if graph_credentials_inputs
|
||||
else {}
|
||||
),
|
||||
nodes_input_masks or {},
|
||||
)
|
||||
starting_nodes_input = await construct_node_execution_input(
|
||||
graph=graph,
|
||||
user_id=user_id,
|
||||
graph_inputs=inputs or {},
|
||||
graph_version=graph_version,
|
||||
graph_credentials_inputs=graph_credentials_inputs,
|
||||
nodes_input_masks=nodes_input_masks,
|
||||
)
|
||||
graph_exec = None
|
||||
@@ -879,19 +835,11 @@ async def add_graph_execution(
|
||||
graph_exec_entry = graph_exec.to_graph_execution_entry()
|
||||
if nodes_input_masks:
|
||||
graph_exec_entry.nodes_input_masks = nodes_input_masks
|
||||
|
||||
logger.info(
|
||||
f"Created graph execution #{graph_exec.id} for graph "
|
||||
f"#{graph_id} with {len(starting_nodes_input)} starting nodes. "
|
||||
f"Now publishing to execution queue."
|
||||
)
|
||||
|
||||
await queue.publish_message(
|
||||
routing_key=GRAPH_EXECUTION_ROUTING_KEY,
|
||||
message=graph_exec_entry.model_dump_json(),
|
||||
exchange=GRAPH_EXECUTION_EXCHANGE,
|
||||
)
|
||||
logger.info(f"Published execution {graph_exec.id} to RabbitMQ queue")
|
||||
|
||||
bus = get_async_execution_event_bus()
|
||||
await bus.publish(graph_exec)
|
||||
|
||||
@@ -1,15 +0,0 @@
|
||||
from backend.app import run_processes
|
||||
from backend.notifications.notifications import NotificationManager
|
||||
|
||||
|
||||
def main():
|
||||
"""
|
||||
Run the AutoGPT-server Notification Service.
|
||||
"""
|
||||
run_processes(
|
||||
NotificationManager(),
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,7 +1,8 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Awaitable, Callable
|
||||
from typing import Callable
|
||||
|
||||
import aio_pika
|
||||
from prisma.enums import NotificationType
|
||||
@@ -27,17 +28,11 @@ from backend.data.notifications import (
|
||||
from backend.data.rabbitmq import Exchange, ExchangeType, Queue, RabbitMQConfig
|
||||
from backend.data.user import generate_unsubscribe_link
|
||||
from backend.notifications.email import EmailSender
|
||||
from backend.util.clients import get_database_manager_async_client
|
||||
from backend.util.clients import get_database_manager_client
|
||||
from backend.util.logging import TruncatedLogger
|
||||
from backend.util.metrics import discord_send_alert
|
||||
from backend.util.retry import continuous_retry
|
||||
from backend.util.service import (
|
||||
AppService,
|
||||
AppServiceClient,
|
||||
UnhealthyServiceError,
|
||||
endpoint_to_sync,
|
||||
expose,
|
||||
)
|
||||
from backend.util.service import AppService, AppServiceClient, endpoint_to_sync, expose
|
||||
from backend.util.settings import Settings
|
||||
|
||||
logger = TruncatedLogger(logging.getLogger(__name__), "[NotificationManager]")
|
||||
@@ -48,6 +43,8 @@ NOTIFICATION_EXCHANGE = Exchange(name="notifications", type=ExchangeType.TOPIC)
|
||||
DEAD_LETTER_EXCHANGE = Exchange(name="dead_letter", type=ExchangeType.TOPIC)
|
||||
EXCHANGES = [NOTIFICATION_EXCHANGE, DEAD_LETTER_EXCHANGE]
|
||||
|
||||
background_executor = ProcessPoolExecutor(max_workers=2)
|
||||
|
||||
|
||||
def create_notification_config() -> RabbitMQConfig:
|
||||
"""Create RabbitMQ configuration for notifications"""
|
||||
@@ -188,33 +185,24 @@ class NotificationManager(AppService):
|
||||
@property
|
||||
def rabbit(self) -> rabbitmq.AsyncRabbitMQ:
|
||||
"""Access the RabbitMQ service. Will raise if not configured."""
|
||||
if not hasattr(self, "rabbitmq_service") or not self.rabbitmq_service:
|
||||
raise UnhealthyServiceError("RabbitMQ not configured for this service")
|
||||
if not self.rabbitmq_service:
|
||||
raise RuntimeError("RabbitMQ not configured for this service")
|
||||
return self.rabbitmq_service
|
||||
|
||||
@property
|
||||
def rabbit_config(self) -> rabbitmq.RabbitMQConfig:
|
||||
"""Access the RabbitMQ config. Will raise if not configured."""
|
||||
if not self.rabbitmq_config:
|
||||
raise UnhealthyServiceError("RabbitMQ not configured for this service")
|
||||
raise RuntimeError("RabbitMQ not configured for this service")
|
||||
return self.rabbitmq_config
|
||||
|
||||
async def health_check(self) -> str:
|
||||
# Service is unhealthy if RabbitMQ is not ready
|
||||
if not hasattr(self, "rabbitmq_service") or not self.rabbitmq_service:
|
||||
raise UnhealthyServiceError("RabbitMQ not configured for this service")
|
||||
if not self.rabbitmq_service.is_ready:
|
||||
raise UnhealthyServiceError("RabbitMQ channel is not ready")
|
||||
return await super().health_check()
|
||||
|
||||
@classmethod
|
||||
def get_port(cls) -> int:
|
||||
return settings.config.notification_service_port
|
||||
|
||||
@expose
|
||||
async def queue_weekly_summary(self):
|
||||
# Use the existing event loop instead of creating a new one with asyncio.run()
|
||||
asyncio.create_task(self._queue_weekly_summary())
|
||||
def queue_weekly_summary(self):
|
||||
background_executor.submit(lambda: asyncio.run(self._queue_weekly_summary()))
|
||||
|
||||
async def _queue_weekly_summary(self):
|
||||
"""Process weekly summary for specified notification types"""
|
||||
@@ -223,7 +211,7 @@ class NotificationManager(AppService):
|
||||
processed_count = 0
|
||||
current_time = datetime.now(tz=timezone.utc)
|
||||
start_time = current_time - timedelta(days=7)
|
||||
users = await get_database_manager_async_client().get_active_user_ids_in_timerange(
|
||||
users = get_database_manager_client().get_active_user_ids_in_timerange(
|
||||
end_time=current_time.isoformat(),
|
||||
start_time=start_time.isoformat(),
|
||||
)
|
||||
@@ -246,15 +234,10 @@ class NotificationManager(AppService):
|
||||
logger.exception(f"Error processing weekly summary: {e}")
|
||||
|
||||
@expose
|
||||
async def process_existing_batches(
|
||||
self, notification_types: list[NotificationType]
|
||||
):
|
||||
# Use the existing event loop instead of creating a new process
|
||||
asyncio.create_task(self._process_existing_batches(notification_types))
|
||||
def process_existing_batches(self, notification_types: list[NotificationType]):
|
||||
background_executor.submit(self._process_existing_batches, notification_types)
|
||||
|
||||
async def _process_existing_batches(
|
||||
self, notification_types: list[NotificationType]
|
||||
):
|
||||
def _process_existing_batches(self, notification_types: list[NotificationType]):
|
||||
"""Process existing batches for specified notification types"""
|
||||
try:
|
||||
processed_count = 0
|
||||
@@ -262,15 +245,13 @@ class NotificationManager(AppService):
|
||||
|
||||
for notification_type in notification_types:
|
||||
# Get all batches for this notification type
|
||||
batches = (
|
||||
await get_database_manager_async_client().get_all_batches_by_type(
|
||||
notification_type
|
||||
)
|
||||
batches = get_database_manager_client().get_all_batches_by_type(
|
||||
notification_type
|
||||
)
|
||||
|
||||
for batch in batches:
|
||||
# Check if batch has aged out
|
||||
oldest_message = await get_database_manager_async_client().get_user_notification_oldest_message_in_batch(
|
||||
oldest_message = get_database_manager_client().get_user_notification_oldest_message_in_batch(
|
||||
batch.user_id, notification_type
|
||||
)
|
||||
|
||||
@@ -285,8 +266,10 @@ class NotificationManager(AppService):
|
||||
|
||||
# If batch has aged out, process it
|
||||
if oldest_message.created_at + max_delay < current_time:
|
||||
recipient_email = await get_database_manager_async_client().get_user_email_by_id(
|
||||
batch.user_id
|
||||
recipient_email = (
|
||||
get_database_manager_client().get_user_email_by_id(
|
||||
batch.user_id
|
||||
)
|
||||
)
|
||||
|
||||
if not recipient_email:
|
||||
@@ -295,7 +278,7 @@ class NotificationManager(AppService):
|
||||
)
|
||||
continue
|
||||
|
||||
should_send = await self._should_email_user_based_on_preference(
|
||||
should_send = self._should_email_user_based_on_preference(
|
||||
batch.user_id, notification_type
|
||||
)
|
||||
|
||||
@@ -304,13 +287,15 @@ class NotificationManager(AppService):
|
||||
f"User {batch.user_id} does not want to receive {notification_type} notifications"
|
||||
)
|
||||
# Clear the batch
|
||||
await get_database_manager_async_client().empty_user_notification_batch(
|
||||
get_database_manager_client().empty_user_notification_batch(
|
||||
batch.user_id, notification_type
|
||||
)
|
||||
continue
|
||||
|
||||
batch_data = await get_database_manager_async_client().get_user_notification_batch(
|
||||
batch.user_id, notification_type
|
||||
batch_data = (
|
||||
get_database_manager_client().get_user_notification_batch(
|
||||
batch.user_id, notification_type
|
||||
)
|
||||
)
|
||||
|
||||
if not batch_data or not batch_data.notifications:
|
||||
@@ -318,7 +303,7 @@ class NotificationManager(AppService):
|
||||
f"Batch data not found for user {batch.user_id}"
|
||||
)
|
||||
# Clear the batch
|
||||
await get_database_manager_async_client().empty_user_notification_batch(
|
||||
get_database_manager_client().empty_user_notification_batch(
|
||||
batch.user_id, notification_type
|
||||
)
|
||||
continue
|
||||
@@ -354,7 +339,7 @@ class NotificationManager(AppService):
|
||||
)
|
||||
|
||||
# Clear the batch
|
||||
await get_database_manager_async_client().empty_user_notification_batch(
|
||||
get_database_manager_client().empty_user_notification_batch(
|
||||
batch.user_id, notification_type
|
||||
)
|
||||
|
||||
@@ -399,20 +384,18 @@ class NotificationManager(AppService):
|
||||
except Exception as e:
|
||||
logger.exception(f"Error queueing notification: {e}")
|
||||
|
||||
async def _should_email_user_based_on_preference(
|
||||
def _should_email_user_based_on_preference(
|
||||
self, user_id: str, event_type: NotificationType
|
||||
) -> bool:
|
||||
"""Check if a user wants to receive a notification based on their preferences and email verification status"""
|
||||
validated_email = (
|
||||
await get_database_manager_async_client().get_user_email_verification(
|
||||
user_id
|
||||
)
|
||||
validated_email = get_database_manager_client().get_user_email_verification(
|
||||
user_id
|
||||
)
|
||||
preference = (
|
||||
await get_database_manager_async_client().get_user_notification_preference(
|
||||
user_id
|
||||
)
|
||||
).preferences.get(event_type, True)
|
||||
get_database_manager_client()
|
||||
.get_user_notification_preference(user_id)
|
||||
.preferences.get(event_type, True)
|
||||
)
|
||||
# only if both are true, should we email this person
|
||||
return validated_email and preference
|
||||
|
||||
@@ -496,16 +479,18 @@ class NotificationManager(AppService):
|
||||
else:
|
||||
raise ValueError("Invalid event type or params")
|
||||
|
||||
async def _should_batch(
|
||||
def _should_batch(
|
||||
self, user_id: str, event_type: NotificationType, event: NotificationEventModel
|
||||
) -> bool:
|
||||
|
||||
await get_database_manager_async_client().create_or_add_to_user_notification_batch(
|
||||
get_database_manager_client().create_or_add_to_user_notification_batch(
|
||||
user_id, event_type, event
|
||||
)
|
||||
|
||||
oldest_message = await get_database_manager_async_client().get_user_notification_oldest_message_in_batch(
|
||||
user_id, event_type
|
||||
oldest_message = (
|
||||
get_database_manager_client().get_user_notification_oldest_message_in_batch(
|
||||
user_id, event_type
|
||||
)
|
||||
)
|
||||
if not oldest_message:
|
||||
logger.error(
|
||||
@@ -534,7 +519,7 @@ class NotificationManager(AppService):
|
||||
logger.error(f"Error parsing message due to non matching schema {e}")
|
||||
return None
|
||||
|
||||
async def _process_admin_message(self, message: str) -> bool:
|
||||
def _process_admin_message(self, message: str) -> bool:
|
||||
"""Process a single notification, sending to an admin, returning whether to put into the failed queue"""
|
||||
try:
|
||||
event = self._parse_message(message)
|
||||
@@ -548,7 +533,7 @@ class NotificationManager(AppService):
|
||||
logger.exception(f"Error processing notification for admin queue: {e}")
|
||||
return False
|
||||
|
||||
async def _process_immediate(self, message: str) -> bool:
|
||||
def _process_immediate(self, message: str) -> bool:
|
||||
"""Process a single notification immediately, returning whether to put into the failed queue"""
|
||||
try:
|
||||
event = self._parse_message(message)
|
||||
@@ -556,16 +541,14 @@ class NotificationManager(AppService):
|
||||
return False
|
||||
logger.debug(f"Processing immediate notification: {event}")
|
||||
|
||||
recipient_email = (
|
||||
await get_database_manager_async_client().get_user_email_by_id(
|
||||
event.user_id
|
||||
)
|
||||
recipient_email = get_database_manager_client().get_user_email_by_id(
|
||||
event.user_id
|
||||
)
|
||||
if not recipient_email:
|
||||
logger.error(f"User email not found for user {event.user_id}")
|
||||
return False
|
||||
|
||||
should_send = await self._should_email_user_based_on_preference(
|
||||
should_send = self._should_email_user_based_on_preference(
|
||||
event.user_id, event.type
|
||||
)
|
||||
if not should_send:
|
||||
@@ -587,7 +570,7 @@ class NotificationManager(AppService):
|
||||
logger.exception(f"Error processing notification for immediate queue: {e}")
|
||||
return False
|
||||
|
||||
async def _process_batch(self, message: str) -> bool:
|
||||
def _process_batch(self, message: str) -> bool:
|
||||
"""Process a single notification with a batching strategy, returning whether to put into the failed queue"""
|
||||
try:
|
||||
event = self._parse_message(message)
|
||||
@@ -595,16 +578,14 @@ class NotificationManager(AppService):
|
||||
return False
|
||||
logger.info(f"Processing batch notification: {event}")
|
||||
|
||||
recipient_email = (
|
||||
await get_database_manager_async_client().get_user_email_by_id(
|
||||
event.user_id
|
||||
)
|
||||
recipient_email = get_database_manager_client().get_user_email_by_id(
|
||||
event.user_id
|
||||
)
|
||||
if not recipient_email:
|
||||
logger.error(f"User email not found for user {event.user_id}")
|
||||
return False
|
||||
|
||||
should_send = await self._should_email_user_based_on_preference(
|
||||
should_send = self._should_email_user_based_on_preference(
|
||||
event.user_id, event.type
|
||||
)
|
||||
if not should_send:
|
||||
@@ -613,15 +594,13 @@ class NotificationManager(AppService):
|
||||
)
|
||||
return True
|
||||
|
||||
should_send = await self._should_batch(event.user_id, event.type, event)
|
||||
should_send = self._should_batch(event.user_id, event.type, event)
|
||||
|
||||
if not should_send:
|
||||
logger.info("Batch not old enough to send")
|
||||
return False
|
||||
batch = (
|
||||
await get_database_manager_async_client().get_user_notification_batch(
|
||||
event.user_id, event.type
|
||||
)
|
||||
batch = get_database_manager_client().get_user_notification_batch(
|
||||
event.user_id, event.type
|
||||
)
|
||||
if not batch or not batch.notifications:
|
||||
logger.error(f"Batch not found for user {event.user_id}")
|
||||
@@ -723,7 +702,7 @@ class NotificationManager(AppService):
|
||||
logger.info(
|
||||
f"Successfully sent all {successfully_sent_count} notifications, clearing batch"
|
||||
)
|
||||
await get_database_manager_async_client().empty_user_notification_batch(
|
||||
get_database_manager_client().empty_user_notification_batch(
|
||||
event.user_id, event.type
|
||||
)
|
||||
else:
|
||||
@@ -736,7 +715,7 @@ class NotificationManager(AppService):
|
||||
logger.exception(f"Error processing notification for batch queue: {e}")
|
||||
return False
|
||||
|
||||
async def _process_summary(self, message: str) -> bool:
|
||||
def _process_summary(self, message: str) -> bool:
|
||||
"""Process a single notification with a summary strategy, returning whether to put into the failed queue"""
|
||||
try:
|
||||
logger.info(f"Processing summary notification: {message}")
|
||||
@@ -747,15 +726,13 @@ class NotificationManager(AppService):
|
||||
|
||||
logger.info(f"Processing summary notification: {model}")
|
||||
|
||||
recipient_email = (
|
||||
await get_database_manager_async_client().get_user_email_by_id(
|
||||
event.user_id
|
||||
)
|
||||
recipient_email = get_database_manager_client().get_user_email_by_id(
|
||||
event.user_id
|
||||
)
|
||||
if not recipient_email:
|
||||
logger.error(f"User email not found for user {event.user_id}")
|
||||
return False
|
||||
should_send = await self._should_email_user_based_on_preference(
|
||||
should_send = self._should_email_user_based_on_preference(
|
||||
event.user_id, event.type
|
||||
)
|
||||
if not should_send:
|
||||
@@ -790,7 +767,7 @@ class NotificationManager(AppService):
|
||||
async def _consume_queue(
|
||||
self,
|
||||
queue: aio_pika.abc.AbstractQueue,
|
||||
process_func: Callable[[str], Awaitable[bool]],
|
||||
process_func: Callable[[str], bool],
|
||||
queue_name: str,
|
||||
):
|
||||
"""Continuously consume messages from a queue using async iteration"""
|
||||
@@ -804,7 +781,7 @@ class NotificationManager(AppService):
|
||||
|
||||
try:
|
||||
async with message.process():
|
||||
result = await process_func(message.body.decode())
|
||||
result = process_func(message.body.decode())
|
||||
if not result:
|
||||
# Message will be rejected when exiting context without exception
|
||||
raise aio_pika.exceptions.MessageProcessError(
|
||||
@@ -903,8 +880,6 @@ class NotificationManagerClient(AppServiceClient):
|
||||
def get_service_type(cls):
|
||||
return NotificationManager
|
||||
|
||||
process_existing_batches = endpoint_to_sync(
|
||||
NotificationManager.process_existing_batches
|
||||
)
|
||||
queue_weekly_summary = endpoint_to_sync(NotificationManager.queue_weekly_summary)
|
||||
process_existing_batches = NotificationManager.process_existing_batches
|
||||
queue_weekly_summary = NotificationManager.queue_weekly_summary
|
||||
discord_system_alert = endpoint_to_sync(NotificationManager.discord_system_alert)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from backend.app import run_processes
|
||||
from backend.executor.scheduler import Scheduler
|
||||
from backend.notifications.notifications import NotificationManager
|
||||
|
||||
|
||||
def main():
|
||||
@@ -7,6 +8,7 @@ def main():
|
||||
Run all the processes required for the AutoGPT-server Scheduling System.
|
||||
"""
|
||||
run_processes(
|
||||
NotificationManager(),
|
||||
Scheduler(),
|
||||
)
|
||||
|
||||
|
||||
@@ -634,7 +634,7 @@ async def get_ayrshare_sso_url(
|
||||
# SocialPlatform.TELEGRAM,
|
||||
# SocialPlatform.GOOGLE_MY_BUSINESS,
|
||||
# SocialPlatform.PINTEREST,
|
||||
SocialPlatform.TIKTOK,
|
||||
# SocialPlatform.TIKTOK,
|
||||
# SocialPlatform.BLUESKY,
|
||||
# SocialPlatform.SNAPCHAT,
|
||||
# SocialPlatform.THREADS,
|
||||
|
||||
@@ -40,8 +40,6 @@ from backend.integrations.providers import ProviderName
|
||||
from backend.server.external.api import external_app
|
||||
from backend.server.middleware.security import SecurityHeadersMiddleware
|
||||
from backend.util import json
|
||||
from backend.util.cloud_storage import shutdown_cloud_storage_handler
|
||||
from backend.util.service import UnhealthyServiceError
|
||||
|
||||
settings = backend.util.settings.Settings()
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -77,12 +75,6 @@ async def lifespan_context(app: fastapi.FastAPI):
|
||||
await backend.data.graph.migrate_llm_models(LlmModel.GPT4O)
|
||||
with launch_darkly_context():
|
||||
yield
|
||||
|
||||
try:
|
||||
await shutdown_cloud_storage_handler()
|
||||
except Exception as e:
|
||||
logger.warning(f"Error shutting down cloud storage handler: {e}")
|
||||
|
||||
await backend.data.db.disconnect()
|
||||
|
||||
|
||||
@@ -233,7 +225,7 @@ app.mount("/external-api", external_app)
|
||||
@app.get(path="/health", tags=["health"], dependencies=[])
|
||||
async def health():
|
||||
if not backend.data.db.is_connected():
|
||||
raise UnhealthyServiceError("Database is not connected")
|
||||
raise RuntimeError("Database is not connected")
|
||||
return {"status": "healthy"}
|
||||
|
||||
|
||||
|
||||
@@ -33,30 +33,30 @@ async def check_media_exists(user_id: str, filename: str) -> str | None:
|
||||
if not settings.config.media_gcs_bucket_name:
|
||||
raise MissingConfigError("GCS media bucket is not configured")
|
||||
|
||||
async with async_storage.Storage() as async_client:
|
||||
bucket_name = settings.config.media_gcs_bucket_name
|
||||
async_client = async_storage.Storage()
|
||||
bucket_name = settings.config.media_gcs_bucket_name
|
||||
|
||||
# Check images
|
||||
image_path = f"users/{user_id}/images/{filename}"
|
||||
try:
|
||||
await async_client.download_metadata(bucket_name, image_path)
|
||||
# If we get here, the file exists - construct public URL
|
||||
return f"https://storage.googleapis.com/{bucket_name}/{image_path}"
|
||||
except Exception:
|
||||
# File doesn't exist, continue to check videos
|
||||
pass
|
||||
# Check images
|
||||
image_path = f"users/{user_id}/images/{filename}"
|
||||
try:
|
||||
await async_client.download_metadata(bucket_name, image_path)
|
||||
# If we get here, the file exists - construct public URL
|
||||
return f"https://storage.googleapis.com/{bucket_name}/{image_path}"
|
||||
except Exception:
|
||||
# File doesn't exist, continue to check videos
|
||||
pass
|
||||
|
||||
# Check videos
|
||||
video_path = f"users/{user_id}/videos/{filename}"
|
||||
try:
|
||||
await async_client.download_metadata(bucket_name, video_path)
|
||||
# If we get here, the file exists - construct public URL
|
||||
return f"https://storage.googleapis.com/{bucket_name}/{video_path}"
|
||||
except Exception:
|
||||
# File doesn't exist
|
||||
pass
|
||||
# Check videos
|
||||
video_path = f"users/{user_id}/videos/{filename}"
|
||||
try:
|
||||
await async_client.download_metadata(bucket_name, video_path)
|
||||
# If we get here, the file exists - construct public URL
|
||||
return f"https://storage.googleapis.com/{bucket_name}/{video_path}"
|
||||
except Exception:
|
||||
# File doesn't exist
|
||||
pass
|
||||
|
||||
return None
|
||||
return None
|
||||
|
||||
|
||||
async def upload_media(
|
||||
@@ -177,24 +177,22 @@ async def upload_media(
|
||||
storage_path = f"users/{user_id}/{media_type}/{unique_filename}"
|
||||
|
||||
try:
|
||||
async with async_storage.Storage() as async_client:
|
||||
bucket_name = settings.config.media_gcs_bucket_name
|
||||
async_client = async_storage.Storage()
|
||||
bucket_name = settings.config.media_gcs_bucket_name
|
||||
|
||||
file_bytes = await file.read()
|
||||
await scan_content_safe(file_bytes, filename=unique_filename)
|
||||
file_bytes = await file.read()
|
||||
await scan_content_safe(file_bytes, filename=unique_filename)
|
||||
|
||||
# Upload using pure async client
|
||||
await async_client.upload(
|
||||
bucket_name, storage_path, file_bytes, content_type=content_type
|
||||
)
|
||||
# Upload using pure async client
|
||||
await async_client.upload(
|
||||
bucket_name, storage_path, file_bytes, content_type=content_type
|
||||
)
|
||||
|
||||
# Construct public URL
|
||||
public_url = (
|
||||
f"https://storage.googleapis.com/{bucket_name}/{storage_path}"
|
||||
)
|
||||
# Construct public URL
|
||||
public_url = f"https://storage.googleapis.com/{bucket_name}/{storage_path}"
|
||||
|
||||
logger.info(f"Successfully uploaded file to: {storage_path}")
|
||||
return public_url
|
||||
logger.info(f"Successfully uploaded file to: {storage_path}")
|
||||
return public_url
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"GCS storage error: {str(e)}")
|
||||
|
||||
@@ -26,10 +26,6 @@ def mock_storage_client(mocker):
|
||||
mock_client = AsyncMock()
|
||||
mock_client.upload = AsyncMock()
|
||||
|
||||
# Mock context manager methods
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
# Mock the constructor to return our mock client
|
||||
mocker.patch(
|
||||
"backend.server.v2.store.media.async_storage.Storage", return_value=mock_client
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
from pathlib import Path
|
||||
|
||||
from prisma.models import User
|
||||
|
||||
from backend.blocks.basic import StoreValueBlock
|
||||
from backend.blocks.block import BlockInstallationBlock
|
||||
from backend.blocks.http import SendWebRequestBlock
|
||||
from backend.blocks.llm import AITextGeneratorBlock
|
||||
from backend.blocks.text import ExtractTextInformationBlock, FillTextTemplateBlock
|
||||
from backend.data.graph import Graph, Link, Node, create_graph
|
||||
from backend.data.model import User
|
||||
from backend.data.user import get_or_create_user
|
||||
from backend.util.test import SpinTestServer, wait_execution
|
||||
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
from prisma.models import User
|
||||
|
||||
from backend.blocks.llm import AIStructuredResponseGeneratorBlock
|
||||
from backend.blocks.reddit import GetRedditPostsBlock, PostRedditCommentBlock
|
||||
from backend.blocks.text import FillTextTemplateBlock, MatchTextPatternBlock
|
||||
from backend.data.graph import Graph, Link, Node, create_graph
|
||||
from backend.data.model import User
|
||||
from backend.data.user import get_or_create_user
|
||||
from backend.util.test import SpinTestServer, wait_execution
|
||||
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
from prisma.models import User
|
||||
|
||||
from backend.blocks.basic import StoreValueBlock
|
||||
from backend.blocks.io import AgentInputBlock
|
||||
from backend.blocks.text import FillTextTemplateBlock
|
||||
from backend.data import graph
|
||||
from backend.data.graph import create_graph
|
||||
from backend.data.model import User
|
||||
from backend.data.user import get_or_create_user
|
||||
from backend.util.test import SpinTestServer, wait_execution
|
||||
|
||||
|
||||
@@ -46,20 +46,6 @@ class CloudStorageHandler:
|
||||
self._async_gcs_client = async_gcs_storage.Storage()
|
||||
return self._async_gcs_client
|
||||
|
||||
async def close(self):
|
||||
"""Close all client connections properly."""
|
||||
if self._async_gcs_client is not None:
|
||||
await self._async_gcs_client.close()
|
||||
self._async_gcs_client = None
|
||||
|
||||
async def __aenter__(self):
|
||||
"""Async context manager entry."""
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Async context manager exit."""
|
||||
await self.close()
|
||||
|
||||
def _get_sync_gcs_client(self):
|
||||
"""Lazy initialization of sync GCS client (only for signed URLs)."""
|
||||
if self._sync_gcs_client is None:
|
||||
@@ -521,17 +507,6 @@ async def get_cloud_storage_handler() -> CloudStorageHandler:
|
||||
return _cloud_storage_handler
|
||||
|
||||
|
||||
async def shutdown_cloud_storage_handler():
|
||||
"""Properly shutdown the global cloud storage handler."""
|
||||
global _cloud_storage_handler
|
||||
|
||||
if _cloud_storage_handler is not None:
|
||||
async with _handler_lock:
|
||||
if _cloud_storage_handler is not None:
|
||||
await _cloud_storage_handler.close()
|
||||
_cloud_storage_handler = None
|
||||
|
||||
|
||||
async def cleanup_expired_files_async() -> int:
|
||||
"""
|
||||
Clean up expired files from cloud storage.
|
||||
|
||||
@@ -7,16 +7,14 @@ from sentry_sdk.integrations.logging import LoggingIntegration
|
||||
|
||||
from backend.util.settings import Settings
|
||||
|
||||
settings = Settings()
|
||||
|
||||
|
||||
def sentry_init():
|
||||
sentry_dsn = settings.secrets.sentry_dsn
|
||||
sentry_dsn = Settings().secrets.sentry_dsn
|
||||
sentry_sdk.init(
|
||||
dsn=sentry_dsn,
|
||||
traces_sample_rate=1.0,
|
||||
profiles_sample_rate=1.0,
|
||||
environment=f"app:{settings.config.app_env.value}-behave:{settings.config.behave_as.value}",
|
||||
environment=f"app:{Settings().config.app_env.value}-behave:{Settings().config.behave_as.value}",
|
||||
_experiments={"enable_logs": True},
|
||||
integrations=[
|
||||
LoggingIntegration(sentry_logs_level=logging.INFO),
|
||||
@@ -35,7 +33,9 @@ def sentry_capture_error(error: Exception):
|
||||
async def discord_send_alert(content: str):
|
||||
from backend.blocks.discord import SendDiscordMessageBlock
|
||||
from backend.data.model import APIKeyCredentials, CredentialsMetaInput, ProviderName
|
||||
from backend.util.settings import Settings
|
||||
|
||||
settings = Settings()
|
||||
creds = APIKeyCredentials(
|
||||
provider="discord",
|
||||
api_key=SecretStr(settings.secrets.discord_bot_token),
|
||||
|
||||
@@ -45,34 +45,6 @@ api_comm_retry = config.pyro_client_comm_retry
|
||||
api_comm_timeout = config.pyro_client_comm_timeout
|
||||
api_call_timeout = config.rpc_client_call_timeout
|
||||
|
||||
|
||||
def _validate_no_prisma_objects(obj: Any, path: str = "result") -> None:
|
||||
"""
|
||||
Recursively validate that no Prisma objects are being returned from service methods.
|
||||
This enforces proper separation of layers - only application models should cross service boundaries.
|
||||
"""
|
||||
if obj is None:
|
||||
return
|
||||
|
||||
# Check if it's a Prisma model object
|
||||
if hasattr(obj, "__class__") and hasattr(obj.__class__, "__module__"):
|
||||
module_name = obj.__class__.__module__
|
||||
if module_name and "prisma.models" in module_name:
|
||||
raise ValueError(
|
||||
f"Prisma object {obj.__class__.__name__} found in {path}. "
|
||||
"Service methods must return application models, not Prisma objects. "
|
||||
f"Use {obj.__class__.__name__}.from_db() to convert to application model."
|
||||
)
|
||||
|
||||
# Recursively check collections
|
||||
if isinstance(obj, (list, tuple)):
|
||||
for i, item in enumerate(obj):
|
||||
_validate_no_prisma_objects(item, f"{path}[{i}]")
|
||||
elif isinstance(obj, dict):
|
||||
for key, value in obj.items():
|
||||
_validate_no_prisma_objects(value, f"{path}['{key}']")
|
||||
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
EXPOSED_FLAG = "__exposed__"
|
||||
@@ -125,36 +97,6 @@ class RemoteCallError(BaseModel):
|
||||
args: Optional[Tuple[Any, ...]] = None
|
||||
|
||||
|
||||
class UnhealthyServiceError(ValueError):
|
||||
def __init__(
|
||||
self, message: str = "Service is unhealthy or not ready", log: bool = True
|
||||
):
|
||||
msg = f"[{get_service_name()}] - {message}"
|
||||
super().__init__(msg)
|
||||
self.message = msg
|
||||
if log:
|
||||
logger.error(self.message)
|
||||
|
||||
def __str__(self):
|
||||
return self.message
|
||||
|
||||
|
||||
class HTTPClientError(Exception):
|
||||
"""Exception for HTTP client errors (4xx status codes) that should not be retried."""
|
||||
|
||||
def __init__(self, status_code: int, message: str):
|
||||
self.status_code = status_code
|
||||
super().__init__(f"HTTP {status_code}: {message}")
|
||||
|
||||
|
||||
class HTTPServerError(Exception):
|
||||
"""Exception for HTTP server errors (5xx status codes) that can be retried."""
|
||||
|
||||
def __init__(self, status_code: int, message: str):
|
||||
self.status_code = status_code
|
||||
super().__init__(f"HTTP {status_code}: {message}")
|
||||
|
||||
|
||||
EXCEPTION_MAPPING = {
|
||||
e.__name__: e
|
||||
for e in [
|
||||
@@ -162,9 +104,6 @@ EXCEPTION_MAPPING = {
|
||||
RuntimeError,
|
||||
TimeoutError,
|
||||
ConnectionError,
|
||||
UnhealthyServiceError,
|
||||
HTTPClientError,
|
||||
HTTPServerError,
|
||||
*[
|
||||
ErrorType
|
||||
for _, ErrorType in inspect.getmembers(exceptions)
|
||||
@@ -237,21 +176,17 @@ class AppService(BaseAppService, ABC):
|
||||
if asyncio.iscoroutinefunction(f):
|
||||
|
||||
async def async_endpoint(body: RequestBodyModel): # type: ignore #RequestBodyModel being variable
|
||||
result = await f(
|
||||
return await f(
|
||||
**{name: getattr(body, name) for name in type(body).model_fields}
|
||||
)
|
||||
_validate_no_prisma_objects(result, f"{func.__name__} result")
|
||||
return result
|
||||
|
||||
return async_endpoint
|
||||
else:
|
||||
|
||||
def sync_endpoint(body: RequestBodyModel): # type: ignore #RequestBodyModel being variable
|
||||
result = f(
|
||||
return f(
|
||||
**{name: getattr(body, name) for name in type(body).model_fields}
|
||||
)
|
||||
_validate_no_prisma_objects(result, f"{func.__name__} result")
|
||||
return result
|
||||
|
||||
return sync_endpoint
|
||||
|
||||
@@ -272,7 +207,7 @@ class AppService(BaseAppService, ABC):
|
||||
)
|
||||
self.shared_event_loop.run_until_complete(server.serve())
|
||||
|
||||
async def health_check(self) -> str:
|
||||
def health_check(self) -> str:
|
||||
"""
|
||||
A method to check the health of the process.
|
||||
"""
|
||||
@@ -363,7 +298,6 @@ def get_service_client(
|
||||
AttributeError, # Missing attributes
|
||||
asyncio.CancelledError, # Task was cancelled
|
||||
concurrent.futures.CancelledError, # Future was cancelled
|
||||
HTTPClientError, # HTTP 4xx client errors - don't retry
|
||||
),
|
||||
)(fn)
|
||||
|
||||
@@ -441,31 +375,11 @@ def get_service_client(
|
||||
self._connection_failure_count = 0
|
||||
return response.json()
|
||||
except httpx.HTTPStatusError as e:
|
||||
status_code = e.response.status_code
|
||||
|
||||
# Try to parse the error response as RemoteCallError for mapped exceptions
|
||||
error_response = None
|
||||
try:
|
||||
error_response = RemoteCallError.model_validate(e.response.json())
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# If we successfully parsed a mapped exception type, re-raise it
|
||||
if error_response and error_response.type in EXCEPTION_MAPPING:
|
||||
exception_class = EXCEPTION_MAPPING[error_response.type]
|
||||
args = error_response.args or [str(e)]
|
||||
raise exception_class(*args)
|
||||
|
||||
# Otherwise categorize by HTTP status code
|
||||
if 400 <= status_code < 500:
|
||||
# Client errors (4xx) - wrap to prevent retries
|
||||
raise HTTPClientError(status_code, str(e))
|
||||
elif 500 <= status_code < 600:
|
||||
# Server errors (5xx) - wrap but allow retries
|
||||
raise HTTPServerError(status_code, str(e))
|
||||
else:
|
||||
# Other status codes (1xx, 2xx, 3xx) - re-raise original error
|
||||
raise e
|
||||
error = RemoteCallError.model_validate(e.response.json())
|
||||
# DEBUG HELP: if you made a custom exception, make sure you override self.args to be how to make your exception
|
||||
raise EXCEPTION_MAPPING.get(error.type, Exception)(
|
||||
*(error.args or [str(e)])
|
||||
)
|
||||
|
||||
@_maybe_retry
|
||||
def _call_method_sync(self, method_name: str, **kwargs: Any) -> Any:
|
||||
@@ -492,43 +406,11 @@ def get_service_client(
|
||||
raise
|
||||
|
||||
async def aclose(self) -> None:
|
||||
if hasattr(self, "sync_client"):
|
||||
self.sync_client.close()
|
||||
if hasattr(self, "async_client"):
|
||||
await self.async_client.aclose()
|
||||
self.sync_client.close()
|
||||
await self.async_client.aclose()
|
||||
|
||||
def close(self) -> None:
|
||||
if hasattr(self, "sync_client"):
|
||||
self.sync_client.close()
|
||||
# Note: Cannot close async client synchronously
|
||||
|
||||
def __del__(self):
|
||||
"""Cleanup HTTP clients on garbage collection to prevent resource leaks."""
|
||||
try:
|
||||
if hasattr(self, "sync_client"):
|
||||
self.sync_client.close()
|
||||
if hasattr(self, "async_client"):
|
||||
# Note: Can't await in __del__, so we just close sync
|
||||
# The async client will be cleaned up by garbage collection
|
||||
import warnings
|
||||
|
||||
warnings.warn(
|
||||
"DynamicClient async client not explicitly closed. "
|
||||
"Call aclose() before destroying the client.",
|
||||
ResourceWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
except Exception:
|
||||
# Silently ignore cleanup errors in __del__
|
||||
pass
|
||||
|
||||
async def __aenter__(self):
|
||||
"""Async context manager entry."""
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Async context manager exit."""
|
||||
await self.aclose()
|
||||
self.sync_client.close()
|
||||
|
||||
def _get_params(
|
||||
self, signature: inspect.Signature, *args: Any, **kwargs: Any
|
||||
|
||||
@@ -8,8 +8,6 @@ import pytest
|
||||
from backend.util.service import (
|
||||
AppService,
|
||||
AppServiceClient,
|
||||
HTTPClientError,
|
||||
HTTPServerError,
|
||||
endpoint_to_async,
|
||||
expose,
|
||||
get_service_client,
|
||||
@@ -368,125 +366,3 @@ def test_service_no_retry_when_disabled(server):
|
||||
# This should fail immediately without retry
|
||||
with pytest.raises(RuntimeError, match="Intended error for testing"):
|
||||
client.always_failing_add(5, 3)
|
||||
|
||||
|
||||
class TestHTTPErrorRetryBehavior:
|
||||
"""Test that HTTP client errors (4xx) are not retried but server errors (5xx) can be."""
|
||||
|
||||
# Note: These tests access private methods for testing internal behavior
|
||||
# Type ignore comments are used to suppress warnings about accessing private methods
|
||||
|
||||
def test_http_client_error_not_retried(self):
|
||||
"""Test that 4xx errors are wrapped as HTTPClientError and not retried."""
|
||||
# Create a mock response with 404 status
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 404
|
||||
mock_response.json.return_value = {"message": "Not found"}
|
||||
mock_response.raise_for_status.side_effect = httpx.HTTPStatusError(
|
||||
"404 Not Found", request=Mock(), response=mock_response
|
||||
)
|
||||
|
||||
# Create client
|
||||
client = get_service_client(ServiceTestClient)
|
||||
dynamic_client = client
|
||||
|
||||
# Test the _handle_call_method_response directly
|
||||
with pytest.raises(HTTPClientError) as exc_info:
|
||||
dynamic_client._handle_call_method_response( # type: ignore[attr-defined]
|
||||
response=mock_response, method_name="test_method"
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 404
|
||||
assert "404" in str(exc_info.value)
|
||||
|
||||
def test_http_server_error_can_be_retried(self):
|
||||
"""Test that 5xx errors are wrapped as HTTPServerError and can be retried."""
|
||||
# Create a mock response with 500 status
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 500
|
||||
mock_response.json.return_value = {"message": "Internal server error"}
|
||||
mock_response.raise_for_status.side_effect = httpx.HTTPStatusError(
|
||||
"500 Internal Server Error", request=Mock(), response=mock_response
|
||||
)
|
||||
|
||||
# Create client
|
||||
client = get_service_client(ServiceTestClient)
|
||||
dynamic_client = client
|
||||
|
||||
# Test the _handle_call_method_response directly
|
||||
with pytest.raises(HTTPServerError) as exc_info:
|
||||
dynamic_client._handle_call_method_response( # type: ignore[attr-defined]
|
||||
response=mock_response, method_name="test_method"
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 500
|
||||
assert "500" in str(exc_info.value)
|
||||
|
||||
def test_mapped_exception_preserves_original_type(self):
|
||||
"""Test that mapped exceptions preserve their original type regardless of HTTP status."""
|
||||
# Create a mock response with ValueError in the remote call error
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 400
|
||||
mock_response.json.return_value = {
|
||||
"type": "ValueError",
|
||||
"args": ["Invalid parameter value"],
|
||||
}
|
||||
mock_response.raise_for_status.side_effect = httpx.HTTPStatusError(
|
||||
"400 Bad Request", request=Mock(), response=mock_response
|
||||
)
|
||||
|
||||
# Create client
|
||||
client = get_service_client(ServiceTestClient)
|
||||
dynamic_client = client
|
||||
|
||||
# Test the _handle_call_method_response directly
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
dynamic_client._handle_call_method_response( # type: ignore[attr-defined]
|
||||
response=mock_response, method_name="test_method"
|
||||
)
|
||||
|
||||
assert "Invalid parameter value" in str(exc_info.value)
|
||||
|
||||
def test_client_error_status_codes_coverage(self):
|
||||
"""Test that various 4xx status codes are all wrapped as HTTPClientError."""
|
||||
client_error_codes = [400, 401, 403, 404, 405, 409, 422, 429]
|
||||
|
||||
for status_code in client_error_codes:
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = status_code
|
||||
mock_response.json.return_value = {"message": f"Error {status_code}"}
|
||||
mock_response.raise_for_status.side_effect = httpx.HTTPStatusError(
|
||||
f"{status_code} Error", request=Mock(), response=mock_response
|
||||
)
|
||||
|
||||
client = get_service_client(ServiceTestClient)
|
||||
dynamic_client = client
|
||||
|
||||
with pytest.raises(HTTPClientError) as exc_info:
|
||||
dynamic_client._handle_call_method_response( # type: ignore
|
||||
response=mock_response, method_name="test_method"
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == status_code
|
||||
|
||||
def test_server_error_status_codes_coverage(self):
|
||||
"""Test that various 5xx status codes are all wrapped as HTTPServerError."""
|
||||
server_error_codes = [500, 501, 502, 503, 504, 505]
|
||||
|
||||
for status_code in server_error_codes:
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = status_code
|
||||
mock_response.json.return_value = {"message": f"Error {status_code}"}
|
||||
mock_response.raise_for_status.side_effect = httpx.HTTPStatusError(
|
||||
f"{status_code} Error", request=Mock(), response=mock_response
|
||||
)
|
||||
|
||||
client = get_service_client(ServiceTestClient)
|
||||
dynamic_client = client
|
||||
|
||||
with pytest.raises(HTTPServerError) as exc_info:
|
||||
dynamic_client._handle_call_method_response( # type: ignore
|
||||
response=mock_response, method_name="test_method"
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == status_code
|
||||
|
||||
21
autogpt_platform/backend/poetry.lock
generated
21
autogpt_platform/backend/poetry.lock
generated
@@ -1079,25 +1079,6 @@ files = [
|
||||
dnspython = ">=2.0.0"
|
||||
idna = ">=2.0.0"
|
||||
|
||||
[[package]]
|
||||
name = "exa-py"
|
||||
version = "1.14.20"
|
||||
description = "Python SDK for Exa API."
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "exa_py-1.14.20-py3-none-any.whl", hash = "sha256:e0ed9d99c3c494a0e6903e11a0f6fb773b3b23d0cd802380cf58efc97d9d332d"},
|
||||
{file = "exa_py-1.14.20.tar.gz", hash = "sha256:423789a0635b7a4ecd5f56d6b4a0dfb01126fa45ce1e04106c0bb96b7d551ebf"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
httpx = ">=0.28.1"
|
||||
openai = ">=1.48"
|
||||
pydantic = ">=2.10.6"
|
||||
requests = ">=2.32.3"
|
||||
typing-extensions = ">=4.12.2"
|
||||
|
||||
[[package]]
|
||||
name = "exceptiongroup"
|
||||
version = "1.3.0"
|
||||
@@ -6737,4 +6718,4 @@ cffi = ["cffi (>=1.11)"]
|
||||
[metadata]
|
||||
lock-version = "2.1"
|
||||
python-versions = ">=3.10,<3.13"
|
||||
content-hash = "05e2b99bd6dc5a74a89df0e1e504853b66bd519062159b0de5fbedf6a1f4d986"
|
||||
content-hash = "225ddae645d22cc57f46330e735c069fb52e708123aa642e74adbf077dda0796"
|
||||
|
||||
@@ -75,7 +75,6 @@ setuptools = "^80.9.0"
|
||||
gcloud-aio-storage = "^9.5.0"
|
||||
pandas = "^2.3.1"
|
||||
firecrawl-py = "^2.16.3"
|
||||
exa-py = "^1.14.20"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
aiohappyeyeballs = "^2.6.1"
|
||||
@@ -102,7 +101,6 @@ rest = "backend.rest:main"
|
||||
db = "backend.db:main"
|
||||
ws = "backend.ws:main"
|
||||
scheduler = "backend.scheduler:main"
|
||||
notification = "backend.notification:main"
|
||||
executor = "backend.exec:main"
|
||||
cli = "backend.cli:main"
|
||||
format = "linter:format"
|
||||
|
||||
@@ -530,61 +530,6 @@ class TestDataCreator:
|
||||
submissions = []
|
||||
approved_submissions = []
|
||||
|
||||
# Create a special test submission for test123@gmail.com
|
||||
test_user = next(
|
||||
(user for user in self.users if user["email"] == "test123@gmail.com"), None
|
||||
)
|
||||
if test_user:
|
||||
# Special test data for consistent testing
|
||||
test_submission_data = {
|
||||
"user_id": test_user["id"],
|
||||
"agent_id": self.agent_graphs[0]["id"], # Use first available graph
|
||||
"agent_version": 1,
|
||||
"slug": "test-agent-submission",
|
||||
"name": "Test Agent Submission",
|
||||
"sub_heading": "A test agent for frontend testing",
|
||||
"video_url": "https://www.youtube.com/watch?v=test123",
|
||||
"image_urls": [
|
||||
"https://picsum.photos/200/300",
|
||||
"https://picsum.photos/200/301",
|
||||
"https://picsum.photos/200/302",
|
||||
],
|
||||
"description": "This is a test agent submission specifically created for frontend testing purposes.",
|
||||
"categories": ["test", "demo", "frontend"],
|
||||
"changes_summary": "Initial test submission",
|
||||
}
|
||||
|
||||
try:
|
||||
test_submission = await create_store_submission(**test_submission_data)
|
||||
submissions.append(test_submission.model_dump())
|
||||
print("✅ Created special test store submission for test123@gmail.com")
|
||||
|
||||
# Auto-approve the test submission
|
||||
if test_submission.store_listing_version_id:
|
||||
approved_submission = await review_store_submission(
|
||||
store_listing_version_id=test_submission.store_listing_version_id,
|
||||
is_approved=True,
|
||||
external_comments="Test submission approved",
|
||||
internal_comments="Auto-approved test submission",
|
||||
reviewer_id=test_user["id"],
|
||||
)
|
||||
approved_submissions.append(approved_submission.model_dump())
|
||||
print("✅ Approved test store submission")
|
||||
|
||||
# Mark test submission as featured
|
||||
await prisma.storelistingversion.update(
|
||||
where={"id": test_submission.store_listing_version_id},
|
||||
data={"isFeatured": True},
|
||||
)
|
||||
print("🌟 Marked test agent as FEATURED")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error creating test store submission: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
# Create regular submissions for all users
|
||||
for user in self.users:
|
||||
# Get available graphs for this specific user
|
||||
user_graphs = [
|
||||
|
||||
@@ -73,6 +73,8 @@ services:
|
||||
condition: service_completed_successfully
|
||||
rabbitmq:
|
||||
condition: service_healthy
|
||||
# scheduler_server:
|
||||
# condition: service_healthy
|
||||
environment:
|
||||
- SUPABASE_URL=http://kong:8000
|
||||
- SUPABASE_JWT_SECRET=your-super-secret-jwt-token-with-at-least-32-characters-long
|
||||
@@ -90,7 +92,7 @@ services:
|
||||
- PYRO_HOST=0.0.0.0
|
||||
- SCHEDULER_HOST=scheduler_server
|
||||
- EXECUTIONMANAGER_HOST=executor
|
||||
- NOTIFICATIONMANAGER_HOST=notification_server
|
||||
- NOTIFICATIONMANAGER_HOST=rest_server
|
||||
- CLAMAV_SERVICE_HOST=clamav
|
||||
- NEXT_PUBLIC_FRONTEND_BASE_URL=http://localhost:3000
|
||||
- BACKEND_CORS_ALLOW_ORIGINS=["http://localhost:3000"]
|
||||
@@ -98,6 +100,7 @@ services:
|
||||
- UNSUBSCRIBE_SECRET_KEY=HlP8ivStJjmbf6NKi78m_3FnOogut0t5ckzjsIqeaio= # DO NOT USE IN PRODUCTION!!
|
||||
ports:
|
||||
- "8006:8006"
|
||||
- "8007:8007"
|
||||
networks:
|
||||
- app-network
|
||||
|
||||
@@ -140,7 +143,7 @@ services:
|
||||
- ENABLE_AUTH=true
|
||||
- PYRO_HOST=0.0.0.0
|
||||
- AGENTSERVER_HOST=rest_server
|
||||
- NOTIFICATIONMANAGER_HOST=notification_server
|
||||
- NOTIFICATIONMANAGER_HOST=rest_server
|
||||
- CLAMAV_SERVICE_HOST=clamav
|
||||
- ENCRYPTION_KEY=dvziYgz0KSK8FENhju0ZYi8-fRTfAdlz6YLhdB_jhNw= # DO NOT USE IN PRODUCTION!!
|
||||
ports:
|
||||
@@ -164,6 +167,8 @@ services:
|
||||
condition: service_healthy
|
||||
redis:
|
||||
condition: service_healthy
|
||||
# rabbitmq:
|
||||
# condition: service_healthy
|
||||
migrate:
|
||||
condition: service_completed_successfully
|
||||
database_manager:
|
||||
@@ -252,7 +257,7 @@ services:
|
||||
# retries: 5
|
||||
environment:
|
||||
- DATABASEMANAGER_HOST=database_manager
|
||||
- NOTIFICATIONMANAGER_HOST=notification_server
|
||||
- NOTIFICATIONMANAGER_HOST=rest_server
|
||||
- SUPABASE_JWT_SECRET=your-super-secret-jwt-token-with-at-least-32-characters-long
|
||||
- DATABASE_URL=postgresql://postgres:your-super-secret-and-long-postgres-password@db:5432/postgres?connect_timeout=60&schema=platform
|
||||
- DIRECT_URL=postgresql://postgres:your-super-secret-and-long-postgres-password@db:5432/postgres?connect_timeout=60&schema=platform
|
||||
@@ -272,44 +277,6 @@ services:
|
||||
networks:
|
||||
- app-network
|
||||
|
||||
notification_server:
|
||||
build:
|
||||
context: ../
|
||||
dockerfile: autogpt_platform/backend/Dockerfile
|
||||
target: server
|
||||
command: ["python", "-m", "backend.notification"]
|
||||
develop:
|
||||
watch:
|
||||
- path: ./
|
||||
target: autogpt_platform/backend/
|
||||
action: rebuild
|
||||
depends_on:
|
||||
db:
|
||||
condition: service_healthy
|
||||
rabbitmq:
|
||||
condition: service_healthy
|
||||
migrate:
|
||||
condition: service_completed_successfully
|
||||
database_manager:
|
||||
condition: service_started
|
||||
environment:
|
||||
- DATABASEMANAGER_HOST=database_manager
|
||||
- REDIS_HOST=redis
|
||||
- REDIS_PORT=6379
|
||||
- REDIS_PASSWORD=password
|
||||
- RABBITMQ_HOST=rabbitmq
|
||||
- RABBITMQ_PORT=5672
|
||||
- RABBITMQ_DEFAULT_USER=rabbitmq_user_default
|
||||
- RABBITMQ_DEFAULT_PASS=k0VMxyIJF9S35f3x2uaw5IWAl6Y536O7
|
||||
- ENABLE_AUTH=true
|
||||
- PYRO_HOST=0.0.0.0
|
||||
- BACKEND_CORS_ALLOW_ORIGINS=["http://localhost:3000"]
|
||||
|
||||
ports:
|
||||
- "8007:8007"
|
||||
networks:
|
||||
- app-network
|
||||
|
||||
# frontend:
|
||||
# build:
|
||||
# context: ../
|
||||
|
||||
@@ -70,12 +70,6 @@ services:
|
||||
file: ./docker-compose.platform.yml
|
||||
service: scheduler_server
|
||||
|
||||
notification_server:
|
||||
<<: *agpt-services
|
||||
extends:
|
||||
file: ./docker-compose.platform.yml
|
||||
service: notification_server
|
||||
|
||||
clamav:
|
||||
<<: *agpt-services
|
||||
image: clamav/clamav-debian:latest
|
||||
|
||||
@@ -29,7 +29,6 @@ export const useMainMarketplacePage = () => {
|
||||
} = useGetV2ListStoreAgents(
|
||||
{
|
||||
sorted_by: "runs",
|
||||
page_size: 1000,
|
||||
},
|
||||
{
|
||||
query: {
|
||||
|
||||
@@ -26,7 +26,7 @@ export const AgentTable: React.FC<AgentTableProps> = ({
|
||||
onDeleteSubmission,
|
||||
}) => {
|
||||
return (
|
||||
<div className="w-full" data-testid="agent-table">
|
||||
<div className="w-full">
|
||||
{/* Table header - Hide on mobile */}
|
||||
<div className="hidden flex-col md:flex">
|
||||
<div className="border-t border-neutral-300 dark:border-neutral-700" />
|
||||
|
||||
@@ -64,11 +64,7 @@ export const AgentTableRow = ({
|
||||
});
|
||||
|
||||
return (
|
||||
<div
|
||||
data-testid="agent-table-row"
|
||||
data-agent-name={agentName}
|
||||
className="hidden items-center border-b border-neutral-300 px-4 py-4 hover:bg-neutral-50 dark:border-neutral-700 dark:hover:bg-neutral-800 md:flex"
|
||||
>
|
||||
<div className="hidden items-center border-b border-neutral-300 px-4 py-4 hover:bg-neutral-50 dark:border-neutral-700 dark:hover:bg-neutral-800 md:flex">
|
||||
<div className="grid w-full grid-cols-[minmax(400px,1fr),180px,140px,100px,100px,40px] items-center gap-4">
|
||||
{/* Agent info column */}
|
||||
<div className="flex items-center gap-4">
|
||||
@@ -135,7 +131,7 @@ export const AgentTableRow = ({
|
||||
{/* Actions - Three dots menu */}
|
||||
<div className="flex justify-end">
|
||||
<DropdownMenu.Root>
|
||||
<DropdownMenu.Trigger data-testid="agent-table-row-actions">
|
||||
<DropdownMenu.Trigger>
|
||||
<DotsThreeVerticalIcon className="h-5 w-5 text-neutral-800" />
|
||||
</DropdownMenu.Trigger>
|
||||
<DropdownMenu.Content className="z-10 rounded-xl border bg-white p-1 shadow-md dark:bg-gray-800">
|
||||
|
||||
@@ -48,11 +48,7 @@ export const MainDashboardPage = () => {
|
||||
targetState={publishState}
|
||||
onStateChange={onPublishStateChange}
|
||||
trigger={
|
||||
<Button
|
||||
data-testid="submit-agent-button"
|
||||
size="small"
|
||||
onClick={onOpenSubmitModal}
|
||||
>
|
||||
<Button size="small" onClick={onOpenSubmitModal}>
|
||||
Submit agent
|
||||
</Button>
|
||||
}
|
||||
|
||||
@@ -46,7 +46,6 @@ export function AgentReviewStep({
|
||||
<Text
|
||||
variant="lead"
|
||||
className="line-clamp-1 text-ellipsis text-center font-semibold"
|
||||
data-testid="view-agent-name"
|
||||
>
|
||||
{agentName}
|
||||
</Text>
|
||||
|
||||
@@ -1,108 +0,0 @@
|
||||
import { LoginPage } from "./pages/login.page";
|
||||
import test, { expect } from "@playwright/test";
|
||||
import { TEST_AGENT_DATA, TEST_CREDENTIALS } from "./credentials";
|
||||
import { getSelectors } from "./utils/selectors";
|
||||
import { hasUrl } from "./utils/assertion";
|
||||
|
||||
test.describe("Agent Dashboard", () => {
|
||||
test.beforeEach(async ({ page }) => {
|
||||
const loginPage = new LoginPage(page);
|
||||
await page.goto("/login");
|
||||
await loginPage.login(TEST_CREDENTIALS.email, TEST_CREDENTIALS.password);
|
||||
await hasUrl(page, "/marketplace");
|
||||
});
|
||||
|
||||
test("dashboard page loads successfully", async ({ page }) => {
|
||||
const { getText } = getSelectors(page);
|
||||
await page.goto("/profile/dashboard");
|
||||
|
||||
await expect(getText("Agent dashboard")).toBeVisible();
|
||||
await expect(getText("Submit a New Agent")).toBeVisible();
|
||||
await expect(getText("Your uploaded agents")).toBeVisible();
|
||||
});
|
||||
|
||||
test("submit agent button works correctly", async ({ page }) => {
|
||||
const { getId, getText } = getSelectors(page);
|
||||
|
||||
await page.goto("/profile/dashboard");
|
||||
const submitAgentButton = getId("submit-agent-button");
|
||||
await expect(submitAgentButton).toBeVisible();
|
||||
await submitAgentButton.click();
|
||||
|
||||
await expect(getText("Publish Agent")).toBeVisible();
|
||||
await expect(
|
||||
getText("Select your project that you'd like to publish"),
|
||||
).toBeVisible();
|
||||
|
||||
await page.locator('button[aria-label="Close"]').click();
|
||||
await expect(getText("Publish Agent")).not.toBeVisible();
|
||||
});
|
||||
|
||||
test("agent table displays data correctly", async ({ page }) => {
|
||||
const { getText } = getSelectors(page);
|
||||
await page.goto("/profile/dashboard");
|
||||
|
||||
await expect(getText("Agent info")).toBeVisible();
|
||||
await expect(getText("Date submitted")).toBeVisible();
|
||||
|
||||
await expect(getText(TEST_AGENT_DATA.name).first()).toBeVisible();
|
||||
await expect(getText(TEST_AGENT_DATA.description).first()).toBeVisible();
|
||||
});
|
||||
|
||||
test("agent table actions work correctly", async ({ page }) => {
|
||||
await page.goto("/profile/dashboard");
|
||||
|
||||
const agentTable = page.getByTestId("agent-table");
|
||||
await expect(agentTable).toBeVisible();
|
||||
|
||||
const rows = agentTable.getByTestId("agent-table-row");
|
||||
|
||||
const testRow = rows.filter({ hasText: TEST_AGENT_DATA.name }).first();
|
||||
await testRow.scrollIntoViewIfNeeded();
|
||||
|
||||
const actionsButton = testRow.getByTestId("agent-table-row-actions");
|
||||
await actionsButton.waitFor({ state: "visible", timeout: 10000 });
|
||||
await actionsButton.scrollIntoViewIfNeeded();
|
||||
await actionsButton.click();
|
||||
|
||||
// View button testing
|
||||
const viewButton = page.getByRole("menuitem", { name: "View" });
|
||||
await expect(viewButton).toBeVisible();
|
||||
await viewButton.click();
|
||||
|
||||
const modal = page.getByTestId("publish-agent-modal");
|
||||
await expect(modal).toBeVisible();
|
||||
const viewAgentName = page.getByTestId("view-agent-name");
|
||||
await expect(viewAgentName).toBeVisible();
|
||||
await expect(viewAgentName).toHaveText(TEST_AGENT_DATA.name);
|
||||
|
||||
await page.getByRole("button", { name: "Done" }).click();
|
||||
await expect(modal).not.toBeVisible();
|
||||
|
||||
// Delete button testing
|
||||
// Delete button testing — delete the first agent in the list
|
||||
const beforeCount = await rows.count();
|
||||
|
||||
if (beforeCount === 0) {
|
||||
console.log("No agents available; skipping delete flow.");
|
||||
return;
|
||||
}
|
||||
|
||||
const firstRow = rows.first();
|
||||
await firstRow.scrollIntoViewIfNeeded();
|
||||
|
||||
const delActionsButton = firstRow.getByTestId("agent-table-row-actions");
|
||||
await delActionsButton.waitFor({ state: "visible", timeout: 10000 });
|
||||
await delActionsButton.scrollIntoViewIfNeeded();
|
||||
await delActionsButton.click();
|
||||
|
||||
const deleteButton = page.getByRole("menuitem", { name: "Delete" });
|
||||
await expect(deleteButton).toBeVisible();
|
||||
await deleteButton.click();
|
||||
|
||||
// Wait for row count to drop by 1
|
||||
await expect
|
||||
.poll(async () => await rows.count(), { timeout: 15000 })
|
||||
.toBe(beforeCount - 1);
|
||||
});
|
||||
});
|
||||
@@ -6,19 +6,3 @@ export const TEST_CREDENTIALS = {
|
||||
|
||||
// Dummy constant to help developers identify agents that don't need input
|
||||
export const DummyInput = "DummyInput";
|
||||
|
||||
// This will be used for testing agent submission for test123@gmail.com
|
||||
export const TEST_AGENT_DATA = {
|
||||
name: "Test Agent Submission",
|
||||
description:
|
||||
"This is a test agent submission specifically created for frontend testing purposes.",
|
||||
image_urls: [
|
||||
"https://picsum.photos/200/300",
|
||||
"https://picsum.photos/200/301",
|
||||
"https://picsum.photos/200/302",
|
||||
],
|
||||
video_url: "https://www.youtube.com/watch?v=test123",
|
||||
sub_heading: "A test agent for frontend testing",
|
||||
categories: ["test", "demo", "frontend"],
|
||||
changes_summary: "Initial test submission",
|
||||
} as const;
|
||||
|
||||
@@ -93,6 +93,8 @@ test.describe("Marketplace Agent Page - Basic Functionality", () => {
|
||||
await firstStoreCard.click();
|
||||
await page.waitForURL("**/marketplace/agent/**");
|
||||
|
||||
const agentName = await getId("agent-title").textContent();
|
||||
|
||||
const addToLibraryButton = getId("agent-add-library-button");
|
||||
await isVisible(addToLibraryButton);
|
||||
await addToLibraryButton.click();
|
||||
@@ -102,8 +104,7 @@ test.describe("Marketplace Agent Page - Basic Functionality", () => {
|
||||
|
||||
await page.waitForURL("**/library/agents/**");
|
||||
const agentNameOnLibrary = await getId("agent-title").textContent();
|
||||
expect(
|
||||
agentNameOnLibrary && agentNameOnLibrary.trim().length,
|
||||
).toBeGreaterThan(0);
|
||||
|
||||
expect(agentNameOnLibrary?.trim()).toBe(agentName?.trim());
|
||||
});
|
||||
});
|
||||
|
||||
@@ -13,11 +13,7 @@ export class NavBar {
|
||||
}
|
||||
|
||||
async clickBuildLink() {
|
||||
const link = this.page.getByTestId("navbar-link-build");
|
||||
await link.waitFor({ state: "visible", timeout: 15000 });
|
||||
await link.scrollIntoViewIfNeeded();
|
||||
await link.click();
|
||||
await this.page.waitForURL(/\/build$/, { timeout: 15000 });
|
||||
await this.page.getByTestId("navbar-link-build").click();
|
||||
}
|
||||
|
||||
async clickMarketplaceLink() {
|
||||
|
||||
Reference in New Issue
Block a user