mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-30 03:00:41 -04:00
Compare commits
31 Commits
symphony/S
...
fix/backen
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ffe9105157 | ||
|
|
4a1741cc15 | ||
|
|
c08b9774dc | ||
|
|
fe3d6fb118 | ||
|
|
c6d31f8252 | ||
|
|
28ae7ebac8 | ||
|
|
e0f9146d54 | ||
|
|
c3c2737c42 | ||
|
|
37f247c795 | ||
|
|
ae4a421620 | ||
|
|
2879528308 | ||
|
|
1974ec6260 | ||
|
|
932ecd3a07 | ||
|
|
4a567a55a4 | ||
|
|
2b28434786 | ||
|
|
5d1cdc2bad | ||
|
|
3c08b90500 | ||
|
|
599f370206 | ||
|
|
8786c00f9c | ||
|
|
384cbd3ccd | ||
|
|
8be9cf70af | ||
|
|
a723966e0b | ||
|
|
5b1d9763ed | ||
|
|
10ea46663f | ||
|
|
06188a86a6 | ||
|
|
2deac2073e | ||
|
|
24406dfcec | ||
|
|
000ddb007a | ||
|
|
408b205515 | ||
|
|
f8c123a8c3 | ||
|
|
34374dfd55 |
@@ -186,7 +186,7 @@ Multiple worktrees share the same host — Docker infra (postgres, redis, clamav
|
||||
|
||||
### Lock file contract
|
||||
|
||||
Path (**always** the root worktree so all siblings see it): `/Users/majdyz/Code/AutoGPT/.ign.testing.lock`
|
||||
Path (**always** the root worktree so all siblings see it): `$REPO_ROOT/.ign.testing.lock`
|
||||
|
||||
Body (one `key=value` per line):
|
||||
```
|
||||
@@ -202,7 +202,7 @@ intent=<one-line description + rough duration>
|
||||
### Claim
|
||||
|
||||
```bash
|
||||
LOCK=/Users/majdyz/Code/AutoGPT/.ign.testing.lock
|
||||
LOCK=$REPO_ROOT/.ign.testing.lock
|
||||
NOW=$(date -u +%Y-%m-%dT%H:%MZ)
|
||||
STALE_AFTER_MIN=5
|
||||
|
||||
@@ -252,7 +252,7 @@ echo "$HEARTBEAT_PID" > /tmp/pr-test-heartbeat.pid
|
||||
kill "$HEARTBEAT_PID" 2>/dev/null
|
||||
rm -f "$LOCK" /tmp/pr-test-heartbeat.pid
|
||||
echo "$(date -u +%Y-%m-%dT%H:%MZ) [pr-${PR_NUMBER}] released lock" \
|
||||
>> /Users/majdyz/Code/AutoGPT/.ign.testing.log
|
||||
>> $REPO_ROOT/.ign.testing.log
|
||||
```
|
||||
|
||||
Use a `trap` so release runs even on `exit 1`:
|
||||
@@ -278,7 +278,7 @@ Concretely, the sequence at the end of every `/pr-test` run (success or failure)
|
||||
kill "$HEARTBEAT_PID" 2>/dev/null
|
||||
rm -f "$LOCK" /tmp/pr-test-heartbeat.pid
|
||||
echo "$(date -u +%Y-%m-%dT%H:%MZ) [pr-${PR_NUMBER}] released lock (app may still be running)" \
|
||||
>> /Users/majdyz/Code/AutoGPT/.ign.testing.log
|
||||
>> $REPO_ROOT/.ign.testing.log
|
||||
# 3. Optionally leave the app running and note it so the user knows:
|
||||
echo "Native stack still running on :3000 / :8006 for manual poking. Kill with:"
|
||||
echo " pkill -9 -f 'poetry run app'; pkill -9 -f 'next-server|next dev'"
|
||||
@@ -288,10 +288,10 @@ If a sibling agent's `/pr-test` needs to take over, it'll do the kill+rebuild da
|
||||
|
||||
### Shared status log
|
||||
|
||||
`/Users/majdyz/Code/AutoGPT/.ign.testing.log` is an append-only channel any agent can read/write. Use it for "I'm waiting", "I'm done, resources free", or post-run notes:
|
||||
`$REPO_ROOT/.ign.testing.log` is an append-only channel any agent can read/write. Use it for "I'm waiting", "I'm done, resources free", or post-run notes:
|
||||
```bash
|
||||
echo "$(date -u +%Y-%m-%dT%H:%MZ) [pr-${PR_NUMBER}] <message>" \
|
||||
>> /Users/majdyz/Code/AutoGPT/.ign.testing.log
|
||||
>> $REPO_ROOT/.ign.testing.log
|
||||
```
|
||||
|
||||
## Step 3: Environment setup
|
||||
|
||||
79
.github/workflows/platform-backend-ci.yml
vendored
79
.github/workflows/platform-backend-ci.yml
vendored
@@ -119,10 +119,12 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
services:
|
||||
redis:
|
||||
image: redis:latest
|
||||
ports:
|
||||
- 6379:6379
|
||||
# Redis is provisioned as a real 3-shard cluster below via docker
|
||||
# run (see the "Start Redis Cluster" step). GHA services can't
|
||||
# override the image CMD or stand up multi-container clusters, so
|
||||
# that setup is inlined — it mirrors the topology of the local dev
|
||||
# compose stack (autogpt_platform/docker-compose.platform.yml) and
|
||||
# prod helm chart.
|
||||
rabbitmq:
|
||||
image: rabbitmq:4.1.4
|
||||
ports:
|
||||
@@ -166,6 +168,68 @@ jobs:
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Start Redis Cluster (3 shards)
|
||||
run: |
|
||||
# 3-master Redis Cluster matching the local compose stack
|
||||
# (autogpt_platform/docker-compose.platform.yml) and prod. Each
|
||||
# shard runs in its own container on a dedicated bridge network,
|
||||
# announces its compose-style hostname for intra-network clients,
|
||||
# and publishes 1700N on the GHA host so tests can reach every
|
||||
# shard via localhost. The backend's ``_address_remap`` rewrites
|
||||
# every CLUSTER SLOTS reply to localhost:<announced-port>, which
|
||||
# picks the right published port per shard.
|
||||
#
|
||||
# Not reusing docker-compose.platform.yml directly because compose
|
||||
# validates the full file even when only some services are ``up``,
|
||||
# and that file references services (db/kong/...) defined in a
|
||||
# sibling compose file — pulling both in would needlessly couple
|
||||
# CI to the full local-dev stack.
|
||||
docker network create redis-cluster-ci
|
||||
for i in 0 1 2; do
|
||||
port=$((17000 + i))
|
||||
bus=$((27000 + i))
|
||||
docker run -d --name redis-$i --network redis-cluster-ci \
|
||||
--network-alias redis-$i \
|
||||
-p $port:$port \
|
||||
redis:7 \
|
||||
redis-server --port $port \
|
||||
--cluster-enabled yes \
|
||||
--cluster-config-file nodes.conf \
|
||||
--cluster-node-timeout 5000 \
|
||||
--cluster-require-full-coverage no \
|
||||
--cluster-announce-hostname redis-$i \
|
||||
--cluster-announce-port $port \
|
||||
--cluster-announce-bus-port $bus \
|
||||
--cluster-preferred-endpoint-type hostname
|
||||
done
|
||||
# Wait for each shard to accept commands.
|
||||
for i in 0 1 2; do
|
||||
port=$((17000 + i))
|
||||
for _ in $(seq 1 30); do
|
||||
docker exec redis-$i redis-cli -p $port ping 2>/dev/null | grep -q PONG && break
|
||||
sleep 1
|
||||
done
|
||||
done
|
||||
# Form the cluster from an init container on the same network so
|
||||
# --cluster-preferred-endpoint-type hostname resolves redis-0/1/2.
|
||||
docker run --rm --network redis-cluster-ci redis:7 \
|
||||
redis-cli --cluster create \
|
||||
redis-0:17000 redis-1:17001 redis-2:17002 \
|
||||
--cluster-replicas 0 --cluster-yes
|
||||
# Confirm convergence.
|
||||
for _ in $(seq 1 30); do
|
||||
state=$(docker exec redis-0 redis-cli -p 17000 cluster info | awk -F: '/^cluster_state:/ {print $2}' | tr -d '[:cntrl:]')
|
||||
if [ "$state" = "ok" ]; then
|
||||
echo "Redis Cluster ready (3 shards, state=ok)"
|
||||
docker exec redis-0 redis-cli -p 17000 cluster nodes
|
||||
exit 0
|
||||
fi
|
||||
sleep 1
|
||||
done
|
||||
echo "Redis Cluster failed to reach ok state" >&2
|
||||
docker exec redis-0 redis-cli -p 17000 cluster info >&2 || true
|
||||
exit 1
|
||||
|
||||
- name: Setup Supabase
|
||||
uses: supabase/setup-cli@v1
|
||||
with:
|
||||
@@ -286,8 +350,13 @@ jobs:
|
||||
SUPABASE_SERVICE_ROLE_KEY: ${{ steps.supabase.outputs.SERVICE_ROLE_KEY }}
|
||||
JWT_VERIFY_KEY: ${{ steps.supabase.outputs.JWT_SECRET }}
|
||||
REDIS_HOST: "localhost"
|
||||
REDIS_PORT: "6379"
|
||||
REDIS_PORT: "17000"
|
||||
ENCRYPTION_KEY: "dvziYgz0KSK8FENhju0ZYi8-fRTfAdlz6YLhdB_jhNw=" # DO NOT USE IN PRODUCTION!!
|
||||
# Opt-in: lets backend/data/e2e_redis_restart_test.py spin up its
|
||||
# own isolated 3-shard cluster (ports 27110–27112) and exercise
|
||||
# ``docker restart <shard>`` mid-stream. Off locally so a
|
||||
# contributor's ``poetry run test`` doesn't pay the ~15s cost.
|
||||
E2E_RESTART_ISOLATED: "1"
|
||||
|
||||
- name: Upload coverage reports to Codecov
|
||||
if: ${{ !cancelled() }}
|
||||
|
||||
4
.gitignore
vendored
4
.gitignore
vendored
@@ -196,3 +196,7 @@ test.db
|
||||
plans/
|
||||
.claude/worktrees/
|
||||
test-results/
|
||||
|
||||
# Playwright MCP / local browser-testing artifacts
|
||||
.playwright-mcp/
|
||||
copilot-session-switch-qa/
|
||||
|
||||
@@ -267,7 +267,7 @@
|
||||
"filename": "autogpt_platform/backend/backend/blocks/replicate/replicate_block.py",
|
||||
"hashed_secret": "8bbdd6f26368f58ea4011d13d7f763cb662e66f0",
|
||||
"is_verified": false,
|
||||
"line_number": 55
|
||||
"line_number": 67
|
||||
}
|
||||
],
|
||||
"autogpt_platform/backend/backend/blocks/slant3d/webhook.py": [
|
||||
@@ -467,5 +467,5 @@
|
||||
}
|
||||
]
|
||||
},
|
||||
"generated_at": "2026-04-09T14:20:23Z"
|
||||
"generated_at": "2026-04-24T16:42:44Z"
|
||||
}
|
||||
|
||||
@@ -1,33 +0,0 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
||||
class RateLimitSettings(BaseSettings):
|
||||
redis_host: str = Field(
|
||||
default="redis://localhost:6379",
|
||||
description="Redis host",
|
||||
validation_alias="REDIS_HOST",
|
||||
)
|
||||
|
||||
redis_port: str = Field(
|
||||
default="6379", description="Redis port", validation_alias="REDIS_PORT"
|
||||
)
|
||||
|
||||
redis_password: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Redis password",
|
||||
validation_alias="REDIS_PASSWORD",
|
||||
)
|
||||
|
||||
requests_per_minute: int = Field(
|
||||
default=60,
|
||||
description="Maximum number of requests allowed per minute per API key",
|
||||
validation_alias="RATE_LIMIT_REQUESTS_PER_MINUTE",
|
||||
)
|
||||
|
||||
model_config = SettingsConfigDict(case_sensitive=True, extra="ignore")
|
||||
|
||||
|
||||
RATE_LIMIT_SETTINGS = RateLimitSettings()
|
||||
@@ -1,51 +0,0 @@
|
||||
import time
|
||||
from typing import Tuple
|
||||
|
||||
from redis import Redis
|
||||
|
||||
from .config import RATE_LIMIT_SETTINGS
|
||||
|
||||
|
||||
class RateLimiter:
|
||||
def __init__(
|
||||
self,
|
||||
redis_host: str = RATE_LIMIT_SETTINGS.redis_host,
|
||||
redis_port: str = RATE_LIMIT_SETTINGS.redis_port,
|
||||
redis_password: str | None = RATE_LIMIT_SETTINGS.redis_password,
|
||||
requests_per_minute: int = RATE_LIMIT_SETTINGS.requests_per_minute,
|
||||
):
|
||||
self.redis = Redis(
|
||||
host=redis_host,
|
||||
port=int(redis_port),
|
||||
password=redis_password,
|
||||
decode_responses=True,
|
||||
)
|
||||
self.window = 60
|
||||
self.max_requests = requests_per_minute
|
||||
|
||||
async def check_rate_limit(self, api_key_id: str) -> Tuple[bool, int, int]:
|
||||
"""
|
||||
Check if request is within rate limits.
|
||||
|
||||
Args:
|
||||
api_key_id: The API key identifier to check
|
||||
|
||||
Returns:
|
||||
Tuple of (is_allowed, remaining_requests, reset_time)
|
||||
"""
|
||||
now = time.time()
|
||||
window_start = now - self.window
|
||||
key = f"ratelimit:{api_key_id}:1min"
|
||||
|
||||
pipe = self.redis.pipeline()
|
||||
pipe.zremrangebyscore(key, 0, window_start)
|
||||
pipe.zadd(key, {str(now): now})
|
||||
pipe.zcount(key, window_start, now)
|
||||
pipe.expire(key, self.window)
|
||||
|
||||
_, _, request_count, _ = pipe.execute()
|
||||
|
||||
remaining = max(0, self.max_requests - request_count)
|
||||
reset_time = int(now + self.window)
|
||||
|
||||
return request_count <= self.max_requests, remaining, reset_time
|
||||
@@ -1,32 +0,0 @@
|
||||
from fastapi import HTTPException, Request
|
||||
from starlette.middleware.base import RequestResponseEndpoint
|
||||
|
||||
from .limiter import RateLimiter
|
||||
|
||||
|
||||
async def rate_limit_middleware(request: Request, call_next: RequestResponseEndpoint):
|
||||
"""FastAPI middleware for rate limiting API requests."""
|
||||
limiter = RateLimiter()
|
||||
|
||||
if not request.url.path.startswith("/api"):
|
||||
return await call_next(request)
|
||||
|
||||
api_key = request.headers.get("Authorization")
|
||||
if not api_key:
|
||||
return await call_next(request)
|
||||
|
||||
api_key = api_key.replace("Bearer ", "")
|
||||
|
||||
is_allowed, remaining, reset_time = await limiter.check_rate_limit(api_key)
|
||||
|
||||
if not is_allowed:
|
||||
raise HTTPException(
|
||||
status_code=429, detail="Rate limit exceeded. Please try again later."
|
||||
)
|
||||
|
||||
response = await call_next(request)
|
||||
response.headers["X-RateLimit-Limit"] = str(limiter.max_requests)
|
||||
response.headers["X-RateLimit-Remaining"] = str(remaining)
|
||||
response.headers["X-RateLimit-Reset"] = str(reset_time)
|
||||
|
||||
return response
|
||||
@@ -1,13 +1,16 @@
|
||||
import asyncio
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import TYPE_CHECKING, Any, Union
|
||||
|
||||
from expiringdict import ExpiringDict
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from redis.asyncio import Redis as AsyncRedis
|
||||
from redis.asyncio.cluster import RedisCluster as AsyncRedisCluster
|
||||
from redis.asyncio.lock import Lock as AsyncRedisLock
|
||||
|
||||
AsyncRedisLike = Union[AsyncRedis, AsyncRedisCluster]
|
||||
|
||||
|
||||
class AsyncRedisKeyedMutex:
|
||||
"""
|
||||
@@ -17,7 +20,7 @@ class AsyncRedisKeyedMutex:
|
||||
in case the key is not unlocked for a specified duration, to prevent memory leaks.
|
||||
"""
|
||||
|
||||
def __init__(self, redis: "AsyncRedis", timeout: int | None = 60):
|
||||
def __init__(self, redis: "AsyncRedisLike", timeout: int | None = 60):
|
||||
self.redis = redis
|
||||
self.timeout = timeout
|
||||
self.locks: dict[Any, "AsyncRedisLock"] = ExpiringDict(
|
||||
|
||||
@@ -37,6 +37,23 @@ JWT_VERIFY_KEY=your-super-secret-jwt-token-with-at-least-32-characters-long
|
||||
ENCRYPTION_KEY=dvziYgz0KSK8FENhju0ZYi8-fRTfAdlz6YLhdB_jhNw=
|
||||
UNSUBSCRIBE_SECRET_KEY=HlP8ivStJjmbf6NKi78m_3FnOogut0t5ckzjsIqeaio=
|
||||
|
||||
# Web Push (VAPID) — generate with: poetry run python -c "
|
||||
# from py_vapid import Vapid; import base64
|
||||
# from cryptography.hazmat.primitives.serialization import Encoding, PublicFormat
|
||||
# v = Vapid(); v.generate_keys()
|
||||
# raw_priv = v.private_key.private_numbers().private_value.to_bytes(32, 'big')
|
||||
# print('VAPID_PRIVATE_KEY=' + base64.urlsafe_b64encode(raw_priv).rstrip(b'=').decode())
|
||||
# raw_pub = v.public_key.public_bytes(Encoding.X962, PublicFormat.UncompressedPoint)
|
||||
# print('VAPID_PUBLIC_KEY=' + base64.urlsafe_b64encode(raw_pub).rstrip(b'=').decode())
|
||||
# "
|
||||
# Dev-only keypair below — DO NOT use in staging/production. Regenerate
|
||||
# your own with the snippet above before any non-local deployment.
|
||||
VAPID_PRIVATE_KEY=17hBPdSdn6TR_yAgQxA0TjTcvRj3Lf6znHnASZ4rOKc
|
||||
VAPID_PUBLIC_KEY=BBg49iVTWthVbRYphwmZNvZyiSJDqtSO4nmLxDzLKe3Oo9jbtu0Usa14xX4HQQNLUeiEfzD42zWSlrvY1PR12bs
|
||||
# Per RFC 8292 push services use this in 410 Gone reports; set to a real
|
||||
# mailbox in production. Defaults to a placeholder for local dev.
|
||||
VAPID_CLAIM_EMAIL=mailto:dev@example.com
|
||||
|
||||
## ===== IMPORTANT OPTIONAL CONFIGURATION ===== ##
|
||||
# Platform URLs (set these for webhooks and OAuth to work)
|
||||
PLATFORM_BASE_URL=http://localhost:8000
|
||||
@@ -182,6 +199,10 @@ GOOGLE_MAPS_API_KEY=
|
||||
# Platform Bot Linking
|
||||
PLATFORM_LINK_BASE_URL=http://localhost:3000/link
|
||||
|
||||
# CoPilot chat-platform bridge (Discord/Telegram/Slack)
|
||||
# Uses FRONTEND_BASE_URL (above) for link confirmation pages.
|
||||
AUTOPILOT_BOT_DISCORD_TOKEN=
|
||||
|
||||
# Communication Services
|
||||
DISCORD_BOT_TOKEN=
|
||||
MEDIUM_API_KEY=
|
||||
|
||||
@@ -1,14 +1,44 @@
|
||||
import asyncio
|
||||
from typing import Dict, Set
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import Awaitable, Callable, Dict, Optional, Set
|
||||
|
||||
from fastapi import WebSocket
|
||||
from fastapi import WebSocket, WebSocketDisconnect
|
||||
from redis.asyncio import Redis as AsyncRedis
|
||||
from redis.asyncio.client import PubSub as AsyncPubSub
|
||||
from redis.exceptions import MovedError, RedisError, ResponseError
|
||||
from starlette.websockets import WebSocketState
|
||||
|
||||
from backend.api.model import NotificationPayload, WSMessage, WSMethod
|
||||
from backend.api.model import WSMessage, WSMethod
|
||||
from backend.data import redis_client as redis
|
||||
from backend.data.event_bus import _assert_no_wildcard
|
||||
from backend.data.execution import (
|
||||
ExecutionEventType,
|
||||
GraphExecutionEvent,
|
||||
NodeExecutionEvent,
|
||||
exec_channel,
|
||||
get_graph_execution_meta,
|
||||
graph_all_channel,
|
||||
)
|
||||
from backend.data.notification_bus import NotificationEvent
|
||||
from backend.util.settings import Settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
_settings = Settings()
|
||||
|
||||
|
||||
def _is_ws_close_race(exc: BaseException, websocket: WebSocket) -> bool:
|
||||
"""A SPUBLISH→WS send racing with WS close — benign, drop quietly."""
|
||||
if isinstance(exc, WebSocketDisconnect):
|
||||
return True
|
||||
if (
|
||||
getattr(websocket, "application_state", None) == WebSocketState.DISCONNECTED
|
||||
or getattr(websocket, "client_state", None) == WebSocketState.DISCONNECTED
|
||||
):
|
||||
return True
|
||||
if isinstance(exc, RuntimeError) and "close message has been sent" in str(exc):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
_EVENT_TYPE_TO_METHOD_MAP: dict[ExecutionEventType, WSMethod] = {
|
||||
ExecutionEventType.GRAPH_EXEC_UPDATE: WSMethod.GRAPH_EXECUTION_EVENT,
|
||||
@@ -16,128 +46,379 @@ _EVENT_TYPE_TO_METHOD_MAP: dict[ExecutionEventType, WSMethod] = {
|
||||
}
|
||||
|
||||
|
||||
def event_bus_channel(channel_key: str) -> str:
|
||||
"""Prefix a channel key with the execution event bus name."""
|
||||
return f"{_settings.config.execution_event_bus_name}/{channel_key}"
|
||||
|
||||
|
||||
def _notification_bus_channel(user_id: str) -> str:
|
||||
"""Return the full sharded channel name for a user's notifications."""
|
||||
return f"{_settings.config.notification_event_bus_name}/{user_id}"
|
||||
|
||||
|
||||
MessageHandler = Callable[[Optional[bytes | str]], Awaitable[None]]
|
||||
|
||||
|
||||
def _is_moved_error(exc: BaseException) -> bool:
|
||||
"""A MOVED redirect — slot migration mid-stream; pump should reconnect."""
|
||||
if isinstance(exc, MovedError):
|
||||
return True
|
||||
if isinstance(exc, ResponseError) and str(exc).startswith("MOVED "):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
# Reconnect tunables for shard-failover during pubsub.listen().
|
||||
_PUMP_RECONNECT_DEADLINE_S = 60.0
|
||||
_PUMP_RECONNECT_BACKOFF_INITIAL_S = 0.5
|
||||
_PUMP_RECONNECT_BACKOFF_MAX_S = 8.0
|
||||
|
||||
|
||||
class _Subscription:
|
||||
"""One SSUBSCRIBE lifecycle bound to a WebSocket, pinned to the owning shard."""
|
||||
|
||||
def __init__(self, full_channel: str) -> None:
|
||||
_assert_no_wildcard(full_channel)
|
||||
self.full_channel = full_channel
|
||||
self._client: AsyncRedis | None = None
|
||||
self._pubsub: AsyncPubSub | None = None
|
||||
self._task: asyncio.Task | None = None
|
||||
|
||||
async def start(self, on_message: MessageHandler) -> None:
|
||||
await self._open_pubsub()
|
||||
self._task = asyncio.create_task(self._pump(on_message))
|
||||
|
||||
async def _open_pubsub(self) -> None:
|
||||
"""(Re)establish the sharded pubsub connection + SSUBSCRIBE."""
|
||||
self._client = await redis.connect_sharded_pubsub_async(self.full_channel)
|
||||
self._pubsub = self._client.pubsub()
|
||||
await self._pubsub.execute_command("SSUBSCRIBE", self.full_channel)
|
||||
# redis-py 6.x async PubSub.listen() exits when ``channels`` is
|
||||
# empty; raw SSUBSCRIBE doesn't populate it, so do it ourselves.
|
||||
self._pubsub.channels[self.full_channel] = None # type: ignore[index]
|
||||
|
||||
async def _close_pubsub_quietly(self) -> None:
|
||||
"""Best-effort teardown before reconnect — never raises."""
|
||||
if self._pubsub is not None:
|
||||
try:
|
||||
await self._pubsub.aclose()
|
||||
except Exception:
|
||||
pass
|
||||
self._pubsub = None
|
||||
if self._client is not None:
|
||||
try:
|
||||
await self._client.aclose()
|
||||
except Exception:
|
||||
pass
|
||||
self._client = None
|
||||
|
||||
async def _pump(self, on_message: MessageHandler) -> None:
|
||||
if self._pubsub is None:
|
||||
return
|
||||
backoff = _PUMP_RECONNECT_BACKOFF_INITIAL_S
|
||||
deadline = time.monotonic() + _PUMP_RECONNECT_DEADLINE_S
|
||||
while True:
|
||||
pubsub = self._pubsub
|
||||
if pubsub is None:
|
||||
return
|
||||
needs_reconnect = False
|
||||
try:
|
||||
async for message in pubsub.listen():
|
||||
msg_type = message.get("type")
|
||||
# Server-pushed sunsubscribe: slot ownership changed and
|
||||
# Redis revoked our SSUBSCRIBE without dropping the TCP.
|
||||
# Treat as a reconnect trigger so we re-resolve the shard.
|
||||
if msg_type == "sunsubscribe":
|
||||
needs_reconnect = True
|
||||
break
|
||||
if msg_type not in ("smessage", "message", "pmessage"):
|
||||
continue
|
||||
# Successful read resets the reconnect budget.
|
||||
backoff = _PUMP_RECONNECT_BACKOFF_INITIAL_S
|
||||
deadline = time.monotonic() + _PUMP_RECONNECT_DEADLINE_S
|
||||
try:
|
||||
await on_message(message.get("data"))
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Websocket message-handler failed for channel %s",
|
||||
self.full_channel,
|
||||
)
|
||||
if not needs_reconnect:
|
||||
# listen() exited cleanly (channels emptied) — pump is done.
|
||||
return
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except (ConnectionError, RedisError) as exc:
|
||||
if isinstance(exc, ResponseError) and not _is_moved_error(exc):
|
||||
logger.exception(
|
||||
"Pubsub pump crashed on non-retryable ResponseError for %s",
|
||||
self.full_channel,
|
||||
)
|
||||
return
|
||||
if time.monotonic() > deadline:
|
||||
logger.exception(
|
||||
"Pubsub pump giving up after reconnect deadline for %s",
|
||||
self.full_channel,
|
||||
)
|
||||
return
|
||||
logger.warning(
|
||||
"Pubsub pump reconnecting for %s after %s: %s",
|
||||
self.full_channel,
|
||||
type(exc).__name__,
|
||||
exc,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Pubsub pump crashed for %s", self.full_channel)
|
||||
return
|
||||
|
||||
# Either a retryable error was raised, or the server pushed a
|
||||
# sunsubscribe — close the stale pubsub and reopen against the
|
||||
# (possibly migrated) shard.
|
||||
await self._close_pubsub_quietly()
|
||||
await asyncio.sleep(backoff)
|
||||
backoff = min(backoff * 2, _PUMP_RECONNECT_BACKOFF_MAX_S)
|
||||
try:
|
||||
await self._open_pubsub()
|
||||
except (ConnectionError, RedisError) as reopen_exc:
|
||||
logger.warning(
|
||||
"Pubsub pump reopen failed for %s: %s",
|
||||
self.full_channel,
|
||||
reopen_exc,
|
||||
)
|
||||
# Loop again — deadline check will eventually exit.
|
||||
continue
|
||||
|
||||
async def stop(self) -> None:
|
||||
if self._task is not None:
|
||||
self._task.cancel()
|
||||
try:
|
||||
await self._task
|
||||
except (asyncio.CancelledError, Exception):
|
||||
pass
|
||||
self._task = None
|
||||
if self._pubsub is not None:
|
||||
try:
|
||||
await self._pubsub.execute_command("SUNSUBSCRIBE", self.full_channel)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"SUNSUBSCRIBE failed for %s", self.full_channel, exc_info=True
|
||||
)
|
||||
try:
|
||||
await self._pubsub.aclose()
|
||||
except Exception:
|
||||
pass
|
||||
self._pubsub = None
|
||||
if self._client is not None:
|
||||
try:
|
||||
await self._client.aclose()
|
||||
except Exception:
|
||||
pass
|
||||
self._client = None
|
||||
|
||||
|
||||
class ConnectionManager:
|
||||
def __init__(self):
|
||||
self.active_connections: Set[WebSocket] = set()
|
||||
# channel_key → sockets subscribed (public channel keys, not raw Redis channels)
|
||||
self.subscriptions: Dict[str, Set[WebSocket]] = {}
|
||||
self.user_connections: Dict[str, Set[WebSocket]] = {}
|
||||
# websocket → {channel_key: _Subscription}
|
||||
self._ws_subs: Dict[WebSocket, Dict[str, _Subscription]] = {}
|
||||
# websocket → notification subscription
|
||||
self._ws_notifications: Dict[WebSocket, _Subscription] = {}
|
||||
|
||||
async def connect_socket(self, websocket: WebSocket, *, user_id: str):
|
||||
await websocket.accept()
|
||||
self.active_connections.add(websocket)
|
||||
if user_id not in self.user_connections:
|
||||
self.user_connections[user_id] = set()
|
||||
self.user_connections[user_id].add(websocket)
|
||||
self._ws_subs.setdefault(websocket, {})
|
||||
await self._start_notification_subscription(websocket, user_id=user_id)
|
||||
|
||||
def disconnect_socket(self, websocket: WebSocket, *, user_id: str):
|
||||
async def disconnect_socket(self, websocket: WebSocket, *, user_id: str):
|
||||
self.active_connections.discard(websocket)
|
||||
for subscribers in self.subscriptions.values():
|
||||
# Stop SSUBSCRIBE pumps before dropping bookkeeping to avoid leaks.
|
||||
subs = self._ws_subs.pop(websocket, {})
|
||||
for sub in subs.values():
|
||||
await sub.stop()
|
||||
notif_sub = self._ws_notifications.pop(websocket, None)
|
||||
if notif_sub is not None:
|
||||
await notif_sub.stop()
|
||||
for channel_key, subscribers in list(self.subscriptions.items()):
|
||||
subscribers.discard(websocket)
|
||||
user_conns = self.user_connections.get(user_id)
|
||||
if user_conns is not None:
|
||||
user_conns.discard(websocket)
|
||||
if not user_conns:
|
||||
self.user_connections.pop(user_id, None)
|
||||
if not subscribers:
|
||||
self.subscriptions.pop(channel_key, None)
|
||||
|
||||
async def subscribe_graph_exec(
|
||||
self, *, user_id: str, graph_exec_id: str, websocket: WebSocket
|
||||
) -> str:
|
||||
return await self._subscribe(
|
||||
_graph_exec_channel_key(user_id, graph_exec_id=graph_exec_id), websocket
|
||||
# Hash-tagged channel needs graph_id; resolve once per subscribe.
|
||||
meta = await get_graph_execution_meta(user_id, graph_exec_id)
|
||||
if meta is None:
|
||||
raise ValueError(
|
||||
f"graph_exec #{graph_exec_id} not found for user #{user_id}"
|
||||
)
|
||||
channel_key = graph_exec_channel_key(user_id, graph_exec_id=graph_exec_id)
|
||||
full_channel = event_bus_channel(
|
||||
exec_channel(user_id, meta.graph_id, graph_exec_id)
|
||||
)
|
||||
await self._open_subscription(websocket, channel_key, full_channel)
|
||||
return channel_key
|
||||
|
||||
async def subscribe_graph_execs(
|
||||
self, *, user_id: str, graph_id: str, websocket: WebSocket
|
||||
) -> str:
|
||||
return await self._subscribe(
|
||||
_graph_execs_channel_key(user_id, graph_id=graph_id), websocket
|
||||
)
|
||||
channel_key = _graph_execs_channel_key(user_id, graph_id=graph_id)
|
||||
full_channel = event_bus_channel(graph_all_channel(user_id, graph_id))
|
||||
await self._open_subscription(websocket, channel_key, full_channel)
|
||||
return channel_key
|
||||
|
||||
async def unsubscribe_graph_exec(
|
||||
self, *, user_id: str, graph_exec_id: str, websocket: WebSocket
|
||||
) -> str | None:
|
||||
return await self._unsubscribe(
|
||||
_graph_exec_channel_key(user_id, graph_exec_id=graph_exec_id), websocket
|
||||
)
|
||||
channel_key = graph_exec_channel_key(user_id, graph_exec_id=graph_exec_id)
|
||||
return await self._close_subscription(websocket, channel_key)
|
||||
|
||||
async def unsubscribe_graph_execs(
|
||||
self, *, user_id: str, graph_id: str, websocket: WebSocket
|
||||
) -> str | None:
|
||||
return await self._unsubscribe(
|
||||
_graph_execs_channel_key(user_id, graph_id=graph_id), websocket
|
||||
)
|
||||
channel_key = _graph_execs_channel_key(user_id, graph_id=graph_id)
|
||||
return await self._close_subscription(websocket, channel_key)
|
||||
|
||||
async def send_execution_update(
|
||||
self, exec_event: GraphExecutionEvent | NodeExecutionEvent
|
||||
) -> int:
|
||||
graph_exec_id = (
|
||||
exec_event.id
|
||||
if isinstance(exec_event, GraphExecutionEvent)
|
||||
else exec_event.graph_exec_id
|
||||
)
|
||||
async def _open_subscription(
|
||||
self, websocket: WebSocket, channel_key: str, full_channel: str
|
||||
) -> None:
|
||||
self.subscriptions.setdefault(channel_key, set()).add(websocket)
|
||||
per_ws = self._ws_subs.setdefault(websocket, {})
|
||||
if channel_key in per_ws:
|
||||
return
|
||||
sub = _Subscription(full_channel)
|
||||
|
||||
n_sent = 0
|
||||
async def on_message(data: Optional[bytes | str]) -> None:
|
||||
await self._forward_exec_event(websocket, channel_key, data)
|
||||
|
||||
channels: set[str] = {
|
||||
# Send update to listeners for this graph execution
|
||||
_graph_exec_channel_key(exec_event.user_id, graph_exec_id=graph_exec_id)
|
||||
}
|
||||
if isinstance(exec_event, GraphExecutionEvent):
|
||||
# Send update to listeners for all executions of this graph
|
||||
channels.add(
|
||||
_graph_execs_channel_key(
|
||||
exec_event.user_id, graph_id=exec_event.graph_id
|
||||
)
|
||||
)
|
||||
await sub.start(on_message)
|
||||
per_ws[channel_key] = sub
|
||||
|
||||
for channel in channels.intersection(self.subscriptions.keys()):
|
||||
message = WSMessage(
|
||||
method=_EVENT_TYPE_TO_METHOD_MAP[exec_event.event_type],
|
||||
channel=channel,
|
||||
data=exec_event.model_dump(),
|
||||
).model_dump_json()
|
||||
for connection in self.subscriptions[channel]:
|
||||
await connection.send_text(message)
|
||||
n_sent += 1
|
||||
|
||||
return n_sent
|
||||
|
||||
async def send_notification(
|
||||
self, *, user_id: str, payload: NotificationPayload
|
||||
) -> int:
|
||||
"""Send a notification to all websocket connections belonging to a user."""
|
||||
message = WSMessage(
|
||||
method=WSMethod.NOTIFICATION,
|
||||
data=payload.model_dump(),
|
||||
).model_dump_json()
|
||||
|
||||
connections = tuple(self.user_connections.get(user_id, set()))
|
||||
if not connections:
|
||||
return 0
|
||||
|
||||
await asyncio.gather(
|
||||
*(connection.send_text(message) for connection in connections),
|
||||
return_exceptions=True,
|
||||
)
|
||||
|
||||
return len(connections)
|
||||
|
||||
async def _subscribe(self, channel_key: str, websocket: WebSocket) -> str:
|
||||
if channel_key not in self.subscriptions:
|
||||
self.subscriptions[channel_key] = set()
|
||||
self.subscriptions[channel_key].add(websocket)
|
||||
async def _close_subscription(
|
||||
self, websocket: WebSocket, channel_key: str
|
||||
) -> str | None:
|
||||
subscribers = self.subscriptions.get(channel_key)
|
||||
if subscribers is None:
|
||||
return None
|
||||
subscribers.discard(websocket)
|
||||
if not subscribers:
|
||||
self.subscriptions.pop(channel_key, None)
|
||||
per_ws = self._ws_subs.get(websocket)
|
||||
if per_ws and channel_key in per_ws:
|
||||
sub = per_ws.pop(channel_key)
|
||||
await sub.stop()
|
||||
return channel_key
|
||||
|
||||
async def _unsubscribe(self, channel_key: str, websocket: WebSocket) -> str | None:
|
||||
if channel_key in self.subscriptions:
|
||||
self.subscriptions[channel_key].discard(websocket)
|
||||
if not self.subscriptions[channel_key]:
|
||||
del self.subscriptions[channel_key]
|
||||
return channel_key
|
||||
return None
|
||||
async def _forward_exec_event(
|
||||
self,
|
||||
websocket: WebSocket,
|
||||
channel_key: str,
|
||||
raw_payload: Optional[bytes | str],
|
||||
) -> None:
|
||||
if raw_payload is None:
|
||||
return
|
||||
# Unwrap the `_EventPayloadWrapper` envelope, then re-wrap as a WS message.
|
||||
try:
|
||||
wrapper = (
|
||||
raw_payload.decode()
|
||||
if isinstance(raw_payload, (bytes, bytearray))
|
||||
else raw_payload
|
||||
)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to decode pubsub payload on %s", channel_key, exc_info=True
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
parsed = json.loads(wrapper)
|
||||
event_data = parsed.get("payload")
|
||||
if not isinstance(event_data, dict):
|
||||
return
|
||||
event_type = event_data.get("event_type")
|
||||
method = _EVENT_TYPE_TO_METHOD_MAP.get(ExecutionEventType(event_type))
|
||||
if method is None:
|
||||
return
|
||||
message = WSMessage(
|
||||
method=method,
|
||||
channel=channel_key,
|
||||
data=event_data,
|
||||
).model_dump_json()
|
||||
await websocket.send_text(message)
|
||||
except Exception as e:
|
||||
if _is_ws_close_race(e, websocket):
|
||||
logger.debug("Dropped exec event on closed WS for %s", channel_key)
|
||||
return
|
||||
logger.exception("Failed to forward exec event on %s", channel_key)
|
||||
|
||||
async def _start_notification_subscription(
|
||||
self, websocket: WebSocket, *, user_id: str
|
||||
) -> None:
|
||||
full_channel = _notification_bus_channel(user_id)
|
||||
sub = _Subscription(full_channel)
|
||||
|
||||
async def on_message(data: Optional[bytes | str]) -> None:
|
||||
await self._forward_notification(websocket, user_id, data)
|
||||
|
||||
try:
|
||||
await sub.start(on_message)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to open notification SSUBSCRIBE for user=%s", user_id
|
||||
)
|
||||
return
|
||||
self._ws_notifications[websocket] = sub
|
||||
|
||||
async def _forward_notification(
|
||||
self,
|
||||
websocket: WebSocket,
|
||||
user_id: str,
|
||||
raw_payload: Optional[bytes | str],
|
||||
) -> None:
|
||||
if raw_payload is None:
|
||||
return
|
||||
try:
|
||||
wrapper_json = (
|
||||
raw_payload.decode()
|
||||
if isinstance(raw_payload, (bytes, bytearray))
|
||||
else raw_payload
|
||||
)
|
||||
parsed = json.loads(wrapper_json)
|
||||
inner = parsed.get("payload") if isinstance(parsed, dict) else None
|
||||
if not isinstance(inner, dict):
|
||||
return
|
||||
event = NotificationEvent.model_validate(inner)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to parse notification payload for user=%s",
|
||||
user_id,
|
||||
exc_info=True,
|
||||
)
|
||||
return
|
||||
# Defense in depth against cross-user payloads.
|
||||
if event.user_id != user_id:
|
||||
return
|
||||
message = WSMessage(
|
||||
method=WSMethod.NOTIFICATION,
|
||||
data=event.payload.model_dump(),
|
||||
).model_dump_json()
|
||||
try:
|
||||
await websocket.send_text(message)
|
||||
except Exception as e:
|
||||
if _is_ws_close_race(e, websocket):
|
||||
logger.debug("Dropped notification on closed WS for user=%s", user_id)
|
||||
return
|
||||
logger.warning(
|
||||
"Failed to deliver notification to WS for user=%s",
|
||||
user_id,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
|
||||
def _graph_exec_channel_key(user_id: str, *, graph_exec_id: str) -> str:
|
||||
def graph_exec_channel_key(user_id: str, *, graph_exec_id: str) -> str:
|
||||
return f"{user_id}|graph_exec#{graph_exec_id}"
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,386 @@
|
||||
"""ConnectionManager integration over the live 3-shard Redis cluster:
|
||||
SSUBSCRIBE → SPUBLISH → WebSocket forwarding with no Redis mocks. Skips
|
||||
when the cluster is unreachable."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from fastapi import WebSocket
|
||||
|
||||
import backend.data.redis_client as redis_client
|
||||
from backend.api.conn_manager import (
|
||||
ConnectionManager,
|
||||
_graph_execs_channel_key,
|
||||
event_bus_channel,
|
||||
graph_exec_channel_key,
|
||||
)
|
||||
from backend.api.model import WSMethod
|
||||
from backend.data.execution import (
|
||||
ExecutionStatus,
|
||||
GraphExecutionEvent,
|
||||
GraphExecutionMeta,
|
||||
NodeExecutionEvent,
|
||||
exec_channel,
|
||||
graph_all_channel,
|
||||
)
|
||||
|
||||
|
||||
def _has_live_cluster() -> bool:
|
||||
try:
|
||||
c = redis_client.connect()
|
||||
except Exception: # noqa: BLE001 — any connect failure → skip
|
||||
return False
|
||||
try:
|
||||
c.close()
|
||||
except Exception:
|
||||
pass
|
||||
return True
|
||||
|
||||
|
||||
pytestmark = pytest.mark.skipif(
|
||||
not _has_live_cluster(),
|
||||
reason="local redis cluster not reachable; skip conn_manager integration",
|
||||
)
|
||||
|
||||
|
||||
def _meta(user_id: str, graph_id: str, graph_exec_id: str) -> GraphExecutionMeta:
|
||||
"""Build a minimal GraphExecutionMeta for ``subscribe_graph_exec`` to use."""
|
||||
return GraphExecutionMeta(
|
||||
id=graph_exec_id,
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
graph_version=1,
|
||||
inputs=None,
|
||||
credential_inputs=None,
|
||||
nodes_input_masks=None,
|
||||
preset_id=None,
|
||||
status=ExecutionStatus.RUNNING,
|
||||
started_at=datetime.now(tz=timezone.utc),
|
||||
ended_at=None,
|
||||
stats=GraphExecutionMeta.Stats(),
|
||||
)
|
||||
|
||||
|
||||
def _node_event_payload(
|
||||
*, user_id: str, graph_id: str, graph_exec_id: str, marker: str
|
||||
) -> bytes:
|
||||
"""Wire-format a NodeExecutionEvent the way RedisExecutionEventBus would."""
|
||||
inner = NodeExecutionEvent(
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
graph_version=1,
|
||||
graph_exec_id=graph_exec_id,
|
||||
node_exec_id=f"node-exec-{marker}",
|
||||
node_id="node-1",
|
||||
block_id="block-1",
|
||||
status=ExecutionStatus.COMPLETED,
|
||||
input_data={"in": marker},
|
||||
output_data={"out": [marker]},
|
||||
add_time=datetime.now(tz=timezone.utc),
|
||||
queue_time=None,
|
||||
start_time=datetime.now(tz=timezone.utc),
|
||||
end_time=datetime.now(tz=timezone.utc),
|
||||
).model_dump(mode="json")
|
||||
return json.dumps({"payload": inner}).encode()
|
||||
|
||||
|
||||
def _graph_event_payload(
|
||||
*, user_id: str, graph_id: str, graph_exec_id: str, marker: str
|
||||
) -> bytes:
|
||||
inner = GraphExecutionEvent(
|
||||
id=graph_exec_id,
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
graph_version=1,
|
||||
preset_id=None,
|
||||
status=ExecutionStatus.COMPLETED,
|
||||
started_at=datetime.now(tz=timezone.utc),
|
||||
ended_at=datetime.now(tz=timezone.utc),
|
||||
stats=GraphExecutionEvent.Stats(
|
||||
cost=0,
|
||||
duration=1.0,
|
||||
node_exec_time=0.5,
|
||||
node_exec_count=1,
|
||||
),
|
||||
inputs={"x": marker},
|
||||
credential_inputs=None,
|
||||
nodes_input_masks=None,
|
||||
outputs={"y": [marker]},
|
||||
).model_dump(mode="json")
|
||||
return json.dumps({"payload": inner}).encode()
|
||||
|
||||
|
||||
async def _wait_until(predicate, timeout: float = 5.0, interval: float = 0.05) -> bool:
|
||||
"""Poll ``predicate()`` until truthy or timeout — used to wait for pubsub."""
|
||||
deadline = asyncio.get_event_loop().time() + timeout
|
||||
while asyncio.get_event_loop().time() < deadline:
|
||||
if predicate():
|
||||
return True
|
||||
await asyncio.sleep(interval)
|
||||
return False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_two_clients_get_independent_ssubscribes_on_right_shards(
|
||||
monkeypatch,
|
||||
) -> None:
|
||||
"""Two WS clients on different graph_exec_ids each receive ONLY their
|
||||
own publish, even when the channels land on different shards."""
|
||||
user_id = "user-conn-int-1"
|
||||
graph_a = f"graph-a-{uuid4().hex[:8]}"
|
||||
graph_b = f"graph-b-{uuid4().hex[:8]}"
|
||||
exec_a = f"exec-a-{uuid4().hex[:8]}"
|
||||
exec_b = f"exec-b-{uuid4().hex[:8]}"
|
||||
|
||||
# Stub Prisma lookup so tests don't need a DB.
|
||||
async def _fake_meta(_uid, gex_id):
|
||||
return _meta(user_id, graph_a if gex_id == exec_a else graph_b, gex_id)
|
||||
|
||||
monkeypatch.setattr("backend.api.conn_manager.get_graph_execution_meta", _fake_meta)
|
||||
|
||||
cm = ConnectionManager()
|
||||
ws_a: AsyncMock = AsyncMock(spec=WebSocket)
|
||||
ws_b: AsyncMock = AsyncMock(spec=WebSocket)
|
||||
sent_a: list[str] = []
|
||||
sent_b: list[str] = []
|
||||
ws_a.send_text = AsyncMock(side_effect=lambda m: sent_a.append(m))
|
||||
ws_b.send_text = AsyncMock(side_effect=lambda m: sent_b.append(m))
|
||||
|
||||
redis_client.get_redis.cache_clear()
|
||||
cluster = redis_client.get_redis()
|
||||
|
||||
try:
|
||||
await cm.subscribe_graph_exec(
|
||||
user_id=user_id, graph_exec_id=exec_a, websocket=ws_a
|
||||
)
|
||||
await cm.subscribe_graph_exec(
|
||||
user_id=user_id, graph_exec_id=exec_b, websocket=ws_b
|
||||
)
|
||||
# Let SSUBSCRIBE settle on each shard.
|
||||
await asyncio.sleep(0.2)
|
||||
|
||||
# Publish to each per-exec channel.
|
||||
chan_a = event_bus_channel(exec_channel(user_id, graph_a, exec_a))
|
||||
chan_b = event_bus_channel(exec_channel(user_id, graph_b, exec_b))
|
||||
cluster.spublish(
|
||||
chan_a,
|
||||
_node_event_payload(
|
||||
user_id=user_id,
|
||||
graph_id=graph_a,
|
||||
graph_exec_id=exec_a,
|
||||
marker="A",
|
||||
).decode(),
|
||||
)
|
||||
cluster.spublish(
|
||||
chan_b,
|
||||
_node_event_payload(
|
||||
user_id=user_id,
|
||||
graph_id=graph_b,
|
||||
graph_exec_id=exec_b,
|
||||
marker="B",
|
||||
).decode(),
|
||||
)
|
||||
|
||||
delivered = await _wait_until(lambda: sent_a and sent_b, timeout=5.0)
|
||||
assert delivered, f"timeout: sent_a={sent_a!r} sent_b={sent_b!r}"
|
||||
|
||||
msg_a = json.loads(sent_a[0])
|
||||
msg_b = json.loads(sent_b[0])
|
||||
assert msg_a["channel"] == graph_exec_channel_key(user_id, graph_exec_id=exec_a)
|
||||
assert msg_b["channel"] == graph_exec_channel_key(user_id, graph_exec_id=exec_b)
|
||||
assert msg_a["data"]["graph_exec_id"] == exec_a
|
||||
assert msg_b["data"]["graph_exec_id"] == exec_b
|
||||
# No cross-talk: each socket got exactly one message.
|
||||
assert len(sent_a) == 1 and len(sent_b) == 1
|
||||
finally:
|
||||
await cm.disconnect_socket(ws_a, user_id=user_id)
|
||||
await cm.disconnect_socket(ws_b, user_id=user_id)
|
||||
redis_client.disconnect()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aggregate_channel_receives_per_exec_publishes(monkeypatch) -> None:
|
||||
"""A subscriber on the ``graph_execs`` aggregate channel must receive the
|
||||
GraphExecutionEvent published to the ``/all`` channel — even though
|
||||
per-exec events go to a different channel."""
|
||||
user_id = "user-conn-int-2"
|
||||
graph_id = f"graph-{uuid4().hex[:8]}"
|
||||
exec_id = f"exec-{uuid4().hex[:8]}"
|
||||
|
||||
async def _fake_meta(_uid, gex_id):
|
||||
return _meta(user_id, graph_id, gex_id)
|
||||
|
||||
monkeypatch.setattr("backend.api.conn_manager.get_graph_execution_meta", _fake_meta)
|
||||
|
||||
cm = ConnectionManager()
|
||||
ws_agg: AsyncMock = AsyncMock(spec=WebSocket)
|
||||
ws_per: AsyncMock = AsyncMock(spec=WebSocket)
|
||||
sent_agg: list[str] = []
|
||||
sent_per: list[str] = []
|
||||
ws_agg.send_text = AsyncMock(side_effect=lambda m: sent_agg.append(m))
|
||||
ws_per.send_text = AsyncMock(side_effect=lambda m: sent_per.append(m))
|
||||
|
||||
redis_client.get_redis.cache_clear()
|
||||
cluster = redis_client.get_redis()
|
||||
|
||||
try:
|
||||
await cm.subscribe_graph_execs(
|
||||
user_id=user_id, graph_id=graph_id, websocket=ws_agg
|
||||
)
|
||||
await cm.subscribe_graph_exec(
|
||||
user_id=user_id, graph_exec_id=exec_id, websocket=ws_per
|
||||
)
|
||||
await asyncio.sleep(0.2)
|
||||
|
||||
# The eventbus publishes the same event to both channels — replicate.
|
||||
chan_per = event_bus_channel(exec_channel(user_id, graph_id, exec_id))
|
||||
chan_all = event_bus_channel(graph_all_channel(user_id, graph_id))
|
||||
payload = _graph_event_payload(
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
graph_exec_id=exec_id,
|
||||
marker="agg",
|
||||
).decode()
|
||||
cluster.spublish(chan_per, payload)
|
||||
cluster.spublish(chan_all, payload)
|
||||
|
||||
delivered = await _wait_until(lambda: sent_agg and sent_per, timeout=5.0)
|
||||
assert delivered, f"sent_agg={sent_agg!r} sent_per={sent_per!r}"
|
||||
agg_msg = json.loads(sent_agg[0])
|
||||
per_msg = json.loads(sent_per[0])
|
||||
# Aggregate subscriber's channel key is the per-graph executions key.
|
||||
assert agg_msg["channel"] == _graph_execs_channel_key(
|
||||
user_id, graph_id=graph_id
|
||||
)
|
||||
assert per_msg["channel"] == graph_exec_channel_key(
|
||||
user_id, graph_exec_id=exec_id
|
||||
)
|
||||
assert agg_msg["method"] == WSMethod.GRAPH_EXECUTION_EVENT.value
|
||||
finally:
|
||||
await cm.disconnect_socket(ws_agg, user_id=user_id)
|
||||
await cm.disconnect_socket(ws_per, user_id=user_id)
|
||||
redis_client.disconnect()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_disconnect_unsubscribes_and_drops_future_publishes(monkeypatch) -> None:
|
||||
"""After ``disconnect_socket`` runs, a subsequent SPUBLISH must NOT reach
|
||||
the dead websocket — exercises the SUNSUBSCRIBE + bookkeeping cleanup."""
|
||||
user_id = "user-conn-int-3"
|
||||
graph_id = f"graph-{uuid4().hex[:8]}"
|
||||
exec_id = f"exec-{uuid4().hex[:8]}"
|
||||
|
||||
async def _fake_meta(_uid, gex_id):
|
||||
return _meta(user_id, graph_id, gex_id)
|
||||
|
||||
monkeypatch.setattr("backend.api.conn_manager.get_graph_execution_meta", _fake_meta)
|
||||
|
||||
cm = ConnectionManager()
|
||||
ws: AsyncMock = AsyncMock(spec=WebSocket)
|
||||
sent: list[str] = []
|
||||
ws.send_text = AsyncMock(side_effect=lambda m: sent.append(m))
|
||||
|
||||
redis_client.get_redis.cache_clear()
|
||||
cluster = redis_client.get_redis()
|
||||
chan = event_bus_channel(exec_channel(user_id, graph_id, exec_id))
|
||||
payload = _node_event_payload(
|
||||
user_id=user_id, graph_id=graph_id, graph_exec_id=exec_id, marker="live"
|
||||
).decode()
|
||||
|
||||
try:
|
||||
await cm.subscribe_graph_exec(
|
||||
user_id=user_id, graph_exec_id=exec_id, websocket=ws
|
||||
)
|
||||
await asyncio.sleep(0.15)
|
||||
|
||||
# First publish — must reach the socket.
|
||||
cluster.spublish(chan, payload)
|
||||
delivered = await _wait_until(lambda: bool(sent), timeout=5.0)
|
||||
assert delivered
|
||||
assert len(sent) == 1
|
||||
|
||||
# Disconnect → SUNSUBSCRIBE + bookkeeping cleared.
|
||||
await cm.disconnect_socket(ws, user_id=user_id)
|
||||
# Pump cancellation may drain in-flight messages; wait for it.
|
||||
await asyncio.sleep(0.2)
|
||||
|
||||
# Channel bookkeeping must be gone.
|
||||
assert (
|
||||
graph_exec_channel_key(user_id, graph_exec_id=exec_id)
|
||||
not in cm.subscriptions
|
||||
)
|
||||
assert ws not in cm._ws_subs
|
||||
|
||||
# Second publish — must NOT reach the (already-disconnected) socket.
|
||||
cluster.spublish(
|
||||
chan,
|
||||
_node_event_payload(
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
graph_exec_id=exec_id,
|
||||
marker="post-disconnect",
|
||||
).decode(),
|
||||
)
|
||||
await asyncio.sleep(0.5)
|
||||
# Still only the one pre-disconnect message.
|
||||
assert len(sent) == 1
|
||||
finally:
|
||||
redis_client.disconnect()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_slow_consumer_receives_all_events_without_loss(monkeypatch) -> None:
|
||||
"""Burst-publish many SPUBLISHes; assert every one reaches the subscriber
|
||||
in order — guards against drops/reorderings in the pubsub pump."""
|
||||
user_id = "user-conn-int-4"
|
||||
graph_id = f"graph-{uuid4().hex[:8]}"
|
||||
exec_id = f"exec-{uuid4().hex[:8]}"
|
||||
n_events = 100
|
||||
|
||||
async def _fake_meta(_uid, gex_id):
|
||||
return _meta(user_id, graph_id, gex_id)
|
||||
|
||||
monkeypatch.setattr("backend.api.conn_manager.get_graph_execution_meta", _fake_meta)
|
||||
|
||||
cm = ConnectionManager()
|
||||
ws: AsyncMock = AsyncMock(spec=WebSocket)
|
||||
sent: list[str] = []
|
||||
ws.send_text = AsyncMock(side_effect=lambda m: sent.append(m))
|
||||
|
||||
redis_client.get_redis.cache_clear()
|
||||
cluster = redis_client.get_redis()
|
||||
chan = event_bus_channel(exec_channel(user_id, graph_id, exec_id))
|
||||
|
||||
try:
|
||||
await cm.subscribe_graph_exec(
|
||||
user_id=user_id, graph_exec_id=exec_id, websocket=ws
|
||||
)
|
||||
await asyncio.sleep(0.2)
|
||||
|
||||
# Burst-publish n_events without yielding to the pump.
|
||||
for i in range(n_events):
|
||||
cluster.spublish(
|
||||
chan,
|
||||
_node_event_payload(
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
graph_exec_id=exec_id,
|
||||
marker=f"m{i}",
|
||||
).decode(),
|
||||
)
|
||||
|
||||
delivered = await _wait_until(
|
||||
lambda: len(sent) >= n_events, timeout=15.0, interval=0.1
|
||||
)
|
||||
assert delivered, f"only delivered {len(sent)}/{n_events}"
|
||||
|
||||
# Validate ordering — Redis pub/sub is FIFO per channel.
|
||||
markers = [json.loads(m)["data"]["input_data"]["in"] for m in sent[:n_events]]
|
||||
assert markers == [f"m{i}" for i in range(n_events)]
|
||||
finally:
|
||||
await cm.disconnect_socket(ws, user_id=user_id)
|
||||
redis_client.disconnect()
|
||||
File diff suppressed because it is too large
Load Diff
@@ -8,7 +8,7 @@ from uuid import uuid4
|
||||
|
||||
from autogpt_libs import auth
|
||||
from fastapi import APIRouter, HTTPException, Query, Response, Security
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
from backend.copilot import service as chat_service
|
||||
@@ -47,7 +47,14 @@ from backend.copilot.rate_limit import (
|
||||
release_reset_lock,
|
||||
reset_daily_usage,
|
||||
)
|
||||
from backend.copilot.response_model import StreamError, StreamFinish, StreamHeartbeat
|
||||
from backend.copilot.response_model import (
|
||||
StreamError,
|
||||
StreamFinish,
|
||||
StreamFinishStep,
|
||||
StreamHeartbeat,
|
||||
StreamStart,
|
||||
StreamStartStep,
|
||||
)
|
||||
from backend.copilot.service import strip_injected_context_for_display
|
||||
from backend.copilot.tools.e2b_sandbox import kill_sandbox
|
||||
from backend.copilot.tools.models import (
|
||||
@@ -154,6 +161,14 @@ class StreamChatRequest(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class QueuePendingMessageRequest(BaseModel):
|
||||
"""Request model for queueing a follow-up while a turn is running."""
|
||||
|
||||
message: str = Field(max_length=64_000)
|
||||
context: dict[str, str] | None = None
|
||||
file_ids: list[str] | None = Field(default=None, max_length=20)
|
||||
|
||||
|
||||
class PeekPendingMessagesResponse(BaseModel):
|
||||
"""Response for the pending-message peek (GET) endpoint.
|
||||
|
||||
@@ -209,6 +224,11 @@ class ActiveStreamInfo(BaseModel):
|
||||
|
||||
turn_id: str
|
||||
last_message_id: str # Redis Stream message ID for resumption
|
||||
# ISO-8601 timestamp (UTC) marking when the backend registered the turn
|
||||
# as running. Lets the frontend seed its elapsed-time counter so restored
|
||||
# turns show honest "time since turn started" instead of the misleading
|
||||
# "time since this mount resumed the SSE".
|
||||
started_at: str | None = None
|
||||
|
||||
|
||||
class SessionDetailResponse(BaseModel):
|
||||
@@ -300,8 +320,11 @@ async def list_sessions(
|
||||
redis = await get_redis_async()
|
||||
pipe = redis.pipeline(transaction=False)
|
||||
for session in sessions:
|
||||
# Use the canonical helper so the hash-tag braces match every
|
||||
# other writer; building the key inline drops the braces and
|
||||
# silently misses every running session on cluster mode.
|
||||
pipe.hget(
|
||||
f"{config.session_meta_prefix}{session.session_id}",
|
||||
stream_registry.get_session_meta_key(session.session_id),
|
||||
"status",
|
||||
)
|
||||
statuses = await pipe.execute()
|
||||
@@ -529,6 +552,7 @@ async def get_session(
|
||||
active_stream_info = ActiveStreamInfo(
|
||||
turn_id=active_session.turn_id,
|
||||
last_message_id=last_message_id,
|
||||
started_at=active_session.created_at.isoformat(),
|
||||
)
|
||||
|
||||
# Skip session metadata on "load more" — frontend only needs messages
|
||||
@@ -816,17 +840,45 @@ async def cancel_session_task(
|
||||
return CancelSessionResponse(cancelled=True)
|
||||
|
||||
|
||||
def _ui_message_stream_headers() -> dict[str, str]:
|
||||
return {
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
"x-vercel-ai-ui-message-stream": "v1",
|
||||
}
|
||||
|
||||
|
||||
def _empty_ui_message_stream_response() -> StreamingResponse:
|
||||
# Stable placeholder messageId for the empty queued-mid-turn stream.
|
||||
# Real turns generate per-message UUIDs via the executor; this stream
|
||||
# has no message to attach to, but the AI SDK parser still requires a
|
||||
# non-empty ``messageId`` field on ``StreamStart``.
|
||||
message_id = uuid4().hex
|
||||
|
||||
async def event_generator() -> AsyncGenerator[str, None]:
|
||||
# Vercel AI SDK's UI-message-stream parser expects symmetric
|
||||
# start/finish framing at both stream and step level — every
|
||||
# non-empty turn emits the pair. Without an opener, today's parser
|
||||
# tolerates the closer (no active parts to flush) but a future SDK
|
||||
# tightening would silently break the queue-mid-turn UX. Emit the
|
||||
# full empty pair so the contract stays correct.
|
||||
yield StreamStart(messageId=message_id).to_sse()
|
||||
yield StreamStartStep().to_sse()
|
||||
yield StreamFinishStep().to_sse()
|
||||
yield StreamFinish().to_sse()
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type="text/event-stream",
|
||||
headers=_ui_message_stream_headers(),
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/sessions/{session_id}/stream",
|
||||
responses={
|
||||
202: {
|
||||
"model": QueuePendingMessageResponse,
|
||||
"description": (
|
||||
"Session has a turn in flight — message queued into the pending "
|
||||
"buffer and will be picked up between tool-call rounds by the "
|
||||
"executor currently processing the turn."
|
||||
),
|
||||
},
|
||||
404: {"description": "Session not found or access denied"},
|
||||
429: {"description": "Cost rate-limit or call-frequency cap exceeded"},
|
||||
},
|
||||
@@ -836,19 +888,18 @@ async def stream_chat_post(
|
||||
request: StreamChatRequest,
|
||||
user_id: str = Security(auth.get_user_id),
|
||||
):
|
||||
"""Start a new turn OR queue a follow-up — decided server-side.
|
||||
"""Start a new turn and return an AI SDK UI message stream.
|
||||
|
||||
- **Session idle**: starts a turn. Returns an SSE stream (``text/event-stream``)
|
||||
with Vercel AI SDK chunks (text fragments, tool-call UI, tool results).
|
||||
The generation runs in a background task that survives client disconnects;
|
||||
reconnect via ``GET /sessions/{session_id}/stream`` to resume.
|
||||
Returns an SSE stream (``text/event-stream``) with Vercel AI SDK chunks
|
||||
(text fragments, tool-call UI, tool results). The generation runs in a
|
||||
background task that survives client disconnects; reconnect via
|
||||
``GET /sessions/{session_id}/stream`` to resume.
|
||||
|
||||
- **Session has a turn in flight**: pushes the message into the per-session
|
||||
pending buffer and returns ``202 application/json`` with
|
||||
``QueuePendingMessageResponse``. The executor running the current turn
|
||||
drains the buffer between tool-call rounds (baseline) or at the start of
|
||||
the next turn (SDK). Clients should detect the 202 and surface the
|
||||
message as a queued-chip in the UI.
|
||||
Follow-up messages typed while a turn is already running should use
|
||||
``POST /sessions/{session_id}/messages/pending``. If an older client still
|
||||
posts that follow-up here, we queue it defensively but still return a valid
|
||||
empty UI-message stream so AI SDK transports never receive a JSON body from
|
||||
the stream endpoint.
|
||||
|
||||
Args:
|
||||
session_id: The chat session identifier.
|
||||
@@ -872,26 +923,29 @@ async def stream_chat_post(
|
||||
extra={"json_fields": log_meta},
|
||||
)
|
||||
session = await _validate_and_get_session(session_id, user_id)
|
||||
builder_permissions = resolve_session_permissions(session)
|
||||
|
||||
# Self-defensive queue-fallback: if a turn is already running, don't race
|
||||
# it on the cluster lock — drop the message into the pending buffer and
|
||||
# return 202 so the caller can render a chip. Both UI chips and autopilot
|
||||
# block follow-ups route through this path; keeping the decision on the
|
||||
# server means every caller gets uniform behaviour.
|
||||
if (
|
||||
request.is_user_message
|
||||
and request.message
|
||||
and await is_turn_in_flight(session_id)
|
||||
):
|
||||
response = await queue_pending_for_http(
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
message=request.message,
|
||||
context=request.context,
|
||||
file_ids=request.file_ids,
|
||||
)
|
||||
return JSONResponse(status_code=202, content=response.model_dump())
|
||||
try:
|
||||
await queue_pending_for_http(
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
message=request.message,
|
||||
context=request.context,
|
||||
file_ids=request.file_ids,
|
||||
)
|
||||
return _empty_ui_message_stream_response()
|
||||
except HTTPException as exc:
|
||||
if exc.status_code != 409:
|
||||
raise
|
||||
|
||||
# Permission resolution is only needed below for the actual turn — keep
|
||||
# it after the queue-fall-through so a queued mid-turn request returns
|
||||
# without paying the work.
|
||||
builder_permissions = resolve_session_permissions(session)
|
||||
|
||||
logger.info(
|
||||
f"[TIMING] session validated in {(time.perf_counter() - stream_start_time) * 1000:.1f}ms",
|
||||
@@ -1130,12 +1184,37 @@ async def stream_chat_post(
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no", # Disable nginx buffering
|
||||
"x-vercel-ai-ui-message-stream": "v1", # AI SDK protocol header
|
||||
},
|
||||
headers=_ui_message_stream_headers(),
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/sessions/{session_id}/messages/pending",
|
||||
response_model=QueuePendingMessageResponse,
|
||||
responses={
|
||||
404: {"description": "Session not found or access denied"},
|
||||
409: {"description": "Session has no active turn to receive pending messages"},
|
||||
429: {"description": "Call-frequency cap exceeded"},
|
||||
},
|
||||
)
|
||||
async def queue_pending_message(
|
||||
session_id: str,
|
||||
request: QueuePendingMessageRequest,
|
||||
user_id: str = Security(auth.get_user_id),
|
||||
):
|
||||
"""Queue a follow-up message while the session has an active turn."""
|
||||
await _validate_and_get_session(session_id, user_id)
|
||||
if not await is_turn_in_flight(session_id):
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail="Session has no active turn. Start a new turn with POST /stream.",
|
||||
)
|
||||
return await queue_pending_for_http(
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
message=request.message,
|
||||
context=request.context,
|
||||
file_ids=request.file_ids,
|
||||
)
|
||||
|
||||
|
||||
@@ -1169,6 +1248,7 @@ async def get_pending_messages(
|
||||
)
|
||||
async def resume_session_stream(
|
||||
session_id: str,
|
||||
last_chunk_id: str | None = Query(default=None, include_in_schema=False),
|
||||
user_id: str = Security(auth.get_user_id),
|
||||
):
|
||||
"""
|
||||
@@ -1178,27 +1258,26 @@ async def resume_session_stream(
|
||||
Checks for an active (in-progress) task on the session and either replays
|
||||
the full SSE stream or returns 204 No Content if nothing is running.
|
||||
|
||||
Args:
|
||||
session_id: The chat session identifier.
|
||||
user_id: Optional authenticated user ID.
|
||||
|
||||
Returns:
|
||||
StreamingResponse (SSE) when an active stream exists,
|
||||
or 204 No Content when there is nothing to resume.
|
||||
Always replays the active turn from ``0-0``. The AI SDK UI-message parser
|
||||
keeps text/reasoning part state inside a single parser instance; resuming
|
||||
from a Redis cursor can skip the ``*-start`` events required by later
|
||||
``*-delta`` chunks.
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
active_session, last_message_id = await stream_registry.get_active_session(
|
||||
active_session, _latest_backend_id = await stream_registry.get_active_session(
|
||||
session_id, user_id
|
||||
)
|
||||
|
||||
if not active_session:
|
||||
return Response(status_code=204)
|
||||
|
||||
# Always replay from the beginning ("0-0") on resume.
|
||||
# We can't use last_message_id because it's the latest ID in the backend
|
||||
# stream, not the latest the frontend received — the gap causes lost
|
||||
# messages. The frontend deduplicates replayed content.
|
||||
if last_chunk_id:
|
||||
logger.info(
|
||||
"Ignoring deprecated last_chunk_id on stream resume",
|
||||
extra={"session_id": session_id, "last_chunk_id": last_chunk_id},
|
||||
)
|
||||
|
||||
subscriber_queue = await stream_registry.subscribe_to_session(
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
@@ -1259,12 +1338,7 @@ async def resume_session_stream(
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
"x-vercel-ai-ui-message-stream": "v1",
|
||||
},
|
||||
headers=_ui_message_stream_headers(),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -157,6 +157,11 @@ def _mock_stream_internals(mocker: pytest_mock.MockerFixture):
|
||||
"backend.api.features.chat.routes._validate_and_get_session",
|
||||
return_value=None,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.is_turn_in_flight",
|
||||
new_callable=AsyncMock,
|
||||
return_value=False,
|
||||
)
|
||||
mock_save = mocker.patch(
|
||||
"backend.api.features.chat.routes.append_and_save_message",
|
||||
return_value=MagicMock(), # non-None = message was saved (not a duplicate)
|
||||
@@ -637,7 +642,7 @@ class TestStreamChatRequestModeValidation:
|
||||
assert req.mode is None
|
||||
|
||||
|
||||
# ─── POST /stream queue-fallback (when a turn is already in flight) ──
|
||||
# ─── Pending message queue (when a turn is already in flight) ─────────
|
||||
|
||||
|
||||
def _mock_stream_queue_internals(
|
||||
@@ -646,11 +651,9 @@ def _mock_stream_queue_internals(
|
||||
session_exists: bool = True,
|
||||
turn_in_flight: bool = True,
|
||||
call_count: int = 1,
|
||||
push_length: int | None = 1,
|
||||
):
|
||||
"""Mock dependencies for the POST /stream queue-fallback path.
|
||||
|
||||
When ``turn_in_flight`` is True the handler takes the 202 queue branch.
|
||||
"""
|
||||
"""Mock dependencies for the pending-message queue path."""
|
||||
if session_exists:
|
||||
mock_session = mocker.MagicMock()
|
||||
mock_session.id = "sess-1"
|
||||
@@ -692,12 +695,10 @@ def _mock_stream_queue_internals(
|
||||
return_value=call_count,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.pending_message_helpers.push_pending_message",
|
||||
"backend.copilot.pending_message_helpers.push_pending_message_if_session_running",
|
||||
new_callable=AsyncMock,
|
||||
return_value=1,
|
||||
return_value=push_length,
|
||||
)
|
||||
# queue_user_message re-runs is_turn_in_flight via the helper module —
|
||||
# stub that path out too so we don't need a fake stream_registry.
|
||||
mocker.patch(
|
||||
"backend.copilot.pending_message_helpers.get_active_session_meta",
|
||||
new_callable=AsyncMock,
|
||||
@@ -705,37 +706,65 @@ def _mock_stream_queue_internals(
|
||||
)
|
||||
|
||||
|
||||
def test_stream_queue_returns_202_when_turn_in_flight(
|
||||
def test_queue_pending_message_returns_200_when_turn_in_flight(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""Happy path: POST /stream to a session with a live turn → 202 queue."""
|
||||
"""Happy path: POST /messages/pending to a live turn queues the message."""
|
||||
_mock_stream_queue_internals(mocker)
|
||||
|
||||
response = client.post(
|
||||
"/sessions/sess-1/stream",
|
||||
json={"message": "follow-up", "is_user_message": True},
|
||||
"/sessions/sess-1/messages/pending",
|
||||
json={"message": "follow-up"},
|
||||
)
|
||||
|
||||
assert response.status_code == 202
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["buffer_length"] == 1
|
||||
assert "turn_in_flight" in data
|
||||
|
||||
|
||||
def test_stream_queue_session_not_found_returns_404(
|
||||
def test_queue_pending_message_session_not_found_returns_404(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""If the session doesn't exist or belong to the user, returns 404."""
|
||||
_mock_stream_queue_internals(mocker, session_exists=False)
|
||||
|
||||
response = client.post(
|
||||
"/sessions/bad-sess/stream",
|
||||
json={"message": "hi", "is_user_message": True},
|
||||
"/sessions/bad-sess/messages/pending",
|
||||
json={"message": "hi"},
|
||||
)
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
def test_stream_queue_call_frequency_limit_returns_429(
|
||||
def test_queue_pending_message_without_active_turn_returns_409(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""A pending-message push needs an active turn to consume it."""
|
||||
_mock_stream_queue_internals(mocker, turn_in_flight=False)
|
||||
|
||||
response = client.post(
|
||||
"/sessions/sess-1/messages/pending",
|
||||
json={"message": "hi"},
|
||||
)
|
||||
|
||||
assert response.status_code == 409
|
||||
|
||||
|
||||
def test_queue_pending_message_race_after_active_check_returns_409(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""If the active turn ends before the atomic push, the message is not queued."""
|
||||
_mock_stream_queue_internals(mocker, push_length=None)
|
||||
|
||||
response = client.post(
|
||||
"/sessions/sess-1/messages/pending",
|
||||
json={"message": "hi"},
|
||||
)
|
||||
|
||||
assert response.status_code == 409
|
||||
|
||||
|
||||
def test_queue_pending_message_call_frequency_limit_returns_429(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""Per-user call-frequency cap rejects rapid-fire queued pushes."""
|
||||
@@ -744,14 +773,14 @@ def test_stream_queue_call_frequency_limit_returns_429(
|
||||
_mock_stream_queue_internals(mocker, call_count=PENDING_CALL_LIMIT + 1)
|
||||
|
||||
response = client.post(
|
||||
"/sessions/sess-1/stream",
|
||||
json={"message": "hi", "is_user_message": True},
|
||||
"/sessions/sess-1/messages/pending",
|
||||
json={"message": "hi"},
|
||||
)
|
||||
assert response.status_code == 429
|
||||
assert "Too many queued message requests this minute" in response.json()["detail"]
|
||||
|
||||
|
||||
def test_stream_queue_converts_context_dict_to_pending_context(
|
||||
def test_queue_pending_message_converts_context_dict_to_pending_context(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""StreamChatRequest.context is a raw dict; must be coerced to the
|
||||
@@ -768,15 +797,14 @@ def test_stream_queue_converts_context_dict_to_pending_context(
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/sessions/sess-1/stream",
|
||||
"/sessions/sess-1/messages/pending",
|
||||
json={
|
||||
"message": "hi",
|
||||
"is_user_message": True,
|
||||
"context": {"url": "https://example.test", "content": "body"},
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 202
|
||||
assert response.status_code == 200
|
||||
queue_spy.assert_awaited_once()
|
||||
kwargs = queue_spy.await_args.kwargs
|
||||
from backend.copilot.pending_messages import PendingMessageContext
|
||||
@@ -786,7 +814,7 @@ def test_stream_queue_converts_context_dict_to_pending_context(
|
||||
assert kwargs["context"].content == "body"
|
||||
|
||||
|
||||
def test_stream_queue_passes_none_context_when_omitted(
|
||||
def test_queue_pending_message_passes_none_context_when_omitted(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""When request.context is omitted, the queue call receives context=None."""
|
||||
@@ -802,15 +830,31 @@ def test_stream_queue_passes_none_context_when_omitted(
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/sessions/sess-1/stream",
|
||||
json={"message": "hi", "is_user_message": True},
|
||||
"/sessions/sess-1/messages/pending",
|
||||
json={"message": "hi"},
|
||||
)
|
||||
|
||||
assert response.status_code == 202
|
||||
assert response.status_code == 200
|
||||
queue_spy.assert_awaited_once()
|
||||
assert queue_spy.await_args.kwargs["context"] is None
|
||||
|
||||
|
||||
def test_stream_chat_queues_legacy_inflight_post_but_returns_sse(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""POST /stream must not return JSON to an AI SDK transport."""
|
||||
_mock_stream_queue_internals(mocker)
|
||||
|
||||
response = client.post(
|
||||
"/sessions/sess-1/stream",
|
||||
json={"message": "follow-up", "is_user_message": True},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"].startswith("text/event-stream")
|
||||
assert '"type":"finish"' in response.text
|
||||
|
||||
|
||||
# ─── get_pending_messages (GET /sessions/{session_id}/messages/pending) ─────
|
||||
|
||||
|
||||
@@ -1581,9 +1625,14 @@ def test_resume_session_stream_no_subscriber_queue(
|
||||
mock_registry.subscribe_to_session = AsyncMock(return_value=None)
|
||||
mocker.patch("backend.api.features.chat.routes.stream_registry", mock_registry)
|
||||
|
||||
response = client.get("/sessions/sess-1/stream")
|
||||
response = client.get("/sessions/sess-1/stream?last_chunk_id=9999-9")
|
||||
|
||||
assert response.status_code == 204
|
||||
mock_registry.subscribe_to_session.assert_awaited_once_with(
|
||||
session_id="sess-1",
|
||||
user_id=TEST_USER_ID,
|
||||
last_message_id="0-0",
|
||||
)
|
||||
|
||||
|
||||
# ─── DELETE /sessions/{id}/stream — disconnect listeners ──────────────
|
||||
|
||||
@@ -7,6 +7,7 @@ allowing frontend code generators like Orval to create corresponding TypeScript
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from backend.data.model import CredentialsType
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.sdk.registry import AutoRegistry
|
||||
|
||||
@@ -47,6 +48,57 @@ class ProviderNamesResponse(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class ProviderMetadata(BaseModel):
|
||||
"""Display metadata for a provider, shown in the settings integrations UI."""
|
||||
|
||||
name: str = Field(description="Provider slug (e.g. ``github``)")
|
||||
description: str | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"One-line human-readable summary of what the provider does. "
|
||||
"Declared via ``ProviderBuilder.with_description(...)`` in the "
|
||||
"provider's ``_config.py``. ``None`` if not set."
|
||||
),
|
||||
)
|
||||
supported_auth_types: list[CredentialsType] = Field(
|
||||
default_factory=list,
|
||||
description=(
|
||||
"Credential types this provider accepts. Drives which connection "
|
||||
"tabs the settings UI renders for the provider. Empty list means "
|
||||
"no auth types declared."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def get_supported_auth_types(name: str) -> list[CredentialsType]:
|
||||
"""Return the provider's supported credential types from :class:`AutoRegistry`.
|
||||
|
||||
Populated by :meth:`ProviderBuilder.with_supported_auth_types` (or by
|
||||
``with_oauth`` / ``with_api_key`` / ``with_user_password`` when the provider
|
||||
uses the full builder chain). Returns an empty list for providers with no
|
||||
auth types declared.
|
||||
"""
|
||||
provider = AutoRegistry.get_provider(name)
|
||||
if provider is None:
|
||||
return []
|
||||
return sorted(provider.supported_auth_types)
|
||||
|
||||
|
||||
def get_provider_description(name: str) -> str | None:
|
||||
"""Return the provider's description from :class:`AutoRegistry`.
|
||||
|
||||
Descriptions are declared via ``ProviderBuilder.with_description(...)`` in
|
||||
the provider's ``_config.py`` (SDK path) or in
|
||||
``blocks/_static_provider_configs.py`` (for providers that don't yet have
|
||||
their own directory). Returns ``None`` for providers with no registered
|
||||
description.
|
||||
"""
|
||||
provider = AutoRegistry.get_provider(name)
|
||||
if provider is None:
|
||||
return None
|
||||
return provider.description
|
||||
|
||||
|
||||
class ProviderConstants(BaseModel):
|
||||
"""
|
||||
Model that exposes all provider names as a constant in the OpenAPI schema.
|
||||
|
||||
@@ -66,7 +66,14 @@ from backend.util.exceptions import (
|
||||
)
|
||||
from backend.util.settings import Settings
|
||||
|
||||
from .models import ProviderConstants, ProviderNamesResponse, get_all_provider_names
|
||||
from .models import (
|
||||
ProviderConstants,
|
||||
ProviderMetadata,
|
||||
ProviderNamesResponse,
|
||||
get_all_provider_names,
|
||||
get_provider_description,
|
||||
get_supported_auth_types,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.integrations.oauth import BaseOAuthHandler
|
||||
@@ -1204,20 +1211,37 @@ async def get_ayrshare_sso_url(
|
||||
|
||||
|
||||
# === PROVIDER DISCOVERY ENDPOINTS ===
|
||||
@router.get("/providers", response_model=List[str])
|
||||
async def list_providers() -> List[str]:
|
||||
@router.get("/providers", response_model=List[ProviderMetadata])
|
||||
async def list_providers() -> List[ProviderMetadata]:
|
||||
"""
|
||||
Get a list of all available provider names.
|
||||
Get metadata for every available provider.
|
||||
|
||||
Returns both statically defined providers (from ProviderName enum)
|
||||
and dynamically registered providers (from SDK decorators).
|
||||
Returns both statically defined providers (from ``ProviderName`` enum) and
|
||||
dynamically registered providers (from SDK decorators). Each entry includes
|
||||
a ``description`` declared via ``ProviderBuilder.with_description(...)`` in
|
||||
the provider's ``_config.py``.
|
||||
|
||||
Note: The complete list of provider names is also available as a constant
|
||||
in the generated TypeScript client via PROVIDER_NAMES.
|
||||
"""
|
||||
# Get all providers at runtime
|
||||
# Ensure all block modules (and therefore every provider's _config.py) are
|
||||
# imported before we read from AutoRegistry. Cached on first call.
|
||||
try:
|
||||
from backend.blocks import load_all_blocks
|
||||
|
||||
load_all_blocks()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load blocks for provider metadata: {e}")
|
||||
|
||||
all_providers = get_all_provider_names()
|
||||
return all_providers
|
||||
return [
|
||||
ProviderMetadata(
|
||||
name=name,
|
||||
description=get_provider_description(name),
|
||||
supported_auth_types=get_supported_auth_types(name),
|
||||
)
|
||||
for name in all_providers
|
||||
]
|
||||
|
||||
|
||||
@router.get("/providers/system", response_model=List[str])
|
||||
|
||||
20
autogpt_platform/backend/backend/api/features/push/model.py
Normal file
20
autogpt_platform/backend/backend/api/features/push/model.py
Normal file
@@ -0,0 +1,20 @@
|
||||
import pydantic
|
||||
|
||||
|
||||
class PushSubscriptionKeys(pydantic.BaseModel):
|
||||
p256dh: str = pydantic.Field(min_length=1, max_length=512)
|
||||
auth: str = pydantic.Field(min_length=1, max_length=512)
|
||||
|
||||
|
||||
class PushSubscribeRequest(pydantic.BaseModel):
|
||||
endpoint: str = pydantic.Field(min_length=1, max_length=2048)
|
||||
keys: PushSubscriptionKeys
|
||||
user_agent: str | None = pydantic.Field(default=None, max_length=512)
|
||||
|
||||
|
||||
class PushUnsubscribeRequest(pydantic.BaseModel):
|
||||
endpoint: str = pydantic.Field(min_length=1, max_length=2048)
|
||||
|
||||
|
||||
class VapidPublicKeyResponse(pydantic.BaseModel):
|
||||
public_key: str
|
||||
64
autogpt_platform/backend/backend/api/features/push/routes.py
Normal file
64
autogpt_platform/backend/backend/api/features/push/routes.py
Normal file
@@ -0,0 +1,64 @@
|
||||
from typing import Annotated
|
||||
|
||||
from autogpt_libs.auth import get_user_id, requires_user
|
||||
from fastapi import APIRouter, HTTPException, Security
|
||||
from starlette.status import HTTP_204_NO_CONTENT, HTTP_400_BAD_REQUEST
|
||||
|
||||
from backend.api.features.push.model import (
|
||||
PushSubscribeRequest,
|
||||
PushUnsubscribeRequest,
|
||||
VapidPublicKeyResponse,
|
||||
)
|
||||
from backend.data.push_subscription import (
|
||||
delete_push_subscription,
|
||||
upsert_push_subscription,
|
||||
validate_push_endpoint,
|
||||
)
|
||||
from backend.util.settings import Settings
|
||||
|
||||
router = APIRouter()
|
||||
_settings = Settings()
|
||||
|
||||
|
||||
@router.get(
|
||||
"/vapid-key",
|
||||
summary="Get VAPID public key for push subscription",
|
||||
)
|
||||
async def get_vapid_public_key() -> VapidPublicKeyResponse:
|
||||
return VapidPublicKeyResponse(public_key=_settings.secrets.vapid_public_key)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/subscribe",
|
||||
summary="Register a push subscription for the current user",
|
||||
status_code=HTTP_204_NO_CONTENT,
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
async def subscribe_push(
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
body: PushSubscribeRequest,
|
||||
) -> None:
|
||||
try:
|
||||
await validate_push_endpoint(body.endpoint)
|
||||
await upsert_push_subscription(
|
||||
user_id=user_id,
|
||||
endpoint=body.endpoint,
|
||||
p256dh=body.keys.p256dh,
|
||||
auth=body.keys.auth,
|
||||
user_agent=body.user_agent,
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=HTTP_400_BAD_REQUEST, detail=str(e))
|
||||
|
||||
|
||||
@router.post(
|
||||
"/unsubscribe",
|
||||
summary="Remove a push subscription",
|
||||
status_code=HTTP_204_NO_CONTENT,
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
async def unsubscribe_push(
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
body: PushUnsubscribeRequest,
|
||||
) -> None:
|
||||
await delete_push_subscription(user_id, body.endpoint)
|
||||
@@ -0,0 +1,240 @@
|
||||
"""Tests for push notification routes."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import fastapi
|
||||
import fastapi.testclient
|
||||
import pytest
|
||||
|
||||
from backend.api.features.push.routes import router
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(router)
|
||||
client = fastapi.testclient.TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_app_auth(mock_jwt_user):
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
|
||||
yield
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
def test_get_vapid_public_key(mocker):
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.secrets.vapid_public_key = "test-vapid-public-key-base64url"
|
||||
mocker.patch(
|
||||
"backend.api.features.push.routes._settings",
|
||||
mock_settings,
|
||||
)
|
||||
|
||||
response = client.get("/vapid-key")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["public_key"] == "test-vapid-public-key-base64url"
|
||||
|
||||
|
||||
def test_get_vapid_public_key_empty(mocker):
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.secrets.vapid_public_key = ""
|
||||
mocker.patch(
|
||||
"backend.api.features.push.routes._settings",
|
||||
mock_settings,
|
||||
)
|
||||
|
||||
response = client.get("/vapid-key")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["public_key"] == ""
|
||||
|
||||
|
||||
def test_subscribe_push(mocker, test_user_id):
|
||||
mock_upsert = mocker.patch(
|
||||
"backend.api.features.push.routes.upsert_push_subscription",
|
||||
new_callable=AsyncMock,
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/subscribe",
|
||||
json={
|
||||
"endpoint": "https://fcm.googleapis.com/fcm/send/abc123",
|
||||
"keys": {
|
||||
"p256dh": "test-p256dh-key",
|
||||
"auth": "test-auth-key",
|
||||
},
|
||||
"user_agent": "Mozilla/5.0 Test",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 204
|
||||
mock_upsert.assert_awaited_once_with(
|
||||
user_id=test_user_id,
|
||||
endpoint="https://fcm.googleapis.com/fcm/send/abc123",
|
||||
p256dh="test-p256dh-key",
|
||||
auth="test-auth-key",
|
||||
user_agent="Mozilla/5.0 Test",
|
||||
)
|
||||
|
||||
|
||||
def test_subscribe_push_without_user_agent(mocker, test_user_id):
|
||||
mock_upsert = mocker.patch(
|
||||
"backend.api.features.push.routes.upsert_push_subscription",
|
||||
new_callable=AsyncMock,
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/subscribe",
|
||||
json={
|
||||
"endpoint": "https://fcm.googleapis.com/fcm/send/abc123",
|
||||
"keys": {
|
||||
"p256dh": "test-p256dh-key",
|
||||
"auth": "test-auth-key",
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 204
|
||||
mock_upsert.assert_awaited_once_with(
|
||||
user_id=test_user_id,
|
||||
endpoint="https://fcm.googleapis.com/fcm/send/abc123",
|
||||
p256dh="test-p256dh-key",
|
||||
auth="test-auth-key",
|
||||
user_agent=None,
|
||||
)
|
||||
|
||||
|
||||
def test_subscribe_push_missing_keys():
|
||||
response = client.post(
|
||||
"/subscribe",
|
||||
json={
|
||||
"endpoint": "https://fcm.googleapis.com/fcm/send/abc123",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
def test_subscribe_push_missing_endpoint():
|
||||
response = client.post(
|
||||
"/subscribe",
|
||||
json={
|
||||
"keys": {
|
||||
"p256dh": "test-p256dh-key",
|
||||
"auth": "test-auth-key",
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
def test_subscribe_push_rejects_empty_crypto_keys():
|
||||
response = client.post(
|
||||
"/subscribe",
|
||||
json={
|
||||
"endpoint": "https://fcm.googleapis.com/fcm/send/abc123",
|
||||
"keys": {"p256dh": "", "auth": ""},
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
def test_subscribe_push_rejects_oversized_endpoint():
|
||||
response = client.post(
|
||||
"/subscribe",
|
||||
json={
|
||||
"endpoint": "https://fcm.googleapis.com/fcm/send/" + "x" * 3000,
|
||||
"keys": {"p256dh": "k", "auth": "a"},
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
def test_unsubscribe_push(mocker, test_user_id):
|
||||
mock_delete = mocker.patch(
|
||||
"backend.api.features.push.routes.delete_push_subscription",
|
||||
new_callable=AsyncMock,
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/unsubscribe",
|
||||
json={
|
||||
"endpoint": "https://fcm.googleapis.com/fcm/send/abc123",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 204
|
||||
mock_delete.assert_awaited_once_with(
|
||||
test_user_id,
|
||||
"https://fcm.googleapis.com/fcm/send/abc123",
|
||||
)
|
||||
|
||||
|
||||
def test_unsubscribe_push_missing_endpoint():
|
||||
response = client.post(
|
||||
"/unsubscribe",
|
||||
json={},
|
||||
)
|
||||
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"untrusted_endpoint",
|
||||
[
|
||||
"https://localhost/evil",
|
||||
"https://127.0.0.1/evil",
|
||||
"https://169.254.169.254/latest/meta-data/",
|
||||
"https://internal-service.local/api",
|
||||
"https://attacker.example.com/push",
|
||||
"http://fcm.googleapis.com/fcm/send/abc",
|
||||
"file:///etc/passwd",
|
||||
],
|
||||
)
|
||||
def test_subscribe_push_rejects_untrusted_endpoints(mocker, untrusted_endpoint):
|
||||
mock_upsert = mocker.patch(
|
||||
"backend.api.features.push.routes.upsert_push_subscription",
|
||||
new_callable=AsyncMock,
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/subscribe",
|
||||
json={
|
||||
"endpoint": untrusted_endpoint,
|
||||
"keys": {
|
||||
"p256dh": "test-p256dh-key",
|
||||
"auth": "test-auth-key",
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
mock_upsert.assert_not_awaited()
|
||||
|
||||
|
||||
def test_subscribe_push_surfaces_cap_as_400(mocker):
|
||||
mocker.patch(
|
||||
"backend.api.features.push.routes.upsert_push_subscription",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=ValueError("Subscription limit of 20 per user reached"),
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/subscribe",
|
||||
json={
|
||||
"endpoint": "https://fcm.googleapis.com/fcm/send/abc123",
|
||||
"keys": {
|
||||
"p256dh": "test-p256dh-key",
|
||||
"auth": "test-auth-key",
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "Subscription limit" in response.json()["detail"]
|
||||
@@ -490,6 +490,9 @@ async def get_store_creators(
|
||||
# Build where clause with sanitized inputs
|
||||
where = {}
|
||||
|
||||
# Only return creators with approved agents
|
||||
where["num_agents"] = {"gt": 0}
|
||||
|
||||
if featured:
|
||||
where["is_featured"] = featured
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from datetime import datetime
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import prisma.enums
|
||||
import prisma.errors
|
||||
@@ -50,8 +51,8 @@ async def test_get_store_agents(mocker):
|
||||
|
||||
# Mock prisma calls
|
||||
mock_store_agent = mocker.patch("prisma.models.StoreAgent.prisma")
|
||||
mock_store_agent.return_value.find_many = mocker.AsyncMock(return_value=mock_agents)
|
||||
mock_store_agent.return_value.count = mocker.AsyncMock(return_value=1)
|
||||
mock_store_agent.return_value.find_many = AsyncMock(return_value=mock_agents)
|
||||
mock_store_agent.return_value.count = AsyncMock(return_value=1)
|
||||
|
||||
# Call function
|
||||
result = await db.get_store_agents()
|
||||
@@ -94,7 +95,7 @@ async def test_get_store_agent_details(mocker):
|
||||
|
||||
# Mock StoreAgent prisma call
|
||||
mock_store_agent = mocker.patch("prisma.models.StoreAgent.prisma")
|
||||
mock_store_agent.return_value.find_first = mocker.AsyncMock(return_value=mock_agent)
|
||||
mock_store_agent.return_value.find_first = AsyncMock(return_value=mock_agent)
|
||||
|
||||
# Call function
|
||||
result = await db.get_store_agent_details("creator", "test-agent")
|
||||
@@ -133,7 +134,7 @@ async def test_get_store_creator(mocker):
|
||||
|
||||
# Mock prisma call
|
||||
mock_creator = mocker.patch("prisma.models.Creator.prisma")
|
||||
mock_creator.return_value.find_unique = mocker.AsyncMock()
|
||||
mock_creator.return_value.find_unique = AsyncMock()
|
||||
# Configure the mock to return values that will pass validation
|
||||
mock_creator.return_value.find_unique.return_value = mock_creator_data
|
||||
|
||||
@@ -236,23 +237,23 @@ async def test_create_store_submission(mocker):
|
||||
|
||||
# Mock prisma calls
|
||||
mock_agent_graph = mocker.patch("prisma.models.AgentGraph.prisma")
|
||||
mock_agent_graph.return_value.find_first = mocker.AsyncMock(return_value=mock_agent)
|
||||
mock_agent_graph.return_value.find_first = AsyncMock(return_value=mock_agent)
|
||||
|
||||
# Mock transaction context manager
|
||||
mock_tx = mocker.MagicMock()
|
||||
mocker.patch(
|
||||
"backend.api.features.store.db.transaction",
|
||||
return_value=mocker.AsyncMock(
|
||||
__aenter__=mocker.AsyncMock(return_value=mock_tx),
|
||||
__aexit__=mocker.AsyncMock(return_value=False),
|
||||
return_value=AsyncMock(
|
||||
__aenter__=AsyncMock(return_value=mock_tx),
|
||||
__aexit__=AsyncMock(return_value=False),
|
||||
),
|
||||
)
|
||||
|
||||
mock_sl = mocker.patch("prisma.models.StoreListing.prisma")
|
||||
mock_sl.return_value.find_unique = mocker.AsyncMock(return_value=None)
|
||||
mock_sl.return_value.find_unique = AsyncMock(return_value=None)
|
||||
|
||||
mock_slv = mocker.patch("prisma.models.StoreListingVersion.prisma")
|
||||
mock_slv.return_value.create = mocker.AsyncMock(return_value=mock_version)
|
||||
mock_slv.return_value.create = AsyncMock(return_value=mock_version)
|
||||
|
||||
# Call function
|
||||
result = await db.create_store_submission(
|
||||
@@ -292,10 +293,8 @@ async def test_update_profile(mocker):
|
||||
|
||||
# Mock prisma calls
|
||||
mock_profile_db = mocker.patch("prisma.models.Profile.prisma")
|
||||
mock_profile_db.return_value.find_first = mocker.AsyncMock(
|
||||
return_value=mock_profile
|
||||
)
|
||||
mock_profile_db.return_value.update = mocker.AsyncMock(return_value=mock_profile)
|
||||
mock_profile_db.return_value.find_first = AsyncMock(return_value=mock_profile)
|
||||
mock_profile_db.return_value.update = AsyncMock(return_value=mock_profile)
|
||||
|
||||
# Test data
|
||||
profile = Profile(
|
||||
@@ -336,9 +335,7 @@ async def test_get_user_profile(mocker):
|
||||
|
||||
# Mock prisma calls
|
||||
mock_profile_db = mocker.patch("prisma.models.Profile.prisma")
|
||||
mock_profile_db.return_value.find_first = mocker.AsyncMock(
|
||||
return_value=mock_profile
|
||||
)
|
||||
mock_profile_db.return_value.find_first = AsyncMock(return_value=mock_profile)
|
||||
|
||||
# Call function
|
||||
result = await db.get_user_profile("user-id")
|
||||
@@ -396,3 +393,38 @@ async def test_get_store_agents_search_category_array_injection():
|
||||
# Verify the query executed without error
|
||||
# Category should be parameterized, preventing SQL injection
|
||||
assert isinstance(result.agents, list)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_get_store_creators_only_returns_approved(mocker):
|
||||
mock_creators = [
|
||||
prisma.models.Creator(
|
||||
name="Creator One",
|
||||
username="creator1",
|
||||
description="desc",
|
||||
links=["link1"],
|
||||
avatar_url="avatar.jpg",
|
||||
num_agents=1,
|
||||
agent_rating=4.5,
|
||||
agent_runs=10,
|
||||
top_categories=["test"],
|
||||
is_featured=False,
|
||||
)
|
||||
]
|
||||
|
||||
mock_creator = mocker.patch("prisma.models.Creator.prisma")
|
||||
mock_creator.return_value.find_many = AsyncMock(return_value=mock_creators)
|
||||
mock_creator.return_value.count = AsyncMock(return_value=1)
|
||||
|
||||
result = await db.get_store_creators()
|
||||
|
||||
assert len(result.creators) == 1
|
||||
assert result.creators[0].username == "creator1"
|
||||
|
||||
mock_creator.return_value.find_many.assert_called_once()
|
||||
mock_creator.return_value.count.assert_called_once()
|
||||
|
||||
_, find_kwargs = mock_creator.return_value.find_many.call_args
|
||||
_, count_kwargs = mock_creator.return_value.count.call_args
|
||||
assert find_kwargs["where"]["num_agents"] == {"gt": 0}
|
||||
assert count_kwargs["where"]["num_agents"] == {"gt": 0}
|
||||
|
||||
@@ -70,7 +70,8 @@ _DEFAULT_TIER_PRICES: dict[SubscriptionTier, str | None] = {
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _stub_subscription_status_lookups(mocker: pytest_mock.MockFixture) -> None:
|
||||
"""Stub Stripe price + proration lookups used by get_subscription_status.
|
||||
"""Stub Stripe price + proration + tier-multiplier lookups used by
|
||||
get_subscription_status.
|
||||
|
||||
The POST /credits/subscription handler now returns the full subscription
|
||||
status payload from every branch (same-tier, BASIC downgrade, paid→paid
|
||||
@@ -90,6 +91,16 @@ def _stub_subscription_status_lookups(mocker: pytest_mock.MockFixture) -> None:
|
||||
new_callable=AsyncMock,
|
||||
return_value=0,
|
||||
)
|
||||
# Default tier-multiplier resolver to the backend defaults so the endpoint
|
||||
# never reaches LaunchDarkly during tests. Individual tests override for
|
||||
# LD-override scenarios.
|
||||
from backend.copilot.rate_limit import _DEFAULT_TIER_MULTIPLIERS
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.get_tier_multipliers",
|
||||
new_callable=AsyncMock,
|
||||
return_value=dict(_DEFAULT_TIER_MULTIPLIERS),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -187,13 +198,59 @@ def test_get_subscription_status_pro(
|
||||
assert data["tier_costs"]["BASIC"] == 0
|
||||
assert "ENTERPRISE" not in data["tier_costs"]
|
||||
assert data["proration_credit_cents"] == 500
|
||||
# tier_multipliers mirrors the same set of tiers that land in tier_costs,
|
||||
# so the frontend never renders a multiplier badge for a hidden row.
|
||||
assert set(data["tier_multipliers"].keys()) == set(data["tier_costs"].keys())
|
||||
assert data["tier_multipliers"]["BASIC"] == 1.0
|
||||
assert data["tier_multipliers"]["PRO"] == 5.0
|
||||
assert data["tier_multipliers"]["MAX"] == 20.0
|
||||
assert data["tier_multipliers"]["BUSINESS"] == 60.0
|
||||
|
||||
|
||||
def test_get_subscription_status_defaults_to_basic(
|
||||
def test_get_subscription_status_tier_multipliers_ld_override(
|
||||
client: fastapi.testclient.TestClient,
|
||||
mocker: pytest_mock.MockFixture,
|
||||
) -> None:
|
||||
"""When all LD price IDs are unset, tier_costs is empty and the caller sees cost=0."""
|
||||
"""A LaunchDarkly-overridden tier multiplier flows through the response."""
|
||||
mock_user = Mock()
|
||||
mock_user.subscription_tier = SubscriptionTier.BASIC
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.get_user_by_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_user,
|
||||
)
|
||||
|
||||
# LD says PRO is 7.5× (instead of the 5× default); other tiers unchanged.
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.get_tier_multipliers",
|
||||
new_callable=AsyncMock,
|
||||
return_value={
|
||||
SubscriptionTier.BASIC: 1.0,
|
||||
SubscriptionTier.PRO: 7.5,
|
||||
SubscriptionTier.MAX: 20.0,
|
||||
SubscriptionTier.BUSINESS: 60.0,
|
||||
SubscriptionTier.ENTERPRISE: 60.0,
|
||||
},
|
||||
)
|
||||
|
||||
response = client.get("/credits/subscription")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
# Only tiers that made it into tier_costs get a multiplier (default stub
|
||||
# exposes PRO + MAX via _DEFAULT_TIER_PRICES).
|
||||
assert data["tier_multipliers"]["PRO"] == 7.5
|
||||
assert data["tier_multipliers"]["MAX"] == 20.0
|
||||
# BUSINESS has no price configured → hidden from both maps.
|
||||
assert "BUSINESS" not in data["tier_multipliers"]
|
||||
|
||||
|
||||
def test_get_subscription_status_defaults_to_no_tier(
|
||||
client: fastapi.testclient.TestClient,
|
||||
mocker: pytest_mock.MockFixture,
|
||||
) -> None:
|
||||
"""When user has no subscription_tier, defaults to NO_TIER (the explicit
|
||||
no-active-subscription state)."""
|
||||
mock_user = Mock()
|
||||
mock_user.subscription_tier = None
|
||||
|
||||
@@ -217,7 +274,7 @@ def test_get_subscription_status_defaults_to_basic(
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["tier"] == SubscriptionTier.BASIC.value
|
||||
assert data["tier"] == SubscriptionTier.NO_TIER.value
|
||||
assert data["monthly_cost"] == 0
|
||||
assert data["tier_costs"] == {}
|
||||
assert data["proration_credit_cents"] == 0
|
||||
@@ -270,11 +327,11 @@ def test_get_subscription_status_stripe_error_falls_back_to_zero(
|
||||
assert data["tier_costs"]["PRO"] == 0
|
||||
|
||||
|
||||
def test_update_subscription_tier_basic_no_payment(
|
||||
def test_update_subscription_tier_no_tier_no_payment(
|
||||
client: fastapi.testclient.TestClient,
|
||||
mocker: pytest_mock.MockFixture,
|
||||
) -> None:
|
||||
"""POST /credits/subscription to BASIC tier when payment disabled skips Stripe."""
|
||||
"""POST /credits/subscription to NO_TIER (cancel) when payment disabled skips Stripe."""
|
||||
mock_user = Mock()
|
||||
mock_user.subscription_tier = SubscriptionTier.PRO
|
||||
|
||||
@@ -295,7 +352,7 @@ def test_update_subscription_tier_basic_no_payment(
|
||||
new_callable=AsyncMock,
|
||||
)
|
||||
|
||||
response = client.post("/credits/subscription", json={"tier": "BASIC"})
|
||||
response = client.post("/credits/subscription", json={"tier": "NO_TIER"})
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["url"] == ""
|
||||
@@ -348,12 +405,109 @@ def test_update_subscription_tier_paid_requires_urls(
|
||||
"backend.api.features.v1.is_feature_enabled",
|
||||
side_effect=mock_feature_enabled,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.modify_stripe_subscription_for_tier",
|
||||
new_callable=AsyncMock,
|
||||
return_value=False,
|
||||
)
|
||||
|
||||
response = client.post("/credits/subscription", json={"tier": "PRO"})
|
||||
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
def test_update_subscription_tier_currency_mismatch_returns_422(
|
||||
client: fastapi.testclient.TestClient,
|
||||
mocker: pytest_mock.MockFixture,
|
||||
) -> None:
|
||||
"""Stripe rejects a SubscriptionSchedule whose phases mix currencies (e.g.
|
||||
GBP-checkout sub trying to schedule a USD-only target Price). The handler
|
||||
must convert that into a specific 422 instead of the generic 502 so the
|
||||
caller can tell the difference between a currency-config bug and a Stripe
|
||||
outage."""
|
||||
mock_user = Mock()
|
||||
mock_user.subscription_tier = SubscriptionTier.MAX
|
||||
|
||||
async def mock_feature_enabled(*args, **kwargs):
|
||||
return True
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.get_user_by_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_user,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.is_feature_enabled",
|
||||
side_effect=mock_feature_enabled,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.modify_stripe_subscription_for_tier",
|
||||
side_effect=stripe.InvalidRequestError(
|
||||
"The price specified only supports `usd`. This doesn't match the"
|
||||
" expected currency: `gbp`.",
|
||||
param="phases",
|
||||
),
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/credits/subscription",
|
||||
json={
|
||||
"tier": "PRO",
|
||||
"success_url": f"{TEST_FRONTEND_ORIGIN}/success",
|
||||
"cancel_url": f"{TEST_FRONTEND_ORIGIN}/cancel",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 422
|
||||
detail = response.json()["detail"]
|
||||
assert "billing currency" in detail.lower()
|
||||
assert "contact support" in detail.lower()
|
||||
|
||||
|
||||
def test_update_subscription_tier_non_currency_invalid_request_returns_502(
|
||||
client: fastapi.testclient.TestClient,
|
||||
mocker: pytest_mock.MockFixture,
|
||||
) -> None:
|
||||
"""Locks the contract that *only* currency-mismatch InvalidRequestErrors
|
||||
translate to 422 — every other Stripe InvalidRequestError must still
|
||||
surface as the generic 502 so that widening the conditional later is
|
||||
caught by the suite."""
|
||||
mock_user = Mock()
|
||||
mock_user.subscription_tier = SubscriptionTier.MAX
|
||||
|
||||
async def mock_feature_enabled(*args, **kwargs):
|
||||
return True
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.get_user_by_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_user,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.is_feature_enabled",
|
||||
side_effect=mock_feature_enabled,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.modify_stripe_subscription_for_tier",
|
||||
side_effect=stripe.InvalidRequestError(
|
||||
"No such price: 'price_does_not_exist'",
|
||||
param="items[0][price]",
|
||||
),
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/credits/subscription",
|
||||
json={
|
||||
"tier": "PRO",
|
||||
"success_url": f"{TEST_FRONTEND_ORIGIN}/success",
|
||||
"cancel_url": f"{TEST_FRONTEND_ORIGIN}/cancel",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 502
|
||||
assert "billing currency" not in response.json()["detail"].lower()
|
||||
|
||||
|
||||
def test_update_subscription_tier_creates_checkout(
|
||||
client: fastapi.testclient.TestClient,
|
||||
mocker: pytest_mock.MockFixture,
|
||||
@@ -374,6 +528,11 @@ def test_update_subscription_tier_creates_checkout(
|
||||
"backend.api.features.v1.is_feature_enabled",
|
||||
side_effect=mock_feature_enabled,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.modify_stripe_subscription_for_tier",
|
||||
new_callable=AsyncMock,
|
||||
return_value=False,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.create_subscription_checkout",
|
||||
new_callable=AsyncMock,
|
||||
@@ -413,6 +572,11 @@ def test_update_subscription_tier_rejects_open_redirect(
|
||||
"backend.api.features.v1.is_feature_enabled",
|
||||
side_effect=mock_feature_enabled,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.modify_stripe_subscription_for_tier",
|
||||
new_callable=AsyncMock,
|
||||
return_value=False,
|
||||
)
|
||||
checkout_mock = mocker.patch(
|
||||
"backend.api.features.v1.create_subscription_checkout",
|
||||
new_callable=AsyncMock,
|
||||
@@ -593,14 +757,14 @@ def test_update_subscription_tier_same_tier_stripe_error_returns_502(
|
||||
assert "contact support" in response.json()["detail"].lower()
|
||||
|
||||
|
||||
def test_update_subscription_tier_basic_with_payment_schedules_cancel_and_does_not_update_db(
|
||||
def test_update_subscription_tier_no_tier_with_payment_schedules_cancel_and_does_not_update_db(
|
||||
client: fastapi.testclient.TestClient,
|
||||
mocker: pytest_mock.MockFixture,
|
||||
) -> None:
|
||||
"""Downgrading to BASIC schedules Stripe cancellation at period end.
|
||||
"""Cancelling to NO_TIER schedules Stripe cancellation at period end.
|
||||
|
||||
The DB tier must NOT be updated immediately — the customer.subscription.deleted
|
||||
webhook fires at period end and downgrades to BASIC then.
|
||||
webhook fires at period end and downgrades to NO_TIER then.
|
||||
"""
|
||||
mock_user = Mock()
|
||||
mock_user.subscription_tier = SubscriptionTier.PRO
|
||||
@@ -626,18 +790,18 @@ def test_update_subscription_tier_basic_with_payment_schedules_cancel_and_does_n
|
||||
side_effect=mock_feature_enabled,
|
||||
)
|
||||
|
||||
response = client.post("/credits/subscription", json={"tier": "BASIC"})
|
||||
response = client.post("/credits/subscription", json={"tier": "NO_TIER"})
|
||||
|
||||
assert response.status_code == 200
|
||||
mock_cancel.assert_awaited_once()
|
||||
mock_set_tier.assert_not_awaited()
|
||||
|
||||
|
||||
def test_update_subscription_tier_basic_cancel_failure_returns_502(
|
||||
def test_update_subscription_tier_no_tier_cancel_failure_returns_502(
|
||||
client: fastapi.testclient.TestClient,
|
||||
mocker: pytest_mock.MockFixture,
|
||||
) -> None:
|
||||
"""Downgrading to BASIC returns 502 with a generic error (no Stripe detail leakage)."""
|
||||
"""Cancelling to NO_TIER returns 502 with a generic error (no Stripe detail leakage)."""
|
||||
mock_user = Mock()
|
||||
mock_user.subscription_tier = SubscriptionTier.PRO
|
||||
|
||||
@@ -660,7 +824,7 @@ def test_update_subscription_tier_basic_cancel_failure_returns_502(
|
||||
side_effect=mock_feature_enabled,
|
||||
)
|
||||
|
||||
response = client.post("/credits/subscription", json={"tier": "BASIC"})
|
||||
response = client.post("/credits/subscription", json={"tier": "NO_TIER"})
|
||||
|
||||
assert response.status_code == 502
|
||||
detail = response.json()["detail"]
|
||||
@@ -865,29 +1029,20 @@ def test_update_subscription_tier_max_checkout(
|
||||
checkout_mock.assert_not_awaited()
|
||||
|
||||
|
||||
def test_update_subscription_tier_admin_granted_paid_to_paid_updates_db_directly(
|
||||
def test_update_subscription_tier_no_active_sub_falls_through_to_checkout(
|
||||
client: fastapi.testclient.TestClient,
|
||||
mocker: pytest_mock.MockFixture,
|
||||
) -> None:
|
||||
"""Admin-granted paid tier users are NOT sent to Stripe checkout for paid→paid changes.
|
||||
"""Any tier change from a user with no active Stripe sub goes through Checkout.
|
||||
|
||||
When modify_stripe_subscription_for_tier returns False (no Stripe subscription
|
||||
found — admin-granted tier), the endpoint must update the DB tier directly and
|
||||
return 200 with url="", rather than falling through to Checkout Session creation.
|
||||
Admin-granted users (no Stripe sub yet) and never-paid users follow the
|
||||
exact same path: modify returns False → Checkout to set up payment. The
|
||||
endpoint has no admin-specific branch — admin tier grants happen out-of-band
|
||||
via the admin portal, not this user-facing route.
|
||||
"""
|
||||
mock_user = Mock()
|
||||
mock_user.subscription_tier = SubscriptionTier.PRO
|
||||
|
||||
async def price_id_with_business(tier: SubscriptionTier) -> str | None:
|
||||
return {
|
||||
**_DEFAULT_TIER_PRICES,
|
||||
SubscriptionTier.BUSINESS: "price_business",
|
||||
}.get(tier)
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.get_subscription_price_id",
|
||||
side_effect=price_id_with_business,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.get_user_by_id",
|
||||
new_callable=AsyncMock,
|
||||
@@ -898,7 +1053,6 @@ def test_update_subscription_tier_admin_granted_paid_to_paid_updates_db_directly
|
||||
new_callable=AsyncMock,
|
||||
return_value=True,
|
||||
)
|
||||
# Return False = no Stripe subscription (admin-granted tier)
|
||||
modify_mock = mocker.patch(
|
||||
"backend.api.features.v1.modify_stripe_subscription_for_tier",
|
||||
new_callable=AsyncMock,
|
||||
@@ -911,23 +1065,24 @@ def test_update_subscription_tier_admin_granted_paid_to_paid_updates_db_directly
|
||||
checkout_mock = mocker.patch(
|
||||
"backend.api.features.v1.create_subscription_checkout",
|
||||
new_callable=AsyncMock,
|
||||
return_value="https://checkout.stripe.com/pay/cs_test_no_sub",
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/credits/subscription",
|
||||
json={
|
||||
"tier": "BUSINESS",
|
||||
"tier": "MAX",
|
||||
"success_url": f"{TEST_FRONTEND_ORIGIN}/success",
|
||||
"cancel_url": f"{TEST_FRONTEND_ORIGIN}/cancel",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["url"] == ""
|
||||
modify_mock.assert_awaited_once_with(TEST_USER_ID, SubscriptionTier.BUSINESS)
|
||||
# DB tier updated directly — no Stripe Checkout Session created
|
||||
set_tier_mock.assert_awaited_once_with(TEST_USER_ID, SubscriptionTier.BUSINESS)
|
||||
checkout_mock.assert_not_awaited()
|
||||
assert response.json()["url"] == "https://checkout.stripe.com/pay/cs_test_no_sub"
|
||||
modify_mock.assert_awaited_once_with(TEST_USER_ID, SubscriptionTier.MAX)
|
||||
# No DB-flip — payment must be collected via Checkout regardless of direction.
|
||||
set_tier_mock.assert_not_awaited()
|
||||
checkout_mock.assert_awaited_once()
|
||||
|
||||
|
||||
def test_update_subscription_tier_priced_basic_no_sub_falls_through_to_checkout(
|
||||
@@ -1098,14 +1253,14 @@ def test_update_subscription_tier_paid_to_paid_stripe_error_returns_502(
|
||||
assert response.status_code == 502
|
||||
|
||||
|
||||
def test_update_subscription_tier_basic_no_stripe_subscription(
|
||||
def test_update_subscription_tier_no_tier_no_stripe_subscription(
|
||||
client: fastapi.testclient.TestClient,
|
||||
mocker: pytest_mock.MockFixture,
|
||||
) -> None:
|
||||
"""Downgrading to BASIC when no Stripe subscription exists updates DB tier directly.
|
||||
"""Cancelling to NO_TIER when no Stripe subscription exists updates DB tier directly.
|
||||
|
||||
Admin-granted paid tiers have no associated Stripe subscription. When such a
|
||||
user requests a self-service downgrade, cancel_stripe_subscription returns False
|
||||
user requests a self-service cancel, cancel_stripe_subscription returns False
|
||||
(nothing to cancel), so the endpoint must immediately call set_subscription_tier
|
||||
rather than waiting for a webhook that will never arrive.
|
||||
"""
|
||||
@@ -1133,13 +1288,13 @@ def test_update_subscription_tier_basic_no_stripe_subscription(
|
||||
new_callable=AsyncMock,
|
||||
)
|
||||
|
||||
response = client.post("/credits/subscription", json={"tier": "BASIC"})
|
||||
response = client.post("/credits/subscription", json={"tier": "NO_TIER"})
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["url"] == ""
|
||||
cancel_mock.assert_awaited_once_with(TEST_USER_ID)
|
||||
# DB tier must be updated immediately — no webhook will fire for a missing sub
|
||||
set_tier_mock.assert_awaited_once_with(TEST_USER_ID, SubscriptionTier.BASIC)
|
||||
set_tier_mock.assert_awaited_once_with(TEST_USER_ID, SubscriptionTier.NO_TIER)
|
||||
|
||||
|
||||
def test_get_subscription_status_includes_pending_tier(
|
||||
|
||||
@@ -44,6 +44,7 @@ from backend.api.model import (
|
||||
UploadFileResponse,
|
||||
)
|
||||
from backend.blocks import get_block, get_blocks
|
||||
from backend.copilot.rate_limit import get_tier_multipliers
|
||||
from backend.data import execution as execution_db
|
||||
from backend.data import graph as graph_db
|
||||
from backend.data.auth import api_key as api_key_db
|
||||
@@ -56,12 +57,14 @@ from backend.data.credit import (
|
||||
UserCredit,
|
||||
cancel_stripe_subscription,
|
||||
create_subscription_checkout,
|
||||
get_active_subscription_period_end,
|
||||
get_auto_top_up,
|
||||
get_pending_subscription_change,
|
||||
get_proration_credit_cents,
|
||||
get_subscription_price_id,
|
||||
get_user_credit_model,
|
||||
handle_subscription_payment_failure,
|
||||
handle_subscription_payment_success,
|
||||
modify_stripe_subscription_for_tier,
|
||||
release_pending_subscription_schedule,
|
||||
set_auto_top_up,
|
||||
@@ -699,17 +702,42 @@ async def get_user_auto_top_up(
|
||||
|
||||
|
||||
class SubscriptionTierRequest(BaseModel):
|
||||
tier: Literal["BASIC", "PRO", "MAX", "BUSINESS"]
|
||||
tier: Literal["NO_TIER", "BASIC", "PRO", "MAX", "BUSINESS"]
|
||||
success_url: str = ""
|
||||
cancel_url: str = ""
|
||||
|
||||
|
||||
class SubscriptionStatusResponse(BaseModel):
|
||||
tier: Literal["BASIC", "PRO", "MAX", "BUSINESS", "ENTERPRISE"]
|
||||
tier: Literal["NO_TIER", "BASIC", "PRO", "MAX", "BUSINESS", "ENTERPRISE"]
|
||||
monthly_cost: int # amount in cents (Stripe convention)
|
||||
tier_costs: dict[str, int] # tier name -> amount in cents
|
||||
tier_multipliers: dict[str, float] = Field(
|
||||
default_factory=dict,
|
||||
description=(
|
||||
"Tier → rate-limit multiplier. Covers the same tiers listed in"
|
||||
" ``tier_costs`` so the frontend can render rate-limit badges"
|
||||
" relative to the lowest visible tier without knowing backend"
|
||||
" defaults."
|
||||
),
|
||||
)
|
||||
proration_credit_cents: int # unused portion of current sub to convert on upgrade
|
||||
pending_tier: Optional[Literal["BASIC", "PRO", "MAX", "BUSINESS"]] = None
|
||||
has_active_stripe_subscription: bool = Field(
|
||||
default=False,
|
||||
description=(
|
||||
"True when the user has an active/trialing Stripe subscription. The"
|
||||
" frontend uses this to branch upgrade UX: modify-in-place + saved-card"
|
||||
" auto-charge when True, redirect to Stripe Checkout when False."
|
||||
),
|
||||
)
|
||||
current_period_end: Optional[int] = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Unix timestamp of the active subscription's current_period_end. Used"
|
||||
" to show the date Stripe will issue the next invoice (with prorated"
|
||||
" upgrade charges, if any). None when no active sub."
|
||||
),
|
||||
)
|
||||
pending_tier: Optional[Literal["NO_TIER", "BASIC", "PRO", "MAX", "BUSINESS"]] = None
|
||||
pending_tier_effective_at: Optional[datetime] = None
|
||||
url: str = Field(
|
||||
default="",
|
||||
@@ -794,8 +822,11 @@ async def get_subscription_status(
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> SubscriptionStatusResponse:
|
||||
user = await get_user_by_id(user_id)
|
||||
tier = user.subscription_tier or SubscriptionTier.BASIC
|
||||
tier = user.subscription_tier or SubscriptionTier.NO_TIER
|
||||
|
||||
# Tiers that *can* have a Stripe price configured (and therefore appear
|
||||
# in the tier picker if the LD flag exposes a price-id). NO_TIER is not
|
||||
# priceable — it's the implicit "no active subscription" state.
|
||||
priceable_tiers = [
|
||||
SubscriptionTier.BASIC,
|
||||
SubscriptionTier.PRO,
|
||||
@@ -816,8 +847,23 @@ async def get_subscription_status(
|
||||
if pid:
|
||||
tier_costs[t.value] = cost
|
||||
|
||||
# Expose the effective rate-limit multipliers alongside prices so the
|
||||
# frontend can render "Nx rate limits" relative to the lowest visible
|
||||
# tier without hard-coding backend defaults. Only emit entries for tiers
|
||||
# that land in ``tier_costs`` — rows hidden at the price layer must stay
|
||||
# hidden in the multiplier layer too.
|
||||
multipliers = await get_tier_multipliers()
|
||||
tier_multipliers: dict[str, float] = {
|
||||
t.value: multipliers.get(t, 1.0)
|
||||
for t in priceable_tiers
|
||||
if t.value in tier_costs
|
||||
}
|
||||
|
||||
current_monthly_cost = tier_costs.get(tier.value, 0)
|
||||
proration_credit = await get_proration_credit_cents(user_id, current_monthly_cost)
|
||||
proration_credit, current_period_end = await asyncio.gather(
|
||||
get_proration_credit_cents(user_id, current_monthly_cost),
|
||||
get_active_subscription_period_end(user_id),
|
||||
)
|
||||
|
||||
try:
|
||||
pending = await get_pending_subscription_change(user_id)
|
||||
@@ -837,11 +883,15 @@ async def get_subscription_status(
|
||||
tier=tier.value,
|
||||
monthly_cost=current_monthly_cost,
|
||||
tier_costs=tier_costs,
|
||||
tier_multipliers=tier_multipliers,
|
||||
proration_credit_cents=proration_credit,
|
||||
has_active_stripe_subscription=current_period_end is not None,
|
||||
current_period_end=current_period_end,
|
||||
)
|
||||
if pending is not None:
|
||||
pending_tier_enum, pending_effective_at = pending
|
||||
if pending_tier_enum in (
|
||||
SubscriptionTier.NO_TIER,
|
||||
SubscriptionTier.BASIC,
|
||||
SubscriptionTier.PRO,
|
||||
SubscriptionTier.MAX,
|
||||
@@ -869,7 +919,7 @@ async def update_subscription_tier(
|
||||
# ENTERPRISE tier is admin-managed — block self-service changes from ENTERPRISE users.
|
||||
user = await get_user_by_id(user_id)
|
||||
if (
|
||||
user.subscription_tier or SubscriptionTier.BASIC
|
||||
user.subscription_tier or SubscriptionTier.NO_TIER
|
||||
) == SubscriptionTier.ENTERPRISE:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
@@ -881,7 +931,7 @@ async def update_subscription_tier(
|
||||
# collapsed behaviour that replaces the old /credits/subscription/cancel-pending
|
||||
# route. Safe when no pending change exists: release_pending_subscription_schedule
|
||||
# returns False and we simply return the current status.
|
||||
if (user.subscription_tier or SubscriptionTier.BASIC) == tier:
|
||||
if (user.subscription_tier or SubscriptionTier.NO_TIER) == tier:
|
||||
try:
|
||||
await release_pending_subscription_schedule(user_id)
|
||||
except stripe.StripeError as e:
|
||||
@@ -903,18 +953,14 @@ async def update_subscription_tier(
|
||||
Flag.ENABLE_PLATFORM_PAYMENT, user_id, default=False
|
||||
)
|
||||
|
||||
current_tier = user.subscription_tier or SubscriptionTier.BASIC
|
||||
target_price_id, current_tier_price_id = await asyncio.gather(
|
||||
get_subscription_price_id(tier),
|
||||
get_subscription_price_id(current_tier),
|
||||
)
|
||||
target_price_id = await get_subscription_price_id(tier)
|
||||
|
||||
# Legacy cancel: target BASIC + stripe-price-id-basic unset. Schedule Stripe
|
||||
# cancellation at period end; cancel_at_period_end=True lets the webhook flip
|
||||
# the DB tier. No active sub (admin-granted) or payment disabled → DB flip.
|
||||
# Once stripe-price-id-basic is configured, BASIC becomes a real sub and falls
|
||||
# through to the modify/checkout flow below.
|
||||
if tier == SubscriptionTier.BASIC and target_price_id is None:
|
||||
# Cancel: target NO_TIER. Schedule Stripe cancellation at period end;
|
||||
# cancel_at_period_end=True lets the webhook flip the DB tier. No active
|
||||
# sub (admin-granted or never-paid) or payment disabled → DB flip.
|
||||
# NO_TIER is never priceable, so this branch always fires for cancel
|
||||
# requests regardless of LD config.
|
||||
if tier == SubscriptionTier.NO_TIER:
|
||||
if payment_enabled:
|
||||
try:
|
||||
had_subscription = await cancel_stripe_subscription(user_id)
|
||||
@@ -950,32 +996,53 @@ async def update_subscription_tier(
|
||||
detail=f"Subscription not available for tier {tier.value}",
|
||||
)
|
||||
|
||||
# User has an active Stripe subscription (current tier has an LD price):
|
||||
# modify it in-place. modify_stripe_subscription_for_tier returns False when no
|
||||
# active sub exists — that's only a "DB-only flip is OK" signal for admin-granted
|
||||
# paid tiers (PRO/BUSINESS with no Stripe record). Priced-BASIC users without a
|
||||
# sub must still go through Checkout so they set up payment.
|
||||
if current_tier_price_id is not None:
|
||||
try:
|
||||
modified = await modify_stripe_subscription_for_tier(user_id, tier)
|
||||
if modified:
|
||||
return await get_subscription_status(user_id)
|
||||
if current_tier != SubscriptionTier.BASIC:
|
||||
await set_subscription_tier(user_id, tier)
|
||||
return await get_subscription_status(user_id)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=422, detail=str(e))
|
||||
except stripe.StripeError as e:
|
||||
logger.exception(
|
||||
"Stripe error modifying subscription for user %s: %s", user_id, e
|
||||
# Modify in place if there's a sub; else fall through to Checkout below.
|
||||
try:
|
||||
modified = await modify_stripe_subscription_for_tier(user_id, tier)
|
||||
if modified:
|
||||
return await get_subscription_status(user_id)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=422, detail=str(e))
|
||||
except stripe.InvalidRequestError as e:
|
||||
# Stripe rejects schedule modify when phases mix currencies, e.g. the
|
||||
# active sub was checked out in GBP but the target tier's Price is
|
||||
# USD-only. 502 reads as outage; surface a 422 with a specific message
|
||||
# so the user/admin can see what to fix in Stripe.
|
||||
msg = str(e)
|
||||
if "currency" in msg.lower():
|
||||
logger.warning(
|
||||
"Currency mismatch on tier change for user %s: %s", user_id, msg
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
status_code=422,
|
||||
detail=(
|
||||
"Unable to update your subscription right now. "
|
||||
"Please try again or contact support."
|
||||
"Tier change unavailable for your current billing currency."
|
||||
" Please contact support — the target tier needs to be"
|
||||
" configured for your currency in Stripe before this"
|
||||
" change can go through."
|
||||
),
|
||||
)
|
||||
logger.exception(
|
||||
"Stripe error modifying subscription for user %s: %s", user_id, e
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail=(
|
||||
"Unable to update your subscription right now. "
|
||||
"Please try again or contact support."
|
||||
),
|
||||
)
|
||||
except stripe.StripeError as e:
|
||||
logger.exception(
|
||||
"Stripe error modifying subscription for user %s: %s", user_id, e
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail=(
|
||||
"Unable to update your subscription right now. "
|
||||
"Please try again or contact support."
|
||||
),
|
||||
)
|
||||
|
||||
# No active Stripe subscription → create Stripe Checkout Session.
|
||||
if not request.success_url or not request.cancel_url:
|
||||
@@ -1111,6 +1178,9 @@ async def stripe_webhook(request: Request):
|
||||
):
|
||||
await sync_subscription_schedule_from_stripe(data_object)
|
||||
|
||||
if event_type == "invoice.payment_succeeded":
|
||||
await handle_subscription_payment_success(data_object)
|
||||
|
||||
if event_type == "invoice.payment_failed":
|
||||
await handle_subscription_payment_failure(data_object)
|
||||
|
||||
|
||||
@@ -34,6 +34,7 @@ import backend.api.features.oauth
|
||||
import backend.api.features.otto.routes
|
||||
import backend.api.features.platform_linking.routes
|
||||
import backend.api.features.postmark.postmark
|
||||
import backend.api.features.push.routes as push_routes
|
||||
import backend.api.features.store.model
|
||||
import backend.api.features.store.routes
|
||||
import backend.api.features.v1
|
||||
@@ -41,6 +42,7 @@ import backend.api.features.workspace.routes as workspace_routes
|
||||
import backend.data.block
|
||||
import backend.data.db
|
||||
import backend.data.graph
|
||||
import backend.data.redis_client
|
||||
import backend.data.user
|
||||
import backend.integrations.webhooks.utils
|
||||
import backend.util.service
|
||||
@@ -95,6 +97,8 @@ async def lifespan_context(app: fastapi.FastAPI):
|
||||
verify_auth_settings()
|
||||
|
||||
await backend.data.db.connect()
|
||||
# Eager connect to fail-fast if Redis is unreachable.
|
||||
await backend.data.redis_client.get_redis_async()
|
||||
|
||||
# Configure thread pool for FastAPI sync operation performance
|
||||
# CRITICAL: FastAPI automatically runs ALL sync functions in this thread pool:
|
||||
@@ -146,7 +150,18 @@ async def lifespan_context(app: fastapi.FastAPI):
|
||||
except Exception as e:
|
||||
logger.warning(f"Error shutting down workspace storage: {e}")
|
||||
|
||||
await backend.data.db.disconnect()
|
||||
# Each cleanup is wrapped so one failure doesn't block the rest. The
|
||||
# Redis close in particular silences asyncio's "Unclosed ClusterNode"
|
||||
# GC warning at interpreter shutdown.
|
||||
try:
|
||||
await backend.data.redis_client.disconnect_async()
|
||||
except Exception:
|
||||
logger.warning("redis_client.disconnect_async failed", exc_info=True)
|
||||
|
||||
try:
|
||||
await backend.data.db.disconnect()
|
||||
except Exception:
|
||||
logger.warning("db.disconnect failed", exc_info=True)
|
||||
|
||||
|
||||
def custom_generate_unique_id(route: APIRoute):
|
||||
@@ -379,6 +394,11 @@ app.include_router(
|
||||
tags=["oauth"],
|
||||
prefix="/api/oauth",
|
||||
)
|
||||
app.include_router(
|
||||
push_routes.router,
|
||||
tags=["push"],
|
||||
prefix="/api/push",
|
||||
)
|
||||
app.include_router(
|
||||
backend.api.features.platform_linking.routes.router,
|
||||
tags=["platform-linking"],
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Protocol
|
||||
@@ -17,14 +16,12 @@ from backend.api.model import (
|
||||
WSSubscribeGraphExecutionsRequest,
|
||||
)
|
||||
from backend.api.utils.cors import build_cors_params
|
||||
from backend.data.execution import AsyncRedisExecutionEventBus
|
||||
from backend.data.notification_bus import AsyncRedisNotificationEventBus
|
||||
from backend.data import db, redis_client
|
||||
from backend.data.user import DEFAULT_USER_ID
|
||||
from backend.monitoring.instrumentation import (
|
||||
instrument_fastapi,
|
||||
update_websocket_connections,
|
||||
)
|
||||
from backend.util.retry import continuous_retry
|
||||
from backend.util.service import AppProcess
|
||||
from backend.util.settings import AppEnvironment, Config, Settings
|
||||
|
||||
@@ -34,10 +31,24 @@ settings = Settings()
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
manager = get_connection_manager()
|
||||
fut = asyncio.create_task(event_broadcaster(manager))
|
||||
fut.add_done_callback(lambda _: logger.info("Event broadcaster stopped"))
|
||||
yield
|
||||
# Prisma is needed to resolve graph_id from graph_exec_id on subscribe.
|
||||
await db.connect()
|
||||
# Eager connect to fail-fast if Redis is unreachable.
|
||||
await redis_client.get_redis_async()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
# Each cleanup is wrapped so one failure doesn't block the rest. The
|
||||
# Redis close silences asyncio's "Unclosed ClusterNode" GC warning at
|
||||
# interpreter shutdown.
|
||||
try:
|
||||
await redis_client.disconnect_async()
|
||||
except Exception:
|
||||
logger.warning("redis_client.disconnect_async failed", exc_info=True)
|
||||
try:
|
||||
await db.disconnect()
|
||||
except Exception:
|
||||
logger.warning("db.disconnect failed", exc_info=True)
|
||||
|
||||
|
||||
docs_url = "/docs" if settings.config.app_env == AppEnvironment.LOCAL else None
|
||||
@@ -61,31 +72,6 @@ def get_connection_manager():
|
||||
return _connection_manager
|
||||
|
||||
|
||||
@continuous_retry()
|
||||
async def event_broadcaster(manager: ConnectionManager):
|
||||
execution_bus = AsyncRedisExecutionEventBus()
|
||||
notification_bus = AsyncRedisNotificationEventBus()
|
||||
|
||||
try:
|
||||
|
||||
async def execution_worker():
|
||||
async for event in execution_bus.listen("*"):
|
||||
await manager.send_execution_update(event)
|
||||
|
||||
async def notification_worker():
|
||||
async for notification in notification_bus.listen("*"):
|
||||
await manager.send_notification(
|
||||
user_id=notification.user_id,
|
||||
payload=notification.payload,
|
||||
)
|
||||
|
||||
await asyncio.gather(execution_worker(), notification_worker())
|
||||
finally:
|
||||
# Ensure PubSub connections are closed on any exit to prevent leaks
|
||||
await execution_bus.close()
|
||||
await notification_bus.close()
|
||||
|
||||
|
||||
async def authenticate_websocket(websocket: WebSocket) -> str:
|
||||
if not settings.config.enable_auth:
|
||||
return DEFAULT_USER_ID
|
||||
@@ -297,6 +283,21 @@ async def websocket_router(
|
||||
).model_dump_json()
|
||||
)
|
||||
continue
|
||||
except ValueError as e:
|
||||
logger.warning(
|
||||
"Subscription rejected for user #%s on '%s': %s",
|
||||
user_id,
|
||||
message.method.value,
|
||||
e,
|
||||
)
|
||||
await websocket.send_text(
|
||||
WSMessage(
|
||||
method=WSMethod.ERROR,
|
||||
success=False,
|
||||
error=str(e),
|
||||
).model_dump_json()
|
||||
)
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error while handling '{message.method.value}' message "
|
||||
@@ -321,9 +322,13 @@ async def websocket_router(
|
||||
)
|
||||
|
||||
except WebSocketDisconnect:
|
||||
manager.disconnect_socket(websocket, user_id=user_id)
|
||||
logger.debug("WebSocket client disconnected")
|
||||
except Exception:
|
||||
logger.exception(f"Unexpected error in websocket_router for user #{user_id}")
|
||||
finally:
|
||||
# Always release subscription pumps + Redis connections, regardless of how
|
||||
# the loop exited — otherwise non-WebSocketDisconnect failures leak both.
|
||||
await manager.disconnect_socket(websocket, user_id=user_id)
|
||||
update_websocket_connections(user_id, -1)
|
||||
|
||||
|
||||
|
||||
@@ -44,9 +44,12 @@ def test_websocket_server_uses_cors_helper(mocker) -> None:
|
||||
"backend.api.ws_api.build_cors_params", return_value=cors_params
|
||||
)
|
||||
|
||||
with override_config(
|
||||
settings, "backend_cors_allow_origins", cors_params["allow_origins"]
|
||||
), override_config(settings, "app_env", AppEnvironment.LOCAL):
|
||||
with (
|
||||
override_config(
|
||||
settings, "backend_cors_allow_origins", cors_params["allow_origins"]
|
||||
),
|
||||
override_config(settings, "app_env", AppEnvironment.LOCAL),
|
||||
):
|
||||
WebsocketServer().run()
|
||||
|
||||
build_cors.assert_called_once_with(
|
||||
@@ -65,9 +68,12 @@ def test_websocket_server_uses_cors_helper(mocker) -> None:
|
||||
def test_websocket_server_blocks_localhost_in_production(mocker) -> None:
|
||||
mocker.patch("backend.api.ws_api.uvicorn.run")
|
||||
|
||||
with override_config(
|
||||
settings, "backend_cors_allow_origins", ["http://localhost:3000"]
|
||||
), override_config(settings, "app_env", AppEnvironment.PRODUCTION):
|
||||
with (
|
||||
override_config(
|
||||
settings, "backend_cors_allow_origins", ["http://localhost:3000"]
|
||||
),
|
||||
override_config(settings, "app_env", AppEnvironment.PRODUCTION),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
WebsocketServer().run()
|
||||
|
||||
@@ -290,7 +296,232 @@ async def test_handle_unsubscribe_missing_data(
|
||||
message=message,
|
||||
)
|
||||
|
||||
mock_manager._unsubscribe.assert_not_called()
|
||||
mock_manager.unsubscribe_graph_exec.assert_not_called()
|
||||
mock_websocket.send_text.assert_called_once()
|
||||
assert '"method":"error"' in mock_websocket.send_text.call_args[0][0]
|
||||
assert '"success":false' in mock_websocket.send_text.call_args[0][0]
|
||||
|
||||
|
||||
# ---------- Per-graph subscribe branch ----------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_subscribe_graph_execs_branch(
|
||||
mock_websocket: AsyncMock, mock_manager: AsyncMock
|
||||
) -> None:
|
||||
"""The SUBSCRIBE_GRAPH_EXECS branch must route to subscribe_graph_execs,
|
||||
not subscribe_graph_exec — regression guard for the aggregate channel."""
|
||||
message = WSMessage(
|
||||
method=WSMethod.SUBSCRIBE_GRAPH_EXECS,
|
||||
data={"graph_id": "graph-abc"},
|
||||
)
|
||||
mock_manager.subscribe_graph_execs.return_value = (
|
||||
"user-1|graph#graph-abc|executions"
|
||||
)
|
||||
|
||||
await handle_subscribe(
|
||||
connection_manager=cast(ConnectionManager, mock_manager),
|
||||
websocket=cast(WebSocket, mock_websocket),
|
||||
user_id="user-1",
|
||||
message=message,
|
||||
)
|
||||
|
||||
mock_manager.subscribe_graph_execs.assert_called_once_with(
|
||||
user_id="user-1",
|
||||
graph_id="graph-abc",
|
||||
websocket=mock_websocket,
|
||||
)
|
||||
mock_manager.subscribe_graph_exec.assert_not_called()
|
||||
mock_websocket.send_text.assert_called_once()
|
||||
assert (
|
||||
'"method":"subscribe_graph_executions"'
|
||||
in mock_websocket.send_text.call_args[0][0]
|
||||
)
|
||||
assert '"success":true' in mock_websocket.send_text.call_args[0][0]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_subscribe_rejects_unrelated_method(
|
||||
mock_websocket: AsyncMock, mock_manager: AsyncMock
|
||||
) -> None:
|
||||
"""handle_subscribe must raise for methods that aren't SUBSCRIBE_*."""
|
||||
import pytest as _pytest
|
||||
|
||||
message = WSMessage(
|
||||
method=WSMethod.HEARTBEAT,
|
||||
data={"graph_exec_id": "x"},
|
||||
)
|
||||
|
||||
with _pytest.raises(ValueError):
|
||||
await handle_subscribe(
|
||||
connection_manager=cast(ConnectionManager, mock_manager),
|
||||
websocket=cast(WebSocket, mock_websocket),
|
||||
user_id="user-1",
|
||||
message=message,
|
||||
)
|
||||
|
||||
|
||||
# ---------- authenticate_websocket branches ----------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authenticate_websocket_missing_token_closes_4001(mocker) -> None:
|
||||
from backend.api.ws_api import authenticate_websocket
|
||||
|
||||
mocker.patch.object(settings.config, "enable_auth", True)
|
||||
ws = AsyncMock(spec=WebSocket)
|
||||
ws.query_params = {}
|
||||
|
||||
user_id = await authenticate_websocket(ws)
|
||||
|
||||
ws.close.assert_awaited_once()
|
||||
assert ws.close.call_args.kwargs["code"] == 4001
|
||||
assert user_id == ""
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authenticate_websocket_invalid_token_closes_4003(mocker) -> None:
|
||||
from backend.api.ws_api import authenticate_websocket
|
||||
|
||||
mocker.patch.object(settings.config, "enable_auth", True)
|
||||
mocker.patch(
|
||||
"backend.api.ws_api.parse_jwt_token", side_effect=ValueError("bad token")
|
||||
)
|
||||
ws = AsyncMock(spec=WebSocket)
|
||||
ws.query_params = {"token": "abc"}
|
||||
|
||||
user_id = await authenticate_websocket(ws)
|
||||
|
||||
ws.close.assert_awaited_once()
|
||||
assert ws.close.call_args.kwargs["code"] == 4003
|
||||
assert user_id == ""
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authenticate_websocket_missing_sub_closes_4002(mocker) -> None:
|
||||
from backend.api.ws_api import authenticate_websocket
|
||||
|
||||
mocker.patch.object(settings.config, "enable_auth", True)
|
||||
mocker.patch("backend.api.ws_api.parse_jwt_token", return_value={"not_sub": "x"})
|
||||
ws = AsyncMock(spec=WebSocket)
|
||||
ws.query_params = {"token": "abc"}
|
||||
|
||||
user_id = await authenticate_websocket(ws)
|
||||
|
||||
ws.close.assert_awaited_once()
|
||||
assert ws.close.call_args.kwargs["code"] == 4002
|
||||
assert user_id == ""
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authenticate_websocket_happy_path_returns_sub(mocker) -> None:
|
||||
from backend.api.ws_api import authenticate_websocket
|
||||
|
||||
mocker.patch.object(settings.config, "enable_auth", True)
|
||||
mocker.patch("backend.api.ws_api.parse_jwt_token", return_value={"sub": "user-X"})
|
||||
ws = AsyncMock(spec=WebSocket)
|
||||
ws.query_params = {"token": "abc"}
|
||||
|
||||
user_id = await authenticate_websocket(ws)
|
||||
|
||||
assert user_id == "user-X"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authenticate_websocket_auth_disabled_returns_default(mocker) -> None:
|
||||
from backend.api.ws_api import authenticate_websocket
|
||||
|
||||
mocker.patch.object(settings.config, "enable_auth", False)
|
||||
ws = AsyncMock(spec=WebSocket)
|
||||
ws.query_params = {}
|
||||
|
||||
user_id = await authenticate_websocket(ws)
|
||||
|
||||
assert user_id == DEFAULT_USER_ID
|
||||
|
||||
|
||||
# ---------- get_connection_manager singleton ----------
|
||||
|
||||
|
||||
def test_get_connection_manager_singleton() -> None:
|
||||
"""Repeated calls must return the same ConnectionManager — the WS router
|
||||
depends on a single process-wide subscription table."""
|
||||
import backend.api.ws_api as ws_api
|
||||
|
||||
ws_api._connection_manager = None
|
||||
a = ws_api.get_connection_manager()
|
||||
b = ws_api.get_connection_manager()
|
||||
assert a is b
|
||||
assert isinstance(a, ConnectionManager)
|
||||
|
||||
|
||||
# ---------- Lifespan: Prisma connect/disconnect ----------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lifespan_connects_and_disconnects_prisma(mocker) -> None:
|
||||
"""Lifespan must both connect() and disconnect() db — the subscribe path
|
||||
resolves graph_id via Prisma so a missing connect() is the regression bug."""
|
||||
from fastapi import FastAPI
|
||||
|
||||
from backend.api.ws_api import lifespan
|
||||
|
||||
mock_db = mocker.patch("backend.api.ws_api.db")
|
||||
mock_db.connect = AsyncMock()
|
||||
mock_db.disconnect = AsyncMock()
|
||||
|
||||
dummy_app = FastAPI()
|
||||
async with lifespan(dummy_app):
|
||||
mock_db.connect.assert_awaited_once()
|
||||
mock_db.disconnect.assert_not_called()
|
||||
mock_db.disconnect.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lifespan_still_disconnects_on_exception(mocker) -> None:
|
||||
"""If the app raises inside the yield, Prisma must still disconnect."""
|
||||
from fastapi import FastAPI
|
||||
|
||||
from backend.api.ws_api import lifespan
|
||||
|
||||
mock_db = mocker.patch("backend.api.ws_api.db")
|
||||
mock_db.connect = AsyncMock()
|
||||
mock_db.disconnect = AsyncMock()
|
||||
|
||||
dummy_app = FastAPI()
|
||||
|
||||
class _Boom(Exception):
|
||||
pass
|
||||
|
||||
with pytest.raises(_Boom):
|
||||
async with lifespan(dummy_app):
|
||||
raise _Boom()
|
||||
|
||||
mock_db.disconnect.assert_awaited_once()
|
||||
|
||||
|
||||
# ---------- Health endpoint ----------
|
||||
|
||||
|
||||
def test_health_endpoint_returns_ok() -> None:
|
||||
# TestClient triggers lifespan — stub it out so Prisma isn't hit.
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
import backend.api.ws_api as ws_api
|
||||
|
||||
@asynccontextmanager
|
||||
async def _noop_lifespan(app):
|
||||
yield
|
||||
|
||||
# Replace the app-level lifespan temporarily.
|
||||
real_router_lifespan = ws_api.app.router.lifespan_context
|
||||
ws_api.app.router.lifespan_context = _noop_lifespan
|
||||
try:
|
||||
with TestClient(ws_api.app) as client:
|
||||
r = client.get("/")
|
||||
assert r.status_code == 200
|
||||
assert r.json() == {"status": "healthy"}
|
||||
finally:
|
||||
ws_api.app.router.lifespan_context = real_router_lifespan
|
||||
|
||||
@@ -38,6 +38,7 @@ def main(**kwargs):
|
||||
|
||||
from backend.api.rest_api import AgentServer
|
||||
from backend.api.ws_api import WebsocketServer
|
||||
from backend.copilot.bot.app import CoPilotChatBridge
|
||||
from backend.copilot.executor.manager import CoPilotExecutor
|
||||
from backend.data.db_manager import DatabaseManager
|
||||
from backend.executor import ExecutionManager, Scheduler
|
||||
@@ -52,6 +53,7 @@ def main(**kwargs):
|
||||
WebsocketServer(),
|
||||
AgentServer(),
|
||||
ExecutionManager(),
|
||||
CoPilotChatBridge(),
|
||||
CoPilotExecutor(),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@@ -0,0 +1,56 @@
|
||||
"""Provider descriptions for services that don't yet have their own ``_config.py``.
|
||||
|
||||
Every provider in ``_STATIC_PROVIDER_CONFIGS`` below is declared here because
|
||||
its block code currently lives either in a single shared file (e.g. the 8 LLM
|
||||
providers in ``blocks/llm.py``) or in a single-file block that has no dedicated
|
||||
directory (e.g. ``blocks/reddit.py``).
|
||||
|
||||
This file gets loaded by the block auto-loader in ``blocks/__init__.py``
|
||||
(``rglob("*.py")`` picks it up) so the ``ProviderBuilder(...).build()`` calls
|
||||
run at startup and populate ``AutoRegistry`` before the first API request.
|
||||
|
||||
**Migration path:** when a provider graduates into its own directory with a
|
||||
proper ``_config.py`` (following the SDK pattern, e.g. ``blocks/linear/_config.py``),
|
||||
delete its entry here. The metadata will still be served by
|
||||
``GET /integrations/providers`` — it just moves to live next to the provider's
|
||||
auth and webhook config.
|
||||
"""
|
||||
|
||||
from backend.data.model import CredentialsType
|
||||
from backend.sdk import ProviderBuilder
|
||||
|
||||
_STATIC_PROVIDER_CONFIGS: dict[str, tuple[str, tuple[CredentialsType, ...]]] = {
|
||||
# LLM providers that share blocks/llm.py
|
||||
"aiml_api": ("Unified access to 100+ AI models", ("api_key",)),
|
||||
"anthropic": ("Claude language models", ("api_key",)),
|
||||
"groq": ("Fast LLM inference", ("api_key",)),
|
||||
"llama_api": ("Llama model hosting", ("api_key",)),
|
||||
"ollama": ("Run open-source LLMs locally", ("api_key",)),
|
||||
"open_router": ("One API for every LLM", ("api_key",)),
|
||||
"openai": ("GPT models and embeddings", ("api_key",)),
|
||||
"v0": ("AI-generated UI components", ("api_key",)),
|
||||
# Single-file providers (one provider per standalone blocks/*.py file)
|
||||
"d_id": ("AI avatar and video generation", ("api_key",)),
|
||||
"e2b": ("Sandboxed code execution", ("api_key",)),
|
||||
"google_maps": ("Places, directions, geocoding", ("api_key",)),
|
||||
"http": ("Generic HTTP requests", ("api_key", "host_scoped")),
|
||||
"ideogram": ("Text-to-image generation", ("api_key",)),
|
||||
"medium": ("Publish stories and posts", ("api_key",)),
|
||||
"mem0": ("Long-term memory for agents", ("api_key",)),
|
||||
"openweathermap": ("Weather data and forecasts", ("api_key",)),
|
||||
"pinecone": ("Managed vector database", ("api_key",)),
|
||||
"reddit": ("Subreddits, posts, and comments", ("oauth2",)),
|
||||
"revid": ("AI-generated short-form video", ("api_key",)),
|
||||
"screenshotone": ("Automated website screenshots", ("api_key",)),
|
||||
"smtp": ("Send email via SMTP", ("user_password",)),
|
||||
"unreal_speech": ("Low-cost text-to-speech", ("api_key",)),
|
||||
"webshare_proxy": ("Rotating proxies for scraping", ("api_key",)),
|
||||
}
|
||||
|
||||
for _name, (_description, _auth_types) in _STATIC_PROVIDER_CONFIGS.items():
|
||||
(
|
||||
ProviderBuilder(_name)
|
||||
.with_description(_description)
|
||||
.with_supported_auth_types(*_auth_types)
|
||||
.build()
|
||||
)
|
||||
@@ -12,6 +12,7 @@ from backend.sdk import APIKeyCredentials, BlockCostType, ProviderBuilder, Secre
|
||||
# past billing. Revisit once AgentMail publishes usage-based pricing.
|
||||
agent_mail = (
|
||||
ProviderBuilder("agent_mail")
|
||||
.with_description("Managed email accounts for agents")
|
||||
.with_api_key("AGENTMAIL_API_KEY", "AgentMail API Key")
|
||||
.with_base_cost(1, BlockCostType.RUN)
|
||||
.build()
|
||||
|
||||
@@ -10,6 +10,7 @@ from ._webhook import AirtableWebhookManager
|
||||
# Configure the Airtable provider with API key authentication
|
||||
airtable = (
|
||||
ProviderBuilder("airtable")
|
||||
.with_description("Bases, tables, and records")
|
||||
.with_api_key("AIRTABLE_API_KEY", "Airtable Personal Access Token")
|
||||
.with_webhook_manager(AirtableWebhookManager)
|
||||
.with_base_cost(1, BlockCostType.RUN)
|
||||
|
||||
15
autogpt_platform/backend/backend/blocks/apollo/_config.py
Normal file
15
autogpt_platform/backend/backend/blocks/apollo/_config.py
Normal file
@@ -0,0 +1,15 @@
|
||||
"""Provider registration for Apollo.
|
||||
|
||||
Registers the provider description shown in the settings integrations UI.
|
||||
Apollo doesn't use a full :class:`ProviderBuilder` chain (auth is set up in
|
||||
``_auth.py``), so this file only declares metadata.
|
||||
"""
|
||||
|
||||
from backend.sdk import ProviderBuilder
|
||||
|
||||
apollo = (
|
||||
ProviderBuilder("apollo")
|
||||
.with_description("Sales intelligence and prospecting")
|
||||
.with_supported_auth_types("api_key")
|
||||
.build()
|
||||
)
|
||||
@@ -7,6 +7,7 @@ import logging
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from pydantic import field_validator
|
||||
from typing_extensions import TypedDict # Needed for Python <3.12 compatibility
|
||||
|
||||
from backend.blocks._base import (
|
||||
@@ -17,6 +18,7 @@ from backend.blocks._base import (
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.copilot.permissions import (
|
||||
DISABLED_LEGACY_TOOL_NAMES,
|
||||
CopilotPermissions,
|
||||
ToolName,
|
||||
all_known_tool_names,
|
||||
@@ -198,6 +200,13 @@ class AutoPilotBlock(Block):
|
||||
# timeouts internally; wrapping with asyncio.timeout corrupts the
|
||||
# SDK's internal stream (see service.py CRITICAL comment).
|
||||
|
||||
@field_validator("tools", mode="before")
|
||||
@classmethod
|
||||
def strip_disabled_legacy_tools(cls, tools: Any) -> Any:
|
||||
if not isinstance(tools, list):
|
||||
return tools
|
||||
return [tool for tool in tools if tool not in DISABLED_LEGACY_TOOL_NAMES]
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
"""Output schema for the AutoPilot block."""
|
||||
|
||||
|
||||
@@ -62,6 +62,14 @@ class TestBuildAndValidatePermissions:
|
||||
with pytest.raises(ValidationError, match="not_a_real_tool"):
|
||||
_make_input(tools=["not_a_real_tool"])
|
||||
|
||||
async def test_disabled_legacy_tool_is_accepted_and_removed(self):
|
||||
inp = _make_input(tools=["ask_question", "run_block"])
|
||||
result = await _build_and_validate_permissions(inp)
|
||||
|
||||
assert inp.tools == ["run_block"]
|
||||
assert isinstance(result, CopilotPermissions)
|
||||
assert result.tools == ["run_block"]
|
||||
|
||||
async def test_valid_block_name_accepted(self):
|
||||
mock_block_cls = MagicMock()
|
||||
mock_block_cls.return_value.name = "HTTP Request"
|
||||
|
||||
@@ -18,4 +18,9 @@ reach a block as a "profile key".
|
||||
|
||||
from backend.sdk import ProviderBuilder
|
||||
|
||||
ayrshare = ProviderBuilder("ayrshare").with_managed_api_key().build()
|
||||
ayrshare = (
|
||||
ProviderBuilder("ayrshare")
|
||||
.with_description("Post to every social network")
|
||||
.with_managed_api_key()
|
||||
.build()
|
||||
)
|
||||
|
||||
@@ -7,6 +7,7 @@ from backend.sdk import BlockCostType, ProviderBuilder
|
||||
# Configure the Meeting BaaS provider with API key authentication
|
||||
baas = (
|
||||
ProviderBuilder("baas")
|
||||
.with_description("Meeting recording and transcription")
|
||||
.with_api_key("MEETING_BAAS_API_KEY", "Meeting BaaS API Key")
|
||||
.with_base_cost(5, BlockCostType.RUN) # Higher cost for meeting recording service
|
||||
.build()
|
||||
|
||||
@@ -4,6 +4,7 @@ Meeting BaaS bot (recording) blocks.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from backend.data.model import NodeExecutionStats
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
@@ -21,13 +22,15 @@ from backend.sdk import (
|
||||
from ._api import MeetingBaasAPI
|
||||
from ._config import baas
|
||||
|
||||
# Meeting BaaS recording rate: $0.69 per hour.
|
||||
_MEETING_BAAS_USD_PER_SECOND = 0.69 / 3600
|
||||
|
||||
# Join bills a flat 30 cr commit (covers median short meeting);
|
||||
# FetchMeetingData bills the duration-scaled remainder from the
|
||||
# `duration_seconds` field on the API response. Long meetings no
|
||||
# longer under-bill.
|
||||
|
||||
|
||||
# Meeting BaaS charges $0.69/hour of recording. The Join block is the
|
||||
# trigger that starts the recording session; the meeting itself runs out
|
||||
# of band (we don't get duration back from the FetchMeetingData response
|
||||
# we use). 30 cr ≈ $0.30 covers a median 30-minute meeting with margin.
|
||||
# Interim until FetchMeetingData surfaces duration for post-flight
|
||||
# reconciliation.
|
||||
@cost(BlockCost(cost_type=BlockCostType.RUN, cost_amount=30))
|
||||
class BaasBotJoinMeetingBlock(Block):
|
||||
"""
|
||||
@@ -144,6 +147,7 @@ class BaasBotLeaveMeetingBlock(Block):
|
||||
yield "left", left
|
||||
|
||||
|
||||
@cost(BlockCost(cost_type=BlockCostType.COST_USD, cost_amount=150))
|
||||
class BaasBotFetchMeetingDataBlock(Block):
|
||||
"""
|
||||
Pull MP4 URL, transcript & metadata for a completed meeting.
|
||||
@@ -186,9 +190,21 @@ class BaasBotFetchMeetingDataBlock(Block):
|
||||
include_transcripts=input_data.include_transcripts,
|
||||
)
|
||||
|
||||
bot_meta = data.get("bot_data", {}).get("bot", {}) or {}
|
||||
# Bill recording duration via COST_USD so multi-hour meetings
|
||||
# scale past the Join block's flat 30 cr deposit.
|
||||
duration_seconds = float(bot_meta.get("duration_seconds") or 0)
|
||||
if duration_seconds > 0:
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(
|
||||
provider_cost=duration_seconds * _MEETING_BAAS_USD_PER_SECOND,
|
||||
provider_cost_type="cost_usd",
|
||||
)
|
||||
)
|
||||
|
||||
yield "mp4_url", data.get("mp4", "")
|
||||
yield "transcript", data.get("bot_data", {}).get("transcripts", [])
|
||||
yield "metadata", data.get("bot_data", {}).get("bot", {})
|
||||
yield "metadata", bot_meta
|
||||
|
||||
|
||||
class BaasBotDeleteRecordingBlock(Block):
|
||||
|
||||
@@ -0,0 +1,86 @@
|
||||
"""Unit tests for Meeting BaaS duration-based cost emission."""
|
||||
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.blocks.baas.bots import (
|
||||
_MEETING_BAAS_USD_PER_SECOND,
|
||||
BaasBotFetchMeetingDataBlock,
|
||||
)
|
||||
from backend.data.model import APIKeyCredentials, NodeExecutionStats
|
||||
|
||||
TEST_CREDENTIALS = APIKeyCredentials(
|
||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||
provider="baas",
|
||||
title="Mock BaaS API Key",
|
||||
api_key=SecretStr("mock-baas-api-key"),
|
||||
expires_at=None,
|
||||
)
|
||||
|
||||
|
||||
def test_usd_per_second_derives_from_published_rate():
|
||||
"""$0.69/hour published rate → ~$0.000192/second."""
|
||||
assert _MEETING_BAAS_USD_PER_SECOND == pytest.approx(0.69 / 3600)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"duration_seconds, expected_usd",
|
||||
[
|
||||
(3600, 0.69), # 1 hour
|
||||
(1800, 0.345), # 30 min
|
||||
(0, None), # no recording → no emission
|
||||
(None, None), # missing duration field → no emission
|
||||
],
|
||||
)
|
||||
async def test_fetch_meeting_data_emits_duration_cost_usd(
|
||||
duration_seconds, expected_usd
|
||||
):
|
||||
"""FetchMeetingData extracts duration_seconds from bot metadata and
|
||||
emits provider_cost / cost_usd scaled by the published $0.69/hr rate.
|
||||
Emission is skipped when duration is 0 or missing.
|
||||
"""
|
||||
block = BaasBotFetchMeetingDataBlock()
|
||||
|
||||
bot_meta = {"id": "bot-xyz"}
|
||||
if duration_seconds is not None:
|
||||
bot_meta["duration_seconds"] = duration_seconds
|
||||
|
||||
mock_api = AsyncMock()
|
||||
mock_api.get_meeting_data.return_value = {
|
||||
"mp4": "https://example/recording.mp4",
|
||||
"bot_data": {"bot": bot_meta, "transcripts": []},
|
||||
}
|
||||
|
||||
captured: list[NodeExecutionStats] = []
|
||||
with (
|
||||
patch("backend.blocks.baas.bots.MeetingBaasAPI", return_value=mock_api),
|
||||
patch.object(block, "merge_stats", side_effect=captured.append),
|
||||
):
|
||||
outputs = []
|
||||
async for name, val in block.run(
|
||||
block.input_schema(
|
||||
credentials={
|
||||
"id": TEST_CREDENTIALS.id,
|
||||
"provider": TEST_CREDENTIALS.provider,
|
||||
"type": TEST_CREDENTIALS.type,
|
||||
},
|
||||
bot_id="bot-xyz",
|
||||
include_transcripts=False,
|
||||
),
|
||||
credentials=TEST_CREDENTIALS,
|
||||
):
|
||||
outputs.append((name, val))
|
||||
|
||||
# Always yields the 3 outputs regardless of duration.
|
||||
names = [n for n, _ in outputs]
|
||||
assert "mp4_url" in names and "metadata" in names
|
||||
|
||||
if expected_usd is None:
|
||||
assert captured == []
|
||||
else:
|
||||
assert len(captured) == 1
|
||||
assert captured[0].provider_cost == pytest.approx(expected_usd)
|
||||
assert captured[0].provider_cost_type == "cost_usd"
|
||||
@@ -2,6 +2,7 @@ from backend.sdk import BlockCostType, ProviderBuilder
|
||||
|
||||
bannerbear = (
|
||||
ProviderBuilder("bannerbear")
|
||||
.with_description("Auto-generate images and videos")
|
||||
.with_api_key("BANNERBEAR_API_KEY", "Bannerbear API Key")
|
||||
.with_base_cost(3, BlockCostType.RUN)
|
||||
.build()
|
||||
|
||||
@@ -433,7 +433,7 @@ class TestJinaEmbeddingBlockCostTracking:
|
||||
class TestUnrealTextToSpeechBlockCostTracking:
|
||||
@pytest.mark.asyncio
|
||||
async def test_merge_stats_called_with_character_count(self):
|
||||
"""provider_cost equals len(text) with type='characters'."""
|
||||
"""provider_cost = len(text) * $0.000016 with type='cost_usd'."""
|
||||
from backend.blocks.text_to_speech_block import TEST_CREDENTIALS as TTS_CREDS
|
||||
from backend.blocks.text_to_speech_block import (
|
||||
TEST_CREDENTIALS_INPUT as TTS_CREDS_INPUT,
|
||||
@@ -461,12 +461,12 @@ class TestUnrealTextToSpeechBlockCostTracking:
|
||||
|
||||
mock_merge.assert_called_once()
|
||||
stats = mock_merge.call_args[0][0]
|
||||
assert stats.provider_cost == float(len(test_text))
|
||||
assert stats.provider_cost_type == "characters"
|
||||
assert stats.provider_cost == pytest.approx(len(test_text) * 0.000016)
|
||||
assert stats.provider_cost_type == "cost_usd"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_text_gives_zero_characters(self):
|
||||
"""An empty text string results in provider_cost=0.0."""
|
||||
"""An empty text string results in provider_cost=0.0 (cost_usd)."""
|
||||
from backend.blocks.text_to_speech_block import TEST_CREDENTIALS as TTS_CREDS
|
||||
from backend.blocks.text_to_speech_block import (
|
||||
TEST_CREDENTIALS_INPUT as TTS_CREDS_INPUT,
|
||||
@@ -494,7 +494,7 @@ class TestUnrealTextToSpeechBlockCostTracking:
|
||||
mock_merge.assert_called_once()
|
||||
stats = mock_merge.call_args[0][0]
|
||||
assert stats.provider_cost == 0.0
|
||||
assert stats.provider_cost_type == "characters"
|
||||
assert stats.provider_cost_type == "cost_usd"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -17,6 +17,7 @@ from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
CredentialsField,
|
||||
CredentialsMetaInput,
|
||||
NodeExecutionStats,
|
||||
SchemaField,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
@@ -431,6 +432,7 @@ class ClaudeCodeBlock(Block):
|
||||
# The JSON output contains the result
|
||||
output_data = json.loads(raw_output)
|
||||
response = output_data.get("result", raw_output)
|
||||
self._record_cli_cost(output_data)
|
||||
|
||||
# Build conversation history entry
|
||||
turn_entry = f"User: {prompt}\nClaude: {response}"
|
||||
@@ -484,6 +486,23 @@ class ClaudeCodeBlock(Block):
|
||||
escaped = prompt.replace("'", "'\"'\"'")
|
||||
return f"'{escaped}'"
|
||||
|
||||
def _record_cli_cost(self, output_data: dict) -> None:
|
||||
"""Feed Claude Code CLI's `total_cost_usd` to the COST_USD resolver.
|
||||
|
||||
The CLI rolls up Anthropic LLM + internal tool-call spend into
|
||||
``total_cost_usd`` on its JSON response; piping it through
|
||||
``merge_stats`` lets the wallet reflect real spend.
|
||||
"""
|
||||
total_cost_usd = output_data.get("total_cost_usd")
|
||||
if total_cost_usd is None:
|
||||
return
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(
|
||||
provider_cost=float(total_cost_usd),
|
||||
provider_cost_type="cost_usd",
|
||||
)
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
|
||||
106
autogpt_platform/backend/backend/blocks/claude_code_cost_test.py
Normal file
106
autogpt_platform/backend/backend/blocks/claude_code_cost_test.py
Normal file
@@ -0,0 +1,106 @@
|
||||
"""Unit tests for ClaudeCodeBlock COST_USD billing migration.
|
||||
|
||||
Verifies:
|
||||
- Block emits provider_cost / cost_usd when Claude Code CLI returns
|
||||
total_cost_usd.
|
||||
- block_usage_cost resolves the COST_USD entry to the expected ceil(usd *
|
||||
cost_amount) credit charge.
|
||||
- Missing total_cost_usd gracefully produces provider_cost=None (no bill).
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.blocks._base import BlockCostType
|
||||
from backend.blocks.claude_code import ClaudeCodeBlock
|
||||
from backend.data.block_cost_config import BLOCK_COSTS
|
||||
from backend.data.model import NodeExecutionStats
|
||||
from backend.executor.utils import block_usage_cost
|
||||
|
||||
|
||||
def test_claude_code_registered_as_cost_usd_150():
|
||||
"""Sanity: BLOCK_COSTS holds the COST_USD, 150 cr/$ entry."""
|
||||
entries = BLOCK_COSTS[ClaudeCodeBlock]
|
||||
assert len(entries) == 1
|
||||
entry = entries[0]
|
||||
assert entry.cost_type == BlockCostType.COST_USD
|
||||
assert entry.cost_amount == 150
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"total_cost_usd, expected_credits",
|
||||
[
|
||||
(0.50, 75), # $0.50 × 150 = 75 cr
|
||||
(1.00, 150), # $1.00 × 150 = 150 cr
|
||||
(0.0134, 3), # ceil(0.0134 × 150) = ceil(2.01) = 3
|
||||
(2.00, 300), # $2 × 150 = 300 cr
|
||||
(0.001, 1), # ceil(0.001 × 150) = ceil(0.15) = 1 — no 0-cr leak on
|
||||
# sub-cent runs
|
||||
],
|
||||
)
|
||||
def test_cost_usd_resolver_applies_150_multiplier(total_cost_usd, expected_credits):
|
||||
"""block_usage_cost with cost_usd stats returns ceil(usd * 150)."""
|
||||
block = ClaudeCodeBlock()
|
||||
# cost_filter requires matching e2b_credentials; supply the ones the
|
||||
# registration uses so _is_cost_filter_match accepts the input.
|
||||
entry = BLOCK_COSTS[ClaudeCodeBlock][0]
|
||||
input_data = {"e2b_credentials": entry.cost_filter["e2b_credentials"]}
|
||||
stats = NodeExecutionStats(
|
||||
provider_cost=total_cost_usd,
|
||||
provider_cost_type="cost_usd",
|
||||
)
|
||||
cost, matching_filter = block_usage_cost(
|
||||
block=block, input_data=input_data, stats=stats
|
||||
)
|
||||
assert cost == expected_credits
|
||||
assert matching_filter == entry.cost_filter
|
||||
|
||||
|
||||
def test_cost_usd_resolver_returns_zero_when_stats_missing_cost():
|
||||
"""Pre-flight (no stats) or unbilled run (provider_cost None) → 0."""
|
||||
block = ClaudeCodeBlock()
|
||||
entry = BLOCK_COSTS[ClaudeCodeBlock][0]
|
||||
input_data = {"e2b_credentials": entry.cost_filter["e2b_credentials"]}
|
||||
# No stats at all → pre-flight path, returns 0.
|
||||
pre_cost, _ = block_usage_cost(block=block, input_data=input_data)
|
||||
assert pre_cost == 0
|
||||
# Stats present but no provider_cost → resolver can't bill.
|
||||
stats = NodeExecutionStats()
|
||||
post_cost, _ = block_usage_cost(block=block, input_data=input_data, stats=stats)
|
||||
assert post_cost == 0
|
||||
|
||||
|
||||
def test_record_cli_cost_emits_provider_cost_when_total_cost_present():
|
||||
"""``_record_cli_cost`` (the helper called from ``execute_claude_code``)
|
||||
must emit a single ``merge_stats`` with provider_cost + cost_usd tag
|
||||
when the CLI JSON payload carries ``total_cost_usd``.
|
||||
"""
|
||||
block = ClaudeCodeBlock()
|
||||
captured: list[NodeExecutionStats] = []
|
||||
with patch.object(block, "merge_stats", side_effect=captured.append):
|
||||
block._record_cli_cost(
|
||||
{
|
||||
"result": "hello from claude",
|
||||
"total_cost_usd": 0.0421,
|
||||
"usage": {"input_tokens": 1234, "output_tokens": 56},
|
||||
}
|
||||
)
|
||||
|
||||
assert len(captured) == 1
|
||||
stats = captured[0]
|
||||
assert stats.provider_cost == pytest.approx(0.0421)
|
||||
assert stats.provider_cost_type == "cost_usd"
|
||||
|
||||
|
||||
def test_record_cli_cost_skips_merge_when_total_cost_absent():
|
||||
"""If the CLI payload lacks ``total_cost_usd`` (legacy / non-JSON
|
||||
output), ``_record_cli_cost`` must not call ``merge_stats`` — otherwise
|
||||
we'd pollute telemetry with a ``cost_usd`` emission that has no real
|
||||
cost attached.
|
||||
"""
|
||||
block = ClaudeCodeBlock()
|
||||
mock = MagicMock()
|
||||
with patch.object(block, "merge_stats", mock):
|
||||
block._record_cli_cost({"result": "hello"})
|
||||
mock.assert_not_called()
|
||||
@@ -151,6 +151,17 @@ class CodeGenerationBlock(Block):
|
||||
)
|
||||
self.execution_stats = NodeExecutionStats()
|
||||
|
||||
# GPT-5.1-Codex published pricing: $1.25 / 1M input, $10 / 1M output.
|
||||
_INPUT_USD_PER_1M = 1.25
|
||||
_OUTPUT_USD_PER_1M = 10.0
|
||||
|
||||
@staticmethod
|
||||
def _compute_token_usd(input_tokens: int, output_tokens: int) -> float:
|
||||
return (
|
||||
input_tokens * CodeGenerationBlock._INPUT_USD_PER_1M
|
||||
+ output_tokens * CodeGenerationBlock._OUTPUT_USD_PER_1M
|
||||
) / 1_000_000
|
||||
|
||||
async def call_codex(
|
||||
self,
|
||||
*,
|
||||
@@ -189,13 +200,15 @@ class CodeGenerationBlock(Block):
|
||||
response_id = response.id or ""
|
||||
|
||||
# Update usage stats
|
||||
self.execution_stats.input_token_count = (
|
||||
response.usage.input_tokens if response.usage else 0
|
||||
)
|
||||
self.execution_stats.output_token_count = (
|
||||
response.usage.output_tokens if response.usage else 0
|
||||
)
|
||||
input_tokens = response.usage.input_tokens if response.usage else 0
|
||||
output_tokens = response.usage.output_tokens if response.usage else 0
|
||||
self.execution_stats.input_token_count = input_tokens
|
||||
self.execution_stats.output_token_count = output_tokens
|
||||
self.execution_stats.llm_call_count += 1
|
||||
self.execution_stats.provider_cost = self._compute_token_usd(
|
||||
input_tokens, output_tokens
|
||||
)
|
||||
self.execution_stats.provider_cost_type = "cost_usd"
|
||||
|
||||
return CodexCallResult(
|
||||
response=text_output,
|
||||
|
||||
10
autogpt_platform/backend/backend/blocks/compass/_config.py
Normal file
10
autogpt_platform/backend/backend/blocks/compass/_config.py
Normal file
@@ -0,0 +1,10 @@
|
||||
"""Provider registration for Compass — metadata only (auth lives elsewhere)."""
|
||||
|
||||
from backend.sdk import ProviderBuilder
|
||||
|
||||
compass = (
|
||||
ProviderBuilder("compass")
|
||||
.with_description("Geospatial context for agents")
|
||||
.with_supported_auth_types("api_key")
|
||||
.build()
|
||||
)
|
||||
226
autogpt_platform/backend/backend/blocks/cost_leak_fixes_test.py
Normal file
226
autogpt_platform/backend/backend/blocks/cost_leak_fixes_test.py
Normal file
@@ -0,0 +1,226 @@
|
||||
"""Coverage tests for the cost-leak fixes in this PR.
|
||||
|
||||
Each block's ``run()`` / helper emits provider_cost + cost_usd (or items)
|
||||
via merge_stats so the post-flight resolver bills real provider spend.
|
||||
Tests here drive that emission path directly so a regression on any one
|
||||
block surfaces immediately.
|
||||
"""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.blocks._base import BlockCostType
|
||||
from backend.blocks.ai_condition import AIConditionBlock
|
||||
from backend.data.block_cost_config import BLOCK_COSTS, LLM_COST
|
||||
from backend.data.model import APIKeyCredentials, NodeExecutionStats
|
||||
|
||||
# -------- AIConditionBlock registration --------
|
||||
|
||||
|
||||
def test_ai_condition_registered_under_llm_cost():
|
||||
"""AIConditionBlock was running wallet-free before this PR; verify it
|
||||
now resolves through the same per-model LLM_COST table as every other
|
||||
LLM block.
|
||||
"""
|
||||
assert BLOCK_COSTS[AIConditionBlock] is LLM_COST
|
||||
|
||||
|
||||
# -------- Pinecone insert ITEMS emission --------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pinecone_insert_emits_items_provider_cost():
|
||||
from backend.blocks.pinecone import PineconeInsertBlock
|
||||
|
||||
block = PineconeInsertBlock()
|
||||
captured: list[NodeExecutionStats] = []
|
||||
|
||||
class _FakeIndex:
|
||||
def upsert(self, **_):
|
||||
return None
|
||||
|
||||
class _FakePinecone:
|
||||
def __init__(self, *_, **__):
|
||||
pass
|
||||
|
||||
def Index(self, _name):
|
||||
return _FakeIndex()
|
||||
|
||||
with (
|
||||
patch("backend.blocks.pinecone.Pinecone", _FakePinecone),
|
||||
patch.object(block, "merge_stats", side_effect=captured.append),
|
||||
):
|
||||
input_data = block.input_schema(
|
||||
credentials={
|
||||
"id": "00000000-0000-0000-0000-000000000000",
|
||||
"provider": "pinecone",
|
||||
"type": "api_key",
|
||||
},
|
||||
index="my-index",
|
||||
chunks=["alpha", "beta", "gamma"],
|
||||
embeddings=[[0.1] * 4, [0.2] * 4, [0.3] * 4],
|
||||
namespace="",
|
||||
metadata={},
|
||||
)
|
||||
|
||||
creds = APIKeyCredentials(
|
||||
id="00000000-0000-0000-0000-000000000000",
|
||||
provider="pinecone",
|
||||
title="mock",
|
||||
api_key=SecretStr("mock-key"),
|
||||
expires_at=None,
|
||||
)
|
||||
outputs = [(n, v) async for n, v in block.run(input_data, credentials=creds)]
|
||||
|
||||
assert any(name == "upsert_response" for name, _ in outputs)
|
||||
assert len(captured) == 1
|
||||
stats = captured[0]
|
||||
assert stats.provider_cost == pytest.approx(3.0)
|
||||
assert stats.provider_cost_type == "items"
|
||||
|
||||
|
||||
# -------- Narration model-aware per-char rate --------
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model_id, expected_rate_per_char",
|
||||
[
|
||||
("eleven_flash_v2_5", 0.000167 * 0.5),
|
||||
("eleven_turbo_v2_5", 0.000167 * 0.5),
|
||||
("eleven_multilingual_v2", 0.000167 * 1.0),
|
||||
("eleven_turbo_v2", 0.000167 * 1.0),
|
||||
],
|
||||
)
|
||||
def test_narration_per_char_rate_scales_with_model(model_id, expected_rate_per_char):
|
||||
"""Drive VideoNarrationBlock._record_script_cost directly so a regression
|
||||
that drops the model-aware branching (e.g. hardcoding 1.0 cr/char for
|
||||
all models) makes this test fail.
|
||||
"""
|
||||
from backend.blocks.video.narration import VideoNarrationBlock
|
||||
|
||||
block = VideoNarrationBlock()
|
||||
captured: list[NodeExecutionStats] = []
|
||||
with patch.object(block, "merge_stats", side_effect=captured.append):
|
||||
block._record_script_cost("x" * 5000, model_id)
|
||||
|
||||
assert len(captured) == 1
|
||||
stats = captured[0]
|
||||
assert stats.provider_cost == pytest.approx(5000 * expected_rate_per_char)
|
||||
assert stats.provider_cost_type == "cost_usd"
|
||||
|
||||
|
||||
# -------- Perplexity None-guard on x-total-cost --------
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"openrouter_cost, expect_type",
|
||||
[
|
||||
(0.0421, "cost_usd"), # concrete positive USD → tagged
|
||||
(None, None), # header missing → no tag (keeps gap observable)
|
||||
(0.0, None), # zero → no tag (wouldn't bill anything anyway)
|
||||
],
|
||||
)
|
||||
def test_perplexity_record_openrouter_cost_tags_only_on_concrete_value(
|
||||
openrouter_cost, expect_type
|
||||
):
|
||||
"""Drive PerplexityBlock._record_openrouter_cost directly to verify the
|
||||
None/0 guard. A regression that tags cost_usd unconditionally would
|
||||
silently floor the user's bill to 0 via the resolver — this test
|
||||
would catch it.
|
||||
"""
|
||||
from backend.blocks.perplexity import PerplexityBlock
|
||||
|
||||
block = PerplexityBlock()
|
||||
with patch(
|
||||
"backend.blocks.perplexity.extract_openrouter_cost",
|
||||
return_value=openrouter_cost,
|
||||
):
|
||||
block._record_openrouter_cost(response=object())
|
||||
|
||||
assert block.execution_stats.provider_cost == openrouter_cost
|
||||
assert block.execution_stats.provider_cost_type == expect_type
|
||||
|
||||
|
||||
# -------- Codex COST_USD registration --------
|
||||
|
||||
|
||||
def test_codex_registered_as_cost_usd_150():
|
||||
from backend.blocks.codex import CodeGenerationBlock
|
||||
|
||||
entries = BLOCK_COSTS[CodeGenerationBlock]
|
||||
assert len(entries) == 1
|
||||
entry = entries[0]
|
||||
assert entry.cost_type == BlockCostType.COST_USD
|
||||
assert entry.cost_amount == 150
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input_tokens, output_tokens, expected_usd",
|
||||
[
|
||||
# GPT-5.1-Codex: $1.25 / 1M input, $10 / 1M output.
|
||||
(1_000_000, 0, 1.25),
|
||||
(0, 1_000_000, 10.0),
|
||||
(100_000, 10_000, 0.225), # 0.125 + 0.100
|
||||
(0, 0, 0.0),
|
||||
],
|
||||
)
|
||||
def test_codex_computes_provider_cost_usd_from_token_counts(
|
||||
input_tokens, output_tokens, expected_usd
|
||||
):
|
||||
"""Drive CodeGenerationBlock._compute_token_usd directly. A regression
|
||||
to the wrong rate constants (e.g. swapping the $1.25 input rate for
|
||||
GPT-4o's $2.50) would fail this test.
|
||||
"""
|
||||
from backend.blocks.codex import CodeGenerationBlock
|
||||
|
||||
assert CodeGenerationBlock._compute_token_usd(
|
||||
input_tokens, output_tokens
|
||||
) == pytest.approx(expected_usd)
|
||||
|
||||
|
||||
# -------- ClaudeCode COST_USD registration sanity (already tested in claude_code_cost_test.py) --------
|
||||
|
||||
|
||||
# -------- Perplexity COST_USD registration for all 3 tiers --------
|
||||
|
||||
|
||||
def test_perplexity_sonar_all_tiers_registered_as_cost_usd_150():
|
||||
from backend.blocks.perplexity import PerplexityBlock
|
||||
|
||||
entries = BLOCK_COSTS[PerplexityBlock]
|
||||
# 3 tiers (SONAR, SONAR_PRO, SONAR_DEEP_RESEARCH) all COST_USD 150.
|
||||
assert len(entries) == 3
|
||||
for entry in entries:
|
||||
assert entry.cost_type == BlockCostType.COST_USD
|
||||
assert entry.cost_amount == 150
|
||||
|
||||
|
||||
# -------- Narration COST_USD registration --------
|
||||
|
||||
|
||||
def test_narration_registered_as_cost_usd_150():
|
||||
from backend.blocks.video.narration import VideoNarrationBlock
|
||||
|
||||
entries = BLOCK_COSTS[VideoNarrationBlock]
|
||||
assert len(entries) == 1
|
||||
assert entries[0].cost_type == BlockCostType.COST_USD
|
||||
assert entries[0].cost_amount == 150
|
||||
|
||||
|
||||
# -------- Pinecone registrations --------
|
||||
|
||||
|
||||
def test_pinecone_registrations():
|
||||
from backend.blocks.pinecone import (
|
||||
PineconeInitBlock,
|
||||
PineconeInsertBlock,
|
||||
PineconeQueryBlock,
|
||||
)
|
||||
|
||||
assert BLOCK_COSTS[PineconeInitBlock][0].cost_type == BlockCostType.RUN
|
||||
assert BLOCK_COSTS[PineconeQueryBlock][0].cost_type == BlockCostType.RUN
|
||||
# Insert scales with item count.
|
||||
assert BLOCK_COSTS[PineconeInsertBlock][0].cost_type == BlockCostType.ITEMS
|
||||
assert BLOCK_COSTS[PineconeInsertBlock][0].cost_amount == 1
|
||||
@@ -7,6 +7,7 @@ from backend.sdk import BlockCostType, ProviderBuilder
|
||||
# Build the DataForSEO provider with username/password authentication
|
||||
dataforseo = (
|
||||
ProviderBuilder("dataforseo")
|
||||
.with_description("SEO and SERP data")
|
||||
.with_user_password(
|
||||
username_env_var="DATAFORSEO_USERNAME",
|
||||
password_env_var="DATAFORSEO_PASSWORD",
|
||||
|
||||
10
autogpt_platform/backend/backend/blocks/discord/_config.py
Normal file
10
autogpt_platform/backend/backend/blocks/discord/_config.py
Normal file
@@ -0,0 +1,10 @@
|
||||
"""Provider registration for Discord — metadata only (auth lives in ``_auth.py``)."""
|
||||
|
||||
from backend.sdk import ProviderBuilder
|
||||
|
||||
discord = (
|
||||
ProviderBuilder("discord")
|
||||
.with_description("Messages, channels, and servers")
|
||||
.with_supported_auth_types("api_key", "oauth2")
|
||||
.build()
|
||||
)
|
||||
@@ -0,0 +1,10 @@
|
||||
"""Provider registration for ElevenLabs — metadata only (auth lives in ``_auth.py``)."""
|
||||
|
||||
from backend.sdk import ProviderBuilder
|
||||
|
||||
elevenlabs = (
|
||||
ProviderBuilder("elevenlabs")
|
||||
.with_description("Realistic AI voice synthesis")
|
||||
.with_supported_auth_types("api_key")
|
||||
.build()
|
||||
)
|
||||
@@ -0,0 +1,10 @@
|
||||
"""Provider registration for Enrichlayer — metadata only (auth lives in ``_auth.py``)."""
|
||||
|
||||
from backend.sdk import ProviderBuilder
|
||||
|
||||
enrichlayer = (
|
||||
ProviderBuilder("enrichlayer")
|
||||
.with_description("Enrich leads with company data")
|
||||
.with_supported_auth_types("api_key")
|
||||
.build()
|
||||
)
|
||||
@@ -9,6 +9,7 @@ from ._webhook import ExaWebhookManager
|
||||
# Configure the Exa provider once for all blocks
|
||||
exa = (
|
||||
ProviderBuilder("exa")
|
||||
.with_description("Neural web search")
|
||||
.with_api_key("EXA_API_KEY", "Exa API Key")
|
||||
.with_webhook_manager(ExaWebhookManager)
|
||||
# Exa returns `cost_dollars.total` on every response and ExaSearchBlock
|
||||
|
||||
@@ -17,6 +17,7 @@ from backend.sdk import (
|
||||
)
|
||||
|
||||
from ._config import exa
|
||||
from .helpers import merge_exa_cost
|
||||
|
||||
|
||||
class AnswerCitation(BaseModel):
|
||||
@@ -111,3 +112,7 @@ class ExaAnswerBlock(Block):
|
||||
yield "citations", citations
|
||||
for citation in citations:
|
||||
yield "citation", citation
|
||||
|
||||
# Current SDK AnswerResponse dataclass omits cost_dollars; helper
|
||||
# no-ops today, but keeps billing wired when exa_py adds the field.
|
||||
merge_exa_cost(self, response)
|
||||
|
||||
@@ -9,7 +9,6 @@ from typing import Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.model import NodeExecutionStats
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
@@ -23,6 +22,7 @@ from backend.sdk import (
|
||||
)
|
||||
|
||||
from ._config import exa
|
||||
from .helpers import merge_exa_cost
|
||||
|
||||
|
||||
class CodeContextResponse(BaseModel):
|
||||
@@ -118,9 +118,5 @@ class ExaCodeContextBlock(Block):
|
||||
yield "search_time", context.search_time
|
||||
yield "output_tokens", context.output_tokens
|
||||
|
||||
# Parse cost_dollars (API returns as string, e.g. "0.005")
|
||||
try:
|
||||
cost_usd = float(context.cost_dollars)
|
||||
self.merge_stats(NodeExecutionStats(provider_cost=cost_usd))
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
# API returns costDollars as a bare numeric string like "0.005".
|
||||
merge_exa_cost(self, data)
|
||||
|
||||
@@ -4,7 +4,6 @@ from typing import Optional
|
||||
from exa_py import AsyncExa
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.model import NodeExecutionStats
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
@@ -24,6 +23,7 @@ from .helpers import (
|
||||
HighlightSettings,
|
||||
LivecrawlTypes,
|
||||
SummarySettings,
|
||||
merge_exa_cost,
|
||||
)
|
||||
|
||||
|
||||
@@ -224,6 +224,4 @@ class ExaContentsBlock(Block):
|
||||
|
||||
if response.cost_dollars:
|
||||
yield "cost_dollars", response.cost_dollars
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(provider_cost=response.cost_dollars.total)
|
||||
)
|
||||
merge_exa_cost(self, response)
|
||||
|
||||
@@ -143,7 +143,9 @@ class TestExaContentsCostTracking:
|
||||
mock_exa_cls.return_value = mock_exa
|
||||
|
||||
async for _ in block.run(
|
||||
block.Input(urls=["https://example.com"], credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type]
|
||||
block.Input(
|
||||
urls=["https://example.com"], credentials=TEST_CREDENTIALS_INPUT
|
||||
), # type: ignore[arg-type]
|
||||
credentials=TEST_CREDENTIALS,
|
||||
):
|
||||
pass
|
||||
@@ -172,7 +174,9 @@ class TestExaContentsCostTracking:
|
||||
mock_exa_cls.return_value = mock_exa
|
||||
|
||||
async for _ in block.run(
|
||||
block.Input(urls=["https://example.com"], credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type]
|
||||
block.Input(
|
||||
urls=["https://example.com"], credentials=TEST_CREDENTIALS_INPUT
|
||||
), # type: ignore[arg-type]
|
||||
credentials=TEST_CREDENTIALS,
|
||||
):
|
||||
pass
|
||||
@@ -201,7 +205,9 @@ class TestExaContentsCostTracking:
|
||||
mock_exa_cls.return_value = mock_exa
|
||||
|
||||
async for _ in block.run(
|
||||
block.Input(urls=["https://example.com"], credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type]
|
||||
block.Input(
|
||||
urls=["https://example.com"], credentials=TEST_CREDENTIALS_INPUT
|
||||
), # type: ignore[arg-type]
|
||||
credentials=TEST_CREDENTIALS,
|
||||
):
|
||||
pass
|
||||
@@ -297,7 +303,9 @@ class TestExaSimilarCostTracking:
|
||||
mock_exa_cls.return_value = mock_exa
|
||||
|
||||
async for _ in block.run(
|
||||
block.Input(url="https://example.com", credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type]
|
||||
block.Input(
|
||||
url="https://example.com", credentials=TEST_CREDENTIALS_INPUT
|
||||
), # type: ignore[arg-type]
|
||||
credentials=TEST_CREDENTIALS,
|
||||
):
|
||||
pass
|
||||
@@ -326,7 +334,9 @@ class TestExaSimilarCostTracking:
|
||||
mock_exa_cls.return_value = mock_exa
|
||||
|
||||
async for _ in block.run(
|
||||
block.Input(url="https://example.com", credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type]
|
||||
block.Input(
|
||||
url="https://example.com", credentials=TEST_CREDENTIALS_INPUT
|
||||
), # type: ignore[arg-type]
|
||||
credentials=TEST_CREDENTIALS,
|
||||
):
|
||||
pass
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Literal, Optional, Union
|
||||
|
||||
from backend.sdk import BaseModel, MediaFileType, SchemaField
|
||||
from backend.data.model import NodeExecutionStats
|
||||
from backend.sdk import BaseModel, Block, MediaFileType, SchemaField
|
||||
|
||||
|
||||
class LivecrawlTypes(str, Enum):
|
||||
@@ -319,7 +320,7 @@ class CostDollars(BaseModel):
|
||||
|
||||
# Helper functions for payload processing
|
||||
def process_text_field(
|
||||
text: Union[bool, TextEnabled, TextDisabled, TextAdvanced, None]
|
||||
text: Union[bool, TextEnabled, TextDisabled, TextAdvanced, None],
|
||||
) -> Optional[Union[bool, Dict[str, Any]]]:
|
||||
"""Process text field for API payload."""
|
||||
if text is None:
|
||||
@@ -400,7 +401,7 @@ def process_contents_settings(contents: Optional[ContentSettings]) -> Dict[str,
|
||||
|
||||
|
||||
def process_context_field(
|
||||
context: Union[bool, dict, ContextEnabled, ContextDisabled, ContextAdvanced, None]
|
||||
context: Union[bool, dict, ContextEnabled, ContextDisabled, ContextAdvanced, None],
|
||||
) -> Optional[Union[bool, Dict[str, int]]]:
|
||||
"""Process context field for API payload."""
|
||||
if context is None:
|
||||
@@ -448,3 +449,65 @@ def add_optional_fields(
|
||||
payload[api_field] = value.value
|
||||
else:
|
||||
payload[api_field] = value
|
||||
|
||||
|
||||
def extract_exa_cost_usd(response: Any) -> Optional[float]:
|
||||
"""Return ``cost_dollars.total`` (USD) from an Exa SDK response, or None.
|
||||
|
||||
Handles dataclass/pydantic responses (``response.cost_dollars.total``),
|
||||
dicts with camelCase keys (``response["costDollars"]["total"]``), dicts
|
||||
with snake_case keys, and bare numeric strings. Returns None whenever the
|
||||
shape is missing cost info — the caller then skips merge_stats.
|
||||
"""
|
||||
if response is None:
|
||||
return None
|
||||
|
||||
# Dataclass / pydantic: response.cost_dollars
|
||||
cost_obj = getattr(response, "cost_dollars", None)
|
||||
|
||||
# Dict payloads: try both camelCase and snake_case
|
||||
if cost_obj is None and isinstance(response, dict):
|
||||
cost_obj = response.get("costDollars") or response.get("cost_dollars")
|
||||
|
||||
if cost_obj is None:
|
||||
return None
|
||||
|
||||
# Already a scalar (code_context endpoint returns a string)
|
||||
if isinstance(cost_obj, (int, float)):
|
||||
return max(0.0, float(cost_obj))
|
||||
if isinstance(cost_obj, str):
|
||||
try:
|
||||
return max(0.0, float(cost_obj))
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
# Nested object/dict: grab the `total` field
|
||||
total = getattr(cost_obj, "total", None)
|
||||
if total is None and isinstance(cost_obj, dict):
|
||||
total = cost_obj.get("total")
|
||||
|
||||
if total is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
return max(0.0, float(total))
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
|
||||
|
||||
def merge_exa_cost(block: Block, response: Any) -> None:
|
||||
"""Pull ``cost_dollars.total`` off an Exa response and merge it into stats.
|
||||
|
||||
No-op when the response shape has no cost info (e.g. webset CRUD where
|
||||
the SDK does not expose per-call pricing) — emission happens only when
|
||||
Exa actually reports a USD amount.
|
||||
"""
|
||||
cost_usd = extract_exa_cost_usd(response)
|
||||
if cost_usd is None:
|
||||
return
|
||||
block.merge_stats(
|
||||
NodeExecutionStats(
|
||||
provider_cost=cost_usd,
|
||||
provider_cost_type="cost_usd",
|
||||
)
|
||||
)
|
||||
|
||||
@@ -0,0 +1,65 @@
|
||||
"""Unit tests for exa/helpers cost-extraction + merge helpers."""
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.blocks.exa.helpers import extract_exa_cost_usd, merge_exa_cost
|
||||
from backend.data.model import NodeExecutionStats
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"response, expected",
|
||||
[
|
||||
# Dataclass / SimpleNamespace with cost_dollars.total
|
||||
(SimpleNamespace(cost_dollars=SimpleNamespace(total=0.05)), 0.05),
|
||||
# Dict camelCase
|
||||
({"costDollars": {"total": 0.10}}, 0.10),
|
||||
# Dict snake_case
|
||||
({"cost_dollars": {"total": 0.07}}, 0.07),
|
||||
# code_context endpoint shape: plain numeric string
|
||||
(SimpleNamespace(cost_dollars="0.005"), 0.005),
|
||||
# Scalar float on cost_dollars directly
|
||||
(SimpleNamespace(cost_dollars=0.02), 0.02),
|
||||
# Scalar int on cost_dollars
|
||||
(SimpleNamespace(cost_dollars=3), 3.0),
|
||||
# Missing cost info — returns None
|
||||
({}, None),
|
||||
(SimpleNamespace(other="foo"), None),
|
||||
(None, None),
|
||||
# Nested total=None
|
||||
(SimpleNamespace(cost_dollars=SimpleNamespace(total=None)), None),
|
||||
# Invalid numeric string
|
||||
(SimpleNamespace(cost_dollars="not-a-number"), None),
|
||||
# Negative values clamp to 0
|
||||
(SimpleNamespace(cost_dollars=SimpleNamespace(total=-1.0)), 0.0),
|
||||
],
|
||||
)
|
||||
def test_extract_exa_cost_usd_handles_all_shapes(response, expected):
|
||||
assert extract_exa_cost_usd(response) == expected
|
||||
|
||||
|
||||
def test_merge_exa_cost_emits_stats_when_cost_present():
|
||||
block = MagicMock()
|
||||
response = SimpleNamespace(cost_dollars=SimpleNamespace(total=0.0421))
|
||||
merge_exa_cost(block, response)
|
||||
|
||||
block.merge_stats.assert_called_once()
|
||||
stats: NodeExecutionStats = block.merge_stats.call_args.args[0]
|
||||
assert stats.provider_cost == pytest.approx(0.0421)
|
||||
assert stats.provider_cost_type == "cost_usd"
|
||||
|
||||
|
||||
def test_merge_exa_cost_noops_when_no_cost():
|
||||
"""Webset CRUD endpoints don't surface cost_dollars today — the helper
|
||||
must silently skip instead of emitting a 0-cost telemetry record."""
|
||||
block = MagicMock()
|
||||
merge_exa_cost(block, SimpleNamespace(other_field="nothing"))
|
||||
block.merge_stats.assert_not_called()
|
||||
|
||||
|
||||
def test_merge_exa_cost_noops_when_response_is_none():
|
||||
block = MagicMock()
|
||||
merge_exa_cost(block, None)
|
||||
block.merge_stats.assert_not_called()
|
||||
@@ -12,7 +12,6 @@ from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.model import NodeExecutionStats
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
@@ -26,6 +25,7 @@ from backend.sdk import (
|
||||
)
|
||||
|
||||
from ._config import exa
|
||||
from .helpers import merge_exa_cost
|
||||
|
||||
|
||||
class ResearchModel(str, Enum):
|
||||
@@ -233,11 +233,7 @@ class ExaCreateResearchBlock(Block):
|
||||
|
||||
if research.cost_dollars:
|
||||
yield "cost_total", research.cost_dollars.total
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(
|
||||
provider_cost=research.cost_dollars.total
|
||||
)
|
||||
)
|
||||
merge_exa_cost(self, research)
|
||||
return
|
||||
|
||||
await asyncio.sleep(check_interval)
|
||||
@@ -352,9 +348,7 @@ class ExaGetResearchBlock(Block):
|
||||
yield "cost_searches", research.cost_dollars.num_searches
|
||||
yield "cost_pages", research.cost_dollars.num_pages
|
||||
yield "cost_reasoning_tokens", research.cost_dollars.reasoning_tokens
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(provider_cost=research.cost_dollars.total)
|
||||
)
|
||||
merge_exa_cost(self, research)
|
||||
|
||||
yield "error_message", research.error
|
||||
|
||||
@@ -441,9 +435,7 @@ class ExaWaitForResearchBlock(Block):
|
||||
|
||||
if research.cost_dollars:
|
||||
yield "cost_total", research.cost_dollars.total
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(provider_cost=research.cost_dollars.total)
|
||||
)
|
||||
merge_exa_cost(self, research)
|
||||
|
||||
return
|
||||
|
||||
|
||||
@@ -4,7 +4,6 @@ from typing import Optional
|
||||
|
||||
from exa_py import AsyncExa
|
||||
|
||||
from backend.data.model import NodeExecutionStats
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
@@ -21,6 +20,7 @@ from .helpers import (
|
||||
ContentSettings,
|
||||
CostDollars,
|
||||
ExaSearchResults,
|
||||
merge_exa_cost,
|
||||
process_contents_settings,
|
||||
)
|
||||
|
||||
@@ -207,6 +207,4 @@ class ExaSearchBlock(Block):
|
||||
|
||||
if response.cost_dollars:
|
||||
yield "cost_dollars", response.cost_dollars
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(provider_cost=response.cost_dollars.total)
|
||||
)
|
||||
merge_exa_cost(self, response)
|
||||
|
||||
@@ -3,7 +3,6 @@ from typing import Optional
|
||||
|
||||
from exa_py import AsyncExa
|
||||
|
||||
from backend.data.model import NodeExecutionStats
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
@@ -20,6 +19,7 @@ from .helpers import (
|
||||
ContentSettings,
|
||||
CostDollars,
|
||||
ExaSearchResults,
|
||||
merge_exa_cost,
|
||||
process_contents_settings,
|
||||
)
|
||||
|
||||
@@ -168,6 +168,4 @@ class ExaFindSimilarBlock(Block):
|
||||
|
||||
if response.cost_dollars:
|
||||
yield "cost_dollars", response.cost_dollars
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(provider_cost=response.cost_dollars.total)
|
||||
)
|
||||
merge_exa_cost(self, response)
|
||||
|
||||
@@ -39,6 +39,7 @@ from backend.sdk import (
|
||||
)
|
||||
|
||||
from ._config import exa
|
||||
from .helpers import merge_exa_cost
|
||||
|
||||
|
||||
class SearchEntityType(str, Enum):
|
||||
@@ -394,6 +395,7 @@ class ExaCreateWebsetBlock(Block):
|
||||
metadata=input_data.metadata,
|
||||
)
|
||||
)
|
||||
merge_exa_cost(self, webset)
|
||||
|
||||
webset_result = Webset.model_validate(webset.model_dump(by_alias=True))
|
||||
|
||||
@@ -404,6 +406,7 @@ class ExaCreateWebsetBlock(Block):
|
||||
timeout=input_data.polling_timeout,
|
||||
poll_interval=5,
|
||||
)
|
||||
merge_exa_cost(self, final_webset)
|
||||
completion_time = time.time() - start_time
|
||||
|
||||
item_count = 0
|
||||
@@ -479,6 +482,7 @@ class ExaCreateOrFindWebsetBlock(Block):
|
||||
|
||||
try:
|
||||
webset = await aexa.websets.get(id=input_data.external_id)
|
||||
merge_exa_cost(self, webset)
|
||||
webset_result = Webset.model_validate(webset.model_dump(by_alias=True))
|
||||
|
||||
yield "webset", webset_result
|
||||
@@ -501,6 +505,7 @@ class ExaCreateOrFindWebsetBlock(Block):
|
||||
metadata=input_data.metadata,
|
||||
)
|
||||
)
|
||||
merge_exa_cost(self, webset)
|
||||
|
||||
webset_result = Webset.model_validate(webset.model_dump(by_alias=True))
|
||||
|
||||
@@ -555,6 +560,7 @@ class ExaUpdateWebsetBlock(Block):
|
||||
payload["metadata"] = input_data.metadata
|
||||
|
||||
sdk_webset = await aexa.websets.update(id=input_data.webset_id, params=payload)
|
||||
merge_exa_cost(self, sdk_webset)
|
||||
|
||||
status_str = (
|
||||
sdk_webset.status.value
|
||||
@@ -566,8 +572,9 @@ class ExaUpdateWebsetBlock(Block):
|
||||
yield "status", status_str
|
||||
yield "external_id", sdk_webset.external_id
|
||||
yield "metadata", sdk_webset.metadata or {}
|
||||
yield "updated_at", (
|
||||
sdk_webset.updated_at.isoformat() if sdk_webset.updated_at else ""
|
||||
yield (
|
||||
"updated_at",
|
||||
(sdk_webset.updated_at.isoformat() if sdk_webset.updated_at else ""),
|
||||
)
|
||||
|
||||
|
||||
@@ -621,6 +628,7 @@ class ExaListWebsetsBlock(Block):
|
||||
cursor=input_data.cursor,
|
||||
limit=input_data.limit,
|
||||
)
|
||||
merge_exa_cost(self, response)
|
||||
|
||||
websets_data = [
|
||||
w.model_dump(by_alias=True, exclude_none=True) for w in response.data
|
||||
@@ -679,6 +687,7 @@ class ExaGetWebsetBlock(Block):
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
sdk_webset = await aexa.websets.get(id=input_data.webset_id)
|
||||
merge_exa_cost(self, sdk_webset)
|
||||
|
||||
status_str = (
|
||||
sdk_webset.status.value
|
||||
@@ -706,11 +715,13 @@ class ExaGetWebsetBlock(Block):
|
||||
yield "enrichments", enrichments_data
|
||||
yield "monitors", monitors_data
|
||||
yield "metadata", sdk_webset.metadata or {}
|
||||
yield "created_at", (
|
||||
sdk_webset.created_at.isoformat() if sdk_webset.created_at else ""
|
||||
yield (
|
||||
"created_at",
|
||||
(sdk_webset.created_at.isoformat() if sdk_webset.created_at else ""),
|
||||
)
|
||||
yield "updated_at", (
|
||||
sdk_webset.updated_at.isoformat() if sdk_webset.updated_at else ""
|
||||
yield (
|
||||
"updated_at",
|
||||
(sdk_webset.updated_at.isoformat() if sdk_webset.updated_at else ""),
|
||||
)
|
||||
|
||||
|
||||
@@ -749,6 +760,7 @@ class ExaDeleteWebsetBlock(Block):
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
deleted_webset = await aexa.websets.delete(id=input_data.webset_id)
|
||||
merge_exa_cost(self, deleted_webset)
|
||||
|
||||
status_str = (
|
||||
deleted_webset.status.value
|
||||
@@ -799,6 +811,7 @@ class ExaCancelWebsetBlock(Block):
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
canceled_webset = await aexa.websets.cancel(id=input_data.webset_id)
|
||||
merge_exa_cost(self, canceled_webset)
|
||||
|
||||
status_str = (
|
||||
canceled_webset.status.value
|
||||
@@ -969,6 +982,7 @@ class ExaPreviewWebsetBlock(Block):
|
||||
payload["entity"] = entity
|
||||
|
||||
sdk_preview = await aexa.websets.preview(params=payload)
|
||||
merge_exa_cost(self, sdk_preview)
|
||||
|
||||
preview = PreviewWebsetModel.from_sdk(sdk_preview)
|
||||
|
||||
@@ -1052,6 +1066,7 @@ class ExaWebsetStatusBlock(Block):
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
webset = await aexa.websets.get(id=input_data.webset_id)
|
||||
merge_exa_cost(self, webset)
|
||||
|
||||
status = (
|
||||
webset.status.value
|
||||
@@ -1186,6 +1201,7 @@ class ExaWebsetSummaryBlock(Block):
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
webset = await aexa.websets.get(id=input_data.webset_id)
|
||||
merge_exa_cost(self, webset)
|
||||
|
||||
# Extract basic info
|
||||
webset_id = webset.id
|
||||
@@ -1214,6 +1230,7 @@ class ExaWebsetSummaryBlock(Block):
|
||||
items_response = await aexa.websets.items.list(
|
||||
webset_id=input_data.webset_id, limit=input_data.sample_size
|
||||
)
|
||||
merge_exa_cost(self, items_response)
|
||||
sample_items_data = [
|
||||
item.model_dump(by_alias=True, exclude_none=True)
|
||||
for item in items_response.data
|
||||
@@ -1363,6 +1380,7 @@ class ExaWebsetReadyCheckBlock(Block):
|
||||
|
||||
# Get webset details
|
||||
webset = await aexa.websets.get(id=input_data.webset_id)
|
||||
merge_exa_cost(self, webset)
|
||||
|
||||
status = (
|
||||
webset.status.value
|
||||
|
||||
@@ -25,6 +25,7 @@ from backend.sdk import (
|
||||
)
|
||||
|
||||
from ._config import exa
|
||||
from .helpers import merge_exa_cost
|
||||
|
||||
|
||||
# Mirrored model for stability
|
||||
@@ -205,6 +206,7 @@ class ExaCreateEnrichmentBlock(Block):
|
||||
sdk_enrichment = await aexa.websets.enrichments.create(
|
||||
webset_id=input_data.webset_id, params=payload
|
||||
)
|
||||
merge_exa_cost(self, sdk_enrichment)
|
||||
|
||||
enrichment_id = sdk_enrichment.id
|
||||
status = (
|
||||
@@ -226,6 +228,7 @@ class ExaCreateEnrichmentBlock(Block):
|
||||
current_enrich = await aexa.websets.enrichments.get(
|
||||
webset_id=input_data.webset_id, id=enrichment_id
|
||||
)
|
||||
merge_exa_cost(self, current_enrich)
|
||||
current_status = (
|
||||
current_enrich.status.value
|
||||
if hasattr(current_enrich.status, "value")
|
||||
@@ -235,6 +238,7 @@ class ExaCreateEnrichmentBlock(Block):
|
||||
if current_status in ["completed", "failed", "cancelled"]:
|
||||
# Estimate items from webset searches
|
||||
webset = await aexa.websets.get(id=input_data.webset_id)
|
||||
merge_exa_cost(self, webset)
|
||||
if webset.searches:
|
||||
for search in webset.searches:
|
||||
if search.progress:
|
||||
@@ -332,6 +336,7 @@ class ExaGetEnrichmentBlock(Block):
|
||||
sdk_enrichment = await aexa.websets.enrichments.get(
|
||||
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
||||
)
|
||||
merge_exa_cost(self, sdk_enrichment)
|
||||
|
||||
enrichment = WebsetEnrichmentModel.from_sdk(sdk_enrichment)
|
||||
|
||||
@@ -425,6 +430,7 @@ class ExaUpdateEnrichmentBlock(Block):
|
||||
try:
|
||||
response = await Requests().patch(url, headers=headers, json=payload)
|
||||
data = response.json()
|
||||
# PATCH /websets/{id}/enrichments/{id} doesn't return costDollars.
|
||||
|
||||
yield "enrichment_id", data.get("id", "")
|
||||
yield "status", data.get("status", "")
|
||||
@@ -477,6 +483,7 @@ class ExaDeleteEnrichmentBlock(Block):
|
||||
deleted_enrichment = await aexa.websets.enrichments.delete(
|
||||
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
||||
)
|
||||
merge_exa_cost(self, deleted_enrichment)
|
||||
|
||||
yield "enrichment_id", deleted_enrichment.id
|
||||
yield "success", "true"
|
||||
@@ -528,12 +535,14 @@ class ExaCancelEnrichmentBlock(Block):
|
||||
canceled_enrichment = await aexa.websets.enrichments.cancel(
|
||||
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
||||
)
|
||||
merge_exa_cost(self, canceled_enrichment)
|
||||
|
||||
# Try to estimate how many items were enriched before cancellation
|
||||
items_enriched = 0
|
||||
items_response = await aexa.websets.items.list(
|
||||
webset_id=input_data.webset_id, limit=100
|
||||
)
|
||||
merge_exa_cost(self, items_response)
|
||||
|
||||
for sdk_item in items_response.data:
|
||||
# Check if this enrichment is present
|
||||
|
||||
@@ -29,6 +29,7 @@ from backend.sdk import (
|
||||
|
||||
from ._config import exa
|
||||
from ._test import TEST_CREDENTIALS, TEST_CREDENTIALS_INPUT
|
||||
from .helpers import merge_exa_cost
|
||||
|
||||
|
||||
# Mirrored model for stability - don't use SDK types directly in block outputs
|
||||
@@ -297,6 +298,7 @@ class ExaCreateImportBlock(Block):
|
||||
sdk_import = await aexa.websets.imports.create(
|
||||
params=payload, csv_data=input_data.csv_data
|
||||
)
|
||||
merge_exa_cost(self, sdk_import)
|
||||
|
||||
import_obj = ImportModel.from_sdk(sdk_import)
|
||||
|
||||
@@ -361,6 +363,7 @@ class ExaGetImportBlock(Block):
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
sdk_import = await aexa.websets.imports.get(import_id=input_data.import_id)
|
||||
merge_exa_cost(self, sdk_import)
|
||||
|
||||
import_obj = ImportModel.from_sdk(sdk_import)
|
||||
|
||||
@@ -430,6 +433,7 @@ class ExaListImportsBlock(Block):
|
||||
cursor=input_data.cursor,
|
||||
limit=input_data.limit,
|
||||
)
|
||||
merge_exa_cost(self, response)
|
||||
|
||||
# Convert SDK imports to our stable models
|
||||
imports = [ImportModel.from_sdk(i) for i in response.data]
|
||||
@@ -477,6 +481,7 @@ class ExaDeleteImportBlock(Block):
|
||||
deleted_import = await aexa.websets.imports.delete(
|
||||
import_id=input_data.import_id
|
||||
)
|
||||
merge_exa_cost(self, deleted_import)
|
||||
|
||||
yield "import_id", deleted_import.id
|
||||
yield "success", "true"
|
||||
@@ -599,7 +604,7 @@ class ExaExportWebsetBlock(Block):
|
||||
try:
|
||||
all_items = []
|
||||
|
||||
# Use SDK's list_all iterator to fetch items
|
||||
# list_all paginates internally; cost_dollars is not surfaced per-page
|
||||
item_iterator = aexa.websets.items.list_all(
|
||||
webset_id=input_data.webset_id, limit=input_data.max_items
|
||||
)
|
||||
|
||||
@@ -30,6 +30,7 @@ from backend.sdk import (
|
||||
)
|
||||
|
||||
from ._config import exa
|
||||
from .helpers import merge_exa_cost
|
||||
|
||||
|
||||
# Mirrored model for enrichment results
|
||||
@@ -181,6 +182,7 @@ class ExaGetWebsetItemBlock(Block):
|
||||
sdk_item = await aexa.websets.items.get(
|
||||
webset_id=input_data.webset_id, id=input_data.item_id
|
||||
)
|
||||
merge_exa_cost(self, sdk_item)
|
||||
|
||||
item = WebsetItemModel.from_sdk(sdk_item)
|
||||
|
||||
@@ -293,6 +295,7 @@ class ExaListWebsetItemsBlock(Block):
|
||||
cursor=input_data.cursor,
|
||||
limit=input_data.limit,
|
||||
)
|
||||
merge_exa_cost(self, response)
|
||||
|
||||
items = [WebsetItemModel.from_sdk(item) for item in response.data]
|
||||
|
||||
@@ -343,6 +346,7 @@ class ExaDeleteWebsetItemBlock(Block):
|
||||
deleted_item = await aexa.websets.items.delete(
|
||||
webset_id=input_data.webset_id, id=input_data.item_id
|
||||
)
|
||||
merge_exa_cost(self, deleted_item)
|
||||
|
||||
yield "item_id", deleted_item.id
|
||||
yield "success", "true"
|
||||
@@ -404,6 +408,7 @@ class ExaBulkWebsetItemsBlock(Block):
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
all_items: List[WebsetItemModel] = []
|
||||
# list_all paginates internally; cost_dollars is not surfaced per-page
|
||||
item_iterator = aexa.websets.items.list_all(
|
||||
webset_id=input_data.webset_id, limit=input_data.max_items
|
||||
)
|
||||
@@ -476,6 +481,7 @@ class ExaWebsetItemsSummaryBlock(Block):
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
webset = await aexa.websets.get(id=input_data.webset_id)
|
||||
merge_exa_cost(self, webset)
|
||||
|
||||
entity_type = "unknown"
|
||||
if webset.searches:
|
||||
@@ -498,6 +504,7 @@ class ExaWebsetItemsSummaryBlock(Block):
|
||||
items_response = await aexa.websets.items.list(
|
||||
webset_id=input_data.webset_id, limit=input_data.sample_size
|
||||
)
|
||||
merge_exa_cost(self, items_response)
|
||||
# Convert to our stable models
|
||||
sample_items = [
|
||||
WebsetItemModel.from_sdk(item) for item in items_response.data
|
||||
@@ -574,6 +581,7 @@ class ExaGetNewItemsBlock(Block):
|
||||
cursor=input_data.since_cursor,
|
||||
limit=input_data.max_items,
|
||||
)
|
||||
merge_exa_cost(self, response)
|
||||
|
||||
# Convert SDK items to our stable models
|
||||
new_items = [WebsetItemModel.from_sdk(item) for item in response.data]
|
||||
|
||||
@@ -25,6 +25,7 @@ from backend.sdk import (
|
||||
|
||||
from ._config import exa
|
||||
from ._test import TEST_CREDENTIALS, TEST_CREDENTIALS_INPUT
|
||||
from .helpers import merge_exa_cost
|
||||
|
||||
|
||||
# Mirrored model for stability - don't use SDK types directly in block outputs
|
||||
@@ -321,6 +322,7 @@ class ExaCreateMonitorBlock(Block):
|
||||
payload["metadata"] = input_data.metadata
|
||||
|
||||
sdk_monitor = await aexa.websets.monitors.create(params=payload)
|
||||
merge_exa_cost(self, sdk_monitor)
|
||||
|
||||
monitor = MonitorModel.from_sdk(sdk_monitor)
|
||||
|
||||
@@ -385,6 +387,7 @@ class ExaGetMonitorBlock(Block):
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
sdk_monitor = await aexa.websets.monitors.get(monitor_id=input_data.monitor_id)
|
||||
merge_exa_cost(self, sdk_monitor)
|
||||
|
||||
monitor = MonitorModel.from_sdk(sdk_monitor)
|
||||
|
||||
@@ -479,6 +482,7 @@ class ExaUpdateMonitorBlock(Block):
|
||||
sdk_monitor = await aexa.websets.monitors.update(
|
||||
monitor_id=input_data.monitor_id, params=payload
|
||||
)
|
||||
merge_exa_cost(self, sdk_monitor)
|
||||
|
||||
# Convert to our stable model
|
||||
monitor = MonitorModel.from_sdk(sdk_monitor)
|
||||
@@ -525,6 +529,7 @@ class ExaDeleteMonitorBlock(Block):
|
||||
deleted_monitor = await aexa.websets.monitors.delete(
|
||||
monitor_id=input_data.monitor_id
|
||||
)
|
||||
merge_exa_cost(self, deleted_monitor)
|
||||
|
||||
yield "monitor_id", deleted_monitor.id
|
||||
yield "success", "true"
|
||||
@@ -586,6 +591,7 @@ class ExaListMonitorsBlock(Block):
|
||||
limit=input_data.limit,
|
||||
webset_id=input_data.webset_id,
|
||||
)
|
||||
merge_exa_cost(self, response)
|
||||
|
||||
# Convert SDK monitors to our stable models
|
||||
monitors = [MonitorModel.from_sdk(m) for m in response.data]
|
||||
|
||||
@@ -25,6 +25,7 @@ from backend.sdk import (
|
||||
)
|
||||
|
||||
from ._config import exa
|
||||
from .helpers import merge_exa_cost
|
||||
|
||||
# Import WebsetItemModel for use in enrichment samples
|
||||
# This is safe as websets_items doesn't import from websets_polling
|
||||
@@ -126,6 +127,7 @@ class ExaWaitForWebsetBlock(Block):
|
||||
timeout=input_data.timeout,
|
||||
poll_interval=input_data.check_interval,
|
||||
)
|
||||
merge_exa_cost(self, final_webset)
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
@@ -165,6 +167,7 @@ class ExaWaitForWebsetBlock(Block):
|
||||
while time.time() - start_time < input_data.timeout:
|
||||
# Get current webset status
|
||||
webset = await aexa.websets.get(id=input_data.webset_id)
|
||||
merge_exa_cost(self, webset)
|
||||
current_status = (
|
||||
webset.status.value
|
||||
if hasattr(webset.status, "value")
|
||||
@@ -210,6 +213,7 @@ class ExaWaitForWebsetBlock(Block):
|
||||
# Timeout reached
|
||||
elapsed = time.time() - start_time
|
||||
webset = await aexa.websets.get(id=input_data.webset_id)
|
||||
merge_exa_cost(self, webset)
|
||||
final_status = (
|
||||
webset.status.value
|
||||
if hasattr(webset.status, "value")
|
||||
@@ -348,6 +352,7 @@ class ExaWaitForSearchBlock(Block):
|
||||
search = await aexa.websets.searches.get(
|
||||
webset_id=input_data.webset_id, id=input_data.search_id
|
||||
)
|
||||
merge_exa_cost(self, search)
|
||||
|
||||
# Extract status
|
||||
status = (
|
||||
@@ -404,6 +409,7 @@ class ExaWaitForSearchBlock(Block):
|
||||
search = await aexa.websets.searches.get(
|
||||
webset_id=input_data.webset_id, id=input_data.search_id
|
||||
)
|
||||
merge_exa_cost(self, search)
|
||||
final_status = (
|
||||
search.status.value
|
||||
if hasattr(search.status, "value")
|
||||
@@ -506,6 +512,7 @@ class ExaWaitForEnrichmentBlock(Block):
|
||||
enrichment = await aexa.websets.enrichments.get(
|
||||
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
||||
)
|
||||
merge_exa_cost(self, enrichment)
|
||||
|
||||
# Extract status
|
||||
status = (
|
||||
@@ -523,16 +530,20 @@ class ExaWaitForEnrichmentBlock(Block):
|
||||
items_enriched = 0
|
||||
|
||||
if input_data.sample_results and status == "completed":
|
||||
sample_data, items_enriched = (
|
||||
await self._get_sample_enrichments(
|
||||
input_data.webset_id, input_data.enrichment_id, aexa
|
||||
)
|
||||
(
|
||||
sample_data,
|
||||
items_enriched,
|
||||
) = await self._get_sample_enrichments(
|
||||
input_data.webset_id, input_data.enrichment_id, aexa
|
||||
)
|
||||
|
||||
yield "enrichment_id", input_data.enrichment_id
|
||||
yield "final_status", status
|
||||
yield "items_enriched", items_enriched
|
||||
yield "enrichment_title", enrichment.title or enrichment.description or ""
|
||||
yield (
|
||||
"enrichment_title",
|
||||
enrichment.title or enrichment.description or "",
|
||||
)
|
||||
yield "elapsed_time", elapsed
|
||||
if input_data.sample_results:
|
||||
yield "sample_data", sample_data
|
||||
@@ -551,6 +562,7 @@ class ExaWaitForEnrichmentBlock(Block):
|
||||
enrichment = await aexa.websets.enrichments.get(
|
||||
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
||||
)
|
||||
merge_exa_cost(self, enrichment)
|
||||
final_status = (
|
||||
enrichment.status.value
|
||||
if hasattr(enrichment.status, "value")
|
||||
@@ -576,6 +588,7 @@ class ExaWaitForEnrichmentBlock(Block):
|
||||
"""Get sample enriched data and count."""
|
||||
# Get a few items to see enrichment results using SDK
|
||||
response = await aexa.websets.items.list(webset_id=webset_id, limit=5)
|
||||
merge_exa_cost(self, response)
|
||||
|
||||
sample_data: list[SampleEnrichmentModel] = []
|
||||
enriched_count = 0
|
||||
|
||||
@@ -24,6 +24,7 @@ from backend.sdk import (
|
||||
)
|
||||
|
||||
from ._config import exa
|
||||
from .helpers import merge_exa_cost
|
||||
|
||||
|
||||
# Mirrored model for stability
|
||||
@@ -320,6 +321,7 @@ class ExaCreateWebsetSearchBlock(Block):
|
||||
sdk_search = await aexa.websets.searches.create(
|
||||
webset_id=input_data.webset_id, params=payload
|
||||
)
|
||||
merge_exa_cost(self, sdk_search)
|
||||
|
||||
search_id = sdk_search.id
|
||||
status = (
|
||||
@@ -353,6 +355,7 @@ class ExaCreateWebsetSearchBlock(Block):
|
||||
current_search = await aexa.websets.searches.get(
|
||||
webset_id=input_data.webset_id, id=search_id
|
||||
)
|
||||
merge_exa_cost(self, current_search)
|
||||
current_status = (
|
||||
current_search.status.value
|
||||
if hasattr(current_search.status, "value")
|
||||
@@ -445,6 +448,7 @@ class ExaGetWebsetSearchBlock(Block):
|
||||
sdk_search = await aexa.websets.searches.get(
|
||||
webset_id=input_data.webset_id, id=input_data.search_id
|
||||
)
|
||||
merge_exa_cost(self, sdk_search)
|
||||
|
||||
search = WebsetSearchModel.from_sdk(sdk_search)
|
||||
|
||||
@@ -526,6 +530,7 @@ class ExaCancelWebsetSearchBlock(Block):
|
||||
canceled_search = await aexa.websets.searches.cancel(
|
||||
webset_id=input_data.webset_id, id=input_data.search_id
|
||||
)
|
||||
merge_exa_cost(self, canceled_search)
|
||||
|
||||
# Extract items found before cancellation
|
||||
items_found = 0
|
||||
@@ -605,6 +610,7 @@ class ExaFindOrCreateSearchBlock(Block):
|
||||
|
||||
# Get webset to check existing searches
|
||||
webset = await aexa.websets.get(id=input_data.webset_id)
|
||||
merge_exa_cost(self, webset)
|
||||
|
||||
# Look for existing search with same query
|
||||
existing_search = None
|
||||
@@ -639,6 +645,7 @@ class ExaFindOrCreateSearchBlock(Block):
|
||||
sdk_search = await aexa.websets.searches.create(
|
||||
webset_id=input_data.webset_id, params=payload
|
||||
)
|
||||
merge_exa_cost(self, sdk_search)
|
||||
|
||||
search = WebsetSearchModel.from_sdk(sdk_search)
|
||||
|
||||
|
||||
10
autogpt_platform/backend/backend/blocks/fal/_config.py
Normal file
10
autogpt_platform/backend/backend/blocks/fal/_config.py
Normal file
@@ -0,0 +1,10 @@
|
||||
"""Provider registration for fal — metadata only (auth lives in ``_auth.py``)."""
|
||||
|
||||
from backend.sdk import ProviderBuilder
|
||||
|
||||
fal = (
|
||||
ProviderBuilder("fal")
|
||||
.with_description("Hosted model inference")
|
||||
.with_supported_auth_types("api_key")
|
||||
.build()
|
||||
)
|
||||
@@ -8,6 +8,7 @@ from backend.sdk import BlockCostType, ProviderBuilder
|
||||
# — roughly matches our existing per-call tier for single-page scrape.
|
||||
firecrawl = (
|
||||
ProviderBuilder("firecrawl")
|
||||
.with_description("Web scraping and crawling")
|
||||
.with_api_key("FIRECRAWL_API_KEY", "Firecrawl API Key")
|
||||
.with_base_cost(1000, BlockCostType.COST_USD)
|
||||
.build()
|
||||
|
||||
@@ -14,6 +14,7 @@ from ._webhook import GenericWebhooksManager, GenericWebhookType
|
||||
|
||||
generic_webhook = (
|
||||
ProviderBuilder("generic_webhook")
|
||||
.with_description("Inbound webhook trigger")
|
||||
.with_webhook_manager(GenericWebhooksManager)
|
||||
.build()
|
||||
)
|
||||
|
||||
10
autogpt_platform/backend/backend/blocks/github/_config.py
Normal file
10
autogpt_platform/backend/backend/blocks/github/_config.py
Normal file
@@ -0,0 +1,10 @@
|
||||
"""Provider registration for GitHub — metadata only (auth lives in ``_auth.py``)."""
|
||||
|
||||
from backend.sdk import ProviderBuilder
|
||||
|
||||
github = (
|
||||
ProviderBuilder("github")
|
||||
.with_description("Issues, pull requests, repositories")
|
||||
.with_supported_auth_types("api_key", "oauth2")
|
||||
.build()
|
||||
)
|
||||
10
autogpt_platform/backend/backend/blocks/google/_config.py
Normal file
10
autogpt_platform/backend/backend/blocks/google/_config.py
Normal file
@@ -0,0 +1,10 @@
|
||||
"""Provider registration for Google — metadata only (auth lives in ``_auth.py``)."""
|
||||
|
||||
from backend.sdk import ProviderBuilder
|
||||
|
||||
google = (
|
||||
ProviderBuilder("google")
|
||||
.with_description("Gmail, Drive, Calendar, Sheets")
|
||||
.with_supported_auth_types("oauth2")
|
||||
.build()
|
||||
)
|
||||
10
autogpt_platform/backend/backend/blocks/hubspot/_config.py
Normal file
10
autogpt_platform/backend/backend/blocks/hubspot/_config.py
Normal file
@@ -0,0 +1,10 @@
|
||||
"""Provider registration for HubSpot — metadata only (auth lives in ``_auth.py``)."""
|
||||
|
||||
from backend.sdk import ProviderBuilder
|
||||
|
||||
hubspot = (
|
||||
ProviderBuilder("hubspot")
|
||||
.with_description("CRM, contacts, and deals")
|
||||
.with_supported_auth_types("api_key", "oauth2")
|
||||
.build()
|
||||
)
|
||||
10
autogpt_platform/backend/backend/blocks/jina/_config.py
Normal file
10
autogpt_platform/backend/backend/blocks/jina/_config.py
Normal file
@@ -0,0 +1,10 @@
|
||||
"""Provider registration for Jina — metadata only (auth lives in ``_auth.py``)."""
|
||||
|
||||
from backend.sdk import ProviderBuilder
|
||||
|
||||
jina = (
|
||||
ProviderBuilder("jina")
|
||||
.with_description("Embeddings and reranking")
|
||||
.with_supported_auth_types("api_key")
|
||||
.build()
|
||||
)
|
||||
@@ -39,6 +39,7 @@ class LinearScope(str, Enum):
|
||||
|
||||
linear = (
|
||||
ProviderBuilder("linear")
|
||||
.with_description("Issues and project tracking")
|
||||
.with_api_key(env_var_name="LINEAR_API_KEY", title="Linear API Key")
|
||||
.with_base_cost(1, BlockCostType.RUN)
|
||||
.with_oauth(
|
||||
|
||||
@@ -142,6 +142,7 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
|
||||
CLAUDE_4_5_SONNET = "claude-sonnet-4-5-20250929"
|
||||
CLAUDE_4_5_HAIKU = "claude-haiku-4-5-20251001"
|
||||
CLAUDE_4_6_OPUS = "claude-opus-4-6"
|
||||
CLAUDE_4_7_OPUS = "claude-opus-4-7"
|
||||
CLAUDE_4_6_SONNET = "claude-sonnet-4-6"
|
||||
CLAUDE_3_HAIKU = "claude-3-haiku-20240307"
|
||||
# AI/ML API models
|
||||
@@ -331,6 +332,9 @@ MODEL_METADATA = {
|
||||
LlmModel.CLAUDE_4_6_OPUS: ModelMetadata(
|
||||
"anthropic", 200000, 128000, "Claude Opus 4.6", "Anthropic", "Anthropic", 3
|
||||
), # claude-opus-4-6
|
||||
LlmModel.CLAUDE_4_7_OPUS: ModelMetadata(
|
||||
"anthropic", 200000, 128000, "Claude Opus 4.7", "Anthropic", "Anthropic", 3
|
||||
), # claude-opus-4-7
|
||||
LlmModel.CLAUDE_4_6_SONNET: ModelMetadata(
|
||||
"anthropic", 200000, 64000, "Claude Sonnet 4.6", "Anthropic", "Anthropic", 3
|
||||
), # claude-sonnet-4-6
|
||||
@@ -1624,6 +1628,11 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
llm_call_count=retry_count + 1,
|
||||
llm_retry_count=retry_count,
|
||||
provider_cost=total_provider_cost,
|
||||
provider_cost_type=(
|
||||
"cost_usd"
|
||||
if total_provider_cost is not None
|
||||
else None
|
||||
),
|
||||
)
|
||||
)
|
||||
yield "response", response_obj
|
||||
@@ -1645,6 +1654,9 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
llm_call_count=retry_count + 1,
|
||||
llm_retry_count=retry_count,
|
||||
provider_cost=total_provider_cost,
|
||||
provider_cost_type=(
|
||||
"cost_usd" if total_provider_cost is not None else None
|
||||
),
|
||||
)
|
||||
)
|
||||
yield "response", {"response": response_text}
|
||||
@@ -1679,7 +1691,12 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
# All retries exhausted or user-error break: persist accumulated cost so
|
||||
# the executor can still charge/report the spend even on failure.
|
||||
if total_provider_cost is not None:
|
||||
self.merge_stats(NodeExecutionStats(provider_cost=total_provider_cost))
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(
|
||||
provider_cost=total_provider_cost,
|
||||
provider_cost_type="cost_usd",
|
||||
)
|
||||
)
|
||||
raise RuntimeError(error_feedback_message)
|
||||
|
||||
def response_format_instructions(
|
||||
|
||||
5
autogpt_platform/backend/backend/blocks/mcp/_config.py
Normal file
5
autogpt_platform/backend/backend/blocks/mcp/_config.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""Provider registration for MCP — metadata only."""
|
||||
|
||||
from backend.sdk import ProviderBuilder
|
||||
|
||||
mcp = ProviderBuilder("mcp").with_description("Model Context Protocol servers").build()
|
||||
10
autogpt_platform/backend/backend/blocks/notion/_config.py
Normal file
10
autogpt_platform/backend/backend/blocks/notion/_config.py
Normal file
@@ -0,0 +1,10 @@
|
||||
"""Provider registration for Notion — metadata only (auth lives in ``_auth.py``)."""
|
||||
|
||||
from backend.sdk import ProviderBuilder
|
||||
|
||||
notion = (
|
||||
ProviderBuilder("notion")
|
||||
.with_description("Pages, databases, and blocks")
|
||||
.with_supported_auth_types("oauth2")
|
||||
.build()
|
||||
)
|
||||
10
autogpt_platform/backend/backend/blocks/nvidia/_config.py
Normal file
10
autogpt_platform/backend/backend/blocks/nvidia/_config.py
Normal file
@@ -0,0 +1,10 @@
|
||||
"""Provider registration for Nvidia — metadata only (auth lives in ``_auth.py``)."""
|
||||
|
||||
from backend.sdk import ProviderBuilder
|
||||
|
||||
nvidia = (
|
||||
ProviderBuilder("nvidia")
|
||||
.with_description("NIM-hosted foundation models")
|
||||
.with_supported_auth_types("api_key")
|
||||
.build()
|
||||
)
|
||||
@@ -250,14 +250,7 @@ class PerplexityBlock(Block):
|
||||
self.execution_stats.output_token_count = (
|
||||
response.usage.completion_tokens
|
||||
)
|
||||
# OpenRouter's ``x-total-cost`` response header carries the real
|
||||
# per-request USD cost. Piping it into ``provider_cost`` lets the
|
||||
# direct-run ``PlatformCostLog`` flow
|
||||
# (``executor.cost_tracking::log_system_credential_cost``) record
|
||||
# the actual operator-side spend instead of inferring from tokens.
|
||||
# Always overwrite — ``execution_stats`` is instance state, so a
|
||||
# response without the header must not reuse a previous run's cost.
|
||||
self.execution_stats.provider_cost = extract_openrouter_cost(response)
|
||||
self._record_openrouter_cost(response)
|
||||
|
||||
return {"response": response_content, "annotations": annotations or []}
|
||||
|
||||
@@ -265,6 +258,17 @@ class PerplexityBlock(Block):
|
||||
logger.error(f"Error calling Perplexity: {e}")
|
||||
raise
|
||||
|
||||
def _record_openrouter_cost(self, response: Any) -> None:
|
||||
"""Feed OpenRouter's ``x-total-cost`` USD into execution stats for
|
||||
the COST_USD resolver. Tag as ``cost_usd`` only when the value is
|
||||
concrete and positive — leaving it unset on None/0 keeps the
|
||||
billing gap observable instead of silently floored to 0.
|
||||
"""
|
||||
cost_usd = extract_openrouter_cost(response)
|
||||
self.execution_stats.provider_cost = cost_usd
|
||||
if cost_usd is not None and cost_usd > 0:
|
||||
self.execution_stats.provider_cost_type = "cost_usd"
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
|
||||
@@ -14,6 +14,7 @@ from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
CredentialsField,
|
||||
CredentialsMetaInput,
|
||||
NodeExecutionStats,
|
||||
SchemaField,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
@@ -160,10 +161,13 @@ class PineconeQueryBlock(Block):
|
||||
combined_text = "\n\n".join(texts)
|
||||
|
||||
# Return both the raw matches and combined text
|
||||
yield "results", {
|
||||
"matches": results["matches"],
|
||||
"combined_text": combined_text,
|
||||
}
|
||||
yield (
|
||||
"results",
|
||||
{
|
||||
"matches": results["matches"],
|
||||
"combined_text": combined_text,
|
||||
},
|
||||
)
|
||||
yield "combined_results", combined_text
|
||||
|
||||
except Exception as e:
|
||||
@@ -228,6 +232,13 @@ class PineconeInsertBlock(Block):
|
||||
)
|
||||
idx.upsert(vectors=vectors, namespace=input_data.namespace)
|
||||
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(
|
||||
provider_cost=float(len(vectors)),
|
||||
provider_cost_type="items",
|
||||
)
|
||||
)
|
||||
|
||||
yield "upsert_response", "successfully upserted"
|
||||
|
||||
except Exception as e:
|
||||
|
||||
10
autogpt_platform/backend/backend/blocks/replicate/_config.py
Normal file
10
autogpt_platform/backend/backend/blocks/replicate/_config.py
Normal file
@@ -0,0 +1,10 @@
|
||||
"""Provider registration for Replicate — metadata only."""
|
||||
|
||||
from backend.sdk import ProviderBuilder
|
||||
|
||||
replicate = (
|
||||
ProviderBuilder("replicate")
|
||||
.with_description("Run and host open-source models")
|
||||
.with_supported_auth_types("api_key")
|
||||
.build()
|
||||
)
|
||||
@@ -16,12 +16,24 @@ from backend.blocks.replicate._auth import (
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
ReplicateCredentialsInput,
|
||||
)
|
||||
from backend.blocks.replicate._helper import ReplicateOutputs, extract_result
|
||||
from backend.data.model import APIKeyCredentials, CredentialsField, SchemaField
|
||||
from backend.blocks.replicate._helper import extract_result
|
||||
from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
CredentialsField,
|
||||
NodeExecutionStats,
|
||||
SchemaField,
|
||||
)
|
||||
from backend.util.exceptions import BlockExecutionError, BlockInputError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Replicate hardware tier cost — most popular public models (Flux, SDXL,
|
||||
# Llama 70B etc.) run on Nvidia L40S at $0.001400/sec. Using a single
|
||||
# conservative mid-tier estimate is much better than a flat RUN charge,
|
||||
# which under-bills long-running models by 10-500×. Heavier models run
|
||||
# on A100 at $0.001400/sec; cheaper ones on L4 at $0.000275/sec.
|
||||
_REPLICATE_USD_PER_SEC = 0.001400
|
||||
|
||||
|
||||
class ReplicateModelBlock(Block):
|
||||
"""
|
||||
@@ -138,20 +150,54 @@ class ReplicateModelBlock(Block):
|
||||
"""
|
||||
Run the Replicate model. This method can be mocked for testing.
|
||||
|
||||
Uses predictions.async_create + async_wait instead of async_run so
|
||||
we can read ``prediction.metrics.predict_time`` after completion
|
||||
and emit it as ``provider_cost`` for the COST_USD resolver.
|
||||
|
||||
Args:
|
||||
model_ref: The model reference (e.g., "owner/model-name:version")
|
||||
model_inputs: The inputs to pass to the model
|
||||
api_key: The Replicate API key as SecretStr
|
||||
|
||||
Returns:
|
||||
Tuple of (result, prediction_id)
|
||||
Model output (same shape as previous async_run path)
|
||||
"""
|
||||
api_key_str = api_key.get_secret_value()
|
||||
client = ReplicateClient(api_token=api_key_str)
|
||||
output: ReplicateOutputs = await client.async_run(
|
||||
model_ref, input=model_inputs, wait=False
|
||||
) # type: ignore they suck at typing
|
||||
|
||||
result = extract_result(output)
|
||||
# Replicate SDK: version-pinned refs use `version=`; unpinned use
|
||||
# `model=`. Matches the `owner/name[:version]` contract above.
|
||||
if ":" in model_ref:
|
||||
model_name, version = model_ref.split(":", 1)
|
||||
prediction = await client.predictions.async_create(
|
||||
version=version, input=model_inputs
|
||||
)
|
||||
else:
|
||||
prediction = await client.predictions.async_create(
|
||||
model=model_ref, input=model_inputs
|
||||
)
|
||||
|
||||
return result
|
||||
await prediction.async_wait()
|
||||
|
||||
# async_wait returns normally on "failed"/"canceled" — only async_run
|
||||
# raises. Without this check we'd bill partial compute time on a
|
||||
# failed run and silently yield empty output.
|
||||
if prediction.status == "failed":
|
||||
raise RuntimeError(
|
||||
f"Replicate prediction failed: {prediction.error or 'unknown error'}"
|
||||
)
|
||||
if prediction.status == "canceled":
|
||||
raise RuntimeError("Replicate prediction was canceled")
|
||||
|
||||
if prediction.metrics and prediction.metrics.get("predict_time"):
|
||||
predict_time = float(prediction.metrics["predict_time"])
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(
|
||||
provider_cost=predict_time * _REPLICATE_USD_PER_SEC,
|
||||
provider_cost_type="cost_usd",
|
||||
)
|
||||
)
|
||||
|
||||
if prediction.output is None:
|
||||
raise RuntimeError("Replicate prediction returned no output")
|
||||
return extract_result(prediction.output)
|
||||
|
||||
@@ -0,0 +1,233 @@
|
||||
"""Unit tests for ReplicateModelBlock's predictions.async_create billing path.
|
||||
|
||||
Verifies the refactored run_model correctly:
|
||||
1. Uses predictions.async_create (version= vs model= based on ":" in model_ref)
|
||||
2. Awaits async_wait() for metrics to be populated
|
||||
3. Reads prediction.metrics["predict_time"] and emits provider_cost/cost_usd
|
||||
4. Returns extract_result(prediction.output) with the same shape as the old
|
||||
async_run path
|
||||
5. Gracefully skips merge_stats when metrics is missing (protects against a
|
||||
silent wallet-free leak on SDK quirks)
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.blocks._base import BlockCostType
|
||||
from backend.blocks.replicate.replicate_block import (
|
||||
_REPLICATE_USD_PER_SEC,
|
||||
ReplicateModelBlock,
|
||||
)
|
||||
from backend.data.block_cost_config import BLOCK_COSTS
|
||||
from backend.data.model import NodeExecutionStats
|
||||
|
||||
|
||||
def test_registered_as_cost_usd_150():
|
||||
entries = BLOCK_COSTS[ReplicateModelBlock]
|
||||
assert len(entries) == 1
|
||||
assert entries[0].cost_type == BlockCostType.COST_USD
|
||||
assert entries[0].cost_amount == 150
|
||||
|
||||
|
||||
def test_hardware_rate_constant_in_range():
|
||||
"""$0.0014/s is Nvidia L40S tier. Sanity-check we haven't accidentally
|
||||
shipped a rate that's off by an order of magnitude (e.g. $0.014 would
|
||||
10x over-bill every run).
|
||||
"""
|
||||
# Replicate's public hardware tiers: L4 $0.000275, A10G $0.000575,
|
||||
# L40S $0.000975, A100 $0.001400, A100-80GB $0.001725. L40S @
|
||||
# $0.0014/s covers most popular models with mild over-bill margin.
|
||||
assert 0.0005 <= _REPLICATE_USD_PER_SEC <= 0.002
|
||||
|
||||
|
||||
def _make_fake_prediction(output, predict_time=None, status="succeeded", error=None):
|
||||
"""Build a stand-in for replicate's Prediction with the attrs we touch."""
|
||||
pred = MagicMock()
|
||||
pred.output = output
|
||||
pred.status = status
|
||||
pred.error = error
|
||||
pred.metrics = {"predict_time": predict_time} if predict_time is not None else None
|
||||
pred.async_wait = AsyncMock(return_value=None)
|
||||
return pred
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_model_uses_version_keyword_when_ref_has_colon():
|
||||
"""`"owner/name:sha"` → predictions.async_create(version=sha, ...)."""
|
||||
block = ReplicateModelBlock()
|
||||
prediction = _make_fake_prediction(output="hello", predict_time=3.2)
|
||||
|
||||
client = MagicMock()
|
||||
client.predictions.async_create = AsyncMock(return_value=prediction)
|
||||
|
||||
with patch(
|
||||
"backend.blocks.replicate.replicate_block.ReplicateClient",
|
||||
return_value=client,
|
||||
):
|
||||
await block.run_model(
|
||||
"owner/model:abc123", {"prompt": "hi"}, SecretStr("fake-key")
|
||||
)
|
||||
|
||||
client.predictions.async_create.assert_awaited_once_with(
|
||||
version="abc123", input={"prompt": "hi"}
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_model_uses_model_keyword_when_ref_is_unpinned():
|
||||
"""`"owner/name"` (no `:version`) → predictions.async_create(model=ref, ...)."""
|
||||
block = ReplicateModelBlock()
|
||||
prediction = _make_fake_prediction(output="hello", predict_time=1.0)
|
||||
|
||||
client = MagicMock()
|
||||
client.predictions.async_create = AsyncMock(return_value=prediction)
|
||||
|
||||
with patch(
|
||||
"backend.blocks.replicate.replicate_block.ReplicateClient",
|
||||
return_value=client,
|
||||
):
|
||||
await block.run_model(
|
||||
"owner/flux-schnell", {"prompt": "cat"}, SecretStr("fake-key")
|
||||
)
|
||||
|
||||
client.predictions.async_create.assert_awaited_once_with(
|
||||
model="owner/flux-schnell", input={"prompt": "cat"}
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_model_emits_provider_cost_from_predict_time():
|
||||
"""Core contract: provider_cost = predict_time * $0.0014/s, cost_usd."""
|
||||
block = ReplicateModelBlock()
|
||||
# 5-second run → 5 * 0.0014 = $0.007 → 150 cr/$ * 0.007 ceil = 2 cr
|
||||
prediction = _make_fake_prediction(output="result-data", predict_time=5.0)
|
||||
|
||||
client = MagicMock()
|
||||
client.predictions.async_create = AsyncMock(return_value=prediction)
|
||||
|
||||
captured: list[NodeExecutionStats] = []
|
||||
with (
|
||||
patch(
|
||||
"backend.blocks.replicate.replicate_block.ReplicateClient",
|
||||
return_value=client,
|
||||
),
|
||||
patch.object(block, "merge_stats", side_effect=captured.append),
|
||||
):
|
||||
result = await block.run_model("owner/model", {}, SecretStr("fake-key"))
|
||||
|
||||
assert len(captured) == 1
|
||||
stats = captured[0]
|
||||
assert stats.provider_cost == pytest.approx(5.0 * _REPLICATE_USD_PER_SEC)
|
||||
assert stats.provider_cost_type == "cost_usd"
|
||||
assert result == "result-data"
|
||||
# async_wait MUST be called before reading metrics — otherwise metrics
|
||||
# is None on in-flight predictions.
|
||||
prediction.async_wait.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_model_skips_merge_stats_when_metrics_missing():
|
||||
"""Protect against the nightmare scenario: if the SDK stops populating
|
||||
metrics (or we hit a prediction that completes without metrics),
|
||||
merge_stats must NOT fire. Otherwise we'd emit a zero provider_cost
|
||||
that the resolver treats as 0 credits — a silent wallet-free leak.
|
||||
The block's run() path relies on the flat 0 fallback via
|
||||
charge_reconciled_usage's pre-flight balance guard.
|
||||
"""
|
||||
block = ReplicateModelBlock()
|
||||
prediction = _make_fake_prediction(output="x", predict_time=None)
|
||||
|
||||
client = MagicMock()
|
||||
client.predictions.async_create = AsyncMock(return_value=prediction)
|
||||
|
||||
captured: list[NodeExecutionStats] = []
|
||||
with (
|
||||
patch(
|
||||
"backend.blocks.replicate.replicate_block.ReplicateClient",
|
||||
return_value=client,
|
||||
),
|
||||
patch.object(block, "merge_stats", side_effect=captured.append),
|
||||
):
|
||||
await block.run_model("owner/model", {}, SecretStr("fake-key"))
|
||||
|
||||
# No merge_stats call → no provider_cost emission → COST_USD resolver
|
||||
# returns 0 → run is effectively free post-flight, but pre-flight
|
||||
# balance guard still blocks zero-balance wallets per PR #12894.
|
||||
assert captured == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_model_skips_merge_when_predict_time_is_zero():
|
||||
"""A 0-second predict_time would emit provider_cost=0, which is useless
|
||||
telemetry. Treat 0 same as missing (no emission)."""
|
||||
block = ReplicateModelBlock()
|
||||
prediction = _make_fake_prediction(output="x", predict_time=0)
|
||||
|
||||
client = MagicMock()
|
||||
client.predictions.async_create = AsyncMock(return_value=prediction)
|
||||
|
||||
captured: list[NodeExecutionStats] = []
|
||||
with (
|
||||
patch(
|
||||
"backend.blocks.replicate.replicate_block.ReplicateClient",
|
||||
return_value=client,
|
||||
),
|
||||
patch.object(block, "merge_stats", side_effect=captured.append),
|
||||
):
|
||||
await block.run_model("owner/model", {}, SecretStr("fake-key"))
|
||||
|
||||
assert captured == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_model_raises_on_failed_status_and_does_not_bill():
|
||||
"""async_wait returns normally on 'failed' — without an explicit status
|
||||
check we'd bill partial compute time AND yield 'status: succeeded' with
|
||||
empty output. Verify we raise BEFORE merge_stats so the failed run is
|
||||
not billed."""
|
||||
block = ReplicateModelBlock()
|
||||
prediction = _make_fake_prediction(
|
||||
output=None, predict_time=2.5, status="failed", error="CUDA OOM"
|
||||
)
|
||||
|
||||
client = MagicMock()
|
||||
client.predictions.async_create = AsyncMock(return_value=prediction)
|
||||
|
||||
captured: list[NodeExecutionStats] = []
|
||||
with (
|
||||
patch(
|
||||
"backend.blocks.replicate.replicate_block.ReplicateClient",
|
||||
return_value=client,
|
||||
),
|
||||
patch.object(block, "merge_stats", side_effect=captured.append),
|
||||
):
|
||||
with pytest.raises(RuntimeError, match="CUDA OOM"):
|
||||
await block.run_model("owner/model", {}, SecretStr("fake-key"))
|
||||
|
||||
assert captured == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_model_raises_on_canceled_status_and_does_not_bill():
|
||||
"""Canceled predictions — same guarantees as failed: don't bill, surface
|
||||
the cancellation."""
|
||||
block = ReplicateModelBlock()
|
||||
prediction = _make_fake_prediction(output=None, predict_time=1.0, status="canceled")
|
||||
|
||||
client = MagicMock()
|
||||
client.predictions.async_create = AsyncMock(return_value=prediction)
|
||||
|
||||
captured: list[NodeExecutionStats] = []
|
||||
with (
|
||||
patch(
|
||||
"backend.blocks.replicate.replicate_block.ReplicateClient",
|
||||
return_value=client,
|
||||
),
|
||||
patch.object(block, "merge_stats", side_effect=captured.append),
|
||||
):
|
||||
with pytest.raises(RuntimeError, match="canceled"):
|
||||
await block.run_model("owner/model", {}, SecretStr("fake-key"))
|
||||
|
||||
assert captured == []
|
||||
10
autogpt_platform/backend/backend/blocks/slant3d/_config.py
Normal file
10
autogpt_platform/backend/backend/blocks/slant3d/_config.py
Normal file
@@ -0,0 +1,10 @@
|
||||
"""Provider registration for Slant 3D — metadata only (auth lives in ``_api.py``)."""
|
||||
|
||||
from backend.sdk import ProviderBuilder
|
||||
|
||||
slant3d = (
|
||||
ProviderBuilder("slant3d")
|
||||
.with_description("On-demand 3D printing")
|
||||
.with_supported_auth_types("api_key")
|
||||
.build()
|
||||
)
|
||||
10
autogpt_platform/backend/backend/blocks/smartlead/_config.py
Normal file
10
autogpt_platform/backend/backend/blocks/smartlead/_config.py
Normal file
@@ -0,0 +1,10 @@
|
||||
"""Provider registration for Smartlead — metadata only (auth lives in ``_auth.py``)."""
|
||||
|
||||
from backend.sdk import ProviderBuilder
|
||||
|
||||
smartlead = (
|
||||
ProviderBuilder("smartlead")
|
||||
.with_description("Cold email outreach at scale")
|
||||
.with_supported_auth_types("api_key")
|
||||
.build()
|
||||
)
|
||||
@@ -6,6 +6,7 @@ from backend.sdk import BlockCostType, ProviderBuilder
|
||||
# to COST_USD.
|
||||
stagehand = (
|
||||
ProviderBuilder("stagehand")
|
||||
.with_description("AI browser automation")
|
||||
.with_api_key("STAGEHAND_API_KEY", "Stagehand API Key")
|
||||
.with_base_cost(1, BlockCostType.SECOND, cost_divisor=3)
|
||||
.build()
|
||||
|
||||
10
autogpt_platform/backend/backend/blocks/telegram/_config.py
Normal file
10
autogpt_platform/backend/backend/blocks/telegram/_config.py
Normal file
@@ -0,0 +1,10 @@
|
||||
"""Provider registration for Telegram — metadata only."""
|
||||
|
||||
from backend.sdk import ProviderBuilder
|
||||
|
||||
telegram = (
|
||||
ProviderBuilder("telegram")
|
||||
.with_description("Bot messaging and groups")
|
||||
.with_supported_auth_types("api_key")
|
||||
.build()
|
||||
)
|
||||
@@ -105,10 +105,13 @@ class UnrealTextToSpeechBlock(Block):
|
||||
input_data.text,
|
||||
input_data.voice_id,
|
||||
)
|
||||
# Unreal Speech: $16 / 1M chars = $0.000016/char. Emit USD so the
|
||||
# COST_USD resolver (150 cr/$ via BLOCK_COSTS) bills proportionally
|
||||
# instead of the old flat 5 cr.
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(
|
||||
provider_cost=float(len(input_data.text)),
|
||||
provider_cost_type="characters",
|
||||
provider_cost=len(input_data.text) * 0.000016,
|
||||
provider_cost_type="cost_usd",
|
||||
)
|
||||
)
|
||||
yield "mp3_url", api_response["OutputUri"]
|
||||
|
||||
10
autogpt_platform/backend/backend/blocks/todoist/_config.py
Normal file
10
autogpt_platform/backend/backend/blocks/todoist/_config.py
Normal file
@@ -0,0 +1,10 @@
|
||||
"""Provider registration for Todoist — metadata only (auth lives in ``_auth.py``)."""
|
||||
|
||||
from backend.sdk import ProviderBuilder
|
||||
|
||||
todoist = (
|
||||
ProviderBuilder("todoist")
|
||||
.with_description("Tasks and projects")
|
||||
.with_supported_auth_types("oauth2")
|
||||
.build()
|
||||
)
|
||||
10
autogpt_platform/backend/backend/blocks/twitter/_config.py
Normal file
10
autogpt_platform/backend/backend/blocks/twitter/_config.py
Normal file
@@ -0,0 +1,10 @@
|
||||
"""Provider registration for X (Twitter) — metadata only (auth lives in ``_auth.py``)."""
|
||||
|
||||
from backend.sdk import ProviderBuilder
|
||||
|
||||
twitter = (
|
||||
ProviderBuilder("twitter")
|
||||
.with_description("Tweets, timelines, and DMs")
|
||||
.with_supported_auth_types("oauth2")
|
||||
.build()
|
||||
)
|
||||
@@ -27,7 +27,7 @@ from backend.blocks.video._utils import (
|
||||
strip_chapters_inplace,
|
||||
)
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.data.model import CredentialsField, SchemaField
|
||||
from backend.data.model import CredentialsField, NodeExecutionStats, SchemaField
|
||||
from backend.util.exceptions import BlockExecutionError
|
||||
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
|
||||
|
||||
@@ -44,7 +44,8 @@ class VideoNarrationBlock(Block):
|
||||
)
|
||||
script: str = SchemaField(description="Narration script text")
|
||||
voice_id: str = SchemaField(
|
||||
description="ElevenLabs voice ID", default="21m00Tcm4TlvDq8ikWAM" # Rachel
|
||||
description="ElevenLabs voice ID",
|
||||
default="21m00Tcm4TlvDq8ikWAM", # Rachel
|
||||
)
|
||||
model_id: Literal[
|
||||
"eleven_multilingual_v2",
|
||||
@@ -124,6 +125,26 @@ class VideoNarrationBlock(Block):
|
||||
return_format="for_block_output",
|
||||
)
|
||||
|
||||
# Models that consume 0.5 credits per character (v2.5 tier). All other
|
||||
# models default to 1.0 credit per character.
|
||||
_HALF_RATE_MODELS = {"eleven_flash_v2_5", "eleven_turbo_v2_5"}
|
||||
# ElevenLabs Starter plan: $5 / 30K credits = $0.000167 / credit.
|
||||
_USD_PER_CREDIT = 0.000167
|
||||
|
||||
def _record_script_cost(self, script: str, model_id: str) -> None:
|
||||
"""Emit provider_cost (USD) for the narration run so the COST_USD
|
||||
resolver can bill real ElevenLabs spend. Flash/Turbo v2.5 bill at
|
||||
half the char rate of Multilingual/Turbo v2.
|
||||
"""
|
||||
credits_per_char = 0.5 if model_id in self._HALF_RATE_MODELS else 1.0
|
||||
script_usd = len(script) * self._USD_PER_CREDIT * credits_per_char
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(
|
||||
provider_cost=script_usd,
|
||||
provider_cost_type="cost_usd",
|
||||
)
|
||||
)
|
||||
|
||||
def _generate_narration_audio(
|
||||
self, api_key: str, script: str, voice_id: str, model_id: str
|
||||
) -> bytes:
|
||||
@@ -223,6 +244,8 @@ class VideoNarrationBlock(Block):
|
||||
input_data.model_id,
|
||||
)
|
||||
|
||||
self._record_script_cost(input_data.script, input_data.model_id)
|
||||
|
||||
# Save audio to exec file path
|
||||
audio_filename = MediaFileType(f"{node_exec_id}_narration.mp3")
|
||||
audio_abspath = get_exec_file_path(
|
||||
|
||||
@@ -15,6 +15,7 @@ from ._api import llm_api_call
|
||||
|
||||
wolfram = (
|
||||
ProviderBuilder("wolfram")
|
||||
.with_description("Computational knowledge engine")
|
||||
.with_api_key("WOLFRAM_APP_ID", "Wolfram Alpha App ID")
|
||||
.with_base_cost(1, BlockCostType.RUN)
|
||||
.build()
|
||||
|
||||
@@ -4,6 +4,7 @@ from ._oauth import WordPressOAuthHandler, WordPressScope
|
||||
|
||||
wordpress = (
|
||||
ProviderBuilder("wordpress")
|
||||
.with_description("Posts, pages, and media")
|
||||
.with_base_cost(1, BlockCostType.RUN)
|
||||
.with_oauth(
|
||||
WordPressOAuthHandler,
|
||||
|
||||
@@ -0,0 +1,10 @@
|
||||
"""Provider registration for ZeroBounce — metadata only (auth lives in ``_auth.py``)."""
|
||||
|
||||
from backend.sdk import ProviderBuilder
|
||||
|
||||
zerobounce = (
|
||||
ProviderBuilder("zerobounce")
|
||||
.with_description("Email address verification")
|
||||
.with_supported_auth_types("api_key")
|
||||
.build()
|
||||
)
|
||||
1
autogpt_platform/backend/backend/copilot/bot/AGENTS.md
Normal file
1
autogpt_platform/backend/backend/copilot/bot/AGENTS.md
Normal file
@@ -0,0 +1 @@
|
||||
@README.md
|
||||
1
autogpt_platform/backend/backend/copilot/bot/CLAUDE.md
Normal file
1
autogpt_platform/backend/backend/copilot/bot/CLAUDE.md
Normal file
@@ -0,0 +1 @@
|
||||
@README.md
|
||||
79
autogpt_platform/backend/backend/copilot/bot/README.md
Normal file
79
autogpt_platform/backend/backend/copilot/bot/README.md
Normal file
@@ -0,0 +1,79 @@
|
||||
# CoPilot Bot
|
||||
|
||||
Multi-platform chat bot that bridges AutoPilot to Discord (and later Telegram, Slack, etc).
|
||||
|
||||
## Running
|
||||
|
||||
```bash
|
||||
# As a standalone service
|
||||
poetry run copilot-bot
|
||||
|
||||
# Or auto-start alongside the rest of the platform
|
||||
poetry run app # starts the bot too if AUTOPILOT_BOT_DISCORD_TOKEN is set
|
||||
```
|
||||
|
||||
## Required environment variables
|
||||
|
||||
See `backend/.env.default` for the full list with documentation. Minimum setup:
|
||||
|
||||
| Variable | Purpose |
|
||||
|----------|---------|
|
||||
| `AUTOPILOT_BOT_DISCORD_TOKEN` | Discord bot token — enables the Discord adapter |
|
||||
| `FRONTEND_BASE_URL` | Frontend base URL for link confirmation pages (shared with the rest of the backend) |
|
||||
| `REDIS_HOST` / `REDIS_PORT` | Session + thread subscription state + copilot stream subscription (inherited from the shared backend config) |
|
||||
| `PLATFORMLINKINGMANAGER_HOST` | DNS name of the `PlatformLinkingManager` service pod (cluster-internal RPC) |
|
||||
|
||||
## Architecture
|
||||
|
||||
```
|
||||
bot/
|
||||
├── app.py # CoPilotChatBridge(AppService), adapter factory, outbound @expose RPC
|
||||
├── config.py # Shared (platform-agnostic) config
|
||||
├── handler.py # Core logic: routing, linking, batched streaming
|
||||
├── bot_backend.py # Thin facade over PlatformLinkingManagerClient + stream_registry
|
||||
├── text.py # Text splitting + batch formatting
|
||||
├── threads.py # Redis-backed thread subscription tracking
|
||||
└── adapters/
|
||||
├── base.py # PlatformAdapter interface + MessageContext
|
||||
└── discord/
|
||||
├── adapter.py # Gateway connection, events, sends, thread creation
|
||||
├── commands.py # Slash commands (/setup, /help, /unlink)
|
||||
└── config.py # Discord token + platform limits
|
||||
```
|
||||
|
||||
**Locality rule:** anything platform-specific lives under `adapters/<platform>/`.
|
||||
The only file that names specific platforms is `app.py`, which is the factory
|
||||
that decides which adapters to instantiate based on which tokens are set.
|
||||
|
||||
## How messaging works
|
||||
|
||||
1. User mentions the bot in a channel
|
||||
2. Adapter's `on_message` handler fires, constructs a `MessageContext`, passes
|
||||
it to the shared `MessageHandler`
|
||||
3. Handler:
|
||||
- Checks if the user/server is linked (via `bot_backend`)
|
||||
- If not linked → sends a "Link Account" button prompt
|
||||
- If linked → creates a thread (for channels) or uses the existing thread/DM
|
||||
- Marks the thread as subscribed in Redis (7-day TTL)
|
||||
- Streams the AutoPilot response back, chunked at the adapter's
|
||||
`chunk_flush_at` boundary
|
||||
4. Messages that arrive while a stream is running get batched and sent as a
|
||||
single follow-up turn once the current stream ends
|
||||
|
||||
## Adding a new platform
|
||||
|
||||
1. Create `adapters/<platform>/` with `adapter.py`, `commands.py` (if the
|
||||
platform has commands), and `config.py`
|
||||
2. `adapter.py` subclasses `PlatformAdapter` and implements all its abstract
|
||||
methods — `max_message_length`, `chunk_flush_at`, `send_message`,
|
||||
`send_link`, `create_thread`, etc.
|
||||
3. `config.py` declares the platform's env vars and any platform-specific
|
||||
numbers (message limits, token name, etc.)
|
||||
4. Add two lines to `app.py::_build_adapters`:
|
||||
```python
|
||||
if <platform>_config.BOT_TOKEN:
|
||||
adapters.append(<Platform>Adapter(api))
|
||||
```
|
||||
|
||||
The core handler, text utilities, thread tracking, and platform API all stay
|
||||
untouched.
|
||||
19
autogpt_platform/backend/backend/copilot/bot/__main__.py
Normal file
19
autogpt_platform/backend/backend/copilot/bot/__main__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
"""Entry point for running the CoPilot Chat Bridge service.
|
||||
|
||||
Usage:
|
||||
poetry run copilot-bot
|
||||
python -m backend.copilot.bot
|
||||
"""
|
||||
|
||||
from backend.app import run_processes
|
||||
|
||||
from .app import CoPilotChatBridge
|
||||
|
||||
|
||||
def main():
|
||||
"""Run the CoPilot Chat Bridge service."""
|
||||
run_processes(CoPilotChatBridge())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user