mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-30 03:00:41 -04:00
Compare commits
24 Commits
feat/platf
...
fix/copilo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7a9b0827bc | ||
|
|
7cc1edc61f | ||
|
|
9969b1ac07 | ||
|
|
d765715fbc | ||
|
|
8e13d4cb27 | ||
|
|
e59fe5af76 | ||
|
|
9e8622c1d1 | ||
|
|
0dcd25f73f | ||
|
|
78619ba090 | ||
|
|
4a1741cc15 | ||
|
|
c08b9774dc | ||
|
|
fe3d6fb118 | ||
|
|
c6d31f8252 | ||
|
|
28ae7ebac8 | ||
|
|
e0f9146d54 | ||
|
|
c3c2737c42 | ||
|
|
37f247c795 | ||
|
|
ae4a421620 | ||
|
|
2879528308 | ||
|
|
1974ec6260 | ||
|
|
932ecd3a07 | ||
|
|
4a567a55a4 | ||
|
|
2b28434786 | ||
|
|
5d1cdc2bad |
@@ -160,6 +160,24 @@ while clean_polls < required_clean:
|
||||
|
||||
Only after `clean_polls == 2` do you report `ORCHESTRATOR:DONE`.
|
||||
|
||||
### Concrete CI fetch (don't parse `gh pr checks` text columns)
|
||||
|
||||
The `fetch_check_runs(PR)` step above must use `--json`, not the default text output. Job names can contain spaces and parentheses (e.g. `test (3.11)`, `Analyze (python)`), so `gh pr checks $PR | awk '{print $2}'` extracts `(3.11)` instead of the status — leading to a clean-poll firing while jobs are still pending.
|
||||
|
||||
```bash
|
||||
# Reliable: use --json so columns are unambiguous.
|
||||
ci_json=$(gh pr checks $PR --repo Significant-Gravitas/AutoGPT --json name,state,bucket)
|
||||
pending=$(echo "$ci_json" | jq '[.[] | select(.bucket == "pending")] | length')
|
||||
failed=$(echo "$ci_json" | jq '[.[] | select(.bucket == "fail" or .bucket == "cancel")] | length')
|
||||
|
||||
# Buckets are: pass | fail | pending | cancel | skipping
|
||||
# (NOTE: gh pr checks does NOT expose `conclusion` as a JSON field —
|
||||
# only `bucket`. Don't confuse with the GitHub REST API's check_runs
|
||||
# endpoint, which DOES use conclusion.)
|
||||
```
|
||||
|
||||
Map back to the pseudocode above: `bucket == "pending"` is `ci.conclusion is None (still in_progress)`; `bucket in {"fail", "cancel"}` is `ci.conclusion in NON_SUCCESS_TERMINAL`; `bucket in {"pass", "skipping"}` is clean.
|
||||
|
||||
### Why 2 clean polls, not 1
|
||||
|
||||
A single green snapshot can be misleading — the final CI check often completes ~30s before a bot posts its delayed review. One quiet cycle does not prove the PR is stable; two consecutive cycles with no new threads, reviews, or issue comments arriving gives high confidence nothing else is incoming.
|
||||
@@ -196,6 +214,18 @@ The child skill returning is a **loop iteration boundary**, not a conversation t
|
||||
|
||||
If the user needs to approve a risky action mid-loop (e.g., a force-push or a destructive git operation), pause there — but not at the routine "round N finished, round N+1 needed" boundary. Those are silent transitions.
|
||||
|
||||
### **Run /pr-polish in the foreground — never in a background agent**
|
||||
|
||||
Spawning `/pr-polish` inside an `Agent(subagent_type="general-purpose")` background task **does not work**. Background agents don't inherit the parent's slash-command registry, so `Skill(skill="pr-review")` and `Skill(skill="pr-address")` calls aren't available — the agent has to manually replicate the child skills' logic, which is fragile and tends to stall on the first network or rate-limit hiccup. Symptom: the background task reports `stalled: no progress for 600s` mid-review.
|
||||
|
||||
Run `/pr-polish` inline in the foreground conversation. If the user asks for "/pr-polish + /pr-test in parallel", split them: foreground `/pr-polish`, and ONLY then can the test step go to a background agent (because `/pr-test` doesn't itself need to invoke skills).
|
||||
|
||||
### **You MUST invoke `Skill(pr-review)` every round — even when bot reviews already exist**
|
||||
|
||||
A common failure mode: CodeRabbit / autogpt-reviewer / Sentry have already posted findings on the PR, and the orchestrator skips the `Skill(pr-review)` step on the assumption that "review has been done." That's wrong — the outer loop's purpose is to layer **the agent's own review** on top of the bot reviews, catching issues the bots miss (architecture, naming, cross-file invariants, hidden coupling). If the orchestrator only addresses bot findings without ever running its own review, the loop converges to "bot-clean" but not "agent-reviewed-clean," and the user reasonably asks "did /pr-polish even read the diff?"
|
||||
|
||||
**Self-check before reporting `ORCHESTRATOR:DONE`:** confirm at least one `Skill(skill="pr-review")` call appears in the current orchestration. If none, the loop is incomplete — go back and run one round.
|
||||
|
||||
## GitHub rate limits
|
||||
|
||||
This skill issues many GraphQL calls (one review-thread query per outer iteration plus per-poll queries inside polish polling). Expect the GraphQL budget to be tight on large PRs. When `gh api rate_limit --jq .resources.graphql.remaining` drops below ~200, back off:
|
||||
|
||||
79
.github/workflows/platform-backend-ci.yml
vendored
79
.github/workflows/platform-backend-ci.yml
vendored
@@ -119,10 +119,12 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
services:
|
||||
redis:
|
||||
image: redis:latest
|
||||
ports:
|
||||
- 6379:6379
|
||||
# Redis is provisioned as a real 3-shard cluster below via docker
|
||||
# run (see the "Start Redis Cluster" step). GHA services can't
|
||||
# override the image CMD or stand up multi-container clusters, so
|
||||
# that setup is inlined — it mirrors the topology of the local dev
|
||||
# compose stack (autogpt_platform/docker-compose.platform.yml) and
|
||||
# prod helm chart.
|
||||
rabbitmq:
|
||||
image: rabbitmq:4.1.4
|
||||
ports:
|
||||
@@ -166,6 +168,68 @@ jobs:
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Start Redis Cluster (3 shards)
|
||||
run: |
|
||||
# 3-master Redis Cluster matching the local compose stack
|
||||
# (autogpt_platform/docker-compose.platform.yml) and prod. Each
|
||||
# shard runs in its own container on a dedicated bridge network,
|
||||
# announces its compose-style hostname for intra-network clients,
|
||||
# and publishes 1700N on the GHA host so tests can reach every
|
||||
# shard via localhost. The backend's ``_address_remap`` rewrites
|
||||
# every CLUSTER SLOTS reply to localhost:<announced-port>, which
|
||||
# picks the right published port per shard.
|
||||
#
|
||||
# Not reusing docker-compose.platform.yml directly because compose
|
||||
# validates the full file even when only some services are ``up``,
|
||||
# and that file references services (db/kong/...) defined in a
|
||||
# sibling compose file — pulling both in would needlessly couple
|
||||
# CI to the full local-dev stack.
|
||||
docker network create redis-cluster-ci
|
||||
for i in 0 1 2; do
|
||||
port=$((17000 + i))
|
||||
bus=$((27000 + i))
|
||||
docker run -d --name redis-$i --network redis-cluster-ci \
|
||||
--network-alias redis-$i \
|
||||
-p $port:$port \
|
||||
redis:7 \
|
||||
redis-server --port $port \
|
||||
--cluster-enabled yes \
|
||||
--cluster-config-file nodes.conf \
|
||||
--cluster-node-timeout 5000 \
|
||||
--cluster-require-full-coverage no \
|
||||
--cluster-announce-hostname redis-$i \
|
||||
--cluster-announce-port $port \
|
||||
--cluster-announce-bus-port $bus \
|
||||
--cluster-preferred-endpoint-type hostname
|
||||
done
|
||||
# Wait for each shard to accept commands.
|
||||
for i in 0 1 2; do
|
||||
port=$((17000 + i))
|
||||
for _ in $(seq 1 30); do
|
||||
docker exec redis-$i redis-cli -p $port ping 2>/dev/null | grep -q PONG && break
|
||||
sleep 1
|
||||
done
|
||||
done
|
||||
# Form the cluster from an init container on the same network so
|
||||
# --cluster-preferred-endpoint-type hostname resolves redis-0/1/2.
|
||||
docker run --rm --network redis-cluster-ci redis:7 \
|
||||
redis-cli --cluster create \
|
||||
redis-0:17000 redis-1:17001 redis-2:17002 \
|
||||
--cluster-replicas 0 --cluster-yes
|
||||
# Confirm convergence.
|
||||
for _ in $(seq 1 30); do
|
||||
state=$(docker exec redis-0 redis-cli -p 17000 cluster info | awk -F: '/^cluster_state:/ {print $2}' | tr -d '[:cntrl:]')
|
||||
if [ "$state" = "ok" ]; then
|
||||
echo "Redis Cluster ready (3 shards, state=ok)"
|
||||
docker exec redis-0 redis-cli -p 17000 cluster nodes
|
||||
exit 0
|
||||
fi
|
||||
sleep 1
|
||||
done
|
||||
echo "Redis Cluster failed to reach ok state" >&2
|
||||
docker exec redis-0 redis-cli -p 17000 cluster info >&2 || true
|
||||
exit 1
|
||||
|
||||
- name: Setup Supabase
|
||||
uses: supabase/setup-cli@v1
|
||||
with:
|
||||
@@ -286,8 +350,13 @@ jobs:
|
||||
SUPABASE_SERVICE_ROLE_KEY: ${{ steps.supabase.outputs.SERVICE_ROLE_KEY }}
|
||||
JWT_VERIFY_KEY: ${{ steps.supabase.outputs.JWT_SECRET }}
|
||||
REDIS_HOST: "localhost"
|
||||
REDIS_PORT: "6379"
|
||||
REDIS_PORT: "17000"
|
||||
ENCRYPTION_KEY: "dvziYgz0KSK8FENhju0ZYi8-fRTfAdlz6YLhdB_jhNw=" # DO NOT USE IN PRODUCTION!!
|
||||
# Opt-in: lets backend/data/e2e_redis_restart_test.py spin up its
|
||||
# own isolated 3-shard cluster (ports 27110–27112) and exercise
|
||||
# ``docker restart <shard>`` mid-stream. Off locally so a
|
||||
# contributor's ``poetry run test`` doesn't pay the ~15s cost.
|
||||
E2E_RESTART_ISOLATED: "1"
|
||||
|
||||
- name: Upload coverage reports to Codecov
|
||||
if: ${{ !cancelled() }}
|
||||
|
||||
4
.gitignore
vendored
4
.gitignore
vendored
@@ -196,3 +196,7 @@ test.db
|
||||
plans/
|
||||
.claude/worktrees/
|
||||
test-results/
|
||||
|
||||
# Playwright MCP / local browser-testing artifacts
|
||||
.playwright-mcp/
|
||||
copilot-session-switch-qa/
|
||||
|
||||
@@ -1,33 +0,0 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
||||
class RateLimitSettings(BaseSettings):
|
||||
redis_host: str = Field(
|
||||
default="redis://localhost:6379",
|
||||
description="Redis host",
|
||||
validation_alias="REDIS_HOST",
|
||||
)
|
||||
|
||||
redis_port: str = Field(
|
||||
default="6379", description="Redis port", validation_alias="REDIS_PORT"
|
||||
)
|
||||
|
||||
redis_password: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Redis password",
|
||||
validation_alias="REDIS_PASSWORD",
|
||||
)
|
||||
|
||||
requests_per_minute: int = Field(
|
||||
default=60,
|
||||
description="Maximum number of requests allowed per minute per API key",
|
||||
validation_alias="RATE_LIMIT_REQUESTS_PER_MINUTE",
|
||||
)
|
||||
|
||||
model_config = SettingsConfigDict(case_sensitive=True, extra="ignore")
|
||||
|
||||
|
||||
RATE_LIMIT_SETTINGS = RateLimitSettings()
|
||||
@@ -1,51 +0,0 @@
|
||||
import time
|
||||
from typing import Tuple
|
||||
|
||||
from redis import Redis
|
||||
|
||||
from .config import RATE_LIMIT_SETTINGS
|
||||
|
||||
|
||||
class RateLimiter:
|
||||
def __init__(
|
||||
self,
|
||||
redis_host: str = RATE_LIMIT_SETTINGS.redis_host,
|
||||
redis_port: str = RATE_LIMIT_SETTINGS.redis_port,
|
||||
redis_password: str | None = RATE_LIMIT_SETTINGS.redis_password,
|
||||
requests_per_minute: int = RATE_LIMIT_SETTINGS.requests_per_minute,
|
||||
):
|
||||
self.redis = Redis(
|
||||
host=redis_host,
|
||||
port=int(redis_port),
|
||||
password=redis_password,
|
||||
decode_responses=True,
|
||||
)
|
||||
self.window = 60
|
||||
self.max_requests = requests_per_minute
|
||||
|
||||
async def check_rate_limit(self, api_key_id: str) -> Tuple[bool, int, int]:
|
||||
"""
|
||||
Check if request is within rate limits.
|
||||
|
||||
Args:
|
||||
api_key_id: The API key identifier to check
|
||||
|
||||
Returns:
|
||||
Tuple of (is_allowed, remaining_requests, reset_time)
|
||||
"""
|
||||
now = time.time()
|
||||
window_start = now - self.window
|
||||
key = f"ratelimit:{api_key_id}:1min"
|
||||
|
||||
pipe = self.redis.pipeline()
|
||||
pipe.zremrangebyscore(key, 0, window_start)
|
||||
pipe.zadd(key, {str(now): now})
|
||||
pipe.zcount(key, window_start, now)
|
||||
pipe.expire(key, self.window)
|
||||
|
||||
_, _, request_count, _ = pipe.execute()
|
||||
|
||||
remaining = max(0, self.max_requests - request_count)
|
||||
reset_time = int(now + self.window)
|
||||
|
||||
return request_count <= self.max_requests, remaining, reset_time
|
||||
@@ -1,32 +0,0 @@
|
||||
from fastapi import HTTPException, Request
|
||||
from starlette.middleware.base import RequestResponseEndpoint
|
||||
|
||||
from .limiter import RateLimiter
|
||||
|
||||
|
||||
async def rate_limit_middleware(request: Request, call_next: RequestResponseEndpoint):
|
||||
"""FastAPI middleware for rate limiting API requests."""
|
||||
limiter = RateLimiter()
|
||||
|
||||
if not request.url.path.startswith("/api"):
|
||||
return await call_next(request)
|
||||
|
||||
api_key = request.headers.get("Authorization")
|
||||
if not api_key:
|
||||
return await call_next(request)
|
||||
|
||||
api_key = api_key.replace("Bearer ", "")
|
||||
|
||||
is_allowed, remaining, reset_time = await limiter.check_rate_limit(api_key)
|
||||
|
||||
if not is_allowed:
|
||||
raise HTTPException(
|
||||
status_code=429, detail="Rate limit exceeded. Please try again later."
|
||||
)
|
||||
|
||||
response = await call_next(request)
|
||||
response.headers["X-RateLimit-Limit"] = str(limiter.max_requests)
|
||||
response.headers["X-RateLimit-Remaining"] = str(remaining)
|
||||
response.headers["X-RateLimit-Reset"] = str(reset_time)
|
||||
|
||||
return response
|
||||
@@ -1,13 +1,16 @@
|
||||
import asyncio
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import TYPE_CHECKING, Any, Union
|
||||
|
||||
from expiringdict import ExpiringDict
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from redis.asyncio import Redis as AsyncRedis
|
||||
from redis.asyncio.cluster import RedisCluster as AsyncRedisCluster
|
||||
from redis.asyncio.lock import Lock as AsyncRedisLock
|
||||
|
||||
AsyncRedisLike = Union[AsyncRedis, AsyncRedisCluster]
|
||||
|
||||
|
||||
class AsyncRedisKeyedMutex:
|
||||
"""
|
||||
@@ -17,7 +20,7 @@ class AsyncRedisKeyedMutex:
|
||||
in case the key is not unlocked for a specified duration, to prevent memory leaks.
|
||||
"""
|
||||
|
||||
def __init__(self, redis: "AsyncRedis", timeout: int | None = 60):
|
||||
def __init__(self, redis: "AsyncRedisLike", timeout: int | None = 60):
|
||||
self.redis = redis
|
||||
self.timeout = timeout
|
||||
self.locks: dict[Any, "AsyncRedisLock"] = ExpiringDict(
|
||||
|
||||
@@ -37,6 +37,23 @@ JWT_VERIFY_KEY=your-super-secret-jwt-token-with-at-least-32-characters-long
|
||||
ENCRYPTION_KEY=dvziYgz0KSK8FENhju0ZYi8-fRTfAdlz6YLhdB_jhNw=
|
||||
UNSUBSCRIBE_SECRET_KEY=HlP8ivStJjmbf6NKi78m_3FnOogut0t5ckzjsIqeaio=
|
||||
|
||||
# Web Push (VAPID) — generate with: poetry run python -c "
|
||||
# from py_vapid import Vapid; import base64
|
||||
# from cryptography.hazmat.primitives.serialization import Encoding, PublicFormat
|
||||
# v = Vapid(); v.generate_keys()
|
||||
# raw_priv = v.private_key.private_numbers().private_value.to_bytes(32, 'big')
|
||||
# print('VAPID_PRIVATE_KEY=' + base64.urlsafe_b64encode(raw_priv).rstrip(b'=').decode())
|
||||
# raw_pub = v.public_key.public_bytes(Encoding.X962, PublicFormat.UncompressedPoint)
|
||||
# print('VAPID_PUBLIC_KEY=' + base64.urlsafe_b64encode(raw_pub).rstrip(b'=').decode())
|
||||
# "
|
||||
# Dev-only keypair below — DO NOT use in staging/production. Regenerate
|
||||
# your own with the snippet above before any non-local deployment.
|
||||
VAPID_PRIVATE_KEY=17hBPdSdn6TR_yAgQxA0TjTcvRj3Lf6znHnASZ4rOKc
|
||||
VAPID_PUBLIC_KEY=BBg49iVTWthVbRYphwmZNvZyiSJDqtSO4nmLxDzLKe3Oo9jbtu0Usa14xX4HQQNLUeiEfzD42zWSlrvY1PR12bs
|
||||
# Per RFC 8292 push services use this in 410 Gone reports; set to a real
|
||||
# mailbox in production. Defaults to a placeholder for local dev.
|
||||
VAPID_CLAIM_EMAIL=mailto:dev@example.com
|
||||
|
||||
## ===== IMPORTANT OPTIONAL CONFIGURATION ===== ##
|
||||
# Platform URLs (set these for webhooks and OAuth to work)
|
||||
PLATFORM_BASE_URL=http://localhost:8000
|
||||
@@ -182,6 +199,10 @@ GOOGLE_MAPS_API_KEY=
|
||||
# Platform Bot Linking
|
||||
PLATFORM_LINK_BASE_URL=http://localhost:3000/link
|
||||
|
||||
# CoPilot chat-platform bridge (Discord/Telegram/Slack)
|
||||
# Uses FRONTEND_BASE_URL (above) for link confirmation pages.
|
||||
AUTOPILOT_BOT_DISCORD_TOKEN=
|
||||
|
||||
# Communication Services
|
||||
DISCORD_BOT_TOKEN=
|
||||
MEDIUM_API_KEY=
|
||||
|
||||
@@ -1,14 +1,44 @@
|
||||
import asyncio
|
||||
from typing import Dict, Set
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import Awaitable, Callable, Dict, Optional, Set
|
||||
|
||||
from fastapi import WebSocket
|
||||
from fastapi import WebSocket, WebSocketDisconnect
|
||||
from redis.asyncio import Redis as AsyncRedis
|
||||
from redis.asyncio.client import PubSub as AsyncPubSub
|
||||
from redis.exceptions import MovedError, RedisError, ResponseError
|
||||
from starlette.websockets import WebSocketState
|
||||
|
||||
from backend.api.model import NotificationPayload, WSMessage, WSMethod
|
||||
from backend.api.model import WSMessage, WSMethod
|
||||
from backend.data import redis_client as redis
|
||||
from backend.data.event_bus import _assert_no_wildcard
|
||||
from backend.data.execution import (
|
||||
ExecutionEventType,
|
||||
GraphExecutionEvent,
|
||||
NodeExecutionEvent,
|
||||
exec_channel,
|
||||
get_graph_execution_meta,
|
||||
graph_all_channel,
|
||||
)
|
||||
from backend.data.notification_bus import NotificationEvent
|
||||
from backend.util.settings import Settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
_settings = Settings()
|
||||
|
||||
|
||||
def _is_ws_close_race(exc: BaseException, websocket: WebSocket) -> bool:
|
||||
"""A SPUBLISH→WS send racing with WS close — benign, drop quietly."""
|
||||
if isinstance(exc, WebSocketDisconnect):
|
||||
return True
|
||||
if (
|
||||
getattr(websocket, "application_state", None) == WebSocketState.DISCONNECTED
|
||||
or getattr(websocket, "client_state", None) == WebSocketState.DISCONNECTED
|
||||
):
|
||||
return True
|
||||
if isinstance(exc, RuntimeError) and "close message has been sent" in str(exc):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
_EVENT_TYPE_TO_METHOD_MAP: dict[ExecutionEventType, WSMethod] = {
|
||||
ExecutionEventType.GRAPH_EXEC_UPDATE: WSMethod.GRAPH_EXECUTION_EVENT,
|
||||
@@ -16,128 +46,379 @@ _EVENT_TYPE_TO_METHOD_MAP: dict[ExecutionEventType, WSMethod] = {
|
||||
}
|
||||
|
||||
|
||||
def event_bus_channel(channel_key: str) -> str:
|
||||
"""Prefix a channel key with the execution event bus name."""
|
||||
return f"{_settings.config.execution_event_bus_name}/{channel_key}"
|
||||
|
||||
|
||||
def _notification_bus_channel(user_id: str) -> str:
|
||||
"""Return the full sharded channel name for a user's notifications."""
|
||||
return f"{_settings.config.notification_event_bus_name}/{user_id}"
|
||||
|
||||
|
||||
MessageHandler = Callable[[Optional[bytes | str]], Awaitable[None]]
|
||||
|
||||
|
||||
def _is_moved_error(exc: BaseException) -> bool:
|
||||
"""A MOVED redirect — slot migration mid-stream; pump should reconnect."""
|
||||
if isinstance(exc, MovedError):
|
||||
return True
|
||||
if isinstance(exc, ResponseError) and str(exc).startswith("MOVED "):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
# Reconnect tunables for shard-failover during pubsub.listen().
|
||||
_PUMP_RECONNECT_DEADLINE_S = 60.0
|
||||
_PUMP_RECONNECT_BACKOFF_INITIAL_S = 0.5
|
||||
_PUMP_RECONNECT_BACKOFF_MAX_S = 8.0
|
||||
|
||||
|
||||
class _Subscription:
|
||||
"""One SSUBSCRIBE lifecycle bound to a WebSocket, pinned to the owning shard."""
|
||||
|
||||
def __init__(self, full_channel: str) -> None:
|
||||
_assert_no_wildcard(full_channel)
|
||||
self.full_channel = full_channel
|
||||
self._client: AsyncRedis | None = None
|
||||
self._pubsub: AsyncPubSub | None = None
|
||||
self._task: asyncio.Task | None = None
|
||||
|
||||
async def start(self, on_message: MessageHandler) -> None:
|
||||
await self._open_pubsub()
|
||||
self._task = asyncio.create_task(self._pump(on_message))
|
||||
|
||||
async def _open_pubsub(self) -> None:
|
||||
"""(Re)establish the sharded pubsub connection + SSUBSCRIBE."""
|
||||
self._client = await redis.connect_sharded_pubsub_async(self.full_channel)
|
||||
self._pubsub = self._client.pubsub()
|
||||
await self._pubsub.execute_command("SSUBSCRIBE", self.full_channel)
|
||||
# redis-py 6.x async PubSub.listen() exits when ``channels`` is
|
||||
# empty; raw SSUBSCRIBE doesn't populate it, so do it ourselves.
|
||||
self._pubsub.channels[self.full_channel] = None # type: ignore[index]
|
||||
|
||||
async def _close_pubsub_quietly(self) -> None:
|
||||
"""Best-effort teardown before reconnect — never raises."""
|
||||
if self._pubsub is not None:
|
||||
try:
|
||||
await self._pubsub.aclose()
|
||||
except Exception:
|
||||
pass
|
||||
self._pubsub = None
|
||||
if self._client is not None:
|
||||
try:
|
||||
await self._client.aclose()
|
||||
except Exception:
|
||||
pass
|
||||
self._client = None
|
||||
|
||||
async def _pump(self, on_message: MessageHandler) -> None:
|
||||
if self._pubsub is None:
|
||||
return
|
||||
backoff = _PUMP_RECONNECT_BACKOFF_INITIAL_S
|
||||
deadline = time.monotonic() + _PUMP_RECONNECT_DEADLINE_S
|
||||
while True:
|
||||
pubsub = self._pubsub
|
||||
if pubsub is None:
|
||||
return
|
||||
needs_reconnect = False
|
||||
try:
|
||||
async for message in pubsub.listen():
|
||||
msg_type = message.get("type")
|
||||
# Server-pushed sunsubscribe: slot ownership changed and
|
||||
# Redis revoked our SSUBSCRIBE without dropping the TCP.
|
||||
# Treat as a reconnect trigger so we re-resolve the shard.
|
||||
if msg_type == "sunsubscribe":
|
||||
needs_reconnect = True
|
||||
break
|
||||
if msg_type not in ("smessage", "message", "pmessage"):
|
||||
continue
|
||||
# Successful read resets the reconnect budget.
|
||||
backoff = _PUMP_RECONNECT_BACKOFF_INITIAL_S
|
||||
deadline = time.monotonic() + _PUMP_RECONNECT_DEADLINE_S
|
||||
try:
|
||||
await on_message(message.get("data"))
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Websocket message-handler failed for channel %s",
|
||||
self.full_channel,
|
||||
)
|
||||
if not needs_reconnect:
|
||||
# listen() exited cleanly (channels emptied) — pump is done.
|
||||
return
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except (ConnectionError, RedisError) as exc:
|
||||
if isinstance(exc, ResponseError) and not _is_moved_error(exc):
|
||||
logger.exception(
|
||||
"Pubsub pump crashed on non-retryable ResponseError for %s",
|
||||
self.full_channel,
|
||||
)
|
||||
return
|
||||
if time.monotonic() > deadline:
|
||||
logger.exception(
|
||||
"Pubsub pump giving up after reconnect deadline for %s",
|
||||
self.full_channel,
|
||||
)
|
||||
return
|
||||
logger.warning(
|
||||
"Pubsub pump reconnecting for %s after %s: %s",
|
||||
self.full_channel,
|
||||
type(exc).__name__,
|
||||
exc,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Pubsub pump crashed for %s", self.full_channel)
|
||||
return
|
||||
|
||||
# Either a retryable error was raised, or the server pushed a
|
||||
# sunsubscribe — close the stale pubsub and reopen against the
|
||||
# (possibly migrated) shard.
|
||||
await self._close_pubsub_quietly()
|
||||
await asyncio.sleep(backoff)
|
||||
backoff = min(backoff * 2, _PUMP_RECONNECT_BACKOFF_MAX_S)
|
||||
try:
|
||||
await self._open_pubsub()
|
||||
except (ConnectionError, RedisError) as reopen_exc:
|
||||
logger.warning(
|
||||
"Pubsub pump reopen failed for %s: %s",
|
||||
self.full_channel,
|
||||
reopen_exc,
|
||||
)
|
||||
# Loop again — deadline check will eventually exit.
|
||||
continue
|
||||
|
||||
async def stop(self) -> None:
|
||||
if self._task is not None:
|
||||
self._task.cancel()
|
||||
try:
|
||||
await self._task
|
||||
except (asyncio.CancelledError, Exception):
|
||||
pass
|
||||
self._task = None
|
||||
if self._pubsub is not None:
|
||||
try:
|
||||
await self._pubsub.execute_command("SUNSUBSCRIBE", self.full_channel)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"SUNSUBSCRIBE failed for %s", self.full_channel, exc_info=True
|
||||
)
|
||||
try:
|
||||
await self._pubsub.aclose()
|
||||
except Exception:
|
||||
pass
|
||||
self._pubsub = None
|
||||
if self._client is not None:
|
||||
try:
|
||||
await self._client.aclose()
|
||||
except Exception:
|
||||
pass
|
||||
self._client = None
|
||||
|
||||
|
||||
class ConnectionManager:
|
||||
def __init__(self):
|
||||
self.active_connections: Set[WebSocket] = set()
|
||||
# channel_key → sockets subscribed (public channel keys, not raw Redis channels)
|
||||
self.subscriptions: Dict[str, Set[WebSocket]] = {}
|
||||
self.user_connections: Dict[str, Set[WebSocket]] = {}
|
||||
# websocket → {channel_key: _Subscription}
|
||||
self._ws_subs: Dict[WebSocket, Dict[str, _Subscription]] = {}
|
||||
# websocket → notification subscription
|
||||
self._ws_notifications: Dict[WebSocket, _Subscription] = {}
|
||||
|
||||
async def connect_socket(self, websocket: WebSocket, *, user_id: str):
|
||||
await websocket.accept()
|
||||
self.active_connections.add(websocket)
|
||||
if user_id not in self.user_connections:
|
||||
self.user_connections[user_id] = set()
|
||||
self.user_connections[user_id].add(websocket)
|
||||
self._ws_subs.setdefault(websocket, {})
|
||||
await self._start_notification_subscription(websocket, user_id=user_id)
|
||||
|
||||
def disconnect_socket(self, websocket: WebSocket, *, user_id: str):
|
||||
async def disconnect_socket(self, websocket: WebSocket, *, user_id: str):
|
||||
self.active_connections.discard(websocket)
|
||||
for subscribers in self.subscriptions.values():
|
||||
# Stop SSUBSCRIBE pumps before dropping bookkeeping to avoid leaks.
|
||||
subs = self._ws_subs.pop(websocket, {})
|
||||
for sub in subs.values():
|
||||
await sub.stop()
|
||||
notif_sub = self._ws_notifications.pop(websocket, None)
|
||||
if notif_sub is not None:
|
||||
await notif_sub.stop()
|
||||
for channel_key, subscribers in list(self.subscriptions.items()):
|
||||
subscribers.discard(websocket)
|
||||
user_conns = self.user_connections.get(user_id)
|
||||
if user_conns is not None:
|
||||
user_conns.discard(websocket)
|
||||
if not user_conns:
|
||||
self.user_connections.pop(user_id, None)
|
||||
if not subscribers:
|
||||
self.subscriptions.pop(channel_key, None)
|
||||
|
||||
async def subscribe_graph_exec(
|
||||
self, *, user_id: str, graph_exec_id: str, websocket: WebSocket
|
||||
) -> str:
|
||||
return await self._subscribe(
|
||||
_graph_exec_channel_key(user_id, graph_exec_id=graph_exec_id), websocket
|
||||
# Hash-tagged channel needs graph_id; resolve once per subscribe.
|
||||
meta = await get_graph_execution_meta(user_id, graph_exec_id)
|
||||
if meta is None:
|
||||
raise ValueError(
|
||||
f"graph_exec #{graph_exec_id} not found for user #{user_id}"
|
||||
)
|
||||
channel_key = graph_exec_channel_key(user_id, graph_exec_id=graph_exec_id)
|
||||
full_channel = event_bus_channel(
|
||||
exec_channel(user_id, meta.graph_id, graph_exec_id)
|
||||
)
|
||||
await self._open_subscription(websocket, channel_key, full_channel)
|
||||
return channel_key
|
||||
|
||||
async def subscribe_graph_execs(
|
||||
self, *, user_id: str, graph_id: str, websocket: WebSocket
|
||||
) -> str:
|
||||
return await self._subscribe(
|
||||
_graph_execs_channel_key(user_id, graph_id=graph_id), websocket
|
||||
)
|
||||
channel_key = _graph_execs_channel_key(user_id, graph_id=graph_id)
|
||||
full_channel = event_bus_channel(graph_all_channel(user_id, graph_id))
|
||||
await self._open_subscription(websocket, channel_key, full_channel)
|
||||
return channel_key
|
||||
|
||||
async def unsubscribe_graph_exec(
|
||||
self, *, user_id: str, graph_exec_id: str, websocket: WebSocket
|
||||
) -> str | None:
|
||||
return await self._unsubscribe(
|
||||
_graph_exec_channel_key(user_id, graph_exec_id=graph_exec_id), websocket
|
||||
)
|
||||
channel_key = graph_exec_channel_key(user_id, graph_exec_id=graph_exec_id)
|
||||
return await self._close_subscription(websocket, channel_key)
|
||||
|
||||
async def unsubscribe_graph_execs(
|
||||
self, *, user_id: str, graph_id: str, websocket: WebSocket
|
||||
) -> str | None:
|
||||
return await self._unsubscribe(
|
||||
_graph_execs_channel_key(user_id, graph_id=graph_id), websocket
|
||||
)
|
||||
channel_key = _graph_execs_channel_key(user_id, graph_id=graph_id)
|
||||
return await self._close_subscription(websocket, channel_key)
|
||||
|
||||
async def send_execution_update(
|
||||
self, exec_event: GraphExecutionEvent | NodeExecutionEvent
|
||||
) -> int:
|
||||
graph_exec_id = (
|
||||
exec_event.id
|
||||
if isinstance(exec_event, GraphExecutionEvent)
|
||||
else exec_event.graph_exec_id
|
||||
)
|
||||
async def _open_subscription(
|
||||
self, websocket: WebSocket, channel_key: str, full_channel: str
|
||||
) -> None:
|
||||
self.subscriptions.setdefault(channel_key, set()).add(websocket)
|
||||
per_ws = self._ws_subs.setdefault(websocket, {})
|
||||
if channel_key in per_ws:
|
||||
return
|
||||
sub = _Subscription(full_channel)
|
||||
|
||||
n_sent = 0
|
||||
async def on_message(data: Optional[bytes | str]) -> None:
|
||||
await self._forward_exec_event(websocket, channel_key, data)
|
||||
|
||||
channels: set[str] = {
|
||||
# Send update to listeners for this graph execution
|
||||
_graph_exec_channel_key(exec_event.user_id, graph_exec_id=graph_exec_id)
|
||||
}
|
||||
if isinstance(exec_event, GraphExecutionEvent):
|
||||
# Send update to listeners for all executions of this graph
|
||||
channels.add(
|
||||
_graph_execs_channel_key(
|
||||
exec_event.user_id, graph_id=exec_event.graph_id
|
||||
)
|
||||
)
|
||||
await sub.start(on_message)
|
||||
per_ws[channel_key] = sub
|
||||
|
||||
for channel in channels.intersection(self.subscriptions.keys()):
|
||||
message = WSMessage(
|
||||
method=_EVENT_TYPE_TO_METHOD_MAP[exec_event.event_type],
|
||||
channel=channel,
|
||||
data=exec_event.model_dump(),
|
||||
).model_dump_json()
|
||||
for connection in self.subscriptions[channel]:
|
||||
await connection.send_text(message)
|
||||
n_sent += 1
|
||||
|
||||
return n_sent
|
||||
|
||||
async def send_notification(
|
||||
self, *, user_id: str, payload: NotificationPayload
|
||||
) -> int:
|
||||
"""Send a notification to all websocket connections belonging to a user."""
|
||||
message = WSMessage(
|
||||
method=WSMethod.NOTIFICATION,
|
||||
data=payload.model_dump(),
|
||||
).model_dump_json()
|
||||
|
||||
connections = tuple(self.user_connections.get(user_id, set()))
|
||||
if not connections:
|
||||
return 0
|
||||
|
||||
await asyncio.gather(
|
||||
*(connection.send_text(message) for connection in connections),
|
||||
return_exceptions=True,
|
||||
)
|
||||
|
||||
return len(connections)
|
||||
|
||||
async def _subscribe(self, channel_key: str, websocket: WebSocket) -> str:
|
||||
if channel_key not in self.subscriptions:
|
||||
self.subscriptions[channel_key] = set()
|
||||
self.subscriptions[channel_key].add(websocket)
|
||||
async def _close_subscription(
|
||||
self, websocket: WebSocket, channel_key: str
|
||||
) -> str | None:
|
||||
subscribers = self.subscriptions.get(channel_key)
|
||||
if subscribers is None:
|
||||
return None
|
||||
subscribers.discard(websocket)
|
||||
if not subscribers:
|
||||
self.subscriptions.pop(channel_key, None)
|
||||
per_ws = self._ws_subs.get(websocket)
|
||||
if per_ws and channel_key in per_ws:
|
||||
sub = per_ws.pop(channel_key)
|
||||
await sub.stop()
|
||||
return channel_key
|
||||
|
||||
async def _unsubscribe(self, channel_key: str, websocket: WebSocket) -> str | None:
|
||||
if channel_key in self.subscriptions:
|
||||
self.subscriptions[channel_key].discard(websocket)
|
||||
if not self.subscriptions[channel_key]:
|
||||
del self.subscriptions[channel_key]
|
||||
return channel_key
|
||||
return None
|
||||
async def _forward_exec_event(
|
||||
self,
|
||||
websocket: WebSocket,
|
||||
channel_key: str,
|
||||
raw_payload: Optional[bytes | str],
|
||||
) -> None:
|
||||
if raw_payload is None:
|
||||
return
|
||||
# Unwrap the `_EventPayloadWrapper` envelope, then re-wrap as a WS message.
|
||||
try:
|
||||
wrapper = (
|
||||
raw_payload.decode()
|
||||
if isinstance(raw_payload, (bytes, bytearray))
|
||||
else raw_payload
|
||||
)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to decode pubsub payload on %s", channel_key, exc_info=True
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
parsed = json.loads(wrapper)
|
||||
event_data = parsed.get("payload")
|
||||
if not isinstance(event_data, dict):
|
||||
return
|
||||
event_type = event_data.get("event_type")
|
||||
method = _EVENT_TYPE_TO_METHOD_MAP.get(ExecutionEventType(event_type))
|
||||
if method is None:
|
||||
return
|
||||
message = WSMessage(
|
||||
method=method,
|
||||
channel=channel_key,
|
||||
data=event_data,
|
||||
).model_dump_json()
|
||||
await websocket.send_text(message)
|
||||
except Exception as e:
|
||||
if _is_ws_close_race(e, websocket):
|
||||
logger.debug("Dropped exec event on closed WS for %s", channel_key)
|
||||
return
|
||||
logger.exception("Failed to forward exec event on %s", channel_key)
|
||||
|
||||
async def _start_notification_subscription(
|
||||
self, websocket: WebSocket, *, user_id: str
|
||||
) -> None:
|
||||
full_channel = _notification_bus_channel(user_id)
|
||||
sub = _Subscription(full_channel)
|
||||
|
||||
async def on_message(data: Optional[bytes | str]) -> None:
|
||||
await self._forward_notification(websocket, user_id, data)
|
||||
|
||||
try:
|
||||
await sub.start(on_message)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to open notification SSUBSCRIBE for user=%s", user_id
|
||||
)
|
||||
return
|
||||
self._ws_notifications[websocket] = sub
|
||||
|
||||
async def _forward_notification(
|
||||
self,
|
||||
websocket: WebSocket,
|
||||
user_id: str,
|
||||
raw_payload: Optional[bytes | str],
|
||||
) -> None:
|
||||
if raw_payload is None:
|
||||
return
|
||||
try:
|
||||
wrapper_json = (
|
||||
raw_payload.decode()
|
||||
if isinstance(raw_payload, (bytes, bytearray))
|
||||
else raw_payload
|
||||
)
|
||||
parsed = json.loads(wrapper_json)
|
||||
inner = parsed.get("payload") if isinstance(parsed, dict) else None
|
||||
if not isinstance(inner, dict):
|
||||
return
|
||||
event = NotificationEvent.model_validate(inner)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to parse notification payload for user=%s",
|
||||
user_id,
|
||||
exc_info=True,
|
||||
)
|
||||
return
|
||||
# Defense in depth against cross-user payloads.
|
||||
if event.user_id != user_id:
|
||||
return
|
||||
message = WSMessage(
|
||||
method=WSMethod.NOTIFICATION,
|
||||
data=event.payload.model_dump(),
|
||||
).model_dump_json()
|
||||
try:
|
||||
await websocket.send_text(message)
|
||||
except Exception as e:
|
||||
if _is_ws_close_race(e, websocket):
|
||||
logger.debug("Dropped notification on closed WS for user=%s", user_id)
|
||||
return
|
||||
logger.warning(
|
||||
"Failed to deliver notification to WS for user=%s",
|
||||
user_id,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
|
||||
def _graph_exec_channel_key(user_id: str, *, graph_exec_id: str) -> str:
|
||||
def graph_exec_channel_key(user_id: str, *, graph_exec_id: str) -> str:
|
||||
return f"{user_id}|graph_exec#{graph_exec_id}"
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,386 @@
|
||||
"""ConnectionManager integration over the live 3-shard Redis cluster:
|
||||
SSUBSCRIBE → SPUBLISH → WebSocket forwarding with no Redis mocks. Skips
|
||||
when the cluster is unreachable."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from fastapi import WebSocket
|
||||
|
||||
import backend.data.redis_client as redis_client
|
||||
from backend.api.conn_manager import (
|
||||
ConnectionManager,
|
||||
_graph_execs_channel_key,
|
||||
event_bus_channel,
|
||||
graph_exec_channel_key,
|
||||
)
|
||||
from backend.api.model import WSMethod
|
||||
from backend.data.execution import (
|
||||
ExecutionStatus,
|
||||
GraphExecutionEvent,
|
||||
GraphExecutionMeta,
|
||||
NodeExecutionEvent,
|
||||
exec_channel,
|
||||
graph_all_channel,
|
||||
)
|
||||
|
||||
|
||||
def _has_live_cluster() -> bool:
|
||||
try:
|
||||
c = redis_client.connect()
|
||||
except Exception: # noqa: BLE001 — any connect failure → skip
|
||||
return False
|
||||
try:
|
||||
c.close()
|
||||
except Exception:
|
||||
pass
|
||||
return True
|
||||
|
||||
|
||||
pytestmark = pytest.mark.skipif(
|
||||
not _has_live_cluster(),
|
||||
reason="local redis cluster not reachable; skip conn_manager integration",
|
||||
)
|
||||
|
||||
|
||||
def _meta(user_id: str, graph_id: str, graph_exec_id: str) -> GraphExecutionMeta:
|
||||
"""Build a minimal GraphExecutionMeta for ``subscribe_graph_exec`` to use."""
|
||||
return GraphExecutionMeta(
|
||||
id=graph_exec_id,
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
graph_version=1,
|
||||
inputs=None,
|
||||
credential_inputs=None,
|
||||
nodes_input_masks=None,
|
||||
preset_id=None,
|
||||
status=ExecutionStatus.RUNNING,
|
||||
started_at=datetime.now(tz=timezone.utc),
|
||||
ended_at=None,
|
||||
stats=GraphExecutionMeta.Stats(),
|
||||
)
|
||||
|
||||
|
||||
def _node_event_payload(
|
||||
*, user_id: str, graph_id: str, graph_exec_id: str, marker: str
|
||||
) -> bytes:
|
||||
"""Wire-format a NodeExecutionEvent the way RedisExecutionEventBus would."""
|
||||
inner = NodeExecutionEvent(
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
graph_version=1,
|
||||
graph_exec_id=graph_exec_id,
|
||||
node_exec_id=f"node-exec-{marker}",
|
||||
node_id="node-1",
|
||||
block_id="block-1",
|
||||
status=ExecutionStatus.COMPLETED,
|
||||
input_data={"in": marker},
|
||||
output_data={"out": [marker]},
|
||||
add_time=datetime.now(tz=timezone.utc),
|
||||
queue_time=None,
|
||||
start_time=datetime.now(tz=timezone.utc),
|
||||
end_time=datetime.now(tz=timezone.utc),
|
||||
).model_dump(mode="json")
|
||||
return json.dumps({"payload": inner}).encode()
|
||||
|
||||
|
||||
def _graph_event_payload(
|
||||
*, user_id: str, graph_id: str, graph_exec_id: str, marker: str
|
||||
) -> bytes:
|
||||
inner = GraphExecutionEvent(
|
||||
id=graph_exec_id,
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
graph_version=1,
|
||||
preset_id=None,
|
||||
status=ExecutionStatus.COMPLETED,
|
||||
started_at=datetime.now(tz=timezone.utc),
|
||||
ended_at=datetime.now(tz=timezone.utc),
|
||||
stats=GraphExecutionEvent.Stats(
|
||||
cost=0,
|
||||
duration=1.0,
|
||||
node_exec_time=0.5,
|
||||
node_exec_count=1,
|
||||
),
|
||||
inputs={"x": marker},
|
||||
credential_inputs=None,
|
||||
nodes_input_masks=None,
|
||||
outputs={"y": [marker]},
|
||||
).model_dump(mode="json")
|
||||
return json.dumps({"payload": inner}).encode()
|
||||
|
||||
|
||||
async def _wait_until(predicate, timeout: float = 5.0, interval: float = 0.05) -> bool:
|
||||
"""Poll ``predicate()`` until truthy or timeout — used to wait for pubsub."""
|
||||
deadline = asyncio.get_event_loop().time() + timeout
|
||||
while asyncio.get_event_loop().time() < deadline:
|
||||
if predicate():
|
||||
return True
|
||||
await asyncio.sleep(interval)
|
||||
return False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_two_clients_get_independent_ssubscribes_on_right_shards(
|
||||
monkeypatch,
|
||||
) -> None:
|
||||
"""Two WS clients on different graph_exec_ids each receive ONLY their
|
||||
own publish, even when the channels land on different shards."""
|
||||
user_id = "user-conn-int-1"
|
||||
graph_a = f"graph-a-{uuid4().hex[:8]}"
|
||||
graph_b = f"graph-b-{uuid4().hex[:8]}"
|
||||
exec_a = f"exec-a-{uuid4().hex[:8]}"
|
||||
exec_b = f"exec-b-{uuid4().hex[:8]}"
|
||||
|
||||
# Stub Prisma lookup so tests don't need a DB.
|
||||
async def _fake_meta(_uid, gex_id):
|
||||
return _meta(user_id, graph_a if gex_id == exec_a else graph_b, gex_id)
|
||||
|
||||
monkeypatch.setattr("backend.api.conn_manager.get_graph_execution_meta", _fake_meta)
|
||||
|
||||
cm = ConnectionManager()
|
||||
ws_a: AsyncMock = AsyncMock(spec=WebSocket)
|
||||
ws_b: AsyncMock = AsyncMock(spec=WebSocket)
|
||||
sent_a: list[str] = []
|
||||
sent_b: list[str] = []
|
||||
ws_a.send_text = AsyncMock(side_effect=lambda m: sent_a.append(m))
|
||||
ws_b.send_text = AsyncMock(side_effect=lambda m: sent_b.append(m))
|
||||
|
||||
redis_client.get_redis.cache_clear()
|
||||
cluster = redis_client.get_redis()
|
||||
|
||||
try:
|
||||
await cm.subscribe_graph_exec(
|
||||
user_id=user_id, graph_exec_id=exec_a, websocket=ws_a
|
||||
)
|
||||
await cm.subscribe_graph_exec(
|
||||
user_id=user_id, graph_exec_id=exec_b, websocket=ws_b
|
||||
)
|
||||
# Let SSUBSCRIBE settle on each shard.
|
||||
await asyncio.sleep(0.2)
|
||||
|
||||
# Publish to each per-exec channel.
|
||||
chan_a = event_bus_channel(exec_channel(user_id, graph_a, exec_a))
|
||||
chan_b = event_bus_channel(exec_channel(user_id, graph_b, exec_b))
|
||||
cluster.spublish(
|
||||
chan_a,
|
||||
_node_event_payload(
|
||||
user_id=user_id,
|
||||
graph_id=graph_a,
|
||||
graph_exec_id=exec_a,
|
||||
marker="A",
|
||||
).decode(),
|
||||
)
|
||||
cluster.spublish(
|
||||
chan_b,
|
||||
_node_event_payload(
|
||||
user_id=user_id,
|
||||
graph_id=graph_b,
|
||||
graph_exec_id=exec_b,
|
||||
marker="B",
|
||||
).decode(),
|
||||
)
|
||||
|
||||
delivered = await _wait_until(lambda: sent_a and sent_b, timeout=5.0)
|
||||
assert delivered, f"timeout: sent_a={sent_a!r} sent_b={sent_b!r}"
|
||||
|
||||
msg_a = json.loads(sent_a[0])
|
||||
msg_b = json.loads(sent_b[0])
|
||||
assert msg_a["channel"] == graph_exec_channel_key(user_id, graph_exec_id=exec_a)
|
||||
assert msg_b["channel"] == graph_exec_channel_key(user_id, graph_exec_id=exec_b)
|
||||
assert msg_a["data"]["graph_exec_id"] == exec_a
|
||||
assert msg_b["data"]["graph_exec_id"] == exec_b
|
||||
# No cross-talk: each socket got exactly one message.
|
||||
assert len(sent_a) == 1 and len(sent_b) == 1
|
||||
finally:
|
||||
await cm.disconnect_socket(ws_a, user_id=user_id)
|
||||
await cm.disconnect_socket(ws_b, user_id=user_id)
|
||||
redis_client.disconnect()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aggregate_channel_receives_per_exec_publishes(monkeypatch) -> None:
|
||||
"""A subscriber on the ``graph_execs`` aggregate channel must receive the
|
||||
GraphExecutionEvent published to the ``/all`` channel — even though
|
||||
per-exec events go to a different channel."""
|
||||
user_id = "user-conn-int-2"
|
||||
graph_id = f"graph-{uuid4().hex[:8]}"
|
||||
exec_id = f"exec-{uuid4().hex[:8]}"
|
||||
|
||||
async def _fake_meta(_uid, gex_id):
|
||||
return _meta(user_id, graph_id, gex_id)
|
||||
|
||||
monkeypatch.setattr("backend.api.conn_manager.get_graph_execution_meta", _fake_meta)
|
||||
|
||||
cm = ConnectionManager()
|
||||
ws_agg: AsyncMock = AsyncMock(spec=WebSocket)
|
||||
ws_per: AsyncMock = AsyncMock(spec=WebSocket)
|
||||
sent_agg: list[str] = []
|
||||
sent_per: list[str] = []
|
||||
ws_agg.send_text = AsyncMock(side_effect=lambda m: sent_agg.append(m))
|
||||
ws_per.send_text = AsyncMock(side_effect=lambda m: sent_per.append(m))
|
||||
|
||||
redis_client.get_redis.cache_clear()
|
||||
cluster = redis_client.get_redis()
|
||||
|
||||
try:
|
||||
await cm.subscribe_graph_execs(
|
||||
user_id=user_id, graph_id=graph_id, websocket=ws_agg
|
||||
)
|
||||
await cm.subscribe_graph_exec(
|
||||
user_id=user_id, graph_exec_id=exec_id, websocket=ws_per
|
||||
)
|
||||
await asyncio.sleep(0.2)
|
||||
|
||||
# The eventbus publishes the same event to both channels — replicate.
|
||||
chan_per = event_bus_channel(exec_channel(user_id, graph_id, exec_id))
|
||||
chan_all = event_bus_channel(graph_all_channel(user_id, graph_id))
|
||||
payload = _graph_event_payload(
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
graph_exec_id=exec_id,
|
||||
marker="agg",
|
||||
).decode()
|
||||
cluster.spublish(chan_per, payload)
|
||||
cluster.spublish(chan_all, payload)
|
||||
|
||||
delivered = await _wait_until(lambda: sent_agg and sent_per, timeout=5.0)
|
||||
assert delivered, f"sent_agg={sent_agg!r} sent_per={sent_per!r}"
|
||||
agg_msg = json.loads(sent_agg[0])
|
||||
per_msg = json.loads(sent_per[0])
|
||||
# Aggregate subscriber's channel key is the per-graph executions key.
|
||||
assert agg_msg["channel"] == _graph_execs_channel_key(
|
||||
user_id, graph_id=graph_id
|
||||
)
|
||||
assert per_msg["channel"] == graph_exec_channel_key(
|
||||
user_id, graph_exec_id=exec_id
|
||||
)
|
||||
assert agg_msg["method"] == WSMethod.GRAPH_EXECUTION_EVENT.value
|
||||
finally:
|
||||
await cm.disconnect_socket(ws_agg, user_id=user_id)
|
||||
await cm.disconnect_socket(ws_per, user_id=user_id)
|
||||
redis_client.disconnect()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_disconnect_unsubscribes_and_drops_future_publishes(monkeypatch) -> None:
|
||||
"""After ``disconnect_socket`` runs, a subsequent SPUBLISH must NOT reach
|
||||
the dead websocket — exercises the SUNSUBSCRIBE + bookkeeping cleanup."""
|
||||
user_id = "user-conn-int-3"
|
||||
graph_id = f"graph-{uuid4().hex[:8]}"
|
||||
exec_id = f"exec-{uuid4().hex[:8]}"
|
||||
|
||||
async def _fake_meta(_uid, gex_id):
|
||||
return _meta(user_id, graph_id, gex_id)
|
||||
|
||||
monkeypatch.setattr("backend.api.conn_manager.get_graph_execution_meta", _fake_meta)
|
||||
|
||||
cm = ConnectionManager()
|
||||
ws: AsyncMock = AsyncMock(spec=WebSocket)
|
||||
sent: list[str] = []
|
||||
ws.send_text = AsyncMock(side_effect=lambda m: sent.append(m))
|
||||
|
||||
redis_client.get_redis.cache_clear()
|
||||
cluster = redis_client.get_redis()
|
||||
chan = event_bus_channel(exec_channel(user_id, graph_id, exec_id))
|
||||
payload = _node_event_payload(
|
||||
user_id=user_id, graph_id=graph_id, graph_exec_id=exec_id, marker="live"
|
||||
).decode()
|
||||
|
||||
try:
|
||||
await cm.subscribe_graph_exec(
|
||||
user_id=user_id, graph_exec_id=exec_id, websocket=ws
|
||||
)
|
||||
await asyncio.sleep(0.15)
|
||||
|
||||
# First publish — must reach the socket.
|
||||
cluster.spublish(chan, payload)
|
||||
delivered = await _wait_until(lambda: bool(sent), timeout=5.0)
|
||||
assert delivered
|
||||
assert len(sent) == 1
|
||||
|
||||
# Disconnect → SUNSUBSCRIBE + bookkeeping cleared.
|
||||
await cm.disconnect_socket(ws, user_id=user_id)
|
||||
# Pump cancellation may drain in-flight messages; wait for it.
|
||||
await asyncio.sleep(0.2)
|
||||
|
||||
# Channel bookkeeping must be gone.
|
||||
assert (
|
||||
graph_exec_channel_key(user_id, graph_exec_id=exec_id)
|
||||
not in cm.subscriptions
|
||||
)
|
||||
assert ws not in cm._ws_subs
|
||||
|
||||
# Second publish — must NOT reach the (already-disconnected) socket.
|
||||
cluster.spublish(
|
||||
chan,
|
||||
_node_event_payload(
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
graph_exec_id=exec_id,
|
||||
marker="post-disconnect",
|
||||
).decode(),
|
||||
)
|
||||
await asyncio.sleep(0.5)
|
||||
# Still only the one pre-disconnect message.
|
||||
assert len(sent) == 1
|
||||
finally:
|
||||
redis_client.disconnect()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_slow_consumer_receives_all_events_without_loss(monkeypatch) -> None:
|
||||
"""Burst-publish many SPUBLISHes; assert every one reaches the subscriber
|
||||
in order — guards against drops/reorderings in the pubsub pump."""
|
||||
user_id = "user-conn-int-4"
|
||||
graph_id = f"graph-{uuid4().hex[:8]}"
|
||||
exec_id = f"exec-{uuid4().hex[:8]}"
|
||||
n_events = 100
|
||||
|
||||
async def _fake_meta(_uid, gex_id):
|
||||
return _meta(user_id, graph_id, gex_id)
|
||||
|
||||
monkeypatch.setattr("backend.api.conn_manager.get_graph_execution_meta", _fake_meta)
|
||||
|
||||
cm = ConnectionManager()
|
||||
ws: AsyncMock = AsyncMock(spec=WebSocket)
|
||||
sent: list[str] = []
|
||||
ws.send_text = AsyncMock(side_effect=lambda m: sent.append(m))
|
||||
|
||||
redis_client.get_redis.cache_clear()
|
||||
cluster = redis_client.get_redis()
|
||||
chan = event_bus_channel(exec_channel(user_id, graph_id, exec_id))
|
||||
|
||||
try:
|
||||
await cm.subscribe_graph_exec(
|
||||
user_id=user_id, graph_exec_id=exec_id, websocket=ws
|
||||
)
|
||||
await asyncio.sleep(0.2)
|
||||
|
||||
# Burst-publish n_events without yielding to the pump.
|
||||
for i in range(n_events):
|
||||
cluster.spublish(
|
||||
chan,
|
||||
_node_event_payload(
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
graph_exec_id=exec_id,
|
||||
marker=f"m{i}",
|
||||
).decode(),
|
||||
)
|
||||
|
||||
delivered = await _wait_until(
|
||||
lambda: len(sent) >= n_events, timeout=15.0, interval=0.1
|
||||
)
|
||||
assert delivered, f"only delivered {len(sent)}/{n_events}"
|
||||
|
||||
# Validate ordering — Redis pub/sub is FIFO per channel.
|
||||
markers = [json.loads(m)["data"]["input_data"]["in"] for m in sent[:n_events]]
|
||||
assert markers == [f"m{i}" for i in range(n_events)]
|
||||
finally:
|
||||
await cm.disconnect_socket(ws, user_id=user_id)
|
||||
redis_client.disconnect()
|
||||
File diff suppressed because it is too large
Load Diff
@@ -8,7 +8,7 @@ from uuid import uuid4
|
||||
|
||||
from autogpt_libs import auth
|
||||
from fastapi import APIRouter, HTTPException, Query, Response, Security
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
from backend.copilot import service as chat_service
|
||||
@@ -47,7 +47,14 @@ from backend.copilot.rate_limit import (
|
||||
release_reset_lock,
|
||||
reset_daily_usage,
|
||||
)
|
||||
from backend.copilot.response_model import StreamError, StreamFinish, StreamHeartbeat
|
||||
from backend.copilot.response_model import (
|
||||
StreamError,
|
||||
StreamFinish,
|
||||
StreamFinishStep,
|
||||
StreamHeartbeat,
|
||||
StreamStart,
|
||||
StreamStartStep,
|
||||
)
|
||||
from backend.copilot.service import strip_injected_context_for_display
|
||||
from backend.copilot.tools.e2b_sandbox import kill_sandbox
|
||||
from backend.copilot.tools.models import (
|
||||
@@ -154,6 +161,14 @@ class StreamChatRequest(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class QueuePendingMessageRequest(BaseModel):
|
||||
"""Request model for queueing a follow-up while a turn is running."""
|
||||
|
||||
message: str = Field(max_length=64_000)
|
||||
context: dict[str, str] | None = None
|
||||
file_ids: list[str] | None = Field(default=None, max_length=20)
|
||||
|
||||
|
||||
class PeekPendingMessagesResponse(BaseModel):
|
||||
"""Response for the pending-message peek (GET) endpoint.
|
||||
|
||||
@@ -209,6 +224,11 @@ class ActiveStreamInfo(BaseModel):
|
||||
|
||||
turn_id: str
|
||||
last_message_id: str # Redis Stream message ID for resumption
|
||||
# ISO-8601 timestamp (UTC) marking when the backend registered the turn
|
||||
# as running. Lets the frontend seed its elapsed-time counter so restored
|
||||
# turns show honest "time since turn started" instead of the misleading
|
||||
# "time since this mount resumed the SSE".
|
||||
started_at: str | None = None
|
||||
|
||||
|
||||
class SessionDetailResponse(BaseModel):
|
||||
@@ -300,8 +320,11 @@ async def list_sessions(
|
||||
redis = await get_redis_async()
|
||||
pipe = redis.pipeline(transaction=False)
|
||||
for session in sessions:
|
||||
# Use the canonical helper so the hash-tag braces match every
|
||||
# other writer; building the key inline drops the braces and
|
||||
# silently misses every running session on cluster mode.
|
||||
pipe.hget(
|
||||
f"{config.session_meta_prefix}{session.session_id}",
|
||||
stream_registry.get_session_meta_key(session.session_id),
|
||||
"status",
|
||||
)
|
||||
statuses = await pipe.execute()
|
||||
@@ -529,6 +552,7 @@ async def get_session(
|
||||
active_stream_info = ActiveStreamInfo(
|
||||
turn_id=active_session.turn_id,
|
||||
last_message_id=last_message_id,
|
||||
started_at=active_session.created_at.isoformat(),
|
||||
)
|
||||
|
||||
# Skip session metadata on "load more" — frontend only needs messages
|
||||
@@ -816,17 +840,45 @@ async def cancel_session_task(
|
||||
return CancelSessionResponse(cancelled=True)
|
||||
|
||||
|
||||
def _ui_message_stream_headers() -> dict[str, str]:
|
||||
return {
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
"x-vercel-ai-ui-message-stream": "v1",
|
||||
}
|
||||
|
||||
|
||||
def _empty_ui_message_stream_response() -> StreamingResponse:
|
||||
# Stable placeholder messageId for the empty queued-mid-turn stream.
|
||||
# Real turns generate per-message UUIDs via the executor; this stream
|
||||
# has no message to attach to, but the AI SDK parser still requires a
|
||||
# non-empty ``messageId`` field on ``StreamStart``.
|
||||
message_id = uuid4().hex
|
||||
|
||||
async def event_generator() -> AsyncGenerator[str, None]:
|
||||
# Vercel AI SDK's UI-message-stream parser expects symmetric
|
||||
# start/finish framing at both stream and step level — every
|
||||
# non-empty turn emits the pair. Without an opener, today's parser
|
||||
# tolerates the closer (no active parts to flush) but a future SDK
|
||||
# tightening would silently break the queue-mid-turn UX. Emit the
|
||||
# full empty pair so the contract stays correct.
|
||||
yield StreamStart(messageId=message_id).to_sse()
|
||||
yield StreamStartStep().to_sse()
|
||||
yield StreamFinishStep().to_sse()
|
||||
yield StreamFinish().to_sse()
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type="text/event-stream",
|
||||
headers=_ui_message_stream_headers(),
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/sessions/{session_id}/stream",
|
||||
responses={
|
||||
202: {
|
||||
"model": QueuePendingMessageResponse,
|
||||
"description": (
|
||||
"Session has a turn in flight — message queued into the pending "
|
||||
"buffer and will be picked up between tool-call rounds by the "
|
||||
"executor currently processing the turn."
|
||||
),
|
||||
},
|
||||
404: {"description": "Session not found or access denied"},
|
||||
429: {"description": "Cost rate-limit or call-frequency cap exceeded"},
|
||||
},
|
||||
@@ -836,19 +888,18 @@ async def stream_chat_post(
|
||||
request: StreamChatRequest,
|
||||
user_id: str = Security(auth.get_user_id),
|
||||
):
|
||||
"""Start a new turn OR queue a follow-up — decided server-side.
|
||||
"""Start a new turn and return an AI SDK UI message stream.
|
||||
|
||||
- **Session idle**: starts a turn. Returns an SSE stream (``text/event-stream``)
|
||||
with Vercel AI SDK chunks (text fragments, tool-call UI, tool results).
|
||||
The generation runs in a background task that survives client disconnects;
|
||||
reconnect via ``GET /sessions/{session_id}/stream`` to resume.
|
||||
Returns an SSE stream (``text/event-stream``) with Vercel AI SDK chunks
|
||||
(text fragments, tool-call UI, tool results). The generation runs in a
|
||||
background task that survives client disconnects; reconnect via
|
||||
``GET /sessions/{session_id}/stream`` to resume.
|
||||
|
||||
- **Session has a turn in flight**: pushes the message into the per-session
|
||||
pending buffer and returns ``202 application/json`` with
|
||||
``QueuePendingMessageResponse``. The executor running the current turn
|
||||
drains the buffer between tool-call rounds (baseline) or at the start of
|
||||
the next turn (SDK). Clients should detect the 202 and surface the
|
||||
message as a queued-chip in the UI.
|
||||
Follow-up messages typed while a turn is already running should use
|
||||
``POST /sessions/{session_id}/messages/pending``. If an older client still
|
||||
posts that follow-up here, we queue it defensively but still return a valid
|
||||
empty UI-message stream so AI SDK transports never receive a JSON body from
|
||||
the stream endpoint.
|
||||
|
||||
Args:
|
||||
session_id: The chat session identifier.
|
||||
@@ -872,26 +923,29 @@ async def stream_chat_post(
|
||||
extra={"json_fields": log_meta},
|
||||
)
|
||||
session = await _validate_and_get_session(session_id, user_id)
|
||||
builder_permissions = resolve_session_permissions(session)
|
||||
|
||||
# Self-defensive queue-fallback: if a turn is already running, don't race
|
||||
# it on the cluster lock — drop the message into the pending buffer and
|
||||
# return 202 so the caller can render a chip. Both UI chips and autopilot
|
||||
# block follow-ups route through this path; keeping the decision on the
|
||||
# server means every caller gets uniform behaviour.
|
||||
if (
|
||||
request.is_user_message
|
||||
and request.message
|
||||
and await is_turn_in_flight(session_id)
|
||||
):
|
||||
response = await queue_pending_for_http(
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
message=request.message,
|
||||
context=request.context,
|
||||
file_ids=request.file_ids,
|
||||
)
|
||||
return JSONResponse(status_code=202, content=response.model_dump())
|
||||
try:
|
||||
await queue_pending_for_http(
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
message=request.message,
|
||||
context=request.context,
|
||||
file_ids=request.file_ids,
|
||||
)
|
||||
return _empty_ui_message_stream_response()
|
||||
except HTTPException as exc:
|
||||
if exc.status_code != 409:
|
||||
raise
|
||||
|
||||
# Permission resolution is only needed below for the actual turn — keep
|
||||
# it after the queue-fall-through so a queued mid-turn request returns
|
||||
# without paying the work.
|
||||
builder_permissions = resolve_session_permissions(session)
|
||||
|
||||
logger.info(
|
||||
f"[TIMING] session validated in {(time.perf_counter() - stream_start_time) * 1000:.1f}ms",
|
||||
@@ -1130,12 +1184,37 @@ async def stream_chat_post(
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no", # Disable nginx buffering
|
||||
"x-vercel-ai-ui-message-stream": "v1", # AI SDK protocol header
|
||||
},
|
||||
headers=_ui_message_stream_headers(),
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/sessions/{session_id}/messages/pending",
|
||||
response_model=QueuePendingMessageResponse,
|
||||
responses={
|
||||
404: {"description": "Session not found or access denied"},
|
||||
409: {"description": "Session has no active turn to receive pending messages"},
|
||||
429: {"description": "Call-frequency cap exceeded"},
|
||||
},
|
||||
)
|
||||
async def queue_pending_message(
|
||||
session_id: str,
|
||||
request: QueuePendingMessageRequest,
|
||||
user_id: str = Security(auth.get_user_id),
|
||||
):
|
||||
"""Queue a follow-up message while the session has an active turn."""
|
||||
await _validate_and_get_session(session_id, user_id)
|
||||
if not await is_turn_in_flight(session_id):
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail="Session has no active turn. Start a new turn with POST /stream.",
|
||||
)
|
||||
return await queue_pending_for_http(
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
message=request.message,
|
||||
context=request.context,
|
||||
file_ids=request.file_ids,
|
||||
)
|
||||
|
||||
|
||||
@@ -1169,6 +1248,7 @@ async def get_pending_messages(
|
||||
)
|
||||
async def resume_session_stream(
|
||||
session_id: str,
|
||||
last_chunk_id: str | None = Query(default=None, include_in_schema=False),
|
||||
user_id: str = Security(auth.get_user_id),
|
||||
):
|
||||
"""
|
||||
@@ -1178,27 +1258,26 @@ async def resume_session_stream(
|
||||
Checks for an active (in-progress) task on the session and either replays
|
||||
the full SSE stream or returns 204 No Content if nothing is running.
|
||||
|
||||
Args:
|
||||
session_id: The chat session identifier.
|
||||
user_id: Optional authenticated user ID.
|
||||
|
||||
Returns:
|
||||
StreamingResponse (SSE) when an active stream exists,
|
||||
or 204 No Content when there is nothing to resume.
|
||||
Always replays the active turn from ``0-0``. The AI SDK UI-message parser
|
||||
keeps text/reasoning part state inside a single parser instance; resuming
|
||||
from a Redis cursor can skip the ``*-start`` events required by later
|
||||
``*-delta`` chunks.
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
active_session, last_message_id = await stream_registry.get_active_session(
|
||||
active_session, _latest_backend_id = await stream_registry.get_active_session(
|
||||
session_id, user_id
|
||||
)
|
||||
|
||||
if not active_session:
|
||||
return Response(status_code=204)
|
||||
|
||||
# Always replay from the beginning ("0-0") on resume.
|
||||
# We can't use last_message_id because it's the latest ID in the backend
|
||||
# stream, not the latest the frontend received — the gap causes lost
|
||||
# messages. The frontend deduplicates replayed content.
|
||||
if last_chunk_id:
|
||||
logger.info(
|
||||
"Ignoring deprecated last_chunk_id on stream resume",
|
||||
extra={"session_id": session_id, "last_chunk_id": last_chunk_id},
|
||||
)
|
||||
|
||||
subscriber_queue = await stream_registry.subscribe_to_session(
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
@@ -1259,12 +1338,7 @@ async def resume_session_stream(
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
"x-vercel-ai-ui-message-stream": "v1",
|
||||
},
|
||||
headers=_ui_message_stream_headers(),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -157,6 +157,11 @@ def _mock_stream_internals(mocker: pytest_mock.MockerFixture):
|
||||
"backend.api.features.chat.routes._validate_and_get_session",
|
||||
return_value=None,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.is_turn_in_flight",
|
||||
new_callable=AsyncMock,
|
||||
return_value=False,
|
||||
)
|
||||
mock_save = mocker.patch(
|
||||
"backend.api.features.chat.routes.append_and_save_message",
|
||||
return_value=MagicMock(), # non-None = message was saved (not a duplicate)
|
||||
@@ -637,7 +642,7 @@ class TestStreamChatRequestModeValidation:
|
||||
assert req.mode is None
|
||||
|
||||
|
||||
# ─── POST /stream queue-fallback (when a turn is already in flight) ──
|
||||
# ─── Pending message queue (when a turn is already in flight) ─────────
|
||||
|
||||
|
||||
def _mock_stream_queue_internals(
|
||||
@@ -646,11 +651,9 @@ def _mock_stream_queue_internals(
|
||||
session_exists: bool = True,
|
||||
turn_in_flight: bool = True,
|
||||
call_count: int = 1,
|
||||
push_length: int | None = 1,
|
||||
):
|
||||
"""Mock dependencies for the POST /stream queue-fallback path.
|
||||
|
||||
When ``turn_in_flight`` is True the handler takes the 202 queue branch.
|
||||
"""
|
||||
"""Mock dependencies for the pending-message queue path."""
|
||||
if session_exists:
|
||||
mock_session = mocker.MagicMock()
|
||||
mock_session.id = "sess-1"
|
||||
@@ -692,12 +695,10 @@ def _mock_stream_queue_internals(
|
||||
return_value=call_count,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.copilot.pending_message_helpers.push_pending_message",
|
||||
"backend.copilot.pending_message_helpers.push_pending_message_if_session_running",
|
||||
new_callable=AsyncMock,
|
||||
return_value=1,
|
||||
return_value=push_length,
|
||||
)
|
||||
# queue_user_message re-runs is_turn_in_flight via the helper module —
|
||||
# stub that path out too so we don't need a fake stream_registry.
|
||||
mocker.patch(
|
||||
"backend.copilot.pending_message_helpers.get_active_session_meta",
|
||||
new_callable=AsyncMock,
|
||||
@@ -705,37 +706,65 @@ def _mock_stream_queue_internals(
|
||||
)
|
||||
|
||||
|
||||
def test_stream_queue_returns_202_when_turn_in_flight(
|
||||
def test_queue_pending_message_returns_200_when_turn_in_flight(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""Happy path: POST /stream to a session with a live turn → 202 queue."""
|
||||
"""Happy path: POST /messages/pending to a live turn queues the message."""
|
||||
_mock_stream_queue_internals(mocker)
|
||||
|
||||
response = client.post(
|
||||
"/sessions/sess-1/stream",
|
||||
json={"message": "follow-up", "is_user_message": True},
|
||||
"/sessions/sess-1/messages/pending",
|
||||
json={"message": "follow-up"},
|
||||
)
|
||||
|
||||
assert response.status_code == 202
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["buffer_length"] == 1
|
||||
assert "turn_in_flight" in data
|
||||
|
||||
|
||||
def test_stream_queue_session_not_found_returns_404(
|
||||
def test_queue_pending_message_session_not_found_returns_404(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""If the session doesn't exist or belong to the user, returns 404."""
|
||||
_mock_stream_queue_internals(mocker, session_exists=False)
|
||||
|
||||
response = client.post(
|
||||
"/sessions/bad-sess/stream",
|
||||
json={"message": "hi", "is_user_message": True},
|
||||
"/sessions/bad-sess/messages/pending",
|
||||
json={"message": "hi"},
|
||||
)
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
def test_stream_queue_call_frequency_limit_returns_429(
|
||||
def test_queue_pending_message_without_active_turn_returns_409(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""A pending-message push needs an active turn to consume it."""
|
||||
_mock_stream_queue_internals(mocker, turn_in_flight=False)
|
||||
|
||||
response = client.post(
|
||||
"/sessions/sess-1/messages/pending",
|
||||
json={"message": "hi"},
|
||||
)
|
||||
|
||||
assert response.status_code == 409
|
||||
|
||||
|
||||
def test_queue_pending_message_race_after_active_check_returns_409(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""If the active turn ends before the atomic push, the message is not queued."""
|
||||
_mock_stream_queue_internals(mocker, push_length=None)
|
||||
|
||||
response = client.post(
|
||||
"/sessions/sess-1/messages/pending",
|
||||
json={"message": "hi"},
|
||||
)
|
||||
|
||||
assert response.status_code == 409
|
||||
|
||||
|
||||
def test_queue_pending_message_call_frequency_limit_returns_429(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""Per-user call-frequency cap rejects rapid-fire queued pushes."""
|
||||
@@ -744,14 +773,14 @@ def test_stream_queue_call_frequency_limit_returns_429(
|
||||
_mock_stream_queue_internals(mocker, call_count=PENDING_CALL_LIMIT + 1)
|
||||
|
||||
response = client.post(
|
||||
"/sessions/sess-1/stream",
|
||||
json={"message": "hi", "is_user_message": True},
|
||||
"/sessions/sess-1/messages/pending",
|
||||
json={"message": "hi"},
|
||||
)
|
||||
assert response.status_code == 429
|
||||
assert "Too many queued message requests this minute" in response.json()["detail"]
|
||||
|
||||
|
||||
def test_stream_queue_converts_context_dict_to_pending_context(
|
||||
def test_queue_pending_message_converts_context_dict_to_pending_context(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""StreamChatRequest.context is a raw dict; must be coerced to the
|
||||
@@ -768,15 +797,14 @@ def test_stream_queue_converts_context_dict_to_pending_context(
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/sessions/sess-1/stream",
|
||||
"/sessions/sess-1/messages/pending",
|
||||
json={
|
||||
"message": "hi",
|
||||
"is_user_message": True,
|
||||
"context": {"url": "https://example.test", "content": "body"},
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 202
|
||||
assert response.status_code == 200
|
||||
queue_spy.assert_awaited_once()
|
||||
kwargs = queue_spy.await_args.kwargs
|
||||
from backend.copilot.pending_messages import PendingMessageContext
|
||||
@@ -786,7 +814,7 @@ def test_stream_queue_converts_context_dict_to_pending_context(
|
||||
assert kwargs["context"].content == "body"
|
||||
|
||||
|
||||
def test_stream_queue_passes_none_context_when_omitted(
|
||||
def test_queue_pending_message_passes_none_context_when_omitted(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""When request.context is omitted, the queue call receives context=None."""
|
||||
@@ -802,15 +830,31 @@ def test_stream_queue_passes_none_context_when_omitted(
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/sessions/sess-1/stream",
|
||||
json={"message": "hi", "is_user_message": True},
|
||||
"/sessions/sess-1/messages/pending",
|
||||
json={"message": "hi"},
|
||||
)
|
||||
|
||||
assert response.status_code == 202
|
||||
assert response.status_code == 200
|
||||
queue_spy.assert_awaited_once()
|
||||
assert queue_spy.await_args.kwargs["context"] is None
|
||||
|
||||
|
||||
def test_stream_chat_queues_legacy_inflight_post_but_returns_sse(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""POST /stream must not return JSON to an AI SDK transport."""
|
||||
_mock_stream_queue_internals(mocker)
|
||||
|
||||
response = client.post(
|
||||
"/sessions/sess-1/stream",
|
||||
json={"message": "follow-up", "is_user_message": True},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"].startswith("text/event-stream")
|
||||
assert '"type":"finish"' in response.text
|
||||
|
||||
|
||||
# ─── get_pending_messages (GET /sessions/{session_id}/messages/pending) ─────
|
||||
|
||||
|
||||
@@ -1581,9 +1625,14 @@ def test_resume_session_stream_no_subscriber_queue(
|
||||
mock_registry.subscribe_to_session = AsyncMock(return_value=None)
|
||||
mocker.patch("backend.api.features.chat.routes.stream_registry", mock_registry)
|
||||
|
||||
response = client.get("/sessions/sess-1/stream")
|
||||
response = client.get("/sessions/sess-1/stream?last_chunk_id=9999-9")
|
||||
|
||||
assert response.status_code == 204
|
||||
mock_registry.subscribe_to_session.assert_awaited_once_with(
|
||||
session_id="sess-1",
|
||||
user_id=TEST_USER_ID,
|
||||
last_message_id="0-0",
|
||||
)
|
||||
|
||||
|
||||
# ─── DELETE /sessions/{id}/stream — disconnect listeners ──────────────
|
||||
|
||||
20
autogpt_platform/backend/backend/api/features/push/model.py
Normal file
20
autogpt_platform/backend/backend/api/features/push/model.py
Normal file
@@ -0,0 +1,20 @@
|
||||
import pydantic
|
||||
|
||||
|
||||
class PushSubscriptionKeys(pydantic.BaseModel):
|
||||
p256dh: str = pydantic.Field(min_length=1, max_length=512)
|
||||
auth: str = pydantic.Field(min_length=1, max_length=512)
|
||||
|
||||
|
||||
class PushSubscribeRequest(pydantic.BaseModel):
|
||||
endpoint: str = pydantic.Field(min_length=1, max_length=2048)
|
||||
keys: PushSubscriptionKeys
|
||||
user_agent: str | None = pydantic.Field(default=None, max_length=512)
|
||||
|
||||
|
||||
class PushUnsubscribeRequest(pydantic.BaseModel):
|
||||
endpoint: str = pydantic.Field(min_length=1, max_length=2048)
|
||||
|
||||
|
||||
class VapidPublicKeyResponse(pydantic.BaseModel):
|
||||
public_key: str
|
||||
64
autogpt_platform/backend/backend/api/features/push/routes.py
Normal file
64
autogpt_platform/backend/backend/api/features/push/routes.py
Normal file
@@ -0,0 +1,64 @@
|
||||
from typing import Annotated
|
||||
|
||||
from autogpt_libs.auth import get_user_id, requires_user
|
||||
from fastapi import APIRouter, HTTPException, Security
|
||||
from starlette.status import HTTP_204_NO_CONTENT, HTTP_400_BAD_REQUEST
|
||||
|
||||
from backend.api.features.push.model import (
|
||||
PushSubscribeRequest,
|
||||
PushUnsubscribeRequest,
|
||||
VapidPublicKeyResponse,
|
||||
)
|
||||
from backend.data.push_subscription import (
|
||||
delete_push_subscription,
|
||||
upsert_push_subscription,
|
||||
validate_push_endpoint,
|
||||
)
|
||||
from backend.util.settings import Settings
|
||||
|
||||
router = APIRouter()
|
||||
_settings = Settings()
|
||||
|
||||
|
||||
@router.get(
|
||||
"/vapid-key",
|
||||
summary="Get VAPID public key for push subscription",
|
||||
)
|
||||
async def get_vapid_public_key() -> VapidPublicKeyResponse:
|
||||
return VapidPublicKeyResponse(public_key=_settings.secrets.vapid_public_key)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/subscribe",
|
||||
summary="Register a push subscription for the current user",
|
||||
status_code=HTTP_204_NO_CONTENT,
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
async def subscribe_push(
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
body: PushSubscribeRequest,
|
||||
) -> None:
|
||||
try:
|
||||
await validate_push_endpoint(body.endpoint)
|
||||
await upsert_push_subscription(
|
||||
user_id=user_id,
|
||||
endpoint=body.endpoint,
|
||||
p256dh=body.keys.p256dh,
|
||||
auth=body.keys.auth,
|
||||
user_agent=body.user_agent,
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=HTTP_400_BAD_REQUEST, detail=str(e))
|
||||
|
||||
|
||||
@router.post(
|
||||
"/unsubscribe",
|
||||
summary="Remove a push subscription",
|
||||
status_code=HTTP_204_NO_CONTENT,
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
async def unsubscribe_push(
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
body: PushUnsubscribeRequest,
|
||||
) -> None:
|
||||
await delete_push_subscription(user_id, body.endpoint)
|
||||
@@ -0,0 +1,240 @@
|
||||
"""Tests for push notification routes."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import fastapi
|
||||
import fastapi.testclient
|
||||
import pytest
|
||||
|
||||
from backend.api.features.push.routes import router
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(router)
|
||||
client = fastapi.testclient.TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_app_auth(mock_jwt_user):
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
|
||||
yield
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
def test_get_vapid_public_key(mocker):
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.secrets.vapid_public_key = "test-vapid-public-key-base64url"
|
||||
mocker.patch(
|
||||
"backend.api.features.push.routes._settings",
|
||||
mock_settings,
|
||||
)
|
||||
|
||||
response = client.get("/vapid-key")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["public_key"] == "test-vapid-public-key-base64url"
|
||||
|
||||
|
||||
def test_get_vapid_public_key_empty(mocker):
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.secrets.vapid_public_key = ""
|
||||
mocker.patch(
|
||||
"backend.api.features.push.routes._settings",
|
||||
mock_settings,
|
||||
)
|
||||
|
||||
response = client.get("/vapid-key")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["public_key"] == ""
|
||||
|
||||
|
||||
def test_subscribe_push(mocker, test_user_id):
|
||||
mock_upsert = mocker.patch(
|
||||
"backend.api.features.push.routes.upsert_push_subscription",
|
||||
new_callable=AsyncMock,
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/subscribe",
|
||||
json={
|
||||
"endpoint": "https://fcm.googleapis.com/fcm/send/abc123",
|
||||
"keys": {
|
||||
"p256dh": "test-p256dh-key",
|
||||
"auth": "test-auth-key",
|
||||
},
|
||||
"user_agent": "Mozilla/5.0 Test",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 204
|
||||
mock_upsert.assert_awaited_once_with(
|
||||
user_id=test_user_id,
|
||||
endpoint="https://fcm.googleapis.com/fcm/send/abc123",
|
||||
p256dh="test-p256dh-key",
|
||||
auth="test-auth-key",
|
||||
user_agent="Mozilla/5.0 Test",
|
||||
)
|
||||
|
||||
|
||||
def test_subscribe_push_without_user_agent(mocker, test_user_id):
|
||||
mock_upsert = mocker.patch(
|
||||
"backend.api.features.push.routes.upsert_push_subscription",
|
||||
new_callable=AsyncMock,
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/subscribe",
|
||||
json={
|
||||
"endpoint": "https://fcm.googleapis.com/fcm/send/abc123",
|
||||
"keys": {
|
||||
"p256dh": "test-p256dh-key",
|
||||
"auth": "test-auth-key",
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 204
|
||||
mock_upsert.assert_awaited_once_with(
|
||||
user_id=test_user_id,
|
||||
endpoint="https://fcm.googleapis.com/fcm/send/abc123",
|
||||
p256dh="test-p256dh-key",
|
||||
auth="test-auth-key",
|
||||
user_agent=None,
|
||||
)
|
||||
|
||||
|
||||
def test_subscribe_push_missing_keys():
|
||||
response = client.post(
|
||||
"/subscribe",
|
||||
json={
|
||||
"endpoint": "https://fcm.googleapis.com/fcm/send/abc123",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
def test_subscribe_push_missing_endpoint():
|
||||
response = client.post(
|
||||
"/subscribe",
|
||||
json={
|
||||
"keys": {
|
||||
"p256dh": "test-p256dh-key",
|
||||
"auth": "test-auth-key",
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
def test_subscribe_push_rejects_empty_crypto_keys():
|
||||
response = client.post(
|
||||
"/subscribe",
|
||||
json={
|
||||
"endpoint": "https://fcm.googleapis.com/fcm/send/abc123",
|
||||
"keys": {"p256dh": "", "auth": ""},
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
def test_subscribe_push_rejects_oversized_endpoint():
|
||||
response = client.post(
|
||||
"/subscribe",
|
||||
json={
|
||||
"endpoint": "https://fcm.googleapis.com/fcm/send/" + "x" * 3000,
|
||||
"keys": {"p256dh": "k", "auth": "a"},
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
def test_unsubscribe_push(mocker, test_user_id):
|
||||
mock_delete = mocker.patch(
|
||||
"backend.api.features.push.routes.delete_push_subscription",
|
||||
new_callable=AsyncMock,
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/unsubscribe",
|
||||
json={
|
||||
"endpoint": "https://fcm.googleapis.com/fcm/send/abc123",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 204
|
||||
mock_delete.assert_awaited_once_with(
|
||||
test_user_id,
|
||||
"https://fcm.googleapis.com/fcm/send/abc123",
|
||||
)
|
||||
|
||||
|
||||
def test_unsubscribe_push_missing_endpoint():
|
||||
response = client.post(
|
||||
"/unsubscribe",
|
||||
json={},
|
||||
)
|
||||
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"untrusted_endpoint",
|
||||
[
|
||||
"https://localhost/evil",
|
||||
"https://127.0.0.1/evil",
|
||||
"https://169.254.169.254/latest/meta-data/",
|
||||
"https://internal-service.local/api",
|
||||
"https://attacker.example.com/push",
|
||||
"http://fcm.googleapis.com/fcm/send/abc",
|
||||
"file:///etc/passwd",
|
||||
],
|
||||
)
|
||||
def test_subscribe_push_rejects_untrusted_endpoints(mocker, untrusted_endpoint):
|
||||
mock_upsert = mocker.patch(
|
||||
"backend.api.features.push.routes.upsert_push_subscription",
|
||||
new_callable=AsyncMock,
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/subscribe",
|
||||
json={
|
||||
"endpoint": untrusted_endpoint,
|
||||
"keys": {
|
||||
"p256dh": "test-p256dh-key",
|
||||
"auth": "test-auth-key",
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
mock_upsert.assert_not_awaited()
|
||||
|
||||
|
||||
def test_subscribe_push_surfaces_cap_as_400(mocker):
|
||||
mocker.patch(
|
||||
"backend.api.features.push.routes.upsert_push_subscription",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=ValueError("Subscription limit of 20 per user reached"),
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/subscribe",
|
||||
json={
|
||||
"endpoint": "https://fcm.googleapis.com/fcm/send/abc123",
|
||||
"keys": {
|
||||
"p256dh": "test-p256dh-key",
|
||||
"auth": "test-auth-key",
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "Subscription limit" in response.json()["detail"]
|
||||
@@ -490,6 +490,9 @@ async def get_store_creators(
|
||||
# Build where clause with sanitized inputs
|
||||
where = {}
|
||||
|
||||
# Only return creators with approved agents
|
||||
where["num_agents"] = {"gt": 0}
|
||||
|
||||
if featured:
|
||||
where["is_featured"] = featured
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from datetime import datetime
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import prisma.enums
|
||||
import prisma.errors
|
||||
@@ -50,8 +51,8 @@ async def test_get_store_agents(mocker):
|
||||
|
||||
# Mock prisma calls
|
||||
mock_store_agent = mocker.patch("prisma.models.StoreAgent.prisma")
|
||||
mock_store_agent.return_value.find_many = mocker.AsyncMock(return_value=mock_agents)
|
||||
mock_store_agent.return_value.count = mocker.AsyncMock(return_value=1)
|
||||
mock_store_agent.return_value.find_many = AsyncMock(return_value=mock_agents)
|
||||
mock_store_agent.return_value.count = AsyncMock(return_value=1)
|
||||
|
||||
# Call function
|
||||
result = await db.get_store_agents()
|
||||
@@ -94,7 +95,7 @@ async def test_get_store_agent_details(mocker):
|
||||
|
||||
# Mock StoreAgent prisma call
|
||||
mock_store_agent = mocker.patch("prisma.models.StoreAgent.prisma")
|
||||
mock_store_agent.return_value.find_first = mocker.AsyncMock(return_value=mock_agent)
|
||||
mock_store_agent.return_value.find_first = AsyncMock(return_value=mock_agent)
|
||||
|
||||
# Call function
|
||||
result = await db.get_store_agent_details("creator", "test-agent")
|
||||
@@ -133,7 +134,7 @@ async def test_get_store_creator(mocker):
|
||||
|
||||
# Mock prisma call
|
||||
mock_creator = mocker.patch("prisma.models.Creator.prisma")
|
||||
mock_creator.return_value.find_unique = mocker.AsyncMock()
|
||||
mock_creator.return_value.find_unique = AsyncMock()
|
||||
# Configure the mock to return values that will pass validation
|
||||
mock_creator.return_value.find_unique.return_value = mock_creator_data
|
||||
|
||||
@@ -236,23 +237,23 @@ async def test_create_store_submission(mocker):
|
||||
|
||||
# Mock prisma calls
|
||||
mock_agent_graph = mocker.patch("prisma.models.AgentGraph.prisma")
|
||||
mock_agent_graph.return_value.find_first = mocker.AsyncMock(return_value=mock_agent)
|
||||
mock_agent_graph.return_value.find_first = AsyncMock(return_value=mock_agent)
|
||||
|
||||
# Mock transaction context manager
|
||||
mock_tx = mocker.MagicMock()
|
||||
mocker.patch(
|
||||
"backend.api.features.store.db.transaction",
|
||||
return_value=mocker.AsyncMock(
|
||||
__aenter__=mocker.AsyncMock(return_value=mock_tx),
|
||||
__aexit__=mocker.AsyncMock(return_value=False),
|
||||
return_value=AsyncMock(
|
||||
__aenter__=AsyncMock(return_value=mock_tx),
|
||||
__aexit__=AsyncMock(return_value=False),
|
||||
),
|
||||
)
|
||||
|
||||
mock_sl = mocker.patch("prisma.models.StoreListing.prisma")
|
||||
mock_sl.return_value.find_unique = mocker.AsyncMock(return_value=None)
|
||||
mock_sl.return_value.find_unique = AsyncMock(return_value=None)
|
||||
|
||||
mock_slv = mocker.patch("prisma.models.StoreListingVersion.prisma")
|
||||
mock_slv.return_value.create = mocker.AsyncMock(return_value=mock_version)
|
||||
mock_slv.return_value.create = AsyncMock(return_value=mock_version)
|
||||
|
||||
# Call function
|
||||
result = await db.create_store_submission(
|
||||
@@ -292,10 +293,8 @@ async def test_update_profile(mocker):
|
||||
|
||||
# Mock prisma calls
|
||||
mock_profile_db = mocker.patch("prisma.models.Profile.prisma")
|
||||
mock_profile_db.return_value.find_first = mocker.AsyncMock(
|
||||
return_value=mock_profile
|
||||
)
|
||||
mock_profile_db.return_value.update = mocker.AsyncMock(return_value=mock_profile)
|
||||
mock_profile_db.return_value.find_first = AsyncMock(return_value=mock_profile)
|
||||
mock_profile_db.return_value.update = AsyncMock(return_value=mock_profile)
|
||||
|
||||
# Test data
|
||||
profile = Profile(
|
||||
@@ -336,9 +335,7 @@ async def test_get_user_profile(mocker):
|
||||
|
||||
# Mock prisma calls
|
||||
mock_profile_db = mocker.patch("prisma.models.Profile.prisma")
|
||||
mock_profile_db.return_value.find_first = mocker.AsyncMock(
|
||||
return_value=mock_profile
|
||||
)
|
||||
mock_profile_db.return_value.find_first = AsyncMock(return_value=mock_profile)
|
||||
|
||||
# Call function
|
||||
result = await db.get_user_profile("user-id")
|
||||
@@ -396,3 +393,38 @@ async def test_get_store_agents_search_category_array_injection():
|
||||
# Verify the query executed without error
|
||||
# Category should be parameterized, preventing SQL injection
|
||||
assert isinstance(result.agents, list)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_get_store_creators_only_returns_approved(mocker):
|
||||
mock_creators = [
|
||||
prisma.models.Creator(
|
||||
name="Creator One",
|
||||
username="creator1",
|
||||
description="desc",
|
||||
links=["link1"],
|
||||
avatar_url="avatar.jpg",
|
||||
num_agents=1,
|
||||
agent_rating=4.5,
|
||||
agent_runs=10,
|
||||
top_categories=["test"],
|
||||
is_featured=False,
|
||||
)
|
||||
]
|
||||
|
||||
mock_creator = mocker.patch("prisma.models.Creator.prisma")
|
||||
mock_creator.return_value.find_many = AsyncMock(return_value=mock_creators)
|
||||
mock_creator.return_value.count = AsyncMock(return_value=1)
|
||||
|
||||
result = await db.get_store_creators()
|
||||
|
||||
assert len(result.creators) == 1
|
||||
assert result.creators[0].username == "creator1"
|
||||
|
||||
mock_creator.return_value.find_many.assert_called_once()
|
||||
mock_creator.return_value.count.assert_called_once()
|
||||
|
||||
_, find_kwargs = mock_creator.return_value.find_many.call_args
|
||||
_, count_kwargs = mock_creator.return_value.count.call_args
|
||||
assert find_kwargs["where"]["num_agents"] == {"gt": 0}
|
||||
assert count_kwargs["where"]["num_agents"] == {"gt": 0}
|
||||
|
||||
@@ -245,11 +245,12 @@ def test_get_subscription_status_tier_multipliers_ld_override(
|
||||
assert "BUSINESS" not in data["tier_multipliers"]
|
||||
|
||||
|
||||
def test_get_subscription_status_defaults_to_basic(
|
||||
def test_get_subscription_status_defaults_to_no_tier(
|
||||
client: fastapi.testclient.TestClient,
|
||||
mocker: pytest_mock.MockFixture,
|
||||
) -> None:
|
||||
"""When all LD price IDs are unset, tier_costs is empty and the caller sees cost=0."""
|
||||
"""When user has no subscription_tier, defaults to NO_TIER (the explicit
|
||||
no-active-subscription state)."""
|
||||
mock_user = Mock()
|
||||
mock_user.subscription_tier = None
|
||||
|
||||
@@ -273,7 +274,7 @@ def test_get_subscription_status_defaults_to_basic(
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["tier"] == SubscriptionTier.BASIC.value
|
||||
assert data["tier"] == SubscriptionTier.NO_TIER.value
|
||||
assert data["monthly_cost"] == 0
|
||||
assert data["tier_costs"] == {}
|
||||
assert data["proration_credit_cents"] == 0
|
||||
@@ -326,11 +327,11 @@ def test_get_subscription_status_stripe_error_falls_back_to_zero(
|
||||
assert data["tier_costs"]["PRO"] == 0
|
||||
|
||||
|
||||
def test_update_subscription_tier_basic_no_payment(
|
||||
def test_update_subscription_tier_no_tier_no_payment(
|
||||
client: fastapi.testclient.TestClient,
|
||||
mocker: pytest_mock.MockFixture,
|
||||
) -> None:
|
||||
"""POST /credits/subscription to BASIC tier when payment disabled skips Stripe."""
|
||||
"""POST /credits/subscription to NO_TIER (cancel) when payment disabled skips Stripe."""
|
||||
mock_user = Mock()
|
||||
mock_user.subscription_tier = SubscriptionTier.PRO
|
||||
|
||||
@@ -351,7 +352,7 @@ def test_update_subscription_tier_basic_no_payment(
|
||||
new_callable=AsyncMock,
|
||||
)
|
||||
|
||||
response = client.post("/credits/subscription", json={"tier": "BASIC"})
|
||||
response = client.post("/credits/subscription", json={"tier": "NO_TIER"})
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["url"] == ""
|
||||
@@ -404,12 +405,109 @@ def test_update_subscription_tier_paid_requires_urls(
|
||||
"backend.api.features.v1.is_feature_enabled",
|
||||
side_effect=mock_feature_enabled,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.modify_stripe_subscription_for_tier",
|
||||
new_callable=AsyncMock,
|
||||
return_value=False,
|
||||
)
|
||||
|
||||
response = client.post("/credits/subscription", json={"tier": "PRO"})
|
||||
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
def test_update_subscription_tier_currency_mismatch_returns_422(
|
||||
client: fastapi.testclient.TestClient,
|
||||
mocker: pytest_mock.MockFixture,
|
||||
) -> None:
|
||||
"""Stripe rejects a SubscriptionSchedule whose phases mix currencies (e.g.
|
||||
GBP-checkout sub trying to schedule a USD-only target Price). The handler
|
||||
must convert that into a specific 422 instead of the generic 502 so the
|
||||
caller can tell the difference between a currency-config bug and a Stripe
|
||||
outage."""
|
||||
mock_user = Mock()
|
||||
mock_user.subscription_tier = SubscriptionTier.MAX
|
||||
|
||||
async def mock_feature_enabled(*args, **kwargs):
|
||||
return True
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.get_user_by_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_user,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.is_feature_enabled",
|
||||
side_effect=mock_feature_enabled,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.modify_stripe_subscription_for_tier",
|
||||
side_effect=stripe.InvalidRequestError(
|
||||
"The price specified only supports `usd`. This doesn't match the"
|
||||
" expected currency: `gbp`.",
|
||||
param="phases",
|
||||
),
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/credits/subscription",
|
||||
json={
|
||||
"tier": "PRO",
|
||||
"success_url": f"{TEST_FRONTEND_ORIGIN}/success",
|
||||
"cancel_url": f"{TEST_FRONTEND_ORIGIN}/cancel",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 422
|
||||
detail = response.json()["detail"]
|
||||
assert "billing currency" in detail.lower()
|
||||
assert "contact support" in detail.lower()
|
||||
|
||||
|
||||
def test_update_subscription_tier_non_currency_invalid_request_returns_502(
|
||||
client: fastapi.testclient.TestClient,
|
||||
mocker: pytest_mock.MockFixture,
|
||||
) -> None:
|
||||
"""Locks the contract that *only* currency-mismatch InvalidRequestErrors
|
||||
translate to 422 — every other Stripe InvalidRequestError must still
|
||||
surface as the generic 502 so that widening the conditional later is
|
||||
caught by the suite."""
|
||||
mock_user = Mock()
|
||||
mock_user.subscription_tier = SubscriptionTier.MAX
|
||||
|
||||
async def mock_feature_enabled(*args, **kwargs):
|
||||
return True
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.get_user_by_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_user,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.is_feature_enabled",
|
||||
side_effect=mock_feature_enabled,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.modify_stripe_subscription_for_tier",
|
||||
side_effect=stripe.InvalidRequestError(
|
||||
"No such price: 'price_does_not_exist'",
|
||||
param="items[0][price]",
|
||||
),
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/credits/subscription",
|
||||
json={
|
||||
"tier": "PRO",
|
||||
"success_url": f"{TEST_FRONTEND_ORIGIN}/success",
|
||||
"cancel_url": f"{TEST_FRONTEND_ORIGIN}/cancel",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 502
|
||||
assert "billing currency" not in response.json()["detail"].lower()
|
||||
|
||||
|
||||
def test_update_subscription_tier_creates_checkout(
|
||||
client: fastapi.testclient.TestClient,
|
||||
mocker: pytest_mock.MockFixture,
|
||||
@@ -430,6 +528,11 @@ def test_update_subscription_tier_creates_checkout(
|
||||
"backend.api.features.v1.is_feature_enabled",
|
||||
side_effect=mock_feature_enabled,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.modify_stripe_subscription_for_tier",
|
||||
new_callable=AsyncMock,
|
||||
return_value=False,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.create_subscription_checkout",
|
||||
new_callable=AsyncMock,
|
||||
@@ -469,6 +572,11 @@ def test_update_subscription_tier_rejects_open_redirect(
|
||||
"backend.api.features.v1.is_feature_enabled",
|
||||
side_effect=mock_feature_enabled,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.modify_stripe_subscription_for_tier",
|
||||
new_callable=AsyncMock,
|
||||
return_value=False,
|
||||
)
|
||||
checkout_mock = mocker.patch(
|
||||
"backend.api.features.v1.create_subscription_checkout",
|
||||
new_callable=AsyncMock,
|
||||
@@ -649,14 +757,14 @@ def test_update_subscription_tier_same_tier_stripe_error_returns_502(
|
||||
assert "contact support" in response.json()["detail"].lower()
|
||||
|
||||
|
||||
def test_update_subscription_tier_basic_with_payment_schedules_cancel_and_does_not_update_db(
|
||||
def test_update_subscription_tier_no_tier_with_payment_schedules_cancel_and_does_not_update_db(
|
||||
client: fastapi.testclient.TestClient,
|
||||
mocker: pytest_mock.MockFixture,
|
||||
) -> None:
|
||||
"""Downgrading to BASIC schedules Stripe cancellation at period end.
|
||||
"""Cancelling to NO_TIER schedules Stripe cancellation at period end.
|
||||
|
||||
The DB tier must NOT be updated immediately — the customer.subscription.deleted
|
||||
webhook fires at period end and downgrades to BASIC then.
|
||||
webhook fires at period end and downgrades to NO_TIER then.
|
||||
"""
|
||||
mock_user = Mock()
|
||||
mock_user.subscription_tier = SubscriptionTier.PRO
|
||||
@@ -682,18 +790,18 @@ def test_update_subscription_tier_basic_with_payment_schedules_cancel_and_does_n
|
||||
side_effect=mock_feature_enabled,
|
||||
)
|
||||
|
||||
response = client.post("/credits/subscription", json={"tier": "BASIC"})
|
||||
response = client.post("/credits/subscription", json={"tier": "NO_TIER"})
|
||||
|
||||
assert response.status_code == 200
|
||||
mock_cancel.assert_awaited_once()
|
||||
mock_set_tier.assert_not_awaited()
|
||||
|
||||
|
||||
def test_update_subscription_tier_basic_cancel_failure_returns_502(
|
||||
def test_update_subscription_tier_no_tier_cancel_failure_returns_502(
|
||||
client: fastapi.testclient.TestClient,
|
||||
mocker: pytest_mock.MockFixture,
|
||||
) -> None:
|
||||
"""Downgrading to BASIC returns 502 with a generic error (no Stripe detail leakage)."""
|
||||
"""Cancelling to NO_TIER returns 502 with a generic error (no Stripe detail leakage)."""
|
||||
mock_user = Mock()
|
||||
mock_user.subscription_tier = SubscriptionTier.PRO
|
||||
|
||||
@@ -716,7 +824,7 @@ def test_update_subscription_tier_basic_cancel_failure_returns_502(
|
||||
side_effect=mock_feature_enabled,
|
||||
)
|
||||
|
||||
response = client.post("/credits/subscription", json={"tier": "BASIC"})
|
||||
response = client.post("/credits/subscription", json={"tier": "NO_TIER"})
|
||||
|
||||
assert response.status_code == 502
|
||||
detail = response.json()["detail"]
|
||||
@@ -921,29 +1029,20 @@ def test_update_subscription_tier_max_checkout(
|
||||
checkout_mock.assert_not_awaited()
|
||||
|
||||
|
||||
def test_update_subscription_tier_admin_granted_paid_to_paid_updates_db_directly(
|
||||
def test_update_subscription_tier_no_active_sub_falls_through_to_checkout(
|
||||
client: fastapi.testclient.TestClient,
|
||||
mocker: pytest_mock.MockFixture,
|
||||
) -> None:
|
||||
"""Admin-granted paid tier users are NOT sent to Stripe checkout for paid→paid changes.
|
||||
"""Any tier change from a user with no active Stripe sub goes through Checkout.
|
||||
|
||||
When modify_stripe_subscription_for_tier returns False (no Stripe subscription
|
||||
found — admin-granted tier), the endpoint must update the DB tier directly and
|
||||
return 200 with url="", rather than falling through to Checkout Session creation.
|
||||
Admin-granted users (no Stripe sub yet) and never-paid users follow the
|
||||
exact same path: modify returns False → Checkout to set up payment. The
|
||||
endpoint has no admin-specific branch — admin tier grants happen out-of-band
|
||||
via the admin portal, not this user-facing route.
|
||||
"""
|
||||
mock_user = Mock()
|
||||
mock_user.subscription_tier = SubscriptionTier.PRO
|
||||
|
||||
async def price_id_with_business(tier: SubscriptionTier) -> str | None:
|
||||
return {
|
||||
**_DEFAULT_TIER_PRICES,
|
||||
SubscriptionTier.BUSINESS: "price_business",
|
||||
}.get(tier)
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.get_subscription_price_id",
|
||||
side_effect=price_id_with_business,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.get_user_by_id",
|
||||
new_callable=AsyncMock,
|
||||
@@ -954,7 +1053,6 @@ def test_update_subscription_tier_admin_granted_paid_to_paid_updates_db_directly
|
||||
new_callable=AsyncMock,
|
||||
return_value=True,
|
||||
)
|
||||
# Return False = no Stripe subscription (admin-granted tier)
|
||||
modify_mock = mocker.patch(
|
||||
"backend.api.features.v1.modify_stripe_subscription_for_tier",
|
||||
new_callable=AsyncMock,
|
||||
@@ -967,23 +1065,24 @@ def test_update_subscription_tier_admin_granted_paid_to_paid_updates_db_directly
|
||||
checkout_mock = mocker.patch(
|
||||
"backend.api.features.v1.create_subscription_checkout",
|
||||
new_callable=AsyncMock,
|
||||
return_value="https://checkout.stripe.com/pay/cs_test_no_sub",
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/credits/subscription",
|
||||
json={
|
||||
"tier": "BUSINESS",
|
||||
"tier": "MAX",
|
||||
"success_url": f"{TEST_FRONTEND_ORIGIN}/success",
|
||||
"cancel_url": f"{TEST_FRONTEND_ORIGIN}/cancel",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["url"] == ""
|
||||
modify_mock.assert_awaited_once_with(TEST_USER_ID, SubscriptionTier.BUSINESS)
|
||||
# DB tier updated directly — no Stripe Checkout Session created
|
||||
set_tier_mock.assert_awaited_once_with(TEST_USER_ID, SubscriptionTier.BUSINESS)
|
||||
checkout_mock.assert_not_awaited()
|
||||
assert response.json()["url"] == "https://checkout.stripe.com/pay/cs_test_no_sub"
|
||||
modify_mock.assert_awaited_once_with(TEST_USER_ID, SubscriptionTier.MAX)
|
||||
# No DB-flip — payment must be collected via Checkout regardless of direction.
|
||||
set_tier_mock.assert_not_awaited()
|
||||
checkout_mock.assert_awaited_once()
|
||||
|
||||
|
||||
def test_update_subscription_tier_priced_basic_no_sub_falls_through_to_checkout(
|
||||
@@ -1154,14 +1253,14 @@ def test_update_subscription_tier_paid_to_paid_stripe_error_returns_502(
|
||||
assert response.status_code == 502
|
||||
|
||||
|
||||
def test_update_subscription_tier_basic_no_stripe_subscription(
|
||||
def test_update_subscription_tier_no_tier_no_stripe_subscription(
|
||||
client: fastapi.testclient.TestClient,
|
||||
mocker: pytest_mock.MockFixture,
|
||||
) -> None:
|
||||
"""Downgrading to BASIC when no Stripe subscription exists updates DB tier directly.
|
||||
"""Cancelling to NO_TIER when no Stripe subscription exists updates DB tier directly.
|
||||
|
||||
Admin-granted paid tiers have no associated Stripe subscription. When such a
|
||||
user requests a self-service downgrade, cancel_stripe_subscription returns False
|
||||
user requests a self-service cancel, cancel_stripe_subscription returns False
|
||||
(nothing to cancel), so the endpoint must immediately call set_subscription_tier
|
||||
rather than waiting for a webhook that will never arrive.
|
||||
"""
|
||||
@@ -1189,13 +1288,13 @@ def test_update_subscription_tier_basic_no_stripe_subscription(
|
||||
new_callable=AsyncMock,
|
||||
)
|
||||
|
||||
response = client.post("/credits/subscription", json={"tier": "BASIC"})
|
||||
response = client.post("/credits/subscription", json={"tier": "NO_TIER"})
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["url"] == ""
|
||||
cancel_mock.assert_awaited_once_with(TEST_USER_ID)
|
||||
# DB tier must be updated immediately — no webhook will fire for a missing sub
|
||||
set_tier_mock.assert_awaited_once_with(TEST_USER_ID, SubscriptionTier.BASIC)
|
||||
set_tier_mock.assert_awaited_once_with(TEST_USER_ID, SubscriptionTier.NO_TIER)
|
||||
|
||||
|
||||
def test_get_subscription_status_includes_pending_tier(
|
||||
|
||||
@@ -57,12 +57,14 @@ from backend.data.credit import (
|
||||
UserCredit,
|
||||
cancel_stripe_subscription,
|
||||
create_subscription_checkout,
|
||||
get_active_subscription_period_end,
|
||||
get_auto_top_up,
|
||||
get_pending_subscription_change,
|
||||
get_proration_credit_cents,
|
||||
get_subscription_price_id,
|
||||
get_user_credit_model,
|
||||
handle_subscription_payment_failure,
|
||||
handle_subscription_payment_success,
|
||||
modify_stripe_subscription_for_tier,
|
||||
release_pending_subscription_schedule,
|
||||
set_auto_top_up,
|
||||
@@ -700,13 +702,13 @@ async def get_user_auto_top_up(
|
||||
|
||||
|
||||
class SubscriptionTierRequest(BaseModel):
|
||||
tier: Literal["BASIC", "PRO", "MAX", "BUSINESS"]
|
||||
tier: Literal["NO_TIER", "BASIC", "PRO", "MAX", "BUSINESS"]
|
||||
success_url: str = ""
|
||||
cancel_url: str = ""
|
||||
|
||||
|
||||
class SubscriptionStatusResponse(BaseModel):
|
||||
tier: Literal["BASIC", "PRO", "MAX", "BUSINESS", "ENTERPRISE"]
|
||||
tier: Literal["NO_TIER", "BASIC", "PRO", "MAX", "BUSINESS", "ENTERPRISE"]
|
||||
monthly_cost: int # amount in cents (Stripe convention)
|
||||
tier_costs: dict[str, int] # tier name -> amount in cents
|
||||
tier_multipliers: dict[str, float] = Field(
|
||||
@@ -719,7 +721,23 @@ class SubscriptionStatusResponse(BaseModel):
|
||||
),
|
||||
)
|
||||
proration_credit_cents: int # unused portion of current sub to convert on upgrade
|
||||
pending_tier: Optional[Literal["BASIC", "PRO", "MAX", "BUSINESS"]] = None
|
||||
has_active_stripe_subscription: bool = Field(
|
||||
default=False,
|
||||
description=(
|
||||
"True when the user has an active/trialing Stripe subscription. The"
|
||||
" frontend uses this to branch upgrade UX: modify-in-place + saved-card"
|
||||
" auto-charge when True, redirect to Stripe Checkout when False."
|
||||
),
|
||||
)
|
||||
current_period_end: Optional[int] = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Unix timestamp of the active subscription's current_period_end. Used"
|
||||
" to show the date Stripe will issue the next invoice (with prorated"
|
||||
" upgrade charges, if any). None when no active sub."
|
||||
),
|
||||
)
|
||||
pending_tier: Optional[Literal["NO_TIER", "BASIC", "PRO", "MAX", "BUSINESS"]] = None
|
||||
pending_tier_effective_at: Optional[datetime] = None
|
||||
url: str = Field(
|
||||
default="",
|
||||
@@ -804,8 +822,11 @@ async def get_subscription_status(
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> SubscriptionStatusResponse:
|
||||
user = await get_user_by_id(user_id)
|
||||
tier = user.subscription_tier or SubscriptionTier.BASIC
|
||||
tier = user.subscription_tier or SubscriptionTier.NO_TIER
|
||||
|
||||
# Tiers that *can* have a Stripe price configured (and therefore appear
|
||||
# in the tier picker if the LD flag exposes a price-id). NO_TIER is not
|
||||
# priceable — it's the implicit "no active subscription" state.
|
||||
priceable_tiers = [
|
||||
SubscriptionTier.BASIC,
|
||||
SubscriptionTier.PRO,
|
||||
@@ -839,7 +860,10 @@ async def get_subscription_status(
|
||||
}
|
||||
|
||||
current_monthly_cost = tier_costs.get(tier.value, 0)
|
||||
proration_credit = await get_proration_credit_cents(user_id, current_monthly_cost)
|
||||
proration_credit, current_period_end = await asyncio.gather(
|
||||
get_proration_credit_cents(user_id, current_monthly_cost),
|
||||
get_active_subscription_period_end(user_id),
|
||||
)
|
||||
|
||||
try:
|
||||
pending = await get_pending_subscription_change(user_id)
|
||||
@@ -861,10 +885,13 @@ async def get_subscription_status(
|
||||
tier_costs=tier_costs,
|
||||
tier_multipliers=tier_multipliers,
|
||||
proration_credit_cents=proration_credit,
|
||||
has_active_stripe_subscription=current_period_end is not None,
|
||||
current_period_end=current_period_end,
|
||||
)
|
||||
if pending is not None:
|
||||
pending_tier_enum, pending_effective_at = pending
|
||||
if pending_tier_enum in (
|
||||
SubscriptionTier.NO_TIER,
|
||||
SubscriptionTier.BASIC,
|
||||
SubscriptionTier.PRO,
|
||||
SubscriptionTier.MAX,
|
||||
@@ -892,7 +919,7 @@ async def update_subscription_tier(
|
||||
# ENTERPRISE tier is admin-managed — block self-service changes from ENTERPRISE users.
|
||||
user = await get_user_by_id(user_id)
|
||||
if (
|
||||
user.subscription_tier or SubscriptionTier.BASIC
|
||||
user.subscription_tier or SubscriptionTier.NO_TIER
|
||||
) == SubscriptionTier.ENTERPRISE:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
@@ -904,7 +931,7 @@ async def update_subscription_tier(
|
||||
# collapsed behaviour that replaces the old /credits/subscription/cancel-pending
|
||||
# route. Safe when no pending change exists: release_pending_subscription_schedule
|
||||
# returns False and we simply return the current status.
|
||||
if (user.subscription_tier or SubscriptionTier.BASIC) == tier:
|
||||
if (user.subscription_tier or SubscriptionTier.NO_TIER) == tier:
|
||||
try:
|
||||
await release_pending_subscription_schedule(user_id)
|
||||
except stripe.StripeError as e:
|
||||
@@ -926,18 +953,14 @@ async def update_subscription_tier(
|
||||
Flag.ENABLE_PLATFORM_PAYMENT, user_id, default=False
|
||||
)
|
||||
|
||||
current_tier = user.subscription_tier or SubscriptionTier.BASIC
|
||||
target_price_id, current_tier_price_id = await asyncio.gather(
|
||||
get_subscription_price_id(tier),
|
||||
get_subscription_price_id(current_tier),
|
||||
)
|
||||
target_price_id = await get_subscription_price_id(tier)
|
||||
|
||||
# Legacy cancel: target BASIC + stripe-price-id-basic unset. Schedule Stripe
|
||||
# cancellation at period end; cancel_at_period_end=True lets the webhook flip
|
||||
# the DB tier. No active sub (admin-granted) or payment disabled → DB flip.
|
||||
# Once stripe-price-id-basic is configured, BASIC becomes a real sub and falls
|
||||
# through to the modify/checkout flow below.
|
||||
if tier == SubscriptionTier.BASIC and target_price_id is None:
|
||||
# Cancel: target NO_TIER. Schedule Stripe cancellation at period end;
|
||||
# cancel_at_period_end=True lets the webhook flip the DB tier. No active
|
||||
# sub (admin-granted or never-paid) or payment disabled → DB flip.
|
||||
# NO_TIER is never priceable, so this branch always fires for cancel
|
||||
# requests regardless of LD config.
|
||||
if tier == SubscriptionTier.NO_TIER:
|
||||
if payment_enabled:
|
||||
try:
|
||||
had_subscription = await cancel_stripe_subscription(user_id)
|
||||
@@ -973,32 +996,53 @@ async def update_subscription_tier(
|
||||
detail=f"Subscription not available for tier {tier.value}",
|
||||
)
|
||||
|
||||
# User has an active Stripe subscription (current tier has an LD price):
|
||||
# modify it in-place. modify_stripe_subscription_for_tier returns False when no
|
||||
# active sub exists — that's only a "DB-only flip is OK" signal for admin-granted
|
||||
# paid tiers (PRO/BUSINESS with no Stripe record). Priced-BASIC users without a
|
||||
# sub must still go through Checkout so they set up payment.
|
||||
if current_tier_price_id is not None:
|
||||
try:
|
||||
modified = await modify_stripe_subscription_for_tier(user_id, tier)
|
||||
if modified:
|
||||
return await get_subscription_status(user_id)
|
||||
if current_tier != SubscriptionTier.BASIC:
|
||||
await set_subscription_tier(user_id, tier)
|
||||
return await get_subscription_status(user_id)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=422, detail=str(e))
|
||||
except stripe.StripeError as e:
|
||||
logger.exception(
|
||||
"Stripe error modifying subscription for user %s: %s", user_id, e
|
||||
# Modify in place if there's a sub; else fall through to Checkout below.
|
||||
try:
|
||||
modified = await modify_stripe_subscription_for_tier(user_id, tier)
|
||||
if modified:
|
||||
return await get_subscription_status(user_id)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=422, detail=str(e))
|
||||
except stripe.InvalidRequestError as e:
|
||||
# Stripe rejects schedule modify when phases mix currencies, e.g. the
|
||||
# active sub was checked out in GBP but the target tier's Price is
|
||||
# USD-only. 502 reads as outage; surface a 422 with a specific message
|
||||
# so the user/admin can see what to fix in Stripe.
|
||||
msg = str(e)
|
||||
if "currency" in msg.lower():
|
||||
logger.warning(
|
||||
"Currency mismatch on tier change for user %s: %s", user_id, msg
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
status_code=422,
|
||||
detail=(
|
||||
"Unable to update your subscription right now. "
|
||||
"Please try again or contact support."
|
||||
"Tier change unavailable for your current billing currency."
|
||||
" Please contact support — the target tier needs to be"
|
||||
" configured for your currency in Stripe before this"
|
||||
" change can go through."
|
||||
),
|
||||
)
|
||||
logger.exception(
|
||||
"Stripe error modifying subscription for user %s: %s", user_id, e
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail=(
|
||||
"Unable to update your subscription right now. "
|
||||
"Please try again or contact support."
|
||||
),
|
||||
)
|
||||
except stripe.StripeError as e:
|
||||
logger.exception(
|
||||
"Stripe error modifying subscription for user %s: %s", user_id, e
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail=(
|
||||
"Unable to update your subscription right now. "
|
||||
"Please try again or contact support."
|
||||
),
|
||||
)
|
||||
|
||||
# No active Stripe subscription → create Stripe Checkout Session.
|
||||
if not request.success_url or not request.cancel_url:
|
||||
@@ -1134,6 +1178,9 @@ async def stripe_webhook(request: Request):
|
||||
):
|
||||
await sync_subscription_schedule_from_stripe(data_object)
|
||||
|
||||
if event_type == "invoice.payment_succeeded":
|
||||
await handle_subscription_payment_success(data_object)
|
||||
|
||||
if event_type == "invoice.payment_failed":
|
||||
await handle_subscription_payment_failure(data_object)
|
||||
|
||||
|
||||
@@ -34,6 +34,7 @@ import backend.api.features.oauth
|
||||
import backend.api.features.otto.routes
|
||||
import backend.api.features.platform_linking.routes
|
||||
import backend.api.features.postmark.postmark
|
||||
import backend.api.features.push.routes as push_routes
|
||||
import backend.api.features.store.model
|
||||
import backend.api.features.store.routes
|
||||
import backend.api.features.v1
|
||||
@@ -41,6 +42,7 @@ import backend.api.features.workspace.routes as workspace_routes
|
||||
import backend.data.block
|
||||
import backend.data.db
|
||||
import backend.data.graph
|
||||
import backend.data.redis_client
|
||||
import backend.data.user
|
||||
import backend.integrations.webhooks.utils
|
||||
import backend.util.service
|
||||
@@ -95,6 +97,8 @@ async def lifespan_context(app: fastapi.FastAPI):
|
||||
verify_auth_settings()
|
||||
|
||||
await backend.data.db.connect()
|
||||
# Eager connect to fail-fast if Redis is unreachable.
|
||||
await backend.data.redis_client.get_redis_async()
|
||||
|
||||
# Configure thread pool for FastAPI sync operation performance
|
||||
# CRITICAL: FastAPI automatically runs ALL sync functions in this thread pool:
|
||||
@@ -146,7 +150,18 @@ async def lifespan_context(app: fastapi.FastAPI):
|
||||
except Exception as e:
|
||||
logger.warning(f"Error shutting down workspace storage: {e}")
|
||||
|
||||
await backend.data.db.disconnect()
|
||||
# Each cleanup is wrapped so one failure doesn't block the rest. The
|
||||
# Redis close in particular silences asyncio's "Unclosed ClusterNode"
|
||||
# GC warning at interpreter shutdown.
|
||||
try:
|
||||
await backend.data.redis_client.disconnect_async()
|
||||
except Exception:
|
||||
logger.warning("redis_client.disconnect_async failed", exc_info=True)
|
||||
|
||||
try:
|
||||
await backend.data.db.disconnect()
|
||||
except Exception:
|
||||
logger.warning("db.disconnect failed", exc_info=True)
|
||||
|
||||
|
||||
def custom_generate_unique_id(route: APIRoute):
|
||||
@@ -379,6 +394,11 @@ app.include_router(
|
||||
tags=["oauth"],
|
||||
prefix="/api/oauth",
|
||||
)
|
||||
app.include_router(
|
||||
push_routes.router,
|
||||
tags=["push"],
|
||||
prefix="/api/push",
|
||||
)
|
||||
app.include_router(
|
||||
backend.api.features.platform_linking.routes.router,
|
||||
tags=["platform-linking"],
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Protocol
|
||||
@@ -17,14 +16,12 @@ from backend.api.model import (
|
||||
WSSubscribeGraphExecutionsRequest,
|
||||
)
|
||||
from backend.api.utils.cors import build_cors_params
|
||||
from backend.data.execution import AsyncRedisExecutionEventBus
|
||||
from backend.data.notification_bus import AsyncRedisNotificationEventBus
|
||||
from backend.data import db, redis_client
|
||||
from backend.data.user import DEFAULT_USER_ID
|
||||
from backend.monitoring.instrumentation import (
|
||||
instrument_fastapi,
|
||||
update_websocket_connections,
|
||||
)
|
||||
from backend.util.retry import continuous_retry
|
||||
from backend.util.service import AppProcess
|
||||
from backend.util.settings import AppEnvironment, Config, Settings
|
||||
|
||||
@@ -34,10 +31,24 @@ settings = Settings()
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
manager = get_connection_manager()
|
||||
fut = asyncio.create_task(event_broadcaster(manager))
|
||||
fut.add_done_callback(lambda _: logger.info("Event broadcaster stopped"))
|
||||
yield
|
||||
# Prisma is needed to resolve graph_id from graph_exec_id on subscribe.
|
||||
await db.connect()
|
||||
# Eager connect to fail-fast if Redis is unreachable.
|
||||
await redis_client.get_redis_async()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
# Each cleanup is wrapped so one failure doesn't block the rest. The
|
||||
# Redis close silences asyncio's "Unclosed ClusterNode" GC warning at
|
||||
# interpreter shutdown.
|
||||
try:
|
||||
await redis_client.disconnect_async()
|
||||
except Exception:
|
||||
logger.warning("redis_client.disconnect_async failed", exc_info=True)
|
||||
try:
|
||||
await db.disconnect()
|
||||
except Exception:
|
||||
logger.warning("db.disconnect failed", exc_info=True)
|
||||
|
||||
|
||||
docs_url = "/docs" if settings.config.app_env == AppEnvironment.LOCAL else None
|
||||
@@ -61,31 +72,6 @@ def get_connection_manager():
|
||||
return _connection_manager
|
||||
|
||||
|
||||
@continuous_retry()
|
||||
async def event_broadcaster(manager: ConnectionManager):
|
||||
execution_bus = AsyncRedisExecutionEventBus()
|
||||
notification_bus = AsyncRedisNotificationEventBus()
|
||||
|
||||
try:
|
||||
|
||||
async def execution_worker():
|
||||
async for event in execution_bus.listen("*"):
|
||||
await manager.send_execution_update(event)
|
||||
|
||||
async def notification_worker():
|
||||
async for notification in notification_bus.listen("*"):
|
||||
await manager.send_notification(
|
||||
user_id=notification.user_id,
|
||||
payload=notification.payload,
|
||||
)
|
||||
|
||||
await asyncio.gather(execution_worker(), notification_worker())
|
||||
finally:
|
||||
# Ensure PubSub connections are closed on any exit to prevent leaks
|
||||
await execution_bus.close()
|
||||
await notification_bus.close()
|
||||
|
||||
|
||||
async def authenticate_websocket(websocket: WebSocket) -> str:
|
||||
if not settings.config.enable_auth:
|
||||
return DEFAULT_USER_ID
|
||||
@@ -297,6 +283,21 @@ async def websocket_router(
|
||||
).model_dump_json()
|
||||
)
|
||||
continue
|
||||
except ValueError as e:
|
||||
logger.warning(
|
||||
"Subscription rejected for user #%s on '%s': %s",
|
||||
user_id,
|
||||
message.method.value,
|
||||
e,
|
||||
)
|
||||
await websocket.send_text(
|
||||
WSMessage(
|
||||
method=WSMethod.ERROR,
|
||||
success=False,
|
||||
error=str(e),
|
||||
).model_dump_json()
|
||||
)
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error while handling '{message.method.value}' message "
|
||||
@@ -321,9 +322,13 @@ async def websocket_router(
|
||||
)
|
||||
|
||||
except WebSocketDisconnect:
|
||||
manager.disconnect_socket(websocket, user_id=user_id)
|
||||
logger.debug("WebSocket client disconnected")
|
||||
except Exception:
|
||||
logger.exception(f"Unexpected error in websocket_router for user #{user_id}")
|
||||
finally:
|
||||
# Always release subscription pumps + Redis connections, regardless of how
|
||||
# the loop exited — otherwise non-WebSocketDisconnect failures leak both.
|
||||
await manager.disconnect_socket(websocket, user_id=user_id)
|
||||
update_websocket_connections(user_id, -1)
|
||||
|
||||
|
||||
|
||||
@@ -44,9 +44,12 @@ def test_websocket_server_uses_cors_helper(mocker) -> None:
|
||||
"backend.api.ws_api.build_cors_params", return_value=cors_params
|
||||
)
|
||||
|
||||
with override_config(
|
||||
settings, "backend_cors_allow_origins", cors_params["allow_origins"]
|
||||
), override_config(settings, "app_env", AppEnvironment.LOCAL):
|
||||
with (
|
||||
override_config(
|
||||
settings, "backend_cors_allow_origins", cors_params["allow_origins"]
|
||||
),
|
||||
override_config(settings, "app_env", AppEnvironment.LOCAL),
|
||||
):
|
||||
WebsocketServer().run()
|
||||
|
||||
build_cors.assert_called_once_with(
|
||||
@@ -65,9 +68,12 @@ def test_websocket_server_uses_cors_helper(mocker) -> None:
|
||||
def test_websocket_server_blocks_localhost_in_production(mocker) -> None:
|
||||
mocker.patch("backend.api.ws_api.uvicorn.run")
|
||||
|
||||
with override_config(
|
||||
settings, "backend_cors_allow_origins", ["http://localhost:3000"]
|
||||
), override_config(settings, "app_env", AppEnvironment.PRODUCTION):
|
||||
with (
|
||||
override_config(
|
||||
settings, "backend_cors_allow_origins", ["http://localhost:3000"]
|
||||
),
|
||||
override_config(settings, "app_env", AppEnvironment.PRODUCTION),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
WebsocketServer().run()
|
||||
|
||||
@@ -290,7 +296,232 @@ async def test_handle_unsubscribe_missing_data(
|
||||
message=message,
|
||||
)
|
||||
|
||||
mock_manager._unsubscribe.assert_not_called()
|
||||
mock_manager.unsubscribe_graph_exec.assert_not_called()
|
||||
mock_websocket.send_text.assert_called_once()
|
||||
assert '"method":"error"' in mock_websocket.send_text.call_args[0][0]
|
||||
assert '"success":false' in mock_websocket.send_text.call_args[0][0]
|
||||
|
||||
|
||||
# ---------- Per-graph subscribe branch ----------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_subscribe_graph_execs_branch(
|
||||
mock_websocket: AsyncMock, mock_manager: AsyncMock
|
||||
) -> None:
|
||||
"""The SUBSCRIBE_GRAPH_EXECS branch must route to subscribe_graph_execs,
|
||||
not subscribe_graph_exec — regression guard for the aggregate channel."""
|
||||
message = WSMessage(
|
||||
method=WSMethod.SUBSCRIBE_GRAPH_EXECS,
|
||||
data={"graph_id": "graph-abc"},
|
||||
)
|
||||
mock_manager.subscribe_graph_execs.return_value = (
|
||||
"user-1|graph#graph-abc|executions"
|
||||
)
|
||||
|
||||
await handle_subscribe(
|
||||
connection_manager=cast(ConnectionManager, mock_manager),
|
||||
websocket=cast(WebSocket, mock_websocket),
|
||||
user_id="user-1",
|
||||
message=message,
|
||||
)
|
||||
|
||||
mock_manager.subscribe_graph_execs.assert_called_once_with(
|
||||
user_id="user-1",
|
||||
graph_id="graph-abc",
|
||||
websocket=mock_websocket,
|
||||
)
|
||||
mock_manager.subscribe_graph_exec.assert_not_called()
|
||||
mock_websocket.send_text.assert_called_once()
|
||||
assert (
|
||||
'"method":"subscribe_graph_executions"'
|
||||
in mock_websocket.send_text.call_args[0][0]
|
||||
)
|
||||
assert '"success":true' in mock_websocket.send_text.call_args[0][0]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_subscribe_rejects_unrelated_method(
|
||||
mock_websocket: AsyncMock, mock_manager: AsyncMock
|
||||
) -> None:
|
||||
"""handle_subscribe must raise for methods that aren't SUBSCRIBE_*."""
|
||||
import pytest as _pytest
|
||||
|
||||
message = WSMessage(
|
||||
method=WSMethod.HEARTBEAT,
|
||||
data={"graph_exec_id": "x"},
|
||||
)
|
||||
|
||||
with _pytest.raises(ValueError):
|
||||
await handle_subscribe(
|
||||
connection_manager=cast(ConnectionManager, mock_manager),
|
||||
websocket=cast(WebSocket, mock_websocket),
|
||||
user_id="user-1",
|
||||
message=message,
|
||||
)
|
||||
|
||||
|
||||
# ---------- authenticate_websocket branches ----------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authenticate_websocket_missing_token_closes_4001(mocker) -> None:
|
||||
from backend.api.ws_api import authenticate_websocket
|
||||
|
||||
mocker.patch.object(settings.config, "enable_auth", True)
|
||||
ws = AsyncMock(spec=WebSocket)
|
||||
ws.query_params = {}
|
||||
|
||||
user_id = await authenticate_websocket(ws)
|
||||
|
||||
ws.close.assert_awaited_once()
|
||||
assert ws.close.call_args.kwargs["code"] == 4001
|
||||
assert user_id == ""
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authenticate_websocket_invalid_token_closes_4003(mocker) -> None:
|
||||
from backend.api.ws_api import authenticate_websocket
|
||||
|
||||
mocker.patch.object(settings.config, "enable_auth", True)
|
||||
mocker.patch(
|
||||
"backend.api.ws_api.parse_jwt_token", side_effect=ValueError("bad token")
|
||||
)
|
||||
ws = AsyncMock(spec=WebSocket)
|
||||
ws.query_params = {"token": "abc"}
|
||||
|
||||
user_id = await authenticate_websocket(ws)
|
||||
|
||||
ws.close.assert_awaited_once()
|
||||
assert ws.close.call_args.kwargs["code"] == 4003
|
||||
assert user_id == ""
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authenticate_websocket_missing_sub_closes_4002(mocker) -> None:
|
||||
from backend.api.ws_api import authenticate_websocket
|
||||
|
||||
mocker.patch.object(settings.config, "enable_auth", True)
|
||||
mocker.patch("backend.api.ws_api.parse_jwt_token", return_value={"not_sub": "x"})
|
||||
ws = AsyncMock(spec=WebSocket)
|
||||
ws.query_params = {"token": "abc"}
|
||||
|
||||
user_id = await authenticate_websocket(ws)
|
||||
|
||||
ws.close.assert_awaited_once()
|
||||
assert ws.close.call_args.kwargs["code"] == 4002
|
||||
assert user_id == ""
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authenticate_websocket_happy_path_returns_sub(mocker) -> None:
|
||||
from backend.api.ws_api import authenticate_websocket
|
||||
|
||||
mocker.patch.object(settings.config, "enable_auth", True)
|
||||
mocker.patch("backend.api.ws_api.parse_jwt_token", return_value={"sub": "user-X"})
|
||||
ws = AsyncMock(spec=WebSocket)
|
||||
ws.query_params = {"token": "abc"}
|
||||
|
||||
user_id = await authenticate_websocket(ws)
|
||||
|
||||
assert user_id == "user-X"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authenticate_websocket_auth_disabled_returns_default(mocker) -> None:
|
||||
from backend.api.ws_api import authenticate_websocket
|
||||
|
||||
mocker.patch.object(settings.config, "enable_auth", False)
|
||||
ws = AsyncMock(spec=WebSocket)
|
||||
ws.query_params = {}
|
||||
|
||||
user_id = await authenticate_websocket(ws)
|
||||
|
||||
assert user_id == DEFAULT_USER_ID
|
||||
|
||||
|
||||
# ---------- get_connection_manager singleton ----------
|
||||
|
||||
|
||||
def test_get_connection_manager_singleton() -> None:
|
||||
"""Repeated calls must return the same ConnectionManager — the WS router
|
||||
depends on a single process-wide subscription table."""
|
||||
import backend.api.ws_api as ws_api
|
||||
|
||||
ws_api._connection_manager = None
|
||||
a = ws_api.get_connection_manager()
|
||||
b = ws_api.get_connection_manager()
|
||||
assert a is b
|
||||
assert isinstance(a, ConnectionManager)
|
||||
|
||||
|
||||
# ---------- Lifespan: Prisma connect/disconnect ----------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lifespan_connects_and_disconnects_prisma(mocker) -> None:
|
||||
"""Lifespan must both connect() and disconnect() db — the subscribe path
|
||||
resolves graph_id via Prisma so a missing connect() is the regression bug."""
|
||||
from fastapi import FastAPI
|
||||
|
||||
from backend.api.ws_api import lifespan
|
||||
|
||||
mock_db = mocker.patch("backend.api.ws_api.db")
|
||||
mock_db.connect = AsyncMock()
|
||||
mock_db.disconnect = AsyncMock()
|
||||
|
||||
dummy_app = FastAPI()
|
||||
async with lifespan(dummy_app):
|
||||
mock_db.connect.assert_awaited_once()
|
||||
mock_db.disconnect.assert_not_called()
|
||||
mock_db.disconnect.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lifespan_still_disconnects_on_exception(mocker) -> None:
|
||||
"""If the app raises inside the yield, Prisma must still disconnect."""
|
||||
from fastapi import FastAPI
|
||||
|
||||
from backend.api.ws_api import lifespan
|
||||
|
||||
mock_db = mocker.patch("backend.api.ws_api.db")
|
||||
mock_db.connect = AsyncMock()
|
||||
mock_db.disconnect = AsyncMock()
|
||||
|
||||
dummy_app = FastAPI()
|
||||
|
||||
class _Boom(Exception):
|
||||
pass
|
||||
|
||||
with pytest.raises(_Boom):
|
||||
async with lifespan(dummy_app):
|
||||
raise _Boom()
|
||||
|
||||
mock_db.disconnect.assert_awaited_once()
|
||||
|
||||
|
||||
# ---------- Health endpoint ----------
|
||||
|
||||
|
||||
def test_health_endpoint_returns_ok() -> None:
|
||||
# TestClient triggers lifespan — stub it out so Prisma isn't hit.
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
import backend.api.ws_api as ws_api
|
||||
|
||||
@asynccontextmanager
|
||||
async def _noop_lifespan(app):
|
||||
yield
|
||||
|
||||
# Replace the app-level lifespan temporarily.
|
||||
real_router_lifespan = ws_api.app.router.lifespan_context
|
||||
ws_api.app.router.lifespan_context = _noop_lifespan
|
||||
try:
|
||||
with TestClient(ws_api.app) as client:
|
||||
r = client.get("/")
|
||||
assert r.status_code == 200
|
||||
assert r.json() == {"status": "healthy"}
|
||||
finally:
|
||||
ws_api.app.router.lifespan_context = real_router_lifespan
|
||||
|
||||
@@ -38,6 +38,7 @@ def main(**kwargs):
|
||||
|
||||
from backend.api.rest_api import AgentServer
|
||||
from backend.api.ws_api import WebsocketServer
|
||||
from backend.copilot.bot.app import CoPilotChatBridge
|
||||
from backend.copilot.executor.manager import CoPilotExecutor
|
||||
from backend.data.db_manager import DatabaseManager
|
||||
from backend.executor import ExecutionManager, Scheduler
|
||||
@@ -53,6 +54,7 @@ def main(**kwargs):
|
||||
AgentServer(),
|
||||
ExecutionManager(),
|
||||
CoPilotExecutor(),
|
||||
CoPilotChatBridge(),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ import logging
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from pydantic import field_validator
|
||||
from typing_extensions import TypedDict # Needed for Python <3.12 compatibility
|
||||
|
||||
from backend.blocks._base import (
|
||||
@@ -17,6 +18,7 @@ from backend.blocks._base import (
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.copilot.permissions import (
|
||||
DISABLED_LEGACY_TOOL_NAMES,
|
||||
CopilotPermissions,
|
||||
ToolName,
|
||||
all_known_tool_names,
|
||||
@@ -198,6 +200,13 @@ class AutoPilotBlock(Block):
|
||||
# timeouts internally; wrapping with asyncio.timeout corrupts the
|
||||
# SDK's internal stream (see service.py CRITICAL comment).
|
||||
|
||||
@field_validator("tools", mode="before")
|
||||
@classmethod
|
||||
def strip_disabled_legacy_tools(cls, tools: Any) -> Any:
|
||||
if not isinstance(tools, list):
|
||||
return tools
|
||||
return [tool for tool in tools if tool not in DISABLED_LEGACY_TOOL_NAMES]
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
"""Output schema for the AutoPilot block."""
|
||||
|
||||
|
||||
@@ -62,6 +62,14 @@ class TestBuildAndValidatePermissions:
|
||||
with pytest.raises(ValidationError, match="not_a_real_tool"):
|
||||
_make_input(tools=["not_a_real_tool"])
|
||||
|
||||
async def test_disabled_legacy_tool_is_accepted_and_removed(self):
|
||||
inp = _make_input(tools=["ask_question", "run_block"])
|
||||
result = await _build_and_validate_permissions(inp)
|
||||
|
||||
assert inp.tools == ["run_block"]
|
||||
assert isinstance(result, CopilotPermissions)
|
||||
assert result.tools == ["run_block"]
|
||||
|
||||
async def test_valid_block_name_accepted(self):
|
||||
mock_block_cls = MagicMock()
|
||||
mock_block_cls.return_value.name = "HTTP Request"
|
||||
|
||||
1
autogpt_platform/backend/backend/copilot/bot/AGENTS.md
Normal file
1
autogpt_platform/backend/backend/copilot/bot/AGENTS.md
Normal file
@@ -0,0 +1 @@
|
||||
@README.md
|
||||
1
autogpt_platform/backend/backend/copilot/bot/CLAUDE.md
Normal file
1
autogpt_platform/backend/backend/copilot/bot/CLAUDE.md
Normal file
@@ -0,0 +1 @@
|
||||
@README.md
|
||||
79
autogpt_platform/backend/backend/copilot/bot/README.md
Normal file
79
autogpt_platform/backend/backend/copilot/bot/README.md
Normal file
@@ -0,0 +1,79 @@
|
||||
# CoPilot Bot
|
||||
|
||||
Multi-platform chat bot that bridges AutoPilot to Discord (and later Telegram, Slack, etc).
|
||||
|
||||
## Running
|
||||
|
||||
```bash
|
||||
# As a standalone service
|
||||
poetry run copilot-bot
|
||||
|
||||
# Or auto-start alongside the rest of the platform
|
||||
poetry run app # starts the bot too if AUTOPILOT_BOT_DISCORD_TOKEN is set
|
||||
```
|
||||
|
||||
## Required environment variables
|
||||
|
||||
See `backend/.env.default` for the full list with documentation. Minimum setup:
|
||||
|
||||
| Variable | Purpose |
|
||||
|----------|---------|
|
||||
| `AUTOPILOT_BOT_DISCORD_TOKEN` | Discord bot token — enables the Discord adapter |
|
||||
| `FRONTEND_BASE_URL` | Frontend base URL for link confirmation pages (shared with the rest of the backend) |
|
||||
| `REDIS_HOST` / `REDIS_PORT` | Session + thread subscription state + copilot stream subscription (inherited from the shared backend config) |
|
||||
| `PLATFORMLINKINGMANAGER_HOST` | DNS name of the `PlatformLinkingManager` service pod (cluster-internal RPC) |
|
||||
|
||||
## Architecture
|
||||
|
||||
```
|
||||
bot/
|
||||
├── app.py # CoPilotChatBridge(AppService), adapter factory, outbound @expose RPC
|
||||
├── config.py # Shared (platform-agnostic) config
|
||||
├── handler.py # Core logic: routing, linking, batched streaming
|
||||
├── bot_backend.py # Thin facade over PlatformLinkingManagerClient + stream_registry
|
||||
├── text.py # Text splitting + batch formatting
|
||||
├── threads.py # Redis-backed thread subscription tracking
|
||||
└── adapters/
|
||||
├── base.py # PlatformAdapter interface + MessageContext
|
||||
└── discord/
|
||||
├── adapter.py # Gateway connection, events, sends, thread creation
|
||||
├── commands.py # Slash commands (/setup, /help, /unlink)
|
||||
└── config.py # Discord token + platform limits
|
||||
```
|
||||
|
||||
**Locality rule:** anything platform-specific lives under `adapters/<platform>/`.
|
||||
The only file that names specific platforms is `app.py`, which is the factory
|
||||
that decides which adapters to instantiate based on which tokens are set.
|
||||
|
||||
## How messaging works
|
||||
|
||||
1. User mentions the bot in a channel
|
||||
2. Adapter's `on_message` handler fires, constructs a `MessageContext`, passes
|
||||
it to the shared `MessageHandler`
|
||||
3. Handler:
|
||||
- Checks if the user/server is linked (via `bot_backend`)
|
||||
- If not linked → sends a "Link Account" button prompt
|
||||
- If linked → creates a thread (for channels) or uses the existing thread/DM
|
||||
- Marks the thread as subscribed in Redis (7-day TTL)
|
||||
- Streams the AutoPilot response back, chunked at the adapter's
|
||||
`chunk_flush_at` boundary
|
||||
4. Messages that arrive while a stream is running get batched and sent as a
|
||||
single follow-up turn once the current stream ends
|
||||
|
||||
## Adding a new platform
|
||||
|
||||
1. Create `adapters/<platform>/` with `adapter.py`, `commands.py` (if the
|
||||
platform has commands), and `config.py`
|
||||
2. `adapter.py` subclasses `PlatformAdapter` and implements all its abstract
|
||||
methods — `max_message_length`, `chunk_flush_at`, `send_message`,
|
||||
`send_link`, `create_thread`, etc.
|
||||
3. `config.py` declares the platform's env vars and any platform-specific
|
||||
numbers (message limits, token name, etc.)
|
||||
4. Add two lines to `app.py::_build_adapters`:
|
||||
```python
|
||||
if <platform>_config.BOT_TOKEN:
|
||||
adapters.append(<Platform>Adapter(api))
|
||||
```
|
||||
|
||||
The core handler, text utilities, thread tracking, and platform API all stay
|
||||
untouched.
|
||||
19
autogpt_platform/backend/backend/copilot/bot/__main__.py
Normal file
19
autogpt_platform/backend/backend/copilot/bot/__main__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
"""Entry point for running the CoPilot Chat Bridge service.
|
||||
|
||||
Usage:
|
||||
poetry run copilot-bot
|
||||
python -m backend.copilot.bot
|
||||
"""
|
||||
|
||||
from backend.app import run_processes
|
||||
|
||||
from .app import CoPilotChatBridge
|
||||
|
||||
|
||||
def main():
|
||||
"""Run the CoPilot Chat Bridge service."""
|
||||
run_processes(CoPilotChatBridge())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
110
autogpt_platform/backend/backend/copilot/bot/adapters/base.py
Normal file
110
autogpt_platform/backend/backend/copilot/bot/adapters/base.py
Normal file
@@ -0,0 +1,110 @@
|
||||
"""Abstract base for platform adapters.
|
||||
|
||||
Each chat platform (Discord, Telegram, Slack, etc.) implements this interface.
|
||||
The core bot logic in handler.py is platform-agnostic — it only speaks through
|
||||
these methods.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Awaitable, Callable, Literal, Optional
|
||||
|
||||
# Callback signature: (ctx, adapter) -> awaitable None
|
||||
MessageCallback = Callable[["MessageContext", "PlatformAdapter"], Awaitable[None]]
|
||||
|
||||
# Where the message came from:
|
||||
# - "dm" — 1:1 conversation, reply in-place
|
||||
# - "channel" — public channel, bot was @mentioned, create a thread to respond
|
||||
# - "thread" — ongoing thread conversation, reply in-place
|
||||
ChannelType = Literal["dm", "channel", "thread"]
|
||||
|
||||
|
||||
@dataclass
|
||||
class MessageContext:
|
||||
"""Everything the core handler needs to know about an incoming message."""
|
||||
|
||||
platform: str
|
||||
channel_type: ChannelType
|
||||
server_id: Optional[str]
|
||||
channel_id: str # DM channel ID / parent channel ID / thread ID
|
||||
message_id: str # the incoming message itself — used to create threads from it
|
||||
user_id: str
|
||||
username: str
|
||||
text: str # with bot mentions stripped
|
||||
|
||||
@property
|
||||
def is_dm(self) -> bool:
|
||||
return self.channel_type == "dm"
|
||||
|
||||
|
||||
class PlatformAdapter(ABC):
|
||||
"""Interface that each chat platform must implement."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def platform_name(self) -> str: ...
|
||||
|
||||
@abstractmethod
|
||||
def on_message(self, callback: MessageCallback) -> None: ...
|
||||
|
||||
@abstractmethod
|
||||
async def start(self) -> None: ...
|
||||
|
||||
@abstractmethod
|
||||
async def stop(self) -> None: ...
|
||||
|
||||
@abstractmethod
|
||||
async def send_message(self, channel_id: str, text: str) -> None: ...
|
||||
|
||||
@abstractmethod
|
||||
async def send_link(
|
||||
self, channel_id: str, text: str, link_label: str, link_url: str
|
||||
) -> None:
|
||||
"""Send a message with a clickable link presented as a button/CTA.
|
||||
|
||||
Platforms without native button support should fall back to rendering
|
||||
the URL inline in the text.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def send_reply(
|
||||
self, channel_id: str, text: str, reply_to_message_id: str
|
||||
) -> None: ...
|
||||
|
||||
@abstractmethod
|
||||
async def send_ephemeral(
|
||||
self, channel_id: str, user_id: str, text: str
|
||||
) -> None: ...
|
||||
|
||||
@abstractmethod
|
||||
async def start_typing(self, channel_id: str) -> None: ...
|
||||
|
||||
@abstractmethod
|
||||
async def stop_typing(self, channel_id: str) -> None: ...
|
||||
|
||||
@abstractmethod
|
||||
async def create_thread(
|
||||
self, channel_id: str, message_id: str, name: str
|
||||
) -> Optional[str]:
|
||||
"""Create a thread from a message. Returns the thread ID, or None if
|
||||
the platform doesn't support threads or creation failed.
|
||||
"""
|
||||
...
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def max_message_length(self) -> int:
|
||||
"""Hard platform cap on a single message's content length."""
|
||||
...
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def chunk_flush_at(self) -> int:
|
||||
"""Flush the streaming buffer once it reaches this length.
|
||||
|
||||
Should be slightly under max_message_length to leave headroom for
|
||||
any trailing content that the splitter might pull into the current
|
||||
chunk.
|
||||
"""
|
||||
...
|
||||
@@ -0,0 +1,209 @@
|
||||
"""Discord adapter — connects to the Discord Gateway and forwards messages.
|
||||
|
||||
Platform-specific machinery only: Gateway connection, message event handling,
|
||||
thread creation, typing, button rendering. All platform-agnostic logic lives
|
||||
in the core handler. Slash commands live in commands.py.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
import discord
|
||||
from discord import app_commands
|
||||
|
||||
from backend.copilot.bot.bot_backend import BotBackend
|
||||
|
||||
from ..base import ChannelType, MessageCallback, MessageContext, PlatformAdapter
|
||||
from . import commands, config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DiscordAdapter(PlatformAdapter):
|
||||
def __init__(self, api: BotBackend):
|
||||
intents = discord.Intents.default()
|
||||
intents.message_content = True
|
||||
# AutoPilot output is untrusted w.r.t. mentions — suppress @everyone,
|
||||
# role, and user pings the LLM might produce. Client-level default
|
||||
# applies to every send() + reply() below.
|
||||
self._client = discord.Client(
|
||||
intents=intents,
|
||||
allowed_mentions=discord.AllowedMentions.none(),
|
||||
)
|
||||
self._tree = app_commands.CommandTree(self._client)
|
||||
self._api = api
|
||||
self._on_message_callback: Optional[MessageCallback] = None
|
||||
self._commands_synced = False
|
||||
|
||||
self._register_events()
|
||||
commands.register(self._tree, self._api)
|
||||
|
||||
@property
|
||||
def platform_name(self) -> str:
|
||||
return "discord"
|
||||
|
||||
@property
|
||||
def max_message_length(self) -> int:
|
||||
return config.MAX_MESSAGE_LENGTH
|
||||
|
||||
@property
|
||||
def chunk_flush_at(self) -> int:
|
||||
return config.CHUNK_FLUSH_AT
|
||||
|
||||
def on_message(self, callback: MessageCallback) -> None:
|
||||
self._on_message_callback = callback
|
||||
|
||||
async def start(self) -> None:
|
||||
await self._client.start(config.get_bot_token())
|
||||
|
||||
async def stop(self) -> None:
|
||||
if not self._client.is_closed():
|
||||
await self._client.close()
|
||||
|
||||
async def _resolve_channel(self, channel_id: str):
|
||||
"""Return the channel for ``channel_id``, falling back to a REST fetch.
|
||||
|
||||
``Client.get_channel`` only reads the in-memory cache, so it misses
|
||||
threads the bot hasn't seen since its last restart. Fall back to
|
||||
``fetch_channel`` (REST) so long-lived threads keep working.
|
||||
"""
|
||||
channel = self._client.get_channel(int(channel_id))
|
||||
if channel is not None:
|
||||
return channel
|
||||
try:
|
||||
return await self._client.fetch_channel(int(channel_id))
|
||||
except (discord.NotFound, discord.Forbidden, discord.HTTPException):
|
||||
logger.warning("Channel %s not found or inaccessible", channel_id)
|
||||
return None
|
||||
|
||||
async def send_message(self, channel_id: str, text: str) -> None:
|
||||
channel = await self._resolve_channel(channel_id)
|
||||
if channel and isinstance(channel, discord.abc.Messageable):
|
||||
# tts=False is the default but we pin it explicitly — AutoPilot
|
||||
# output is untrusted and should never blast through voice.
|
||||
await channel.send(text, tts=False)
|
||||
|
||||
async def send_link(
|
||||
self, channel_id: str, text: str, link_label: str, link_url: str
|
||||
) -> None:
|
||||
channel = await self._resolve_channel(channel_id)
|
||||
if channel is None or not isinstance(channel, discord.abc.Messageable):
|
||||
return
|
||||
view = discord.ui.View()
|
||||
view.add_item(
|
||||
discord.ui.Button(
|
||||
style=discord.ButtonStyle.link,
|
||||
label=link_label[:80], # Discord button label max
|
||||
url=link_url,
|
||||
)
|
||||
)
|
||||
await channel.send(text, view=view, tts=False)
|
||||
|
||||
async def send_reply(
|
||||
self, channel_id: str, text: str, reply_to_message_id: str
|
||||
) -> None:
|
||||
channel = await self._resolve_channel(channel_id)
|
||||
if not channel or not isinstance(channel, discord.abc.Messageable):
|
||||
return
|
||||
try:
|
||||
msg = await channel.fetch_message(int(reply_to_message_id))
|
||||
await msg.reply(text, tts=False)
|
||||
except discord.NotFound:
|
||||
await channel.send(text, tts=False)
|
||||
|
||||
async def send_ephemeral(self, channel_id: str, user_id: str, text: str) -> None:
|
||||
# Ephemeral messages are only possible via interaction responses.
|
||||
# Fall back to a normal message for non-interaction contexts.
|
||||
await self.send_message(channel_id, text)
|
||||
|
||||
async def start_typing(self, channel_id: str) -> None:
|
||||
channel = await self._resolve_channel(channel_id)
|
||||
if channel and isinstance(channel, discord.abc.Messageable):
|
||||
await channel.typing()
|
||||
|
||||
async def stop_typing(self, channel_id: str) -> None:
|
||||
pass # Discord typing auto-expires after ~10s
|
||||
|
||||
async def create_thread(
|
||||
self, channel_id: str, message_id: str, name: str
|
||||
) -> Optional[str]:
|
||||
channel = await self._resolve_channel(channel_id)
|
||||
if channel is None or not isinstance(channel, discord.TextChannel):
|
||||
logger.warning("Cannot create thread in non-text channel %s", channel_id)
|
||||
return None
|
||||
try:
|
||||
msg = await channel.fetch_message(int(message_id))
|
||||
thread = await msg.create_thread(name=name[:100])
|
||||
return str(thread.id)
|
||||
except discord.HTTPException:
|
||||
logger.exception("Failed to create thread in channel %s", channel_id)
|
||||
return None
|
||||
|
||||
# -- Internal --
|
||||
|
||||
def _register_events(self) -> None:
|
||||
@self._client.event
|
||||
async def on_ready() -> None:
|
||||
logger.info(f"Discord bot connected as {self._client.user}")
|
||||
# Sync slash commands once per process — on_ready fires on every
|
||||
# gateway reconnect, but the command tree only needs uploading once.
|
||||
if self._commands_synced:
|
||||
return
|
||||
try:
|
||||
synced = await self._tree.sync()
|
||||
self._commands_synced = True
|
||||
logger.info(f"Synced {len(synced)} slash commands")
|
||||
except Exception:
|
||||
logger.exception("Failed to sync slash commands")
|
||||
|
||||
@self._client.event
|
||||
async def on_message(message: discord.Message) -> None:
|
||||
if message.author.bot:
|
||||
return
|
||||
if self._on_message_callback is None:
|
||||
return
|
||||
|
||||
channel_type = self._channel_type(message)
|
||||
|
||||
# Channels require an explicit @mention; DMs and threads always forward
|
||||
# (handler checks thread subscription).
|
||||
if channel_type == "channel" and not self._is_mentioned(message):
|
||||
return
|
||||
|
||||
ctx = MessageContext(
|
||||
platform="discord",
|
||||
channel_type=channel_type,
|
||||
server_id=str(message.guild.id) if message.guild else None,
|
||||
channel_id=str(message.channel.id),
|
||||
message_id=str(message.id),
|
||||
user_id=str(message.author.id),
|
||||
username=message.author.display_name,
|
||||
text=self._strip_mentions(message),
|
||||
)
|
||||
await self._on_message_callback(ctx, self)
|
||||
|
||||
def _is_mentioned(self, message: discord.Message) -> bool:
|
||||
if message.guild is None:
|
||||
return True # DMs always count
|
||||
return bool(self._client.user and self._client.user.mentioned_in(message))
|
||||
|
||||
@staticmethod
|
||||
def _channel_type(message: discord.Message) -> ChannelType:
|
||||
if message.guild is None:
|
||||
return "dm"
|
||||
if isinstance(message.channel, discord.Thread):
|
||||
return "thread"
|
||||
return "channel"
|
||||
|
||||
def _strip_mentions(self, message: discord.Message) -> str:
|
||||
"""Strip the bot's own mention; replace other users' raw mention
|
||||
tokens with `@displayname` so the LLM keeps the context.
|
||||
"""
|
||||
text = message.content
|
||||
bot_id = self._client.user.id if self._client.user else None
|
||||
for user in message.mentions:
|
||||
raw_tokens = (f"<@{user.id}>", f"<@!{user.id}>")
|
||||
replacement = "" if user.id == bot_id else f"@{user.display_name}"
|
||||
for token in raw_tokens:
|
||||
text = text.replace(token, replacement)
|
||||
return text.strip()
|
||||
@@ -0,0 +1,259 @@
|
||||
"""Tests for DiscordAdapter helpers that don't need a live gateway."""
|
||||
|
||||
from typing import cast
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import discord
|
||||
import pytest
|
||||
|
||||
from backend.copilot.bot.adapters.discord.adapter import DiscordAdapter
|
||||
|
||||
|
||||
def _bare_adapter(bot_id: int | None = 1000) -> tuple[DiscordAdapter, MagicMock]:
|
||||
"""Build a DiscordAdapter without going through __init__ (which spins up
|
||||
discord.py internals). Returns the adapter alongside the MagicMock that
|
||||
stands in for ``_client`` — tests reach into the mock directly for
|
||||
per-method stubbing.
|
||||
"""
|
||||
adapter = DiscordAdapter.__new__(DiscordAdapter)
|
||||
client = MagicMock()
|
||||
client.user = MagicMock(id=bot_id) if bot_id is not None else None
|
||||
adapter._client = cast(discord.Client, client)
|
||||
adapter._on_message_callback = None
|
||||
adapter._commands_synced = False
|
||||
return adapter, client
|
||||
|
||||
|
||||
def _mention(user_id: int, display_name: str) -> MagicMock:
|
||||
user = MagicMock()
|
||||
user.id = user_id
|
||||
user.display_name = display_name
|
||||
return user
|
||||
|
||||
|
||||
def _message(content: str, mentions: list[MagicMock]) -> MagicMock:
|
||||
msg = MagicMock()
|
||||
msg.content = content
|
||||
msg.mentions = mentions
|
||||
return msg
|
||||
|
||||
|
||||
# ── _strip_mentions ────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestStripMentions:
|
||||
def test_strips_only_bot_mention(self):
|
||||
adapter, _ = _bare_adapter(bot_id=1000)
|
||||
bot = _mention(1000, "AutoPilot")
|
||||
alice = _mention(2000, "Alice")
|
||||
msg = _message(
|
||||
"<@1000> please summarise what <@2000> said",
|
||||
mentions=[bot, alice],
|
||||
)
|
||||
|
||||
assert adapter._strip_mentions(msg) == "please summarise what @Alice said"
|
||||
|
||||
def test_handles_nickname_style_tokens(self):
|
||||
adapter, _ = _bare_adapter(bot_id=1000)
|
||||
bot = _mention(1000, "AutoPilot")
|
||||
alice = _mention(2000, "Alice")
|
||||
msg = _message("<@!1000> ping <@!2000>", mentions=[bot, alice])
|
||||
|
||||
assert adapter._strip_mentions(msg) == "ping @Alice"
|
||||
|
||||
def test_no_bot_user_leaves_all_mentions_as_names(self):
|
||||
adapter, _ = _bare_adapter(bot_id=None)
|
||||
alice = _mention(2000, "Alice")
|
||||
msg = _message("hi <@2000>", mentions=[alice])
|
||||
|
||||
assert adapter._strip_mentions(msg) == "hi @Alice"
|
||||
|
||||
def test_message_without_mentions_is_trimmed(self):
|
||||
adapter, _ = _bare_adapter(bot_id=1000)
|
||||
msg = _message(" hello world ", mentions=[])
|
||||
|
||||
assert adapter._strip_mentions(msg) == "hello world"
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"content,expected",
|
||||
[
|
||||
("<@1000>", ""),
|
||||
("<@!1000>", ""),
|
||||
("<@1000> hi", "hi"),
|
||||
("hi <@1000>", "hi"),
|
||||
],
|
||||
)
|
||||
def test_bot_only_variants(self, content: str, expected: str):
|
||||
adapter, _ = _bare_adapter(bot_id=1000)
|
||||
bot = _mention(1000, "AutoPilot")
|
||||
msg = _message(content, mentions=[bot])
|
||||
|
||||
assert adapter._strip_mentions(msg) == expected
|
||||
|
||||
|
||||
# ── _channel_type ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestChannelType:
|
||||
def test_dm_has_no_guild(self):
|
||||
msg = MagicMock()
|
||||
msg.guild = None
|
||||
assert DiscordAdapter._channel_type(msg) == "dm"
|
||||
|
||||
def test_thread_inside_guild(self):
|
||||
msg = MagicMock()
|
||||
msg.guild = MagicMock()
|
||||
msg.channel = MagicMock(spec=discord.Thread)
|
||||
assert DiscordAdapter._channel_type(msg) == "thread"
|
||||
|
||||
def test_regular_channel_inside_guild(self):
|
||||
msg = MagicMock()
|
||||
msg.guild = MagicMock()
|
||||
msg.channel = MagicMock()
|
||||
assert DiscordAdapter._channel_type(msg) == "channel"
|
||||
|
||||
|
||||
# ── _is_mentioned ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestIsMentioned:
|
||||
def test_dm_always_counts_as_mentioned(self):
|
||||
adapter, _ = _bare_adapter(bot_id=1000)
|
||||
msg = MagicMock()
|
||||
msg.guild = None
|
||||
assert adapter._is_mentioned(msg) is True
|
||||
|
||||
def test_guild_requires_explicit_mention(self):
|
||||
adapter, client = _bare_adapter(bot_id=1000)
|
||||
msg = MagicMock()
|
||||
msg.guild = MagicMock()
|
||||
client.user.mentioned_in.return_value = False
|
||||
assert adapter._is_mentioned(msg) is False
|
||||
|
||||
def test_guild_with_mention_passes(self):
|
||||
adapter, client = _bare_adapter(bot_id=1000)
|
||||
msg = MagicMock()
|
||||
msg.guild = MagicMock()
|
||||
client.user.mentioned_in.return_value = True
|
||||
assert adapter._is_mentioned(msg) is True
|
||||
|
||||
def test_no_bot_user_treats_guild_mention_as_false(self):
|
||||
adapter, _ = _bare_adapter(bot_id=None)
|
||||
msg = MagicMock()
|
||||
msg.guild = MagicMock()
|
||||
assert adapter._is_mentioned(msg) is False
|
||||
|
||||
|
||||
# ── _resolve_channel ───────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestResolveChannel:
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_hit_skips_rest_fetch(self):
|
||||
adapter, client = _bare_adapter()
|
||||
cached = MagicMock()
|
||||
client.get_channel.return_value = cached
|
||||
client.fetch_channel = AsyncMock()
|
||||
|
||||
result = await adapter._resolve_channel("123")
|
||||
|
||||
assert result is cached
|
||||
client.fetch_channel.assert_not_awaited()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_miss_falls_back_to_rest(self):
|
||||
adapter, client = _bare_adapter()
|
||||
fetched = MagicMock()
|
||||
client.get_channel.return_value = None
|
||||
client.fetch_channel = AsyncMock(return_value=fetched)
|
||||
|
||||
result = await adapter._resolve_channel("123")
|
||||
|
||||
assert result is fetched
|
||||
client.fetch_channel.assert_awaited_once_with(123)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"exc",
|
||||
[
|
||||
discord.NotFound(MagicMock(status=404), "gone"),
|
||||
discord.Forbidden(MagicMock(status=403), "nope"),
|
||||
discord.HTTPException(MagicMock(status=500), "boom"),
|
||||
],
|
||||
)
|
||||
async def test_rest_failure_returns_none(self, exc: Exception):
|
||||
adapter, client = _bare_adapter()
|
||||
client.get_channel.return_value = None
|
||||
client.fetch_channel = AsyncMock(side_effect=exc)
|
||||
|
||||
assert await adapter._resolve_channel("123") is None
|
||||
|
||||
|
||||
# ── send_message / send_reply / send_link ──────────────────────────────
|
||||
|
||||
|
||||
class TestSendMethods:
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_message_pins_tts_false(self):
|
||||
adapter, client = _bare_adapter()
|
||||
channel = MagicMock(spec=discord.TextChannel)
|
||||
channel.send = AsyncMock()
|
||||
client.get_channel.return_value = channel
|
||||
|
||||
await adapter.send_message("123", "hi")
|
||||
|
||||
channel.send.assert_awaited_once_with("hi", tts=False)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_message_silently_drops_when_channel_missing(self):
|
||||
adapter, client = _bare_adapter()
|
||||
client.get_channel.return_value = None
|
||||
client.fetch_channel = AsyncMock(
|
||||
side_effect=discord.NotFound(MagicMock(status=404), "gone")
|
||||
)
|
||||
# Should not raise even though there's nothing to send to.
|
||||
await adapter.send_message("123", "hi")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_link_attaches_button_and_pins_tts(self):
|
||||
adapter, client = _bare_adapter()
|
||||
channel = MagicMock(spec=discord.TextChannel)
|
||||
channel.send = AsyncMock()
|
||||
client.get_channel.return_value = channel
|
||||
|
||||
await adapter.send_link("123", "click me", "Open", "https://example.com")
|
||||
|
||||
assert channel.send.await_count == 1
|
||||
kwargs = channel.send.await_args.kwargs
|
||||
assert kwargs["tts"] is False
|
||||
view = kwargs["view"]
|
||||
assert any(
|
||||
getattr(c, "url", None) == "https://example.com" for c in view.children
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_reply_falls_back_to_send_when_message_missing(self):
|
||||
adapter, client = _bare_adapter()
|
||||
channel = MagicMock(spec=discord.TextChannel)
|
||||
channel.send = AsyncMock()
|
||||
channel.fetch_message = AsyncMock(
|
||||
side_effect=discord.NotFound(MagicMock(status=404), "gone")
|
||||
)
|
||||
client.get_channel.return_value = channel
|
||||
|
||||
await adapter.send_reply("123", "hello", "999")
|
||||
|
||||
channel.send.assert_awaited_once_with("hello", tts=False)
|
||||
|
||||
|
||||
# ── properties ─────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestProperties:
|
||||
def test_platform_name_is_discord(self):
|
||||
adapter, _ = _bare_adapter()
|
||||
assert adapter.platform_name == "discord"
|
||||
|
||||
def test_chunk_flush_at_is_under_message_limit(self):
|
||||
adapter, _ = _bare_adapter()
|
||||
assert adapter.chunk_flush_at < adapter.max_message_length
|
||||
@@ -0,0 +1,134 @@
|
||||
"""Discord slash command handlers.
|
||||
|
||||
Registered on the bot's CommandTree at startup. All responses are ephemeral
|
||||
(visible only to the invoking user) to keep channels clean and to keep link
|
||||
URLs private.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
import discord
|
||||
from discord import app_commands
|
||||
|
||||
from backend.copilot.bot.bot_backend import BotBackend
|
||||
from backend.util.exceptions import LinkAlreadyExistsError
|
||||
from backend.util.settings import Settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def register(tree: app_commands.CommandTree, api: BotBackend) -> None:
|
||||
"""Register all slash commands on the given CommandTree."""
|
||||
|
||||
@tree.command(
|
||||
name="setup",
|
||||
description="Link this server to an AutoGPT account for AutoPilot",
|
||||
)
|
||||
async def setup_command(interaction: discord.Interaction) -> None:
|
||||
await _handle_setup(interaction, api)
|
||||
|
||||
@tree.command(name="help", description="Show AutoPilot bot usage info")
|
||||
async def help_command(interaction: discord.Interaction) -> None:
|
||||
await _handle_help(interaction)
|
||||
|
||||
@tree.command(
|
||||
name="unlink",
|
||||
description="Manage linked servers from your AutoGPT settings",
|
||||
)
|
||||
async def unlink_command(interaction: discord.Interaction) -> None:
|
||||
await _handle_unlink(interaction)
|
||||
|
||||
|
||||
async def _handle_setup(interaction: discord.Interaction, api: BotBackend) -> None:
|
||||
if interaction.guild is None:
|
||||
await interaction.response.send_message(
|
||||
"This command can only be used in a server. "
|
||||
"To link your DMs, just send me a direct message.",
|
||||
ephemeral=True,
|
||||
)
|
||||
return
|
||||
|
||||
await interaction.response.defer(ephemeral=True)
|
||||
try:
|
||||
result = await api.create_link_token(
|
||||
platform="discord",
|
||||
platform_server_id=str(interaction.guild.id),
|
||||
platform_user_id=str(interaction.user.id),
|
||||
platform_username=interaction.user.display_name,
|
||||
server_name=interaction.guild.name,
|
||||
channel_id=str(interaction.channel_id or ""),
|
||||
)
|
||||
except LinkAlreadyExistsError:
|
||||
await interaction.followup.send(
|
||||
"This server is already linked — just mention me!",
|
||||
ephemeral=True,
|
||||
)
|
||||
return
|
||||
except Exception:
|
||||
logger.exception("Failed to create link token")
|
||||
await interaction.followup.send(
|
||||
"Something went wrong. Try again later.",
|
||||
ephemeral=True,
|
||||
)
|
||||
return
|
||||
|
||||
view = discord.ui.View()
|
||||
view.add_item(
|
||||
discord.ui.Button(
|
||||
style=discord.ButtonStyle.link,
|
||||
label="Link Server",
|
||||
url=result.link_url,
|
||||
)
|
||||
)
|
||||
await interaction.followup.send(
|
||||
f"**Set up AutoPilot for {interaction.guild.name}**\n\n"
|
||||
"Click the button below to connect this server to your AutoGPT "
|
||||
"account. Once confirmed, everyone here can mention me to use "
|
||||
"AutoPilot.\n\n"
|
||||
"All usage will be billed to your account.\n"
|
||||
"This link expires in 30 minutes.",
|
||||
ephemeral=True,
|
||||
view=view,
|
||||
)
|
||||
|
||||
|
||||
async def _handle_help(interaction: discord.Interaction) -> None:
|
||||
await interaction.response.send_message(
|
||||
"**AutoPilot Bot**\n\n"
|
||||
"Mention me in a server or DM me directly to chat.\n\n"
|
||||
"**Commands:**\n"
|
||||
"- `/setup` — Link this server to your AutoGPT account\n"
|
||||
"- `/help` — Show this message\n"
|
||||
"- `/unlink` — Manage linked accounts\n\n"
|
||||
"**How it works:**\n"
|
||||
"- In a server: the person who runs `/setup` pays for usage\n"
|
||||
"- In DMs: you link and pay for your own usage\n",
|
||||
ephemeral=True,
|
||||
)
|
||||
|
||||
|
||||
async def _handle_unlink(interaction: discord.Interaction) -> None:
|
||||
config = Settings().config
|
||||
base_url = config.frontend_base_url or config.platform_base_url
|
||||
message = (
|
||||
"Unlinking requires authentication, so it has to be done "
|
||||
"from the web. Click below to manage your linked accounts."
|
||||
)
|
||||
|
||||
if not base_url:
|
||||
await interaction.response.send_message(
|
||||
f"{message}\n\nOpen your AutoGPT settings and visit "
|
||||
"Profile → Linked accounts.",
|
||||
ephemeral=True,
|
||||
)
|
||||
return
|
||||
|
||||
view = discord.ui.View()
|
||||
view.add_item(
|
||||
discord.ui.Button(
|
||||
style=discord.ButtonStyle.link,
|
||||
label="Open Settings",
|
||||
url=f"{base_url}/profile/settings",
|
||||
)
|
||||
)
|
||||
await interaction.response.send_message(message, ephemeral=True, view=view)
|
||||
@@ -0,0 +1,165 @@
|
||||
"""Tests for Discord slash command handlers.
|
||||
|
||||
Targets the ``_handle_*`` functions directly — sidesteps ``CommandTree``
|
||||
registration since it requires a live ``discord.Client``.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.util.exceptions import LinkAlreadyExistsError
|
||||
|
||||
from ...bot_backend import LinkTokenResult
|
||||
from .commands import _handle_help, _handle_setup, _handle_unlink
|
||||
|
||||
|
||||
def _interaction(*, guild: bool = True) -> MagicMock:
|
||||
interaction = MagicMock()
|
||||
interaction.response.send_message = AsyncMock()
|
||||
interaction.response.defer = AsyncMock()
|
||||
interaction.followup.send = AsyncMock()
|
||||
if guild:
|
||||
# MagicMock treats `name` as a constructor kwarg for the mock's repr,
|
||||
# not as an attribute — so set it after construction.
|
||||
interaction.guild = MagicMock(id=123)
|
||||
interaction.guild.name = "Test Guild"
|
||||
interaction.user = MagicMock(id=456, display_name="Bently")
|
||||
interaction.channel_id = 789
|
||||
else:
|
||||
interaction.guild = None
|
||||
interaction.user = MagicMock(id=456, display_name="Bently")
|
||||
interaction.channel_id = None
|
||||
return interaction
|
||||
|
||||
|
||||
def _api_with_token() -> MagicMock:
|
||||
api = MagicMock()
|
||||
api.create_link_token = AsyncMock(
|
||||
return_value=LinkTokenResult(
|
||||
token="abc",
|
||||
link_url="https://example.com/link/abc",
|
||||
expires_at="2099-01-01T00:00:00Z",
|
||||
)
|
||||
)
|
||||
return api
|
||||
|
||||
|
||||
class TestHandleSetup:
|
||||
@pytest.mark.asyncio
|
||||
async def test_dm_invocation_rejects_early(self):
|
||||
interaction = _interaction(guild=False)
|
||||
api = _api_with_token()
|
||||
await _handle_setup(interaction, api)
|
||||
|
||||
interaction.response.send_message.assert_awaited_once()
|
||||
api.create_link_token.assert_not_awaited()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_guild_invocation_creates_token_and_posts_button(self):
|
||||
interaction = _interaction()
|
||||
api = _api_with_token()
|
||||
await _handle_setup(interaction, api)
|
||||
|
||||
interaction.response.defer.assert_awaited_once_with(ephemeral=True)
|
||||
api.create_link_token.assert_awaited_once()
|
||||
call_kwargs = api.create_link_token.await_args.kwargs
|
||||
assert call_kwargs["platform"] == "discord"
|
||||
assert call_kwargs["platform_server_id"] == "123"
|
||||
assert call_kwargs["server_name"] == "Test Guild"
|
||||
|
||||
interaction.followup.send.assert_awaited_once()
|
||||
sent = interaction.followup.send.await_args
|
||||
assert "Set up AutoPilot for Test Guild" in sent.args[0]
|
||||
assert sent.kwargs["view"] is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_already_linked_gets_friendly_message(self):
|
||||
interaction = _interaction()
|
||||
api = _api_with_token()
|
||||
api.create_link_token = AsyncMock(side_effect=LinkAlreadyExistsError("already"))
|
||||
|
||||
await _handle_setup(interaction, api)
|
||||
|
||||
interaction.followup.send.assert_awaited_once()
|
||||
msg = interaction.followup.send.await_args.args[0]
|
||||
assert "already linked" in msg
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_backend_error_surfaces_generic_message(self):
|
||||
interaction = _interaction()
|
||||
api = _api_with_token()
|
||||
api.create_link_token = AsyncMock(side_effect=RuntimeError("boom"))
|
||||
|
||||
await _handle_setup(interaction, api)
|
||||
|
||||
interaction.followup.send.assert_awaited_once()
|
||||
msg = interaction.followup.send.await_args.args[0]
|
||||
assert "went wrong" in msg.lower()
|
||||
|
||||
|
||||
class TestHandleHelp:
|
||||
@pytest.mark.asyncio
|
||||
async def test_help_sends_ephemeral_message(self):
|
||||
interaction = _interaction()
|
||||
await _handle_help(interaction)
|
||||
interaction.response.send_message.assert_awaited_once()
|
||||
assert interaction.response.send_message.await_args.kwargs["ephemeral"] is True
|
||||
body = interaction.response.send_message.await_args.args[0]
|
||||
assert "/setup" in body
|
||||
assert "/help" in body
|
||||
assert "/unlink" in body
|
||||
|
||||
|
||||
class TestHandleUnlink:
|
||||
@pytest.mark.asyncio
|
||||
async def test_with_frontend_url_posts_button(self):
|
||||
interaction = _interaction()
|
||||
fake_settings = MagicMock()
|
||||
fake_settings.config.frontend_base_url = "http://localhost:3000"
|
||||
fake_settings.config.platform_base_url = ""
|
||||
with patch(
|
||||
"backend.copilot.bot.adapters.discord.commands.Settings",
|
||||
return_value=fake_settings,
|
||||
):
|
||||
await _handle_unlink(interaction)
|
||||
|
||||
interaction.response.send_message.assert_awaited_once()
|
||||
sent = interaction.response.send_message.await_args
|
||||
assert sent.kwargs["view"] is not None
|
||||
assert sent.kwargs["ephemeral"] is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_falls_back_to_platform_base_url(self):
|
||||
interaction = _interaction()
|
||||
fake_settings = MagicMock()
|
||||
fake_settings.config.frontend_base_url = ""
|
||||
fake_settings.config.platform_base_url = "http://other"
|
||||
with patch(
|
||||
"backend.copilot.bot.adapters.discord.commands.Settings",
|
||||
return_value=fake_settings,
|
||||
):
|
||||
await _handle_unlink(interaction)
|
||||
|
||||
# Button uses the fallback URL.
|
||||
sent = interaction.response.send_message.await_args
|
||||
view = sent.kwargs["view"]
|
||||
assert any(
|
||||
"http://other" in getattr(child, "url", "") for child in view.children
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_urls_configured_sends_plain_text(self):
|
||||
interaction = _interaction()
|
||||
fake_settings = MagicMock()
|
||||
fake_settings.config.frontend_base_url = ""
|
||||
fake_settings.config.platform_base_url = ""
|
||||
with patch(
|
||||
"backend.copilot.bot.adapters.discord.commands.Settings",
|
||||
return_value=fake_settings,
|
||||
):
|
||||
await _handle_unlink(interaction)
|
||||
|
||||
sent = interaction.response.send_message.await_args
|
||||
assert "view" not in sent.kwargs or sent.kwargs.get("view") is None
|
||||
assert "Profile" in sent.args[0]
|
||||
@@ -0,0 +1,15 @@
|
||||
"""Discord-specific configuration."""
|
||||
|
||||
from backend.util.settings import Settings
|
||||
|
||||
|
||||
def get_bot_token() -> str:
|
||||
return Settings().secrets.autopilot_bot_discord_token
|
||||
|
||||
|
||||
# Discord message content limit (hard platform cap)
|
||||
MAX_MESSAGE_LENGTH = 2000
|
||||
|
||||
# Flush the streaming buffer at 1900 — leaves 100-char headroom under the
|
||||
# 2000 cap so the boundary-splitter has room to reach a natural break point.
|
||||
CHUNK_FLUSH_AT = 1900
|
||||
156
autogpt_platform/backend/backend/copilot/bot/app.py
Normal file
156
autogpt_platform/backend/backend/copilot/bot/app.py
Normal file
@@ -0,0 +1,156 @@
|
||||
"""CoPilot Chat Bridge — AppService that runs the configured chat-platform
|
||||
adapters (Discord, Telegram, Slack) and exposes outbound message RPC for
|
||||
other services to push messages into chat platforms.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from concurrent.futures import Future
|
||||
|
||||
from backend.platform_linking.models import Platform
|
||||
from backend.util.service import (
|
||||
AppService,
|
||||
AppServiceClient,
|
||||
UnhealthyServiceError,
|
||||
endpoint_to_async,
|
||||
expose,
|
||||
)
|
||||
from backend.util.settings import Settings
|
||||
|
||||
from .adapters.base import PlatformAdapter
|
||||
from .adapters.discord import config as discord_config
|
||||
from .adapters.discord.adapter import DiscordAdapter
|
||||
from .bot_backend import BotBackend
|
||||
from .handler import MessageHandler
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Stay up for health-checks and runtime reconfiguration when no adapter is
|
||||
# configured (e.g. deployed without a Discord token).
|
||||
_NO_ADAPTER_SLEEP_SECONDS = 3600
|
||||
|
||||
|
||||
class CoPilotChatBridge(AppService):
|
||||
"""Bridges AutoPilot to external chat platforms via per-platform adapters."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# Flipped to True once `_run_adapters` reaches its blocking gather
|
||||
# (or the no-adapter idle loop), and back to False if the task exits
|
||||
# for any reason. Consumed by `health_check` so orchestrators can
|
||||
# restart the pod when the bridge is dead-but-listening.
|
||||
self._adapters_healthy = False
|
||||
|
||||
@classmethod
|
||||
def get_port(cls) -> int:
|
||||
return Settings().config.copilot_chat_bridge_port
|
||||
|
||||
def run_service(self) -> None:
|
||||
future = asyncio.run_coroutine_threadsafe(
|
||||
self._run_adapters(), self.shared_event_loop
|
||||
)
|
||||
future.add_done_callback(self._on_adapters_exit)
|
||||
super().run_service()
|
||||
|
||||
async def _run_adapters(self) -> None:
|
||||
api = BotBackend()
|
||||
adapters = _build_adapters(api)
|
||||
|
||||
if not adapters:
|
||||
logger.info(
|
||||
"CoPilotChatBridge: no platform adapters configured — idling. "
|
||||
"Set AUTOPILOT_BOT_DISCORD_TOKEN (or another platform token) to "
|
||||
"enable an adapter."
|
||||
)
|
||||
self._adapters_healthy = True
|
||||
try:
|
||||
while True:
|
||||
await asyncio.sleep(_NO_ADAPTER_SLEEP_SECONDS)
|
||||
finally:
|
||||
await api.close()
|
||||
|
||||
handler = MessageHandler(api)
|
||||
for adapter in adapters:
|
||||
adapter.on_message(handler.handle)
|
||||
|
||||
self._adapters_healthy = True
|
||||
try:
|
||||
await asyncio.gather(*(a.start() for a in adapters))
|
||||
finally:
|
||||
await asyncio.gather(*(a.stop() for a in adapters), return_exceptions=True)
|
||||
await api.close()
|
||||
|
||||
def _on_adapters_exit(self, future: "Future[None]") -> None:
|
||||
"""Surface exceptions from `_run_adapters` and flip the health flag.
|
||||
|
||||
`run_coroutine_threadsafe` would otherwise swallow the exception
|
||||
into the returned future, leaving the FastAPI health endpoint
|
||||
cheerfully reporting OK on a dead bridge.
|
||||
"""
|
||||
self._adapters_healthy = False
|
||||
exc = future.exception()
|
||||
if exc is not None:
|
||||
logger.error("CoPilotChatBridge adapters crashed: %r", exc, exc_info=exc)
|
||||
else:
|
||||
logger.warning("CoPilotChatBridge adapters exited without error")
|
||||
|
||||
async def health_check(self) -> str:
|
||||
if not self._adapters_healthy:
|
||||
raise UnhealthyServiceError("CoPilotChatBridge adapter task is not running")
|
||||
return await super().health_check()
|
||||
|
||||
@expose
|
||||
async def send_message_to_channel(
|
||||
self,
|
||||
platform: Platform,
|
||||
channel_id: str,
|
||||
content: str,
|
||||
) -> bool:
|
||||
"""Deliver a message to a channel on the given platform.
|
||||
|
||||
Stub — scaffolding for the inbound-RPC pattern (backend → chat
|
||||
platform). Not yet wired to a concrete adapter. Callers must not use
|
||||
``request_retry=True`` on the client until this is implemented, since
|
||||
``ValueError`` crosses the RPC boundary as a client-side 4xx-ish error
|
||||
rather than a transient 5xx.
|
||||
"""
|
||||
raise ValueError(f"send_message_to_channel not yet wired for {platform.value}")
|
||||
|
||||
@expose
|
||||
async def send_dm(
|
||||
self,
|
||||
platform: Platform,
|
||||
platform_user_id: str,
|
||||
content: str,
|
||||
) -> bool:
|
||||
"""Deliver a DM to a user on the given platform.
|
||||
|
||||
Stub — scaffolding for the inbound-RPC pattern. See
|
||||
:meth:`send_message_to_channel` for the retry caveat.
|
||||
"""
|
||||
raise ValueError(f"send_dm not yet wired for {platform.value}")
|
||||
|
||||
|
||||
class CoPilotChatBridgeClient(AppServiceClient):
|
||||
@classmethod
|
||||
def get_service_type(cls):
|
||||
return CoPilotChatBridge
|
||||
|
||||
send_message_to_channel = endpoint_to_async(
|
||||
CoPilotChatBridge.send_message_to_channel
|
||||
)
|
||||
send_dm = endpoint_to_async(CoPilotChatBridge.send_dm)
|
||||
|
||||
|
||||
def _build_adapters(api: BotBackend) -> list[PlatformAdapter]:
|
||||
"""Instantiate adapters based on which platform tokens are configured."""
|
||||
adapters: list[PlatformAdapter] = []
|
||||
if discord_config.get_bot_token():
|
||||
adapters.append(DiscordAdapter(api))
|
||||
logger.info("Discord adapter enabled")
|
||||
# Future:
|
||||
# if telegram_config.get_bot_token():
|
||||
# adapters.append(TelegramAdapter(api))
|
||||
# if slack_config.get_bot_token():
|
||||
# adapters.append(SlackAdapter(api))
|
||||
return adapters
|
||||
194
autogpt_platform/backend/backend/copilot/bot/bot_backend.py
Normal file
194
autogpt_platform/backend/backend/copilot/bot/bot_backend.py
Normal file
@@ -0,0 +1,194 @@
|
||||
"""Bot-side facade over `PlatformLinkingManagerClient` + `stream_registry`.
|
||||
|
||||
The `BotBackend` class is the bot's single entry point into the AutoGPT
|
||||
backend — it wraps the linking RPC client and the copilot stream registry
|
||||
behind plain string-typed methods. Adapters import this directly so the
|
||||
discord/telegram/slack code never touches Pyro / Redis Streams plumbing.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import AsyncGenerator, Awaitable, Callable, Optional
|
||||
|
||||
from backend.copilot import stream_registry
|
||||
from backend.copilot.response_model import StreamError, StreamFinish, StreamTextDelta
|
||||
from backend.platform_linking.models import (
|
||||
BotChatRequest,
|
||||
CreateLinkTokenRequest,
|
||||
CreateUserLinkTokenRequest,
|
||||
Platform,
|
||||
)
|
||||
from backend.util.clients import get_platform_linking_manager_client
|
||||
from backend.util.exceptions import (
|
||||
DuplicateChatMessageError,
|
||||
LinkAlreadyExistsError,
|
||||
NotFoundError,
|
||||
)
|
||||
|
||||
# How long to wait for a single chunk from the copilot stream before giving
|
||||
# up. Covers the case where the backend crashes mid-stream and never sends
|
||||
# ``StreamFinish`` — without this, the bot would hang forever on ``queue.get()``.
|
||||
STREAM_CHUNK_TIMEOUT_SECONDS = 120
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
__all__ = [
|
||||
"BotBackend",
|
||||
"DuplicateChatMessageError",
|
||||
"LinkAlreadyExistsError",
|
||||
"LinkTokenResult",
|
||||
"NotFoundError",
|
||||
"ResolveResult",
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ResolveResult:
|
||||
linked: bool
|
||||
|
||||
|
||||
@dataclass
|
||||
class LinkTokenResult:
|
||||
token: str
|
||||
link_url: str
|
||||
expires_at: str
|
||||
|
||||
|
||||
class BotBackend:
|
||||
"""Bot-side linking + chat operations, routed over cluster-internal RPC."""
|
||||
|
||||
def __init__(self):
|
||||
self._client = get_platform_linking_manager_client()
|
||||
|
||||
async def close(self) -> None:
|
||||
# The client's lifecycle is owned by the thread-cached factory; nothing
|
||||
# to close here. Kept for API compatibility with older bot code.
|
||||
pass
|
||||
|
||||
async def resolve_server(
|
||||
self, platform: str, platform_server_id: str
|
||||
) -> ResolveResult:
|
||||
resp = await self._client.resolve_server_link(
|
||||
platform=Platform(platform.upper()),
|
||||
platform_server_id=platform_server_id,
|
||||
)
|
||||
return ResolveResult(linked=resp.linked)
|
||||
|
||||
async def resolve_user(self, platform: str, platform_user_id: str) -> ResolveResult:
|
||||
resp = await self._client.resolve_user_link(
|
||||
platform=Platform(platform.upper()),
|
||||
platform_user_id=platform_user_id,
|
||||
)
|
||||
return ResolveResult(linked=resp.linked)
|
||||
|
||||
async def create_link_token(
|
||||
self,
|
||||
platform: str,
|
||||
platform_server_id: str,
|
||||
platform_user_id: str,
|
||||
platform_username: str,
|
||||
server_name: str,
|
||||
channel_id: str = "",
|
||||
) -> LinkTokenResult:
|
||||
resp = await self._client.create_server_link_token(
|
||||
request=CreateLinkTokenRequest(
|
||||
platform=Platform(platform.upper()),
|
||||
platform_server_id=platform_server_id,
|
||||
platform_user_id=platform_user_id,
|
||||
platform_username=platform_username or None,
|
||||
server_name=server_name or None,
|
||||
channel_id=channel_id or None,
|
||||
)
|
||||
)
|
||||
return LinkTokenResult(
|
||||
token=resp.token,
|
||||
link_url=resp.link_url,
|
||||
expires_at=resp.expires_at.isoformat(),
|
||||
)
|
||||
|
||||
async def create_user_link_token(
|
||||
self,
|
||||
platform: str,
|
||||
platform_user_id: str,
|
||||
platform_username: str,
|
||||
) -> LinkTokenResult:
|
||||
resp = await self._client.create_user_link_token(
|
||||
request=CreateUserLinkTokenRequest(
|
||||
platform=Platform(platform.upper()),
|
||||
platform_user_id=platform_user_id,
|
||||
platform_username=platform_username or None,
|
||||
)
|
||||
)
|
||||
return LinkTokenResult(
|
||||
token=resp.token,
|
||||
link_url=resp.link_url,
|
||||
expires_at=resp.expires_at.isoformat(),
|
||||
)
|
||||
|
||||
async def stream_chat(
|
||||
self,
|
||||
platform: str,
|
||||
platform_user_id: str,
|
||||
message: str,
|
||||
session_id: Optional[str] = None,
|
||||
platform_server_id: Optional[str] = None,
|
||||
on_session_id: Optional[Callable[[str], Awaitable[None]]] = None,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Start a copilot turn and yield text deltas from the stream.
|
||||
|
||||
Raises :class:`DuplicateChatMessageError` if the same message is
|
||||
already in flight for this session.
|
||||
"""
|
||||
handle = await self._client.start_chat_turn(
|
||||
request=BotChatRequest(
|
||||
platform=Platform(platform.upper()),
|
||||
platform_user_id=platform_user_id,
|
||||
message=message,
|
||||
session_id=session_id,
|
||||
platform_server_id=platform_server_id,
|
||||
)
|
||||
)
|
||||
if on_session_id:
|
||||
await on_session_id(handle.session_id)
|
||||
|
||||
queue = await stream_registry.subscribe_to_session(
|
||||
session_id=handle.session_id,
|
||||
user_id=handle.user_id,
|
||||
last_message_id=handle.subscribe_from,
|
||||
)
|
||||
if queue is None:
|
||||
yield "\n[Error: failed to subscribe to response stream]"
|
||||
return
|
||||
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
chunk = await asyncio.wait_for(
|
||||
queue.get(), timeout=STREAM_CHUNK_TIMEOUT_SECONDS
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(
|
||||
"Stream idle timeout after %ss for session %s",
|
||||
STREAM_CHUNK_TIMEOUT_SECONDS,
|
||||
handle.session_id,
|
||||
)
|
||||
yield "\n[Error: response timed out]"
|
||||
return
|
||||
if isinstance(chunk, StreamTextDelta):
|
||||
if chunk.delta:
|
||||
yield chunk.delta
|
||||
elif isinstance(chunk, StreamFinish):
|
||||
return
|
||||
elif isinstance(chunk, StreamError):
|
||||
logger.error("Stream error from backend: %s", chunk.errorText)
|
||||
yield f"\n[Error: {chunk.errorText}]"
|
||||
return
|
||||
# Other StreamX types (StreamStart, StreamTextStart, tool events,
|
||||
# etc.) are emitted by the executor for the frontend UI and
|
||||
# aren't useful for the plain-text bot transcript.
|
||||
finally:
|
||||
await stream_registry.unsubscribe_from_session(
|
||||
session_id=handle.session_id,
|
||||
subscriber_queue=queue,
|
||||
)
|
||||
216
autogpt_platform/backend/backend/copilot/bot/bot_backend_test.py
Normal file
216
autogpt_platform/backend/backend/copilot/bot/bot_backend_test.py
Normal file
@@ -0,0 +1,216 @@
|
||||
"""Tests for the bot's thin facade over PlatformLinkingManagerClient."""
|
||||
|
||||
import asyncio
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.copilot.response_model import StreamError, StreamFinish, StreamTextDelta
|
||||
from backend.platform_linking.models import (
|
||||
ChatTurnHandle,
|
||||
LinkTokenResponse,
|
||||
ResolveResponse,
|
||||
)
|
||||
from backend.util.exceptions import (
|
||||
DuplicateChatMessageError,
|
||||
LinkAlreadyExistsError,
|
||||
NotFoundError,
|
||||
)
|
||||
|
||||
from .bot_backend import BotBackend
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def api() -> BotBackend:
|
||||
with patch("backend.copilot.bot.bot_backend.get_platform_linking_manager_client"):
|
||||
instance = BotBackend()
|
||||
# Swap in a MagicMock whose RPC methods are AsyncMocks — simpler than
|
||||
# patching each call site.
|
||||
instance._client = MagicMock()
|
||||
return instance
|
||||
|
||||
|
||||
class TestResolve:
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_server(self, api: BotBackend):
|
||||
api._client.resolve_server_link = AsyncMock(
|
||||
return_value=ResolveResponse(linked=True)
|
||||
)
|
||||
result = await api.resolve_server("discord", "g1")
|
||||
assert result.linked is True
|
||||
api._client.resolve_server_link.assert_awaited_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_user(self, api: BotBackend):
|
||||
api._client.resolve_user_link = AsyncMock(
|
||||
return_value=ResolveResponse(linked=False)
|
||||
)
|
||||
result = await api.resolve_user("discord", "u1")
|
||||
assert result.linked is False
|
||||
|
||||
|
||||
class TestCreateLinkTokens:
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_server_link_token(self, api: BotBackend):
|
||||
api._client.create_server_link_token = AsyncMock(
|
||||
return_value=LinkTokenResponse(
|
||||
token="abc",
|
||||
expires_at=datetime.now(timezone.utc),
|
||||
link_url="https://example.com/link/abc",
|
||||
)
|
||||
)
|
||||
result = await api.create_link_token(
|
||||
platform="discord",
|
||||
platform_server_id="g1",
|
||||
platform_user_id="u1",
|
||||
platform_username="Bently",
|
||||
server_name="Test",
|
||||
)
|
||||
assert result.token == "abc"
|
||||
assert result.link_url.endswith("/link/abc")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_server_link_token_propagates_already_exists(
|
||||
self, api: BotBackend
|
||||
):
|
||||
api._client.create_server_link_token = AsyncMock(
|
||||
side_effect=LinkAlreadyExistsError("already linked")
|
||||
)
|
||||
with pytest.raises(LinkAlreadyExistsError):
|
||||
await api.create_link_token(
|
||||
platform="discord",
|
||||
platform_server_id="g1",
|
||||
platform_user_id="u1",
|
||||
platform_username="",
|
||||
server_name="",
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_user_link_token(self, api: BotBackend):
|
||||
api._client.create_user_link_token = AsyncMock(
|
||||
return_value=LinkTokenResponse(
|
||||
token="xyz",
|
||||
expires_at=datetime.now(timezone.utc),
|
||||
link_url="https://example.com/link/xyz",
|
||||
)
|
||||
)
|
||||
result = await api.create_user_link_token(
|
||||
platform="discord", platform_user_id="u1", platform_username="Bently"
|
||||
)
|
||||
assert result.token == "xyz"
|
||||
|
||||
|
||||
class TestStreamChat:
|
||||
@pytest.mark.asyncio
|
||||
async def test_yields_text_deltas_and_terminates_on_finish(self, api: BotBackend):
|
||||
handle = ChatTurnHandle(session_id="sess", turn_id="turn", user_id="u1")
|
||||
api._client.start_chat_turn = AsyncMock(return_value=handle)
|
||||
|
||||
queue: asyncio.Queue = asyncio.Queue()
|
||||
await queue.put(StreamTextDelta(id="1", delta="Hello "))
|
||||
await queue.put(StreamTextDelta(id="2", delta="world"))
|
||||
await queue.put(StreamFinish())
|
||||
|
||||
captured_session_ids: list[str] = []
|
||||
|
||||
async def capture(sid: str) -> None:
|
||||
captured_session_ids.append(sid)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.bot.bot_backend.stream_registry.subscribe_to_session",
|
||||
new=AsyncMock(return_value=queue),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.bot.bot_backend.stream_registry.unsubscribe_from_session",
|
||||
new=AsyncMock(),
|
||||
),
|
||||
):
|
||||
chunks: list[str] = []
|
||||
async for chunk in api.stream_chat(
|
||||
platform="discord",
|
||||
platform_user_id="u1",
|
||||
message="hi",
|
||||
on_session_id=capture,
|
||||
):
|
||||
chunks.append(chunk)
|
||||
|
||||
assert "".join(chunks) == "Hello world"
|
||||
assert captured_session_ids == ["sess"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_surfaces_stream_error(self, api: BotBackend):
|
||||
handle = ChatTurnHandle(session_id="sess", turn_id="turn", user_id="u1")
|
||||
api._client.start_chat_turn = AsyncMock(return_value=handle)
|
||||
|
||||
queue: asyncio.Queue = asyncio.Queue()
|
||||
await queue.put(StreamError(errorText="executor crashed"))
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.bot.bot_backend.stream_registry.subscribe_to_session",
|
||||
new=AsyncMock(return_value=queue),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.bot.bot_backend.stream_registry.unsubscribe_from_session",
|
||||
new=AsyncMock(),
|
||||
),
|
||||
):
|
||||
chunks: list[str] = []
|
||||
async for chunk in api.stream_chat(
|
||||
platform="discord", platform_user_id="u1", message="hi"
|
||||
):
|
||||
chunks.append(chunk)
|
||||
|
||||
assert any("executor crashed" in c for c in chunks)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_duplicate_message_propagates(self, api: BotBackend):
|
||||
api._client.start_chat_turn = AsyncMock(
|
||||
side_effect=DuplicateChatMessageError("in flight")
|
||||
)
|
||||
|
||||
with pytest.raises(DuplicateChatMessageError):
|
||||
async for _ in api.stream_chat(
|
||||
platform="discord", platform_user_id="u1", message="hi"
|
||||
):
|
||||
pass
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_not_found_propagates(self, api: BotBackend):
|
||||
api._client.start_chat_turn = AsyncMock(
|
||||
side_effect=NotFoundError("session gone")
|
||||
)
|
||||
|
||||
with pytest.raises(NotFoundError):
|
||||
async for _ in api.stream_chat(
|
||||
platform="discord",
|
||||
platform_user_id="u1",
|
||||
message="hi",
|
||||
session_id="missing",
|
||||
):
|
||||
pass
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_subscribe_returns_none_yields_error(self, api: BotBackend):
|
||||
handle = ChatTurnHandle(session_id="sess", turn_id="turn", user_id="u1")
|
||||
api._client.start_chat_turn = AsyncMock(return_value=handle)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.bot.bot_backend.stream_registry.subscribe_to_session",
|
||||
new=AsyncMock(return_value=None),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.bot.bot_backend.stream_registry.unsubscribe_from_session",
|
||||
new=AsyncMock(),
|
||||
),
|
||||
):
|
||||
chunks: list[str] = []
|
||||
async for chunk in api.stream_chat(
|
||||
platform="discord", platform_user_id="u1", message="hi"
|
||||
):
|
||||
chunks.append(chunk)
|
||||
|
||||
assert any("failed to subscribe" in c.lower() for c in chunks)
|
||||
4
autogpt_platform/backend/backend/copilot/bot/config.py
Normal file
4
autogpt_platform/backend/backend/copilot/bot/config.py
Normal file
@@ -0,0 +1,4 @@
|
||||
"""Platform-agnostic bot config."""
|
||||
|
||||
# Cache TTL for AutoPilot session IDs (per channel/thread)
|
||||
SESSION_TTL = 86400 # 24 hours
|
||||
280
autogpt_platform/backend/backend/copilot/bot/handler.py
Normal file
280
autogpt_platform/backend/backend/copilot/bot/handler.py
Normal file
@@ -0,0 +1,280 @@
|
||||
"""Platform-agnostic message handler.
|
||||
|
||||
Receives a MessageContext from any adapter and drives the full AutoPilot
|
||||
interaction: link resolution, thread routing, batched streaming with a
|
||||
persistent typing indicator.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from backend.data.redis_client import get_redis_async
|
||||
from backend.util.exceptions import (
|
||||
DuplicateChatMessageError,
|
||||
LinkAlreadyExistsError,
|
||||
NotFoundError,
|
||||
)
|
||||
|
||||
from . import threads
|
||||
from .adapters.base import MessageContext, PlatformAdapter
|
||||
from .bot_backend import BotBackend
|
||||
from .config import SESSION_TTL
|
||||
from .text import format_batch, split_at_boundary
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TargetState:
|
||||
"""Per-target streaming state.
|
||||
|
||||
A "target" is wherever the bot replies — a thread ID, a DM channel ID.
|
||||
`pending` holds messages that arrived while a stream was running; they
|
||||
get drained as a single batched follow-up turn when the stream ends.
|
||||
"""
|
||||
|
||||
processing: bool = False
|
||||
pending: list[tuple[str, str, str]] = field(default_factory=list)
|
||||
# Each entry: (username, user_id, text)
|
||||
|
||||
|
||||
class MessageHandler:
|
||||
def __init__(self, api: BotBackend):
|
||||
self._api = api
|
||||
self._targets: dict[str, TargetState] = {}
|
||||
|
||||
async def handle(self, ctx: MessageContext, adapter: PlatformAdapter) -> None:
|
||||
if not ctx.text.strip():
|
||||
if ctx.channel_type == "channel":
|
||||
await adapter.send_reply(
|
||||
ctx.channel_id,
|
||||
"You mentioned me but didn't say anything. How can I help?",
|
||||
ctx.message_id,
|
||||
)
|
||||
return
|
||||
|
||||
if not await self._ensure_linked(ctx, adapter):
|
||||
return
|
||||
|
||||
target_id = await self._resolve_target(ctx, adapter)
|
||||
if not target_id:
|
||||
return # Thread not subscribed, ignore silently
|
||||
|
||||
await self._enqueue_and_process(ctx, adapter, target_id)
|
||||
|
||||
# -- Target resolution --
|
||||
|
||||
async def _resolve_target(
|
||||
self, ctx: MessageContext, adapter: PlatformAdapter
|
||||
) -> str | None:
|
||||
if ctx.channel_type == "dm":
|
||||
return ctx.channel_id
|
||||
|
||||
if ctx.channel_type == "thread":
|
||||
if await threads.is_subscribed(ctx.platform, ctx.channel_id):
|
||||
return ctx.channel_id
|
||||
return None
|
||||
|
||||
# channel_type == "channel" — create a thread and subscribe
|
||||
thread_name = f"{ctx.username} × AutoPilot"
|
||||
thread_id = await adapter.create_thread(
|
||||
ctx.channel_id, ctx.message_id, thread_name
|
||||
)
|
||||
if not thread_id:
|
||||
logger.warning("Thread creation failed, falling back to channel reply")
|
||||
return ctx.channel_id
|
||||
await threads.subscribe(ctx.platform, thread_id)
|
||||
return thread_id
|
||||
|
||||
# -- Batched streaming --
|
||||
|
||||
async def _enqueue_and_process(
|
||||
self, ctx: MessageContext, adapter: PlatformAdapter, target_id: str
|
||||
) -> None:
|
||||
state = self._targets.setdefault(target_id, TargetState())
|
||||
state.pending.append((ctx.username, ctx.user_id, ctx.text))
|
||||
|
||||
if state.processing:
|
||||
# Another invocation is streaming for this target — it will pick
|
||||
# up the message we just appended when its current stream ends.
|
||||
return
|
||||
|
||||
state.processing = True
|
||||
try:
|
||||
while state.pending:
|
||||
batch = list(state.pending)
|
||||
state.pending.clear()
|
||||
await self._stream_batch(batch, ctx, adapter, target_id)
|
||||
finally:
|
||||
state.processing = False
|
||||
# Drop the empty state so the dict doesn't grow unbounded across
|
||||
# the bot's lifetime.
|
||||
if not state.pending:
|
||||
self._targets.pop(target_id, None)
|
||||
|
||||
async def _stream_batch(
|
||||
self,
|
||||
batch: list[tuple[str, str, str]],
|
||||
ctx: MessageContext,
|
||||
adapter: PlatformAdapter,
|
||||
target_id: str,
|
||||
) -> None:
|
||||
prefixed = format_batch(batch, ctx.platform)
|
||||
|
||||
redis = await get_redis_async()
|
||||
cache_key = f"copilot-bot:session:{ctx.platform}:{target_id}"
|
||||
cached_session_id = await redis.get(cache_key)
|
||||
|
||||
async def _on_session_id(sid: str) -> None:
|
||||
try:
|
||||
await redis.set(cache_key, sid, ex=SESSION_TTL)
|
||||
except Exception:
|
||||
logger.warning("Failed to cache session id for target %s", target_id)
|
||||
|
||||
flush_at = adapter.chunk_flush_at
|
||||
buffer = ""
|
||||
sent_any_content = False
|
||||
|
||||
typing_task = asyncio.create_task(_keep_typing(adapter, target_id))
|
||||
try:
|
||||
async for chunk in self._api.stream_chat(
|
||||
platform=ctx.platform,
|
||||
platform_user_id=ctx.user_id,
|
||||
message=prefixed,
|
||||
session_id=cached_session_id,
|
||||
platform_server_id=ctx.server_id,
|
||||
on_session_id=_on_session_id,
|
||||
):
|
||||
buffer += chunk
|
||||
if len(buffer) >= flush_at:
|
||||
post, buffer = split_at_boundary(buffer, flush_at)
|
||||
if post:
|
||||
await adapter.send_message(target_id, post)
|
||||
if post.strip():
|
||||
sent_any_content = True
|
||||
except DuplicateChatMessageError:
|
||||
# Another in-flight turn is already processing this exact message —
|
||||
# stay quiet so the user doesn't get a double response.
|
||||
logger.info("Duplicate message dropped for target %s", target_id)
|
||||
return
|
||||
except NotFoundError:
|
||||
logger.exception("Chat turn rejected")
|
||||
await adapter.send_message(
|
||||
target_id, "AutoPilot ran into an error. Try again later."
|
||||
)
|
||||
return
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Unexpected error during streaming for target %s", target_id
|
||||
)
|
||||
await adapter.send_message(
|
||||
target_id,
|
||||
"Something went wrong. Try again in a moment.",
|
||||
)
|
||||
return
|
||||
finally:
|
||||
typing_task.cancel()
|
||||
try:
|
||||
await typing_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
await adapter.stop_typing(target_id)
|
||||
|
||||
if buffer.strip():
|
||||
await adapter.send_message(target_id, buffer)
|
||||
sent_any_content = True
|
||||
|
||||
if not sent_any_content:
|
||||
await adapter.send_message(
|
||||
target_id,
|
||||
"AutoPilot didn't produce a response. Try rephrasing your question.",
|
||||
)
|
||||
|
||||
# -- Linking --
|
||||
|
||||
async def _ensure_linked(
|
||||
self, ctx: MessageContext, adapter: PlatformAdapter
|
||||
) -> bool:
|
||||
try:
|
||||
if ctx.is_dm:
|
||||
result = await self._api.resolve_user(ctx.platform, ctx.user_id)
|
||||
if not result.linked:
|
||||
await self._prompt_user_link(ctx, adapter)
|
||||
return False
|
||||
else:
|
||||
if not ctx.server_id:
|
||||
logger.error("Non-DM message missing server_id: %r", ctx)
|
||||
return False
|
||||
result = await self._api.resolve_server(ctx.platform, ctx.server_id)
|
||||
if not result.linked:
|
||||
await adapter.send_message(
|
||||
ctx.channel_id,
|
||||
"This server isn't linked to an AutoGPT account yet. "
|
||||
"Ask a server admin to run `/setup` first.",
|
||||
)
|
||||
return False
|
||||
except ValueError:
|
||||
# ValueError-based domain exceptions (NotFoundError etc.) arrive
|
||||
# over RPC with this base type.
|
||||
logger.exception("Failed to check link status")
|
||||
await adapter.send_message(
|
||||
ctx.channel_id, "Something went wrong. Try again later."
|
||||
)
|
||||
return False
|
||||
except Exception:
|
||||
logger.exception("Unexpected error while checking link status")
|
||||
await adapter.send_message(
|
||||
ctx.channel_id,
|
||||
"Something went wrong. Try again in a moment.",
|
||||
)
|
||||
return False
|
||||
return True
|
||||
|
||||
async def _prompt_user_link(
|
||||
self, ctx: MessageContext, adapter: PlatformAdapter
|
||||
) -> None:
|
||||
try:
|
||||
result = await self._api.create_user_link_token(
|
||||
platform=ctx.platform,
|
||||
platform_user_id=ctx.user_id,
|
||||
platform_username=ctx.username,
|
||||
)
|
||||
platform_display = ctx.platform.capitalize()
|
||||
await adapter.send_link(
|
||||
ctx.channel_id,
|
||||
f"Your {platform_display} DMs aren't linked to an AutoGPT "
|
||||
"account yet. Click below to connect — once linked, you can "
|
||||
"chat with AutoPilot right here.",
|
||||
link_label="Link Account",
|
||||
link_url=result.link_url,
|
||||
)
|
||||
except LinkAlreadyExistsError:
|
||||
# Race: user got linked between resolve_user and create. Re-check
|
||||
# — if still not linked, the backend returned a stale error and
|
||||
# we shouldn't spam the user.
|
||||
re_check = await self._api.resolve_user(ctx.platform, ctx.user_id)
|
||||
if re_check.linked:
|
||||
return
|
||||
logger.exception(
|
||||
"create_user_link_token raised 'already exists' "
|
||||
"but user isn't actually linked"
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Failed to create user link token")
|
||||
await adapter.send_message(
|
||||
ctx.channel_id,
|
||||
"Something went wrong setting up the link. Try again later.",
|
||||
)
|
||||
|
||||
|
||||
async def _keep_typing(adapter: PlatformAdapter, target_id: str) -> None:
|
||||
"""Re-fire the typing indicator every 8s so it doesn't expire mid-stream."""
|
||||
try:
|
||||
while True:
|
||||
await adapter.start_typing(target_id)
|
||||
await asyncio.sleep(8)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception:
|
||||
logger.debug("Typing loop error", exc_info=True)
|
||||
338
autogpt_platform/backend/backend/copilot/bot/handler_test.py
Normal file
338
autogpt_platform/backend/backend/copilot/bot/handler_test.py
Normal file
@@ -0,0 +1,338 @@
|
||||
"""Tests for the platform-agnostic message handler."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.util.exceptions import DuplicateChatMessageError, NotFoundError
|
||||
|
||||
from .adapters.base import ChannelType, MessageContext
|
||||
from .bot_backend import LinkTokenResult, ResolveResult
|
||||
from .handler import MessageHandler, TargetState
|
||||
|
||||
|
||||
def _ctx(
|
||||
*,
|
||||
channel_type: ChannelType = "channel",
|
||||
server_id: str | None = "guild-1",
|
||||
channel_id: str = "chan-1",
|
||||
message_id: str = "msg-1",
|
||||
user_id: str = "user-1",
|
||||
username: str = "Bently",
|
||||
text: str = "hello bot",
|
||||
) -> MessageContext:
|
||||
return MessageContext(
|
||||
platform="discord",
|
||||
channel_type=channel_type,
|
||||
server_id=server_id,
|
||||
channel_id=channel_id,
|
||||
message_id=message_id,
|
||||
user_id=user_id,
|
||||
username=username,
|
||||
text=text,
|
||||
)
|
||||
|
||||
|
||||
def _adapter() -> MagicMock:
|
||||
adapter = MagicMock()
|
||||
adapter.chunk_flush_at = 1900
|
||||
adapter.send_message = AsyncMock()
|
||||
adapter.send_reply = AsyncMock()
|
||||
adapter.send_link = AsyncMock()
|
||||
adapter.start_typing = AsyncMock()
|
||||
adapter.stop_typing = AsyncMock()
|
||||
adapter.create_thread = AsyncMock(return_value="thread-new")
|
||||
return adapter
|
||||
|
||||
|
||||
def _api(*, server_linked: bool = True, user_linked: bool = True) -> MagicMock:
|
||||
api = MagicMock()
|
||||
api.resolve_server = AsyncMock(return_value=ResolveResult(linked=server_linked))
|
||||
api.resolve_user = AsyncMock(return_value=ResolveResult(linked=user_linked))
|
||||
api.create_user_link_token = AsyncMock(
|
||||
return_value=LinkTokenResult(
|
||||
token="t",
|
||||
link_url="https://example.com/link/t",
|
||||
expires_at="2099-01-01T00:00:00Z",
|
||||
)
|
||||
)
|
||||
|
||||
async def _empty_stream(*args, **kwargs):
|
||||
if False:
|
||||
yield ""
|
||||
|
||||
api.stream_chat = _empty_stream
|
||||
return api
|
||||
|
||||
|
||||
class TestEmptyMessage:
|
||||
@pytest.mark.asyncio
|
||||
async def test_channel_mention_without_text_gets_nudge(self):
|
||||
handler = MessageHandler(_api())
|
||||
adapter = _adapter()
|
||||
await handler.handle(_ctx(text=" "), adapter)
|
||||
adapter.send_reply.assert_awaited_once()
|
||||
adapter.send_message.assert_not_awaited()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_dm_is_silently_dropped(self):
|
||||
handler = MessageHandler(_api())
|
||||
adapter = _adapter()
|
||||
await handler.handle(_ctx(channel_type="dm", text=""), adapter)
|
||||
adapter.send_reply.assert_not_awaited()
|
||||
adapter.send_message.assert_not_awaited()
|
||||
|
||||
|
||||
class TestEnsureLinked:
|
||||
@pytest.mark.asyncio
|
||||
async def test_unlinked_server_tells_user_to_setup(self):
|
||||
handler = MessageHandler(_api(server_linked=False))
|
||||
adapter = _adapter()
|
||||
await handler.handle(_ctx(), adapter)
|
||||
call_args = adapter.send_message.await_args.args
|
||||
assert "isn't linked" in call_args[1]
|
||||
assert "/setup" in call_args[1]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unlinked_dm_prompts_link_flow(self):
|
||||
handler = MessageHandler(_api(user_linked=False))
|
||||
adapter = _adapter()
|
||||
await handler.handle(_ctx(channel_type="dm", server_id=None), adapter)
|
||||
adapter.send_link.assert_awaited_once()
|
||||
assert adapter.send_link.await_args.kwargs["link_url"].startswith(
|
||||
"https://example.com/link/"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_dm_without_server_id_is_rejected(self):
|
||||
handler = MessageHandler(_api())
|
||||
adapter = _adapter()
|
||||
await handler.handle(_ctx(server_id=None), adapter)
|
||||
# Guard short-circuits before calling resolve_server.
|
||||
handler._api.resolve_server.assert_not_awaited()
|
||||
adapter.send_message.assert_not_awaited()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_backend_error_in_resolve_produces_message(self):
|
||||
api = _api()
|
||||
api.resolve_server = AsyncMock(side_effect=NotFoundError("boom"))
|
||||
handler = MessageHandler(api)
|
||||
adapter = _adapter()
|
||||
await handler.handle(_ctx(), adapter)
|
||||
adapter.send_message.assert_awaited_once()
|
||||
assert "went wrong" in adapter.send_message.await_args.args[1].lower()
|
||||
|
||||
|
||||
class TestResolveTarget:
|
||||
@pytest.mark.asyncio
|
||||
async def test_dm_reuses_channel_id(self):
|
||||
handler = MessageHandler(_api())
|
||||
adapter = _adapter()
|
||||
ctx = _ctx(channel_type="dm", server_id=None, channel_id="dm-42")
|
||||
result = await handler._resolve_target(ctx, adapter)
|
||||
assert result == "dm-42"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unsubscribed_thread_returns_none(self):
|
||||
handler = MessageHandler(_api())
|
||||
adapter = _adapter()
|
||||
ctx = _ctx(channel_type="thread", channel_id="thread-old")
|
||||
with patch(
|
||||
"backend.copilot.bot.handler.threads.is_subscribed",
|
||||
new=AsyncMock(return_value=False),
|
||||
):
|
||||
assert await handler._resolve_target(ctx, adapter) is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_subscribed_thread_keeps_channel(self):
|
||||
handler = MessageHandler(_api())
|
||||
adapter = _adapter()
|
||||
ctx = _ctx(channel_type="thread", channel_id="thread-ok")
|
||||
with patch(
|
||||
"backend.copilot.bot.handler.threads.is_subscribed",
|
||||
new=AsyncMock(return_value=True),
|
||||
):
|
||||
assert await handler._resolve_target(ctx, adapter) == "thread-ok"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_channel_creates_and_subscribes_thread(self):
|
||||
handler = MessageHandler(_api())
|
||||
adapter = _adapter()
|
||||
adapter.create_thread = AsyncMock(return_value="thread-created")
|
||||
with patch(
|
||||
"backend.copilot.bot.handler.threads.subscribe", new=AsyncMock()
|
||||
) as subscribe:
|
||||
result = await handler._resolve_target(_ctx(), adapter)
|
||||
assert result == "thread-created"
|
||||
subscribe.assert_awaited_once_with("discord", "thread-created")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_channel_falls_back_to_parent_when_thread_creation_fails(self):
|
||||
handler = MessageHandler(_api())
|
||||
adapter = _adapter()
|
||||
adapter.create_thread = AsyncMock(return_value=None)
|
||||
result = await handler._resolve_target(_ctx(channel_id="parent-chan"), adapter)
|
||||
assert result == "parent-chan"
|
||||
|
||||
|
||||
class TestBatching:
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_message_queues_when_processing(self):
|
||||
"""Second caller with processing=True returns without starting a new stream."""
|
||||
handler = MessageHandler(_api())
|
||||
adapter = _adapter()
|
||||
state = TargetState(processing=True)
|
||||
handler._targets["target-1"] = state
|
||||
|
||||
await handler._enqueue_and_process(_ctx(text="second"), adapter, "target-1")
|
||||
|
||||
assert state.processing is True
|
||||
assert state.pending == [("Bently", "user-1", "second")]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_target_state_cleared_after_drain(self):
|
||||
handler = MessageHandler(_api())
|
||||
adapter = _adapter()
|
||||
|
||||
stream_calls: list[list] = []
|
||||
|
||||
async def fake_stream_batch(batch, ctx, ad, tid):
|
||||
stream_calls.append(list(batch))
|
||||
|
||||
handler._stream_batch = fake_stream_batch # type: ignore[method-assign]
|
||||
|
||||
await handler._enqueue_and_process(_ctx(text="hello"), adapter, "target-1")
|
||||
assert stream_calls == [[("Bently", "user-1", "hello")]]
|
||||
# Dict entry should be gone once processing finishes with empty pending.
|
||||
assert "target-1" not in handler._targets
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_drain_loop_picks_up_appended_messages(self):
|
||||
"""Messages appended to pending mid-drain are processed in the next iter."""
|
||||
handler = MessageHandler(_api())
|
||||
adapter = _adapter()
|
||||
|
||||
state = TargetState()
|
||||
handler._targets["target-1"] = state
|
||||
|
||||
seen: list[list] = []
|
||||
|
||||
async def fake_stream_batch(batch, ctx, ad, tid):
|
||||
seen.append(list(batch))
|
||||
if len(seen) == 1:
|
||||
# Simulate another caller appending during the first stream.
|
||||
state.pending.append(("Later", "u2", "follow-up"))
|
||||
|
||||
handler._stream_batch = fake_stream_batch # type: ignore[method-assign]
|
||||
await handler._enqueue_and_process(_ctx(text="first"), adapter, "target-1")
|
||||
|
||||
assert seen == [
|
||||
[("Bently", "user-1", "first")],
|
||||
[("Later", "u2", "follow-up")],
|
||||
]
|
||||
assert "target-1" not in handler._targets
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_duplicate_message_is_silently_dropped(self):
|
||||
api = _api()
|
||||
|
||||
async def duplicate_stream(*args, **kwargs):
|
||||
raise DuplicateChatMessageError("in flight")
|
||||
yield "" # pragma: no cover
|
||||
|
||||
api.stream_chat = duplicate_stream
|
||||
handler = MessageHandler(api)
|
||||
adapter = _adapter()
|
||||
|
||||
with patch(
|
||||
"backend.copilot.bot.handler.get_redis_async",
|
||||
new=AsyncMock(return_value=AsyncMock(get=AsyncMock(return_value=None))),
|
||||
):
|
||||
await handler._stream_batch(
|
||||
[("Bently", "u1", "hi")], _ctx(), adapter, "target-1"
|
||||
)
|
||||
|
||||
adapter.send_message.assert_not_awaited()
|
||||
|
||||
|
||||
class TestStreamFallback:
|
||||
"""Covers the empty-response fallback, including the boundary-flush bug
|
||||
where prior code posted 'AutoPilot didn't produce a response' even though
|
||||
content had already been flushed mid-stream.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _redis_patch():
|
||||
return patch(
|
||||
"backend.copilot.bot.handler.get_redis_async",
|
||||
new=AsyncMock(return_value=AsyncMock(get=AsyncMock(return_value=None))),
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_stream_sends_fallback(self):
|
||||
api = _api()
|
||||
|
||||
async def empty(*args, **kwargs):
|
||||
if False:
|
||||
yield ""
|
||||
|
||||
api.stream_chat = empty
|
||||
handler = MessageHandler(api)
|
||||
adapter = _adapter()
|
||||
|
||||
with TestStreamFallback._redis_patch():
|
||||
await handler._stream_batch(
|
||||
[("Bently", "u1", "hi")], _ctx(), adapter, "target-1"
|
||||
)
|
||||
|
||||
msgs = [c.args[1] for c in adapter.send_message.await_args_list]
|
||||
assert any("didn't produce a response" in m for m in msgs)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_whitespace_only_stream_sends_fallback(self):
|
||||
api = _api()
|
||||
|
||||
async def whitespace(*args, **kwargs):
|
||||
yield " "
|
||||
yield "\n\n"
|
||||
|
||||
api.stream_chat = whitespace
|
||||
handler = MessageHandler(api)
|
||||
adapter = _adapter()
|
||||
|
||||
with TestStreamFallback._redis_patch():
|
||||
await handler._stream_batch(
|
||||
[("Bently", "u1", "hi")], _ctx(), adapter, "target-1"
|
||||
)
|
||||
|
||||
msgs = [c.args[1] for c in adapter.send_message.await_args_list]
|
||||
assert any("didn't produce a response" in m for m in msgs)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_content_flushed_mid_stream_does_not_trigger_fallback(self):
|
||||
"""Regression: before the fix, a response that flushed exactly at a
|
||||
boundary left buffer == "" and the fallback fired after real content
|
||||
had already been posted.
|
||||
"""
|
||||
api = _api()
|
||||
adapter = _adapter()
|
||||
adapter.chunk_flush_at = 50
|
||||
|
||||
async def streaming_content(*args, **kwargs):
|
||||
# Exactly flush_at chars → split_at_boundary returns the whole
|
||||
# payload as the post and an empty remainder, so the stream ends
|
||||
# with buffer == "". That USED to fall into the `elif not buffer`
|
||||
# branch and send the "didn't produce a response" fallback.
|
||||
yield "x" * 50
|
||||
|
||||
api.stream_chat = streaming_content
|
||||
handler = MessageHandler(api)
|
||||
|
||||
with TestStreamFallback._redis_patch():
|
||||
await handler._stream_batch(
|
||||
[("Bently", "u1", "hi")], _ctx(), adapter, "target-1"
|
||||
)
|
||||
|
||||
msgs = [c.args[1] for c in adapter.send_message.await_args_list]
|
||||
assert not any("didn't produce a response" in m for m in msgs)
|
||||
assert msgs == ["x" * 50]
|
||||
80
autogpt_platform/backend/backend/copilot/bot/text.py
Normal file
80
autogpt_platform/backend/backend/copilot/bot/text.py
Normal file
@@ -0,0 +1,80 @@
|
||||
"""Text formatting helpers — message batching and chunk splitting."""
|
||||
|
||||
import re
|
||||
|
||||
# Matches a triple-backtick fence with an optional language tag. Used to tell
|
||||
# whether a cut falls inside an open Markdown code block.
|
||||
_CODE_FENCE = re.compile(r"```(\w*)")
|
||||
|
||||
|
||||
def format_batch(batch: list[tuple[str, str, str]], platform: str) -> str:
|
||||
"""Format one or more pending messages into a single prompt for AutoPilot.
|
||||
|
||||
Each batch entry is (username, user_id, text). When multiple messages are
|
||||
batched together (because they arrived while the bot was streaming a prior
|
||||
response), they're labelled individually so the LLM can address each.
|
||||
"""
|
||||
platform_display = platform.capitalize()
|
||||
if len(batch) == 1:
|
||||
username, user_id, text = batch[0]
|
||||
return (
|
||||
f"[Message sent by {username} ({platform_display} user ID: {user_id})]\n"
|
||||
f"{text}"
|
||||
)
|
||||
|
||||
lines = ["[Multiple messages — please address them together]"]
|
||||
for username, user_id, text in batch:
|
||||
lines.append(
|
||||
f"\n[From {username} ({platform_display} user ID: {user_id})]\n{text}"
|
||||
)
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def split_at_boundary(text: str, flush_at: int) -> tuple[str, str]:
|
||||
"""Split text at a natural boundary to fit within a length limit.
|
||||
|
||||
Returns (postable_chunk, remaining_text).
|
||||
Prefers: paragraph > newline > sentence end > space > hard cut.
|
||||
If the cut lands inside a Markdown code fence (``\\`\\`\\``), the fence is
|
||||
closed in the chunk and reopened at the start of the remainder so both
|
||||
sides render correctly.
|
||||
"""
|
||||
if len(text) <= flush_at:
|
||||
return text, ""
|
||||
|
||||
search_start = max(0, flush_at - 200)
|
||||
search_region = text[search_start:flush_at]
|
||||
|
||||
for sep in ("\n\n", "\n"):
|
||||
idx = search_region.rfind(sep)
|
||||
if idx != -1:
|
||||
cut = search_start + idx
|
||||
return _balance_code_fences(text[:cut].rstrip(), text[cut:].lstrip("\n"))
|
||||
|
||||
for sep in (". ", "! ", "? "):
|
||||
idx = search_region.rfind(sep)
|
||||
if idx != -1:
|
||||
cut = search_start + idx + len(sep)
|
||||
return _balance_code_fences(text[:cut], text[cut:])
|
||||
|
||||
idx = search_region.rfind(" ")
|
||||
if idx != -1:
|
||||
cut = search_start + idx
|
||||
return _balance_code_fences(text[:cut], text[cut:].lstrip())
|
||||
|
||||
return _balance_code_fences(text[:flush_at], text[flush_at:])
|
||||
|
||||
|
||||
def _balance_code_fences(before: str, after: str) -> tuple[str, str]:
|
||||
"""If ``before`` ends inside an open ``\\`\\`\\`` fence, close and reopen it.
|
||||
|
||||
Preserves the language tag from the opening fence so syntax highlighting
|
||||
survives the split.
|
||||
"""
|
||||
fences = _CODE_FENCE.findall(before)
|
||||
if len(fences) % 2 == 0:
|
||||
return before, after
|
||||
lang = fences[-1]
|
||||
closed_before = f"{before.rstrip()}\n```"
|
||||
reopened_after = f"```{lang}\n{after.lstrip()}"
|
||||
return closed_before, reopened_after
|
||||
105
autogpt_platform/backend/backend/copilot/bot/text_test.py
Normal file
105
autogpt_platform/backend/backend/copilot/bot/text_test.py
Normal file
@@ -0,0 +1,105 @@
|
||||
"""Tests for message batching + boundary splitting."""
|
||||
|
||||
from .text import _balance_code_fences, format_batch, split_at_boundary
|
||||
|
||||
|
||||
class TestFormatBatch:
|
||||
def test_single_message_has_header(self):
|
||||
result = format_batch([("Bently", "123", "hello")], "discord")
|
||||
assert result == "[Message sent by Bently (Discord user ID: 123)]\nhello"
|
||||
|
||||
def test_multi_message_labels_each_sender(self):
|
||||
result = format_batch(
|
||||
[
|
||||
("Alice", "a1", "first"),
|
||||
("Bob", "b2", "second"),
|
||||
],
|
||||
"discord",
|
||||
)
|
||||
assert "[Multiple messages" in result
|
||||
assert "[From Alice (Discord user ID: a1)]\nfirst" in result
|
||||
assert "[From Bob (Discord user ID: b2)]\nsecond" in result
|
||||
|
||||
def test_platform_name_is_capitalized(self):
|
||||
result = format_batch([("u", "1", "x")], "telegram")
|
||||
assert "Telegram user ID" in result
|
||||
|
||||
|
||||
class TestSplitAtBoundary:
|
||||
def test_short_text_returns_unchanged(self):
|
||||
before, after = split_at_boundary("short", 100)
|
||||
assert before == "short"
|
||||
assert after == ""
|
||||
|
||||
def test_splits_at_paragraph_boundary(self):
|
||||
text = "first paragraph.\n\nsecond paragraph that is long enough"
|
||||
before, after = split_at_boundary(text, 20)
|
||||
assert before == "first paragraph."
|
||||
assert after == "second paragraph that is long enough"
|
||||
|
||||
def test_splits_at_newline_when_no_paragraph(self):
|
||||
text = "line one\nline two line three line four line five"
|
||||
before, after = split_at_boundary(text, 15)
|
||||
assert before == "line one"
|
||||
assert after == "line two line three line four line five"
|
||||
|
||||
def test_splits_at_sentence_when_no_newline(self):
|
||||
text = "First sentence. Second sentence is quite a bit longer here."
|
||||
before, after = split_at_boundary(text, 20)
|
||||
assert before == "First sentence. "
|
||||
assert after == "Second sentence is quite a bit longer here."
|
||||
|
||||
def test_falls_back_to_space_split(self):
|
||||
text = "word " * 50
|
||||
before, after = split_at_boundary(text, 30)
|
||||
assert not before.endswith(" ")
|
||||
# Rejoining drops one space at the cut, but no characters other
|
||||
# than whitespace should be lost.
|
||||
rejoined = (before + " " + after).replace(" ", " ").strip()
|
||||
assert rejoined == text.strip()
|
||||
|
||||
def test_hard_cut_on_single_long_token(self):
|
||||
text = "a" * 500
|
||||
before, after = split_at_boundary(text, 100)
|
||||
assert len(before) == 100
|
||||
assert after == "a" * 400
|
||||
|
||||
|
||||
class TestBalanceCodeFences:
|
||||
def test_balanced_code_unchanged(self):
|
||||
before = "prose\n```py\nprint('x')\n```\ntail"
|
||||
after = "more"
|
||||
b, a = _balance_code_fences(before, after)
|
||||
assert b == before
|
||||
assert a == after
|
||||
|
||||
def test_open_fence_gets_closed_and_reopened(self):
|
||||
before = "prose\n```py\nprint('x')"
|
||||
after = "print('y')\n```\ntail"
|
||||
b, a = _balance_code_fences(before, after)
|
||||
assert b.endswith("```")
|
||||
assert a.startswith("```py\n")
|
||||
|
||||
def test_reopens_with_no_lang_when_opener_had_none(self):
|
||||
before = "```\nsome code here"
|
||||
after = "more code\n```"
|
||||
b, a = _balance_code_fences(before, after)
|
||||
assert b.endswith("\n```")
|
||||
assert a.startswith("```\n")
|
||||
|
||||
def test_preserves_latest_language_when_multiple_fences(self):
|
||||
before = "```py\nprint()\n```\nmiddle\n```ts\nconst x = 1"
|
||||
after = "const y = 2\n```"
|
||||
b, a = _balance_code_fences(before, after)
|
||||
assert b.endswith("```")
|
||||
assert a.startswith("```ts\n")
|
||||
|
||||
|
||||
class TestSplitAtBoundaryWithCodeFences:
|
||||
def test_split_inside_fence_rebalances(self):
|
||||
code_block = "```python\n" + ("line\n" * 500) + "```\nafter"
|
||||
before, after = split_at_boundary(code_block, 300)
|
||||
# ``before`` must close the fence it opened.
|
||||
assert before.count("```") % 2 == 0
|
||||
# ``after`` must reopen with the same language tag.
|
||||
assert after.lstrip().startswith("```python")
|
||||
25
autogpt_platform/backend/backend/copilot/bot/threads.py
Normal file
25
autogpt_platform/backend/backend/copilot/bot/threads.py
Normal file
@@ -0,0 +1,25 @@
|
||||
"""Thread subscription tracking.
|
||||
|
||||
When the bot creates a thread in response to an @mention, we record the
|
||||
thread ID so subsequent messages in it don't require another mention.
|
||||
Subscriptions live in Redis with a 7-day TTL — stale threads age out
|
||||
automatically.
|
||||
"""
|
||||
|
||||
from backend.data.redis_client import get_redis_async
|
||||
|
||||
THREAD_SUBSCRIPTION_TTL = 7 * 86400 # 7 days
|
||||
|
||||
|
||||
def _key(platform: str, thread_id: str) -> str:
|
||||
return f"copilot-bot:thread:{platform}:{thread_id}"
|
||||
|
||||
|
||||
async def is_subscribed(platform: str, thread_id: str) -> bool:
|
||||
redis = await get_redis_async()
|
||||
return bool(await redis.get(_key(platform, thread_id)))
|
||||
|
||||
|
||||
async def subscribe(platform: str, thread_id: str) -> None:
|
||||
redis = await get_redis_async()
|
||||
await redis.set(_key(platform, thread_id), "1", ex=THREAD_SUBSCRIPTION_TTL)
|
||||
55
autogpt_platform/backend/backend/copilot/bot/threads_test.py
Normal file
55
autogpt_platform/backend/backend/copilot/bot/threads_test.py
Normal file
@@ -0,0 +1,55 @@
|
||||
"""Tests for Redis-backed thread subscription tracking."""
|
||||
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from . import threads
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def redis_mock():
|
||||
mock = AsyncMock()
|
||||
mock.get = AsyncMock()
|
||||
mock.set = AsyncMock()
|
||||
with patch("backend.copilot.bot.threads.get_redis_async", return_value=mock):
|
||||
yield mock
|
||||
|
||||
|
||||
class TestSubscribe:
|
||||
@pytest.mark.asyncio
|
||||
async def test_writes_key_with_ttl(self, redis_mock):
|
||||
await threads.subscribe("discord", "thread-123")
|
||||
redis_mock.set.assert_awaited_once_with(
|
||||
"copilot-bot:thread:discord:thread-123",
|
||||
"1",
|
||||
ex=threads.THREAD_SUBSCRIPTION_TTL,
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_key_includes_platform(self, redis_mock):
|
||||
await threads.subscribe("telegram", "t-1")
|
||||
key = redis_mock.set.await_args.args[0]
|
||||
assert "telegram" in key
|
||||
assert "t-1" in key
|
||||
|
||||
|
||||
class TestIsSubscribed:
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_true_when_present(self, redis_mock):
|
||||
redis_mock.get.return_value = "1"
|
||||
assert await threads.is_subscribed("discord", "thread-1") is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_false_when_missing(self, redis_mock):
|
||||
redis_mock.get.return_value = None
|
||||
assert await threads.is_subscribed("discord", "thread-1") is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_uses_same_key_as_subscribe(self, redis_mock):
|
||||
redis_mock.get.return_value = None
|
||||
await threads.is_subscribed("discord", "thread-1")
|
||||
await threads.subscribe("discord", "thread-1")
|
||||
read_key = redis_mock.get.await_args.args[0]
|
||||
write_key = redis_mock.set.await_args.args[0]
|
||||
assert read_key == write_key
|
||||
@@ -395,11 +395,17 @@ class ChatConfig(BaseSettings):
|
||||
|
||||
@property
|
||||
def openrouter_active(self) -> bool:
|
||||
"""True when OpenRouter is enabled AND credentials are usable.
|
||||
"""True when OpenRouter config is shape-valid (flag + credentials).
|
||||
|
||||
Single source of truth for "will the SDK route through OpenRouter?".
|
||||
Checks the flag *and* that ``api_key`` + a valid ``base_url`` are
|
||||
present — mirrors the fallback logic in ``build_sdk_env``.
|
||||
Indicates whether OpenRouter settings are present and usable —
|
||||
``use_openrouter`` set, plus ``api_key`` + a valid ``base_url``,
|
||||
mirroring the fallback logic in ``build_sdk_env``.
|
||||
|
||||
Note: this is a **config-shape check only**. Runtime SDK routing
|
||||
is governed by ``effective_transport`` — subscription mode
|
||||
bypasses OpenRouter entirely even when these fields are set, so
|
||||
callers asking "will the SDK actually route through OpenRouter
|
||||
for this turn?" should use ``effective_transport`` instead.
|
||||
"""
|
||||
if not self.use_openrouter:
|
||||
return False
|
||||
@@ -408,6 +414,34 @@ class ChatConfig(BaseSettings):
|
||||
base = base[:-3]
|
||||
return bool(self.api_key and base and base.startswith("http"))
|
||||
|
||||
@property
|
||||
def effective_transport(
|
||||
self,
|
||||
) -> Literal["subscription", "openrouter", "direct_anthropic"]:
|
||||
"""The transport the SDK CLI subprocess actually uses for this turn.
|
||||
|
||||
Detection order:
|
||||
|
||||
1. ``subscription`` — when ``use_claude_code_subscription`` is True
|
||||
the CLI uses OAuth from the keychain or
|
||||
``CLAUDE_CODE_OAUTH_TOKEN`` and ignores ``CHAT_BASE_URL`` /
|
||||
``CHAT_API_KEY`` entirely (see ``build_sdk_env`` mode 1).
|
||||
2. ``openrouter`` — when ``openrouter_active`` (use_openrouter +
|
||||
api_key + a valid base_url).
|
||||
3. ``direct_anthropic`` — fallback (CLI talks to api.anthropic.com
|
||||
with ``ANTHROPIC_API_KEY`` from parent env).
|
||||
|
||||
Use this when the question is "which model-name format will the
|
||||
CLI accept?" — the OpenRouter slug ``anthropic/claude-opus-4.7``
|
||||
works through the proxy but is rejected by the subscription /
|
||||
direct-Anthropic transports.
|
||||
"""
|
||||
if self.use_claude_code_subscription:
|
||||
return "subscription"
|
||||
if self.openrouter_active:
|
||||
return "openrouter"
|
||||
return "direct_anthropic"
|
||||
|
||||
@property
|
||||
def e2b_active(self) -> bool:
|
||||
"""True when E2B is enabled and the API key is present.
|
||||
@@ -532,9 +566,13 @@ class ChatConfig(BaseSettings):
|
||||
(``claude_agent_fallback_model`` via ``_resolve_fallback_model``).
|
||||
|
||||
Skipped when ``use_claude_code_subscription=True`` because the
|
||||
subscription path resolves the model to ``None`` (CLI default)
|
||||
and never calls ``_normalize_model_name``. Empty fallback strings
|
||||
are also skipped (no fallback configured).
|
||||
subscription path normally resolves the static config to ``None``
|
||||
(CLI default). An LD-served override under subscription does
|
||||
flow through ``_normalize_model_name``; the runtime guard first
|
||||
falls back to the tier default, and only avoids a request error
|
||||
when that default is itself valid (otherwise the original LD
|
||||
ValueError is re-raised — see ``_resolve_sdk_model_for_request``).
|
||||
Empty fallback strings are also skipped (no fallback configured).
|
||||
"""
|
||||
if self.use_claude_code_subscription:
|
||||
return self
|
||||
|
||||
@@ -85,7 +85,7 @@ class CoPilotExecutor(AppProcess):
|
||||
self._run_client = None
|
||||
|
||||
self._task_locks: dict[str, ClusterLock] = {}
|
||||
self._active_tasks_lock = threading.Lock()
|
||||
self._active_tasks_lock_obj: threading.Lock | None = None
|
||||
|
||||
# ============ Main Entry Points (AppProcess interface) ============ #
|
||||
|
||||
@@ -502,6 +502,12 @@ class CoPilotExecutor(AppProcess):
|
||||
|
||||
# ============ Lazy-initialized Properties ============ #
|
||||
|
||||
@property
|
||||
def _active_tasks_lock(self) -> threading.Lock:
|
||||
if self._active_tasks_lock_obj is None:
|
||||
self._active_tasks_lock_obj = threading.Lock()
|
||||
return self._active_tasks_lock_obj
|
||||
|
||||
@property
|
||||
def cancel_thread(self) -> threading.Thread:
|
||||
if self._cancel_thread is None:
|
||||
|
||||
@@ -35,12 +35,14 @@ SHUTDOWN_ERROR_MESSAGE = (
|
||||
"Copilot executor shut down before this turn finished. Please retry."
|
||||
)
|
||||
|
||||
# Max time execute() blocks after calling future.cancel() / when draining a
|
||||
# soon-to-be-cancelled future. Gives _execute_async's own finally a chance to
|
||||
# publish the accurate terminal state over the Redis CAS; long enough to let
|
||||
# an in-flight Redis call settle, short enough that shutdown doesn't stall.
|
||||
# Max time execute() blocks after requesting async turn cancellation. The worker
|
||||
# waits for normal cleanup so late stream writes do not race the manager, but it
|
||||
# must still escape to the sync fail-close safety net if cleanup wedges.
|
||||
_CANCEL_GRACE_SECONDS = 5.0
|
||||
|
||||
# How long to wait before logging again that a cancelled turn is still draining.
|
||||
_CANCEL_DRAIN_LOG_INTERVAL_SECONDS = 1.0
|
||||
|
||||
# Max time the sync safety net itself spends on a single Redis CAS. Without
|
||||
# this bound the whole point of ``sync_fail_close_session`` is defeated —
|
||||
# ``mark_session_completed`` would hang on the same broken Redis that caused
|
||||
@@ -92,9 +94,11 @@ def sync_fail_close_session(
|
||||
timeout=_FAIL_CLOSE_REDIS_TIMEOUT,
|
||||
)
|
||||
|
||||
coro = _bounded()
|
||||
try:
|
||||
future = asyncio.run_coroutine_threadsafe(_bounded(), execution_loop)
|
||||
future = asyncio.run_coroutine_threadsafe(coro, execution_loop)
|
||||
except RuntimeError as e:
|
||||
coro.close()
|
||||
# execution_loop is closed — happens if cleanup() already ran the
|
||||
# per-worker teardown. Nothing we can do; let the stale-session
|
||||
# watchdog reap it.
|
||||
@@ -336,8 +340,7 @@ class CoPilotProcessor:
|
||||
|
||||
Thin wrapper around :meth:`_execute`. The ``try/finally`` here
|
||||
guarantees :func:`sync_fail_close_session` runs on every exit
|
||||
path — normal completion, exception, or a wedged event loop
|
||||
that escapes via :data:`_CANCEL_GRACE_SECONDS` timeout.
|
||||
path — normal completion or exception.
|
||||
``mark_session_completed`` is an atomic CAS on
|
||||
``status == "running"``, so when the async path already wrote a
|
||||
terminal state the sync call is a cheap no-op.
|
||||
@@ -370,40 +373,92 @@ class CoPilotProcessor:
|
||||
that lives in :func:`sync_fail_close_session` which the outer
|
||||
:meth:`execute` always invokes on exit.
|
||||
"""
|
||||
task_ready: concurrent.futures.Future[asyncio.Task] = (
|
||||
concurrent.futures.Future()
|
||||
)
|
||||
|
||||
async def run_async_turn():
|
||||
task = asyncio.current_task()
|
||||
if task is not None and not task_ready.done():
|
||||
task_ready.set_result(task)
|
||||
return await self._execute_async(entry, cancel, cluster_lock, log)
|
||||
|
||||
future = asyncio.run_coroutine_threadsafe(
|
||||
self._execute_async(entry, cancel, cluster_lock, log),
|
||||
run_async_turn(),
|
||||
self.execution_loop,
|
||||
)
|
||||
|
||||
# Wait for completion, checking cancel periodically
|
||||
while not future.done():
|
||||
cancel_requested = False
|
||||
cancel_started_at: float | None = None
|
||||
last_cancel_log_at: float | None = None
|
||||
|
||||
def request_cancel() -> None:
|
||||
nonlocal cancel_requested, cancel_started_at, last_cancel_log_at
|
||||
log.info("Cancellation requested")
|
||||
try:
|
||||
task = task_ready.result(timeout=0)
|
||||
except concurrent.futures.TimeoutError:
|
||||
# Sub-millisecond race: ``run_coroutine_threadsafe`` returned
|
||||
# before ``run_async_turn`` actually started, so
|
||||
# ``task_ready.set_result`` has not run yet. ``future.cancel``
|
||||
# on a ``concurrent.futures.Future`` whose underlying task may
|
||||
# already be picked up by the loop is best-effort — frequently
|
||||
# a no-op. The slow path is intentional: ``cancel.is_set()``
|
||||
# is polled inside ``_execute_async`` and the bounded
|
||||
# ``_CANCEL_GRACE_SECONDS`` drain below force-cancels and falls
|
||||
# through to ``sync_fail_close_session``, so the worst-case
|
||||
# observable behaviour is "cancel takes ~5s in this rare race"
|
||||
# rather than a stuck session.
|
||||
future.cancel()
|
||||
else:
|
||||
self.execution_loop.call_soon_threadsafe(task.cancel)
|
||||
cancel_requested = True
|
||||
cancel_started_at = time.monotonic()
|
||||
last_cancel_log_at = cancel_started_at
|
||||
|
||||
def log_cancel_wait() -> None:
|
||||
nonlocal last_cancel_log_at
|
||||
if cancel_started_at is None or last_cancel_log_at is None:
|
||||
return
|
||||
now = time.monotonic()
|
||||
if now - last_cancel_log_at < _CANCEL_DRAIN_LOG_INTERVAL_SECONDS:
|
||||
return
|
||||
elapsed = now - cancel_started_at
|
||||
log.warning(f"Waiting for cancelled turn to drain ({elapsed:.1f}s elapsed)")
|
||||
last_cancel_log_at = now
|
||||
|
||||
def cancel_drain_timed_out() -> bool:
|
||||
if cancel_started_at is None:
|
||||
return False
|
||||
elapsed = time.monotonic() - cancel_started_at
|
||||
if elapsed < _CANCEL_GRACE_SECONDS:
|
||||
return False
|
||||
log.warning(
|
||||
f"Cancelled turn did not drain within {_CANCEL_GRACE_SECONDS:.1f}s; "
|
||||
"falling through to sync fail-close"
|
||||
)
|
||||
future.cancel()
|
||||
return True
|
||||
|
||||
# Wait for completion, checking cancel periodically. A cancellation
|
||||
# request waits for normal async cleanup, but remains bounded so the
|
||||
# worker does not refresh the per-session lock forever on a wedged turn.
|
||||
while True:
|
||||
try:
|
||||
future.result(timeout=1.0)
|
||||
except asyncio.TimeoutError:
|
||||
if cancel.is_set():
|
||||
log.info("Cancellation requested")
|
||||
future.cancel()
|
||||
# Give _execute_async's own finally a short window to
|
||||
# publish its accurate terminal state before the outer
|
||||
# sync safety net fires.
|
||||
try:
|
||||
future.result(timeout=_CANCEL_GRACE_SECONDS)
|
||||
except BaseException:
|
||||
pass
|
||||
return
|
||||
except concurrent.futures.CancelledError:
|
||||
if cancel_requested or cancel.is_set():
|
||||
return
|
||||
cluster_lock.refresh()
|
||||
|
||||
if not future.cancelled():
|
||||
# Bounded timeout so a wedged event loop can't trap us here —
|
||||
# on timeout we escape to execute()'s finally and the sync
|
||||
# safety net fires.
|
||||
try:
|
||||
future.result(timeout=_CANCEL_GRACE_SECONDS)
|
||||
raise
|
||||
except concurrent.futures.TimeoutError:
|
||||
log.warning(
|
||||
"Future did not complete within grace window; "
|
||||
"falling through to sync fail-close"
|
||||
)
|
||||
if cancel.is_set() and not cancel_requested:
|
||||
request_cancel()
|
||||
elif cancel_requested and cancel_started_at is not None:
|
||||
if cancel_drain_timed_out():
|
||||
return
|
||||
log_cancel_wait()
|
||||
cluster_lock.refresh()
|
||||
|
||||
async def _execute_async(
|
||||
self,
|
||||
|
||||
@@ -496,3 +496,108 @@ class TestExecuteSafetyNet:
|
||||
assert call_log == [
|
||||
"sync-ok"
|
||||
], f"expected sync_fail_close_session to run once, got {call_log!r}"
|
||||
|
||||
def test_cancel_waits_for_async_task_to_finish(self, exec_loop) -> None:
|
||||
"""A cancel request must not let ``_execute`` return while the
|
||||
underlying asyncio task is still cleaning up. Returning early would
|
||||
make the manager release the session lock while late stream writes
|
||||
are still possible."""
|
||||
proc = CoPilotProcessor()
|
||||
self._attach_exec_loop(proc, exec_loop)
|
||||
|
||||
started = threading.Event()
|
||||
cancel_seen = threading.Event()
|
||||
release_cleanup = threading.Event()
|
||||
finished = threading.Event()
|
||||
|
||||
async def _stubborn_cancel(*_args, **_kwargs):
|
||||
started.set()
|
||||
try:
|
||||
await asyncio.sleep(3600)
|
||||
except asyncio.CancelledError:
|
||||
cancel_seen.set()
|
||||
while not release_cleanup.is_set():
|
||||
await asyncio.sleep(0.01)
|
||||
finally:
|
||||
finished.set()
|
||||
|
||||
proc._execute_async = _stubborn_cancel # type: ignore[method-assign]
|
||||
|
||||
cancel = threading.Event()
|
||||
pool = concurrent.futures.ThreadPoolExecutor(max_workers=1)
|
||||
try:
|
||||
fut = pool.submit(
|
||||
proc._execute,
|
||||
_make_entry(),
|
||||
cancel,
|
||||
MagicMock(),
|
||||
_make_log(),
|
||||
)
|
||||
assert started.wait(timeout=5)
|
||||
|
||||
cancel.set()
|
||||
assert cancel_seen.wait(timeout=5)
|
||||
assert not fut.done()
|
||||
|
||||
release_cleanup.set()
|
||||
fut.result(timeout=5)
|
||||
assert finished.is_set()
|
||||
finally:
|
||||
pool.shutdown(wait=True)
|
||||
|
||||
def test_cancel_wait_has_bounded_escape_hatch(self, exec_loop) -> None:
|
||||
"""A wedged async cleanup must not keep the worker refreshing the
|
||||
session lock forever; after the grace window, ``_execute`` returns
|
||||
so ``execute`` can run the sync fail-close safety net."""
|
||||
proc = CoPilotProcessor()
|
||||
self._attach_exec_loop(proc, exec_loop)
|
||||
|
||||
started = threading.Event()
|
||||
cancel_seen = threading.Event()
|
||||
release_cleanup = threading.Event()
|
||||
finished = threading.Event()
|
||||
|
||||
async def _wedged_cancel(*_args, **_kwargs):
|
||||
started.set()
|
||||
try:
|
||||
await asyncio.sleep(3600)
|
||||
except asyncio.CancelledError:
|
||||
cancel_seen.set()
|
||||
while not release_cleanup.is_set():
|
||||
try:
|
||||
await asyncio.sleep(0.01)
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
finally:
|
||||
finished.set()
|
||||
|
||||
proc._execute_async = _wedged_cancel # type: ignore[method-assign]
|
||||
|
||||
cancel = threading.Event()
|
||||
cluster_lock = MagicMock()
|
||||
pool = concurrent.futures.ThreadPoolExecutor(max_workers=1)
|
||||
try:
|
||||
with patch(
|
||||
"backend.copilot.executor.processor._CANCEL_GRACE_SECONDS",
|
||||
0.05,
|
||||
):
|
||||
fut = pool.submit(
|
||||
proc._execute,
|
||||
_make_entry(),
|
||||
cancel,
|
||||
cluster_lock,
|
||||
_make_log(),
|
||||
)
|
||||
assert started.wait(timeout=5)
|
||||
|
||||
cancel.set()
|
||||
assert cancel_seen.wait(timeout=5)
|
||||
fut.result(timeout=5)
|
||||
|
||||
assert not finished.is_set()
|
||||
assert cluster_lock.refresh.call_count < 10
|
||||
|
||||
release_cleanup.set()
|
||||
assert finished.wait(timeout=5)
|
||||
finally:
|
||||
pool.shutdown(wait=True)
|
||||
|
||||
@@ -71,7 +71,9 @@ COPILOT_EXECUTION_EXCHANGE = Exchange(
|
||||
durable=True,
|
||||
auto_delete=False,
|
||||
)
|
||||
COPILOT_EXECUTION_QUEUE_NAME = "copilot_execution_queue"
|
||||
# ``_v2`` suffix marks the classic→quorum rollover; old-image consumers
|
||||
# drain the unsuffixed queue. Orphans cleaned up in a follow-up PR.
|
||||
COPILOT_EXECUTION_QUEUE_NAME = "copilot_execution_queue_v2"
|
||||
COPILOT_EXECUTION_ROUTING_KEY = "copilot.run"
|
||||
|
||||
COPILOT_CANCEL_EXCHANGE = Exchange(
|
||||
@@ -80,7 +82,7 @@ COPILOT_CANCEL_EXCHANGE = Exchange(
|
||||
durable=True,
|
||||
auto_delete=False,
|
||||
)
|
||||
COPILOT_CANCEL_QUEUE_NAME = "copilot_cancel_queue"
|
||||
COPILOT_CANCEL_QUEUE_NAME = "copilot_cancel_queue_v2"
|
||||
|
||||
|
||||
def get_session_lock_key(session_id: str) -> str:
|
||||
@@ -118,6 +120,9 @@ def create_copilot_queue_config() -> RabbitMQConfig:
|
||||
durable=True,
|
||||
auto_delete=False,
|
||||
arguments={
|
||||
# Quorum (not classic mirrored) for leader election + stronger
|
||||
# replication across RabbitMQ 4.x cluster nodes.
|
||||
"x-queue-type": "quorum",
|
||||
# Consumer timeout matches the pod graceful-shutdown window so a
|
||||
# rolling deploy never forces redelivery of a turn that the pod
|
||||
# is still legitimately finishing.
|
||||
@@ -131,7 +136,7 @@ def create_copilot_queue_config() -> RabbitMQConfig:
|
||||
# limit), apply a policy:
|
||||
#
|
||||
# rabbitmqctl set_policy copilot-consumer-timeout \
|
||||
# "^copilot_execution_queue$" \
|
||||
# "^copilot_execution_queue_v2$" \
|
||||
# '{"consumer-timeout": 21600000}' \
|
||||
# --apply-to queues
|
||||
#
|
||||
@@ -139,8 +144,7 @@ def create_copilot_queue_config() -> RabbitMQConfig:
|
||||
# to match the code's value the policy is redundant for new
|
||||
# pods and can be removed after a stable deploy if desired —
|
||||
# but it's harmless to leave in place.
|
||||
"x-consumer-timeout": COPILOT_CONSUMER_TIMEOUT_SECONDS
|
||||
* 1000,
|
||||
"x-consumer-timeout": COPILOT_CONSUMER_TIMEOUT_SECONDS * 1000,
|
||||
},
|
||||
)
|
||||
cancel_queue = Queue(
|
||||
@@ -149,6 +153,7 @@ def create_copilot_queue_config() -> RabbitMQConfig:
|
||||
routing_key="", # not used for FANOUT
|
||||
durable=True,
|
||||
auto_delete=False,
|
||||
arguments={"x-queue-type": "quorum"},
|
||||
)
|
||||
return RabbitMQConfig(
|
||||
vhost="/",
|
||||
|
||||
@@ -1,9 +1,22 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from .falkordb_driver import AutoGPTFalkorDriver
|
||||
|
||||
|
||||
def test_build_fulltext_query_uses_unquoted_group_ids_for_falkordb() -> None:
|
||||
driver = AutoGPTFalkorDriver()
|
||||
@pytest.fixture
|
||||
def driver() -> AutoGPTFalkorDriver:
|
||||
# ``build_fulltext_query`` is a pure string-builder that never touches
|
||||
# the FalkorDB client; injecting a mock avoids the eager Redis probe
|
||||
# that the upstream ``FalkorDriver.__init__`` runs against
|
||||
# ``localhost:6379``.
|
||||
return AutoGPTFalkorDriver(falkor_db=MagicMock())
|
||||
|
||||
|
||||
def test_build_fulltext_query_uses_unquoted_group_ids_for_falkordb(
|
||||
driver: AutoGPTFalkorDriver,
|
||||
) -> None:
|
||||
query = driver.build_fulltext_query(
|
||||
"Sarah",
|
||||
group_ids=["user_883cc9da-fe37-4863-839b-acba022bf3ef"],
|
||||
@@ -13,18 +26,18 @@ def test_build_fulltext_query_uses_unquoted_group_ids_for_falkordb() -> None:
|
||||
assert '"user_883cc9da-fe37-4863-839b-acba022bf3ef"' not in query
|
||||
|
||||
|
||||
def test_build_fulltext_query_joins_multiple_group_ids_with_or() -> None:
|
||||
driver = AutoGPTFalkorDriver()
|
||||
|
||||
def test_build_fulltext_query_joins_multiple_group_ids_with_or(
|
||||
driver: AutoGPTFalkorDriver,
|
||||
) -> None:
|
||||
query = driver.build_fulltext_query("Sarah", group_ids=["user_a", "user_b"])
|
||||
|
||||
assert query == "(@group_id:user_a|user_b) (Sarah)"
|
||||
|
||||
|
||||
def test_stopwords_only_query_returns_group_filter_only() -> None:
|
||||
def test_stopwords_only_query_returns_group_filter_only(
|
||||
driver: AutoGPTFalkorDriver,
|
||||
) -> None:
|
||||
"""Line 25: sanitized_query is empty (all stopwords) but group_ids present."""
|
||||
driver = AutoGPTFalkorDriver()
|
||||
|
||||
# "the" is a common stopword — the query should reduce to just the group filter.
|
||||
query = driver.build_fulltext_query(
|
||||
"the",
|
||||
@@ -34,10 +47,10 @@ def test_stopwords_only_query_returns_group_filter_only() -> None:
|
||||
assert query == "(@group_id:user_abc)"
|
||||
|
||||
|
||||
def test_query_without_group_ids_returns_parenthesized_query() -> None:
|
||||
def test_query_without_group_ids_returns_parenthesized_query(
|
||||
driver: AutoGPTFalkorDriver,
|
||||
) -> None:
|
||||
"""Line 27: sanitized_query has content but no group_ids provided."""
|
||||
driver = AutoGPTFalkorDriver()
|
||||
|
||||
query = driver.build_fulltext_query("Sarah", group_ids=None)
|
||||
|
||||
assert query == "(Sarah)"
|
||||
|
||||
@@ -21,8 +21,10 @@ from backend.copilot.pending_messages import (
|
||||
drain_pending_messages,
|
||||
format_pending_as_user_message,
|
||||
push_pending_message,
|
||||
push_pending_message_if_session_running,
|
||||
)
|
||||
from backend.copilot.stream_registry import get_session as get_active_session_meta
|
||||
from backend.copilot.stream_registry import get_session_meta_key
|
||||
from backend.data.redis_client import get_redis_async
|
||||
from backend.data.redis_helpers import incr_with_ttl
|
||||
from backend.data.workspace import resolve_workspace_files
|
||||
@@ -44,8 +46,8 @@ _PENDING_CALL_KEY_PREFIX = "copilot:pending:calls:"
|
||||
async def is_turn_in_flight(session_id: str) -> bool:
|
||||
"""Return ``True`` when a copilot turn is actively running for *session_id*.
|
||||
|
||||
Used by the unified POST /stream entry point and the autopilot block so
|
||||
a second message arriving while an earlier turn is still executing gets
|
||||
Used by the HTTP pending-message endpoint and the autopilot block so a
|
||||
second message arriving while an earlier turn is still executing gets
|
||||
queued into the pending buffer instead of racing the in-flight turn on
|
||||
the cluster lock.
|
||||
"""
|
||||
@@ -54,15 +56,14 @@ async def is_turn_in_flight(session_id: str) -> bool:
|
||||
|
||||
|
||||
class QueuePendingMessageResponse(BaseModel):
|
||||
"""Response returned by ``POST /stream`` with status 202 when a message
|
||||
is queued because the session already has a turn in flight.
|
||||
"""Response returned when a message is queued because the session already
|
||||
has a turn in flight.
|
||||
|
||||
- ``buffer_length``: how many messages are now in the session's
|
||||
pending buffer (after this push)
|
||||
- ``max_buffer_length``: the per-session cap (server-side constant)
|
||||
- ``turn_in_flight``: ``True`` if a copilot turn was running when
|
||||
we checked — purely informational for UX feedback. Always ``True``
|
||||
for responses from ``POST /stream`` with status 202.
|
||||
we checked — purely informational for UX feedback.
|
||||
"""
|
||||
|
||||
buffer_length: int
|
||||
@@ -76,11 +77,12 @@ async def queue_user_message(
|
||||
message: str,
|
||||
context: PendingMessageContext | None = None,
|
||||
file_ids: list[str] | None = None,
|
||||
require_turn_in_flight: bool = False,
|
||||
) -> QueuePendingMessageResponse:
|
||||
"""Push *message* into the per-session pending buffer.
|
||||
|
||||
The shared primitive for "a message arrived while a turn is in flight" —
|
||||
called from the unified POST /stream handler and the autopilot block.
|
||||
called from the HTTP pending-message path and the autopilot block.
|
||||
Call-frequency rate limiting is the caller's responsibility (HTTP path
|
||||
enforces it; internal block callers skip it).
|
||||
"""
|
||||
@@ -89,6 +91,18 @@ async def queue_user_message(
|
||||
file_ids=file_ids or [],
|
||||
context=context,
|
||||
)
|
||||
if require_turn_in_flight:
|
||||
new_len = await push_pending_message_if_session_running(
|
||||
session_id,
|
||||
pending,
|
||||
session_meta_key=get_session_meta_key(session_id),
|
||||
)
|
||||
return QueuePendingMessageResponse(
|
||||
buffer_length=new_len or 0,
|
||||
max_buffer_length=MAX_PENDING_MESSAGES,
|
||||
turn_in_flight=new_len is not None,
|
||||
)
|
||||
|
||||
new_len = await push_pending_message(session_id, pending)
|
||||
return QueuePendingMessageResponse(
|
||||
buffer_length=new_len,
|
||||
@@ -107,7 +121,7 @@ async def queue_pending_for_http(
|
||||
) -> QueuePendingMessageResponse:
|
||||
"""HTTP-facing wrapper around :func:`queue_user_message`.
|
||||
|
||||
Owns the HTTP-only concerns that sat inline in ``stream_chat_post``:
|
||||
Owns the HTTP-only concerns for the pending-message route:
|
||||
|
||||
1. Per-user call-rate cap (429 on overflow).
|
||||
2. File-ID sanitisation against the user's own workspace.
|
||||
@@ -116,19 +130,8 @@ async def queue_pending_for_http(
|
||||
|
||||
Raises :class:`HTTPException` with status 429 if the rate cap is hit;
|
||||
otherwise returns the ``QueuePendingMessageResponse`` the handler can
|
||||
serialise 1:1 into the 202 body.
|
||||
serialise 1:1.
|
||||
"""
|
||||
call_count = await check_pending_call_rate(user_id)
|
||||
if call_count > PENDING_CALL_LIMIT:
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail=(
|
||||
f"Too many queued message requests this minute: limit is "
|
||||
f"{PENDING_CALL_LIMIT} per {PENDING_CALL_WINDOW_SECONDS}s "
|
||||
"across all sessions"
|
||||
),
|
||||
)
|
||||
|
||||
sanitized_file_ids: list[str] | None = None
|
||||
if file_ids:
|
||||
files = await resolve_workspace_files(user_id, file_ids)
|
||||
@@ -141,12 +144,41 @@ async def queue_pending_for_http(
|
||||
# typos, but the upstream ``StreamChatRequest.context: dict[str, str]``
|
||||
# is already schemaless, so the strict mode adds no real safety.
|
||||
queue_context = PendingMessageContext.model_validate(context) if context else None
|
||||
return await queue_user_message(
|
||||
|
||||
# Push first via the Lua CAS gate. Bumping the per-user call-rate
|
||||
# counter BEFORE the push would charge a budget tick on every TOCTOU
|
||||
# loss against turn completion (status flips running→completed between
|
||||
# the FE's is_turn_in_flight check and our gate), which both this
|
||||
# endpoint and the POST /stream queue-fall-through can hit. Pushing
|
||||
# first lets the gate own the no-op short-circuit.
|
||||
response = await queue_user_message(
|
||||
session_id=session_id,
|
||||
message=message,
|
||||
context=queue_context,
|
||||
file_ids=sanitized_file_ids,
|
||||
require_turn_in_flight=True,
|
||||
)
|
||||
if not response.turn_in_flight:
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail="Session has no active turn. Start a new turn with POST /stream.",
|
||||
)
|
||||
|
||||
# Push landed — now charge the rate counter. If this tick crosses the
|
||||
# limit we still keep the queued message (next drain will pick it up)
|
||||
# but report 429 so the client backs off further pushes.
|
||||
call_count = await check_pending_call_rate(user_id)
|
||||
if call_count > PENDING_CALL_LIMIT:
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail=(
|
||||
f"Too many queued message requests this minute: limit is "
|
||||
f"{PENDING_CALL_LIMIT} per {PENDING_CALL_WINDOW_SECONDS}s "
|
||||
"across all sessions"
|
||||
),
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
async def check_pending_call_rate(user_id: str) -> int:
|
||||
@@ -366,14 +398,35 @@ async def persist_pending_as_user_rows(
|
||||
transcript_builder.restore(transcript_snapshot)
|
||||
if on_rollback is not None:
|
||||
on_rollback(session_anchor)
|
||||
# ``push_pending_message`` uses the bounded ``capped_rpush`` (LTRIM
|
||||
# to ``MAX_PENDING_MESSAGES``). If ≥``MAX_PENDING_MESSAGES`` fresh
|
||||
# follow-ups arrived between the original drain and this rollback
|
||||
# (heavy typing across a tool boundary), the LTRIM drops oldest
|
||||
# entries — which can include the ones we just re-pushed. The model
|
||||
# already saw that content (mid-turn injection earlier in the
|
||||
# turn), but no DB row lands so the user sees no UI bubble.
|
||||
# Surface a warning so the bounded data-loss path is visible in
|
||||
# prod (it is rare and would otherwise be observable only via the
|
||||
# absence of a bubble).
|
||||
rollback_buffer_at_cap = False
|
||||
for pm in pending:
|
||||
try:
|
||||
await push_pending_message(session.session_id, pm)
|
||||
new_length = await push_pending_message(session.session_id, pm)
|
||||
if new_length >= MAX_PENDING_MESSAGES:
|
||||
rollback_buffer_at_cap = True
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"%s Failed to re-queue mid-turn follow-up on rollback",
|
||||
log_prefix,
|
||||
)
|
||||
if rollback_buffer_at_cap:
|
||||
logger.warning(
|
||||
"%s Rollback re-push hit pending-buffer cap (MAX=%d); a "
|
||||
"previously queued follow-up may have been LTRIM-displaced "
|
||||
"(silent UI-bubble drop). Investigate if observed.",
|
||||
log_prefix,
|
||||
MAX_PENDING_MESSAGES,
|
||||
)
|
||||
return False
|
||||
|
||||
logger.info(
|
||||
|
||||
@@ -4,17 +4,20 @@ from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
|
||||
from backend.copilot import pending_message_helpers as helpers_module
|
||||
from backend.copilot.pending_message_helpers import (
|
||||
PENDING_CALL_LIMIT,
|
||||
QueuePendingMessageResponse,
|
||||
check_pending_call_rate,
|
||||
combine_pending_with_current,
|
||||
drain_pending_safe,
|
||||
insert_pending_before_last,
|
||||
persist_session_safe,
|
||||
queue_pending_for_http,
|
||||
)
|
||||
from backend.copilot.pending_messages import PendingMessage
|
||||
from backend.copilot.pending_messages import MAX_PENDING_MESSAGES, PendingMessage
|
||||
|
||||
# ── check_pending_call_rate ────────────────────────────────────────────
|
||||
|
||||
@@ -46,6 +49,112 @@ async def test_check_pending_call_rate_fails_open_on_redis_error(
|
||||
assert result == 0
|
||||
|
||||
|
||||
# ── queue_pending_for_http: gate-then-bump ordering ───────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_queue_pending_does_not_charge_rate_on_toctou_409(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""When the Lua gate refuses the push because the turn just completed,
|
||||
the per-user call-rate counter must NOT have been incremented — bumping
|
||||
it before the gate would charge a budget tick for every TOCTOU loss
|
||||
against turn completion (race that both this endpoint and the POST
|
||||
/stream queue-fall-through can trigger)."""
|
||||
monkeypatch.setattr(
|
||||
helpers_module,
|
||||
"queue_user_message",
|
||||
AsyncMock(
|
||||
return_value=QueuePendingMessageResponse(
|
||||
buffer_length=0,
|
||||
max_buffer_length=MAX_PENDING_MESSAGES,
|
||||
turn_in_flight=False,
|
||||
)
|
||||
),
|
||||
)
|
||||
rate_mock = AsyncMock(return_value=1)
|
||||
monkeypatch.setattr(helpers_module, "check_pending_call_rate", rate_mock)
|
||||
monkeypatch.setattr(
|
||||
helpers_module, "resolve_workspace_files", AsyncMock(return_value=[])
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await queue_pending_for_http(
|
||||
session_id="sess-1",
|
||||
user_id="user-1",
|
||||
message="hi",
|
||||
context=None,
|
||||
file_ids=None,
|
||||
)
|
||||
assert exc_info.value.status_code == 409
|
||||
rate_mock.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_queue_pending_charges_rate_only_after_successful_push(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
response = QueuePendingMessageResponse(
|
||||
buffer_length=2,
|
||||
max_buffer_length=MAX_PENDING_MESSAGES,
|
||||
turn_in_flight=True,
|
||||
)
|
||||
queue_mock = AsyncMock(return_value=response)
|
||||
monkeypatch.setattr(helpers_module, "queue_user_message", queue_mock)
|
||||
rate_mock = AsyncMock(return_value=PENDING_CALL_LIMIT)
|
||||
monkeypatch.setattr(helpers_module, "check_pending_call_rate", rate_mock)
|
||||
monkeypatch.setattr(
|
||||
helpers_module, "resolve_workspace_files", AsyncMock(return_value=[])
|
||||
)
|
||||
|
||||
result = await queue_pending_for_http(
|
||||
session_id="sess-1",
|
||||
user_id="user-1",
|
||||
message="hi",
|
||||
context=None,
|
||||
file_ids=None,
|
||||
)
|
||||
|
||||
assert result is response
|
||||
queue_mock.assert_awaited_once()
|
||||
rate_mock.assert_awaited_once_with("user-1")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_queue_pending_429_after_push_when_limit_exceeded(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""When the post-push rate counter crosses the limit, the message stays
|
||||
in the buffer (next drain will pick it up) but the response is 429 so
|
||||
the client backs off."""
|
||||
response = QueuePendingMessageResponse(
|
||||
buffer_length=3,
|
||||
max_buffer_length=MAX_PENDING_MESSAGES,
|
||||
turn_in_flight=True,
|
||||
)
|
||||
queue_mock = AsyncMock(return_value=response)
|
||||
monkeypatch.setattr(helpers_module, "queue_user_message", queue_mock)
|
||||
monkeypatch.setattr(
|
||||
helpers_module,
|
||||
"check_pending_call_rate",
|
||||
AsyncMock(return_value=PENDING_CALL_LIMIT + 1),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
helpers_module, "resolve_workspace_files", AsyncMock(return_value=[])
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await queue_pending_for_http(
|
||||
session_id="sess-1",
|
||||
user_id="user-1",
|
||||
message="hi",
|
||||
context=None,
|
||||
file_ids=None,
|
||||
)
|
||||
assert exc_info.value.status_code == 429
|
||||
queue_mock.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_pending_call_rate_at_limit(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
|
||||
@@ -29,32 +29,21 @@ from typing import Any, cast
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
|
||||
from backend.data.redis_client import get_redis_async
|
||||
from backend.data.redis_helpers import capped_rpush
|
||||
from backend.data.redis_helpers import capped_rpush, capped_rpush_if_hash_field
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Per-session cap. Higher values risk a runaway consumer; lower values
|
||||
# risk dropping user input under heavy typing. 10 was chosen as a
|
||||
# reasonable ceiling — a user typing faster than the copilot can drain
|
||||
# between tool rounds is already an unusual usage pattern.
|
||||
# Per-session cap; typing faster than the copilot drains is already unusual.
|
||||
MAX_PENDING_MESSAGES = 10
|
||||
|
||||
# Redis key + TTL. The buffer is ephemeral: if a turn completes or the
|
||||
# executor dies, the pending messages should either have been drained
|
||||
# already or are safe to drop (the user can resend).
|
||||
# Ephemeral buffer: undrained messages are safe to drop at TTL expiry.
|
||||
_PENDING_KEY_PREFIX = "copilot:pending:"
|
||||
_PENDING_CHANNEL_PREFIX = "copilot:pending:notify:"
|
||||
_PENDING_TTL_SECONDS = 3600 # 1 hour — matches stream_ttl default
|
||||
|
||||
# Secondary queue that carries drained-but-awaiting-persist PendingMessages
|
||||
# from the MCP tool wrapper (which drains the primary buffer and injects
|
||||
# into tool output for the LLM) to sdk/service.py's _dispatch_response
|
||||
# handler for StreamToolOutputAvailable, which pops and persists them as a
|
||||
# separate user row chronologically after the tool_result row. This is the
|
||||
# hand-off between "Claude saw the follow-up mid-turn" (wrapper) and "UI
|
||||
# renders a user bubble for it" (service). Rollback path re-queues into
|
||||
# the PRIMARY buffer so the next turn-start drain picks them up if the
|
||||
# user-row persist fails.
|
||||
# Secondary queue: carries drained-but-awaiting-persist PendingMessages from
|
||||
# the tool wrapper (which injects them into tool output) to sdk/service.py
|
||||
# (which persists a user row after the tool_result row).
|
||||
_PERSIST_QUEUE_KEY_PREFIX = "copilot:pending-persist:"
|
||||
|
||||
# Payload sent on the pub/sub notify channel. Subscribers treat any
|
||||
@@ -65,13 +54,8 @@ _NOTIFY_PAYLOAD = "1"
|
||||
class PendingMessageContext(BaseModel):
|
||||
"""Structured page context attached to a pending message.
|
||||
|
||||
Default ``extra='ignore'`` (pydantic's default): unknown keys from
|
||||
the loose HTTP-level ``StreamChatRequest.context: dict[str, str]``
|
||||
are silently dropped rather than raising ``ValidationError`` on
|
||||
forward-compat additions. The strict ``extra='forbid'`` mode was
|
||||
removed after sentry r3105553772 — strict validation at this
|
||||
boundary only added a 500 footgun; the upstream request model is
|
||||
already schemaless so strict mode protects nothing.
|
||||
Unknown keys are silently dropped: the upstream request model is
|
||||
``dict[str, str]``, so strict validation here only adds 500 footguns.
|
||||
"""
|
||||
|
||||
url: str | None = Field(default=None, max_length=2_000)
|
||||
@@ -84,19 +68,16 @@ class PendingMessage(BaseModel):
|
||||
content: str = Field(min_length=1, max_length=32_000)
|
||||
file_ids: list[str] = Field(default_factory=list, max_length=20)
|
||||
context: PendingMessageContext | None = None
|
||||
# Wall-clock time (unix seconds, float) the message was queued by the
|
||||
# user. Used by the turn-start drain to order pending relative to the
|
||||
# turn's ``current`` message: items typed *before* the current's
|
||||
# /stream arrival go ahead of it; items typed *after* (race path,
|
||||
# queued while the /stream HTTP request was still processing) go
|
||||
# after. Defaults to 0.0 for backward compatibility with entries
|
||||
# written before this field existed — those sort as "before everything"
|
||||
# which matches the pre-fix behaviour.
|
||||
# Enqueue time (unix seconds) so the turn-start drain can order pending
|
||||
# messages relative to the turn's ``current`` message.
|
||||
enqueued_at: float = Field(default_factory=time.time)
|
||||
|
||||
|
||||
def _buffer_key(session_id: str) -> str:
|
||||
return f"{_PENDING_KEY_PREFIX}{session_id}"
|
||||
# Hash-tag braces colocate this key with stream_registry's session-meta key
|
||||
# on the same Redis Cluster slot, which the gated-rpush Lua script needs
|
||||
# (multi-key scripts return CROSSSLOT when KEYS hash to different slots).
|
||||
return f"{_PENDING_KEY_PREFIX}{{{session_id}}}"
|
||||
|
||||
|
||||
def _notify_channel(session_id: str) -> str:
|
||||
@@ -104,12 +85,7 @@ def _notify_channel(session_id: str) -> str:
|
||||
|
||||
|
||||
def _decode_redis_item(item: Any) -> str:
|
||||
"""Decode a redis-py list item to a str.
|
||||
|
||||
redis-py returns ``bytes`` when ``decode_responses=False`` and ``str``
|
||||
when ``decode_responses=True``. This helper handles both so callers
|
||||
don't have to repeat the isinstance guard.
|
||||
"""
|
||||
"""Decode a redis-py list item to str (handles ``bytes`` and ``str``)."""
|
||||
return item.decode("utf-8") if isinstance(item, bytes) else str(item)
|
||||
|
||||
|
||||
@@ -117,22 +93,11 @@ async def push_pending_message(
|
||||
session_id: str,
|
||||
message: PendingMessage,
|
||||
) -> int:
|
||||
"""Append a pending message to the session's buffer.
|
||||
"""Append a pending message to the session's buffer, capped at
|
||||
``MAX_PENDING_MESSAGES`` (oldest trimmed). Returns the new buffer length.
|
||||
|
||||
Returns the new buffer length. Enforces ``MAX_PENDING_MESSAGES`` by
|
||||
trimming from the left (oldest) — the newest message always wins if
|
||||
the user has been typing faster than the copilot can drain.
|
||||
|
||||
Delegates to :func:`backend.data.redis_helpers.capped_rpush` so RPUSH
|
||||
+ LTRIM + EXPIRE + LLEN run atomically (MULTI/EXEC) in one round
|
||||
trip; a concurrent drain (LPOP) can no longer observe the list
|
||||
temporarily over ``MAX_PENDING_MESSAGES``.
|
||||
|
||||
Note on durability: if the executor turn crashes after a push but before
|
||||
the drain window runs, the message remains in Redis until the TTL expires
|
||||
(``_PENDING_TTL_SECONDS``, currently 1 hour). It is delivered on the
|
||||
next turn that drains the buffer. If no turn runs within the TTL the
|
||||
message is silently dropped; the user may resend it.
|
||||
The buffer survives consumer crashes until ``_PENDING_TTL_SECONDS``
|
||||
expires; messages not drained within that window are dropped.
|
||||
"""
|
||||
redis = await get_redis_async()
|
||||
key = _buffer_key(session_id)
|
||||
@@ -146,10 +111,14 @@ async def push_pending_message(
|
||||
ttl_seconds=_PENDING_TTL_SECONDS,
|
||||
)
|
||||
|
||||
# Fire-and-forget notify. Subscribers use this as a wake-up hint;
|
||||
# the buffer itself is authoritative so a lost notify is harmless.
|
||||
# Fire-and-forget wake-up hint via sharded pub/sub (SPUBLISH routes to
|
||||
# one shard vs classic PUBLISH's cluster-bus broadcast). Use
|
||||
# execute_command because redis-py 6.x AsyncRedisCluster has no
|
||||
# spublish() wrapper.
|
||||
try:
|
||||
await redis.publish(_notify_channel(session_id), _NOTIFY_PAYLOAD)
|
||||
await redis.execute_command(
|
||||
"SPUBLISH", _notify_channel(session_id), _NOTIFY_PAYLOAD
|
||||
)
|
||||
except Exception as e: # pragma: no cover
|
||||
logger.warning("pending_messages: publish failed for %s: %s", session_id, e)
|
||||
|
||||
@@ -161,6 +130,51 @@ async def push_pending_message(
|
||||
return new_length
|
||||
|
||||
|
||||
async def push_pending_message_if_session_running(
|
||||
session_id: str,
|
||||
message: PendingMessage,
|
||||
*,
|
||||
session_meta_key: str,
|
||||
) -> int | None:
|
||||
"""Append a pending message only while the stream meta is still running."""
|
||||
redis = await get_redis_async()
|
||||
key = _buffer_key(session_id)
|
||||
payload = message.model_dump_json()
|
||||
|
||||
new_length = await capped_rpush_if_hash_field(
|
||||
redis,
|
||||
hash_key=session_meta_key,
|
||||
hash_field="status",
|
||||
expected="running",
|
||||
list_key=key,
|
||||
value=payload,
|
||||
max_len=MAX_PENDING_MESSAGES,
|
||||
ttl_seconds=_PENDING_TTL_SECONDS,
|
||||
)
|
||||
if new_length is None:
|
||||
logger.info(
|
||||
"pending_messages: skipped push to session=%s because no running turn exists",
|
||||
session_id,
|
||||
)
|
||||
return None
|
||||
|
||||
# Match push_pending_message: SPUBLISH via execute_command so it works on
|
||||
# both Redis and AsyncRedisCluster (the cluster client has no publish()).
|
||||
try:
|
||||
await redis.execute_command(
|
||||
"SPUBLISH", _notify_channel(session_id), _NOTIFY_PAYLOAD
|
||||
)
|
||||
except Exception as e: # pragma: no cover
|
||||
logger.warning("pending_messages: publish failed for %s: %s", session_id, e)
|
||||
|
||||
logger.info(
|
||||
"pending_messages: pushed message to running session=%s (buffer_len=%d)",
|
||||
session_id,
|
||||
new_length,
|
||||
)
|
||||
return new_length
|
||||
|
||||
|
||||
async def drain_pending_messages(session_id: str) -> list[PendingMessage]:
|
||||
"""Atomically pop all pending messages for *session_id*.
|
||||
|
||||
@@ -171,13 +185,8 @@ async def drain_pending_messages(session_id: str) -> list[PendingMessage]:
|
||||
redis = await get_redis_async()
|
||||
key = _buffer_key(session_id)
|
||||
|
||||
# Redis LPOP with count (Redis 6.2+) returns None for missing key,
|
||||
# empty list if we somehow race an empty key, or the popped items.
|
||||
# Draining MAX_PENDING_MESSAGES at once is safe because the push side
|
||||
# uses RPUSH + LTRIM(-MAX_PENDING_MESSAGES, -1) to cap the list to that
|
||||
# same value, so the list can never hold more items than we drain here.
|
||||
# If the cap is raised on the push side, raise the drain count here too
|
||||
# (or switch to a loop drain).
|
||||
# LPOP with count drains everything in one round-trip; the push side
|
||||
# caps the list at MAX_PENDING_MESSAGES so nothing is left behind.
|
||||
lpop_result = await redis.lpop(key, MAX_PENDING_MESSAGES) # type: ignore[assignment]
|
||||
if not lpop_result:
|
||||
return []
|
||||
@@ -241,24 +250,17 @@ async def peek_pending_messages(session_id: str) -> list[PendingMessage]:
|
||||
|
||||
|
||||
async def clear_pending_messages_unsafe(session_id: str) -> None:
|
||||
"""Drop the session's pending buffer — **not** the normal turn cleanup.
|
||||
"""Drop the session's pending buffer — operator/debug escape hatch.
|
||||
|
||||
The ``_unsafe`` suffix warns: reaching for this at turn end drops queued
|
||||
follow-ups on the floor instead of running them (the bug fixed by commit
|
||||
b64be73). The atomic ``LPOP`` drain at turn start is the primary consumer;
|
||||
anything pushed after the drain window belongs to the next turn by
|
||||
definition. Retained only as an operator/debug escape hatch for manually
|
||||
clearing a stuck session and as a fixture in the unit tests.
|
||||
The ``_unsafe`` suffix warns: normal turn cleanup uses the atomic LPOP
|
||||
drain; this bypass drops queued follow-ups on the floor.
|
||||
"""
|
||||
redis = await get_redis_async()
|
||||
await redis.delete(_buffer_key(session_id))
|
||||
|
||||
|
||||
# Per-message and total-block caps for inline tool-boundary injection.
|
||||
# Per-message keeps a single long paste from dominating; the total cap
|
||||
# keeps the follow-up block small relative to the 100 KB MCP truncation
|
||||
# boundary so tool output always stays the larger share of the wrapper
|
||||
# return value.
|
||||
# Per-message + total caps keep the follow-up block bounded relative to the
|
||||
# 100 KB MCP tool-output truncation boundary.
|
||||
_FOLLOWUP_CONTENT_MAX_CHARS = 2_000
|
||||
_FOLLOWUP_TOTAL_MAX_CHARS = 6_000
|
||||
|
||||
@@ -273,17 +275,9 @@ async def stash_pending_for_persist(
|
||||
) -> None:
|
||||
"""Enqueue drained PendingMessages for UI-row persistence.
|
||||
|
||||
Writes each message as a JSON payload to
|
||||
``copilot:pending-persist:{session_id}``. The SDK service's
|
||||
tool-result dispatch handler LPOPs this queue right after appending
|
||||
the tool_result row to ``session.messages``, so the resulting user
|
||||
row lands at the correct chronological position (after the tool
|
||||
output the follow-up was drained against).
|
||||
|
||||
Fire-and-forget on Redis failures: a stash failure means Claude
|
||||
still saw the follow-up in tool output (the injection step ran
|
||||
first), so the only consequence is a missing UI bubble. Logged
|
||||
so it can be spotted.
|
||||
The SDK service LPOPs this right after appending the tool_result row so
|
||||
the user bubble lands after the tool output. Stash failures are logged
|
||||
but not raised — the only consequence is a missing UI bubble.
|
||||
"""
|
||||
if not messages:
|
||||
return
|
||||
@@ -336,8 +330,7 @@ async def drain_pending_for_persist(session_id: str) -> list[PendingMessage]:
|
||||
)
|
||||
except (json.JSONDecodeError, ValidationError, TypeError, ValueError) as e:
|
||||
logger.warning(
|
||||
"pending_messages: dropping malformed persist-queue entry "
|
||||
"for %s: %s",
|
||||
"pending_messages: dropping malformed persist-queue entry for %s: %s",
|
||||
session_id,
|
||||
e,
|
||||
)
|
||||
|
||||
@@ -60,6 +60,16 @@ class _FakeRedis:
|
||||
self.published.append((channel, payload))
|
||||
return 1
|
||||
|
||||
async def execute_command(self, *args: Any) -> Any:
|
||||
# Minimal handler for the sharded SPUBLISH call made by
|
||||
# push_pending_message. Routing semantics are irrelevant here —
|
||||
# we just record the publish for assertion.
|
||||
if args and args[0] == "SPUBLISH":
|
||||
_, channel, payload = args
|
||||
self.published.append((channel, payload))
|
||||
return 1
|
||||
raise NotImplementedError(f"fake execute_command does not handle {args[0]!r}")
|
||||
|
||||
async def lpop(self, key: str, count: int) -> list[str | bytes] | None:
|
||||
lst = self.lists.get(key)
|
||||
if not lst:
|
||||
@@ -326,7 +336,7 @@ async def test_drain_skips_malformed_entries(
|
||||
fake_redis: _FakeRedis,
|
||||
) -> None:
|
||||
# Seed the fake with a mix of valid and malformed payloads
|
||||
fake_redis.lists["copilot:pending:bad"] = [
|
||||
fake_redis.lists[pm_module._buffer_key("bad")] = [
|
||||
json.dumps({"content": "valid"}),
|
||||
"{not valid json",
|
||||
json.dumps({"content": "also valid", "file_ids": ["a"]}),
|
||||
@@ -347,7 +357,7 @@ async def test_drain_decodes_bytes_payloads(
|
||||
branch in ``drain_pending_messages`` so a regression there doesn't
|
||||
slip past CI.
|
||||
"""
|
||||
fake_redis.lists["copilot:pending:bytes_sess"] = [
|
||||
fake_redis.lists[pm_module._buffer_key("bytes_sess")] = [
|
||||
json.dumps({"content": "from bytes"}).encode("utf-8"),
|
||||
]
|
||||
drained = await drain_pending_messages("bytes_sess")
|
||||
@@ -362,14 +372,14 @@ async def test_peek_decodes_bytes_payloads(
|
||||
"""``peek_pending_messages`` uses the same ``_decode_redis_item`` helper
|
||||
as the drain path. Seed with bytes to guard against regression.
|
||||
"""
|
||||
fake_redis.lists["copilot:pending:peek_bytes_sess"] = [
|
||||
fake_redis.lists[pm_module._buffer_key("peek_bytes_sess")] = [
|
||||
json.dumps({"content": "peeked from bytes"}).encode("utf-8"),
|
||||
]
|
||||
peeked = await peek_pending_messages("peek_bytes_sess")
|
||||
assert len(peeked) == 1
|
||||
assert peeked[0].content == "peeked from bytes"
|
||||
# peek must NOT consume the item
|
||||
assert fake_redis.lists["copilot:pending:peek_bytes_sess"] != []
|
||||
assert fake_redis.lists[pm_module._buffer_key("peek_bytes_sess")] != []
|
||||
|
||||
|
||||
# ── Concurrency ─────────────────────────────────────────────────────
|
||||
@@ -445,7 +455,7 @@ async def test_peek_pending_messages_decodes_bytes_payloads(
|
||||
fake_redis: _FakeRedis,
|
||||
) -> None:
|
||||
"""peek_pending_messages decodes bytes entries the same way drain does."""
|
||||
fake_redis.lists["copilot:pending:peek_bytes"] = [
|
||||
fake_redis.lists[pm_module._buffer_key("peek_bytes")] = [
|
||||
json.dumps({"content": "from bytes"}).encode("utf-8"),
|
||||
]
|
||||
peeked = await peek_pending_messages("peek_bytes")
|
||||
@@ -458,7 +468,7 @@ async def test_peek_pending_messages_skips_malformed_entries(
|
||||
fake_redis: _FakeRedis,
|
||||
) -> None:
|
||||
"""Malformed entries are skipped and valid ones are returned."""
|
||||
fake_redis.lists["copilot:pending:peek_bad"] = [
|
||||
fake_redis.lists[pm_module._buffer_key("peek_bad")] = [
|
||||
json.dumps({"content": "valid peek"}),
|
||||
"{bad json",
|
||||
json.dumps({"content": "also valid peek"}),
|
||||
@@ -486,7 +496,7 @@ async def test_stash_for_persist_enqueues_and_drain_pops_in_order(
|
||||
|
||||
# Stored under the distinct persist key, NOT the primary buffer.
|
||||
assert "copilot:pending-persist:sess-persist" in fake_redis.lists
|
||||
assert "copilot:pending:sess-persist" not in fake_redis.lists
|
||||
assert pm_module._buffer_key("sess-persist") not in fake_redis.lists
|
||||
|
||||
drained = await drain_pending_for_persist("sess-persist")
|
||||
assert len(drained) == 2
|
||||
@@ -612,3 +622,59 @@ async def test_drain_and_format_for_injection_swallows_redis_error(
|
||||
@pytest.mark.asyncio
|
||||
async def test_drain_and_format_for_injection_missing_session_id() -> None:
|
||||
assert await drain_and_format_for_injection("", log_prefix="[TEST]") == ""
|
||||
|
||||
|
||||
# ── Cluster-slot colocation regression ───────────────────────────────
|
||||
# The gated-rpush Lua script in `capped_rpush_if_hash_field` touches both
|
||||
# the session-meta hash (`stream_registry._get_session_meta_key`) and the
|
||||
# pending buffer list (`_buffer_key`) atomically. Redis Cluster requires
|
||||
# every key referenced by a multi-key Lua script to hash to the same slot,
|
||||
# so both keys must share a hash tag (the `{...}` substring Redis uses for
|
||||
# slot calculation). Without this, the EVAL returns `CROSSSLOT keys in
|
||||
# request` once cluster mode is active.
|
||||
|
||||
|
||||
def _redis_keyslot(key: str) -> int:
|
||||
"""Compute the Redis Cluster slot for ``key`` using CRC16-XMODEM mod 16384.
|
||||
|
||||
Mirrors the algorithm in redis-py's ``RedisCluster.keyslot`` and the
|
||||
Redis spec — extracts the first ``{...}`` substring as the hash tag,
|
||||
falls back to the whole key when no tag is present.
|
||||
"""
|
||||
start = key.find("{")
|
||||
if start != -1:
|
||||
end = key.find("}", start + 1)
|
||||
if end > start + 1:
|
||||
key = key[start + 1 : end]
|
||||
crc = 0
|
||||
poly = 0x1021
|
||||
for byte in key.encode():
|
||||
crc ^= byte << 8
|
||||
for _ in range(8):
|
||||
crc = ((crc << 1) ^ poly) & 0xFFFF if crc & 0x8000 else (crc << 1) & 0xFFFF
|
||||
return crc % 16384
|
||||
|
||||
|
||||
def test_buffer_and_session_meta_keys_share_cluster_slot() -> None:
|
||||
"""Regression: pending-buffer key + session-meta key must hash to the
|
||||
same Redis Cluster slot, otherwise the gated-rpush Lua script returns
|
||||
CROSSSLOT once cluster mode is enabled."""
|
||||
# Late import so the test doesn't pull stream_registry's heavy module
|
||||
# graph (it transitively wires the AppService client) at file load.
|
||||
from backend.copilot.stream_registry import _get_session_meta_key
|
||||
|
||||
for session_id in [
|
||||
"sess-abcdef-123",
|
||||
"0eb0aae8-6926-4b50-97af-72840841dc70",
|
||||
"x",
|
||||
]:
|
||||
buf = pm_module._buffer_key(session_id)
|
||||
meta = _get_session_meta_key(session_id)
|
||||
assert "{" in buf and "}" in buf, f"_buffer_key missing hash tag: {buf!r}"
|
||||
assert (
|
||||
"{" in meta and "}" in meta
|
||||
), f"_get_session_meta_key missing hash tag: {meta!r}"
|
||||
assert _redis_keyslot(buf) == _redis_keyslot(meta), (
|
||||
f"CROSSSLOT regression: {buf!r} (slot {_redis_keyslot(buf)}) "
|
||||
f"!= {meta!r} (slot {_redis_keyslot(meta)})"
|
||||
)
|
||||
|
||||
@@ -128,6 +128,14 @@ ToolName = Literal[
|
||||
# Frozen set of all valid tool names — derived from the Literal.
|
||||
ALL_TOOL_NAMES: frozenset[str] = frozenset(get_args(ToolName))
|
||||
|
||||
DISABLED_LEGACY_TOOL_NAMES: frozenset[str] = frozenset({"ask_question"})
|
||||
"""Tool names accepted only for backwards compatibility with saved graphs.
|
||||
|
||||
These names are intentionally absent from ``ToolName`` and
|
||||
``PLATFORM_TOOL_NAMES`` so they are not exposed in new block schemas or sent to
|
||||
the model as available tools.
|
||||
"""
|
||||
|
||||
# SDK built-in tool names — tools provided by the Claude Code CLI that our
|
||||
# code does not implement directly. ``TodoWrite`` is DELIBERATELY excluded:
|
||||
# baseline mode ships an MCP-wrapped platform version
|
||||
@@ -304,7 +312,11 @@ def validate_tool_names(tools: list[str]) -> list[str]:
|
||||
Returns:
|
||||
List of invalid names (empty if all are valid).
|
||||
"""
|
||||
return [t for t in tools if t not in ALL_TOOL_NAMES]
|
||||
return [
|
||||
t
|
||||
for t in tools
|
||||
if t not in ALL_TOOL_NAMES and t not in DISABLED_LEGACY_TOOL_NAMES
|
||||
]
|
||||
|
||||
|
||||
_tool_names_checked = False
|
||||
|
||||
@@ -257,6 +257,9 @@ class TestValidateToolNames:
|
||||
def test_valid_sdk_builtin(self):
|
||||
assert validate_tool_names(["Read", "Task", "WebSearch"]) == []
|
||||
|
||||
def test_disabled_legacy_tool_name_is_accepted(self):
|
||||
assert validate_tool_names(["ask_question"]) == []
|
||||
|
||||
def test_invalid_tool(self):
|
||||
result = validate_tool_names(["nonexistent_tool"])
|
||||
assert "nonexistent_tool" in result
|
||||
|
||||
@@ -47,15 +47,14 @@ from pydantic import BaseModel, Field
|
||||
from redis.exceptions import RedisError
|
||||
|
||||
from backend.data.db_accessors import user_db
|
||||
from backend.data.redis_client import get_redis_async
|
||||
from backend.data.redis_client import AsyncRedisClient, get_redis_async
|
||||
from backend.data.user import get_user_by_id
|
||||
from backend.util.cache import cached
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Redis key prefixes. Bumped from "copilot:usage" (token-based) to
|
||||
# "copilot:cost" on the token→cost migration so stale counters do not
|
||||
# get misinterpreted as microdollars (which would dramatically under-count).
|
||||
# "copilot:cost" (not the legacy "copilot:usage") so stale token-based
|
||||
# counters are not misread as microdollars.
|
||||
_USAGE_KEY_PREFIX = "copilot:cost"
|
||||
|
||||
|
||||
@@ -73,6 +72,7 @@ class SubscriptionTier(str, Enum):
|
||||
from prisma.enums import SubscriptionTier
|
||||
"""
|
||||
|
||||
NO_TIER = "NO_TIER"
|
||||
BASIC = "BASIC"
|
||||
PRO = "PRO"
|
||||
MAX = "MAX"
|
||||
@@ -88,6 +88,14 @@ class SubscriptionTier(str, Enum):
|
||||
# eventual ``int(base * multiplier)`` in ``get_global_rate_limits`` keeps the
|
||||
# downstream microdollar math integer.
|
||||
_DEFAULT_TIER_MULTIPLIERS: dict[SubscriptionTier, float] = {
|
||||
# NO_TIER is the explicit "no active Stripe subscription" state —
|
||||
# multiplier 0.0 collapses the per-period limit to int(base * 0) = 0, so
|
||||
# all rate-limited routes (CoPilot chat, AutoPilot) refuse with 429
|
||||
# before any business logic runs. This is the backend half of the
|
||||
# paywall (the frontend modal nudges UI users; this gate enforces
|
||||
# server-side regardless of client). BASIC stays as a future paid-tier
|
||||
# option; for now it falls back to the same baseline as paid tiers.
|
||||
SubscriptionTier.NO_TIER: 0.0,
|
||||
SubscriptionTier.BASIC: 1.0,
|
||||
SubscriptionTier.PRO: 5.0,
|
||||
SubscriptionTier.MAX: 20.0,
|
||||
@@ -100,7 +108,7 @@ _DEFAULT_TIER_MULTIPLIERS: dict[SubscriptionTier, float] = {
|
||||
# ``get_tier_multipliers`` so LD overrides are honoured.
|
||||
TIER_MULTIPLIERS = _DEFAULT_TIER_MULTIPLIERS
|
||||
|
||||
DEFAULT_TIER = SubscriptionTier.BASIC
|
||||
DEFAULT_TIER = SubscriptionTier.NO_TIER
|
||||
|
||||
|
||||
@cached(ttl_seconds=60, maxsize=1, cache_none=False)
|
||||
@@ -447,25 +455,16 @@ async def reset_daily_usage(user_id: str, daily_cost_limit: int = 0) -> bool:
|
||||
try:
|
||||
redis = await get_redis_async()
|
||||
|
||||
# Use a MULTI/EXEC transaction so that DELETE (daily) and DECRBY
|
||||
# (weekly) either both execute or neither does. This prevents the
|
||||
# scenario where the daily counter is cleared but the weekly
|
||||
# counter is not decremented — which would let the caller refund
|
||||
# credits even though the daily limit was already reset.
|
||||
d_key = _daily_key(user_id, now=now)
|
||||
w_key = _weekly_key(user_id, now=now) if daily_cost_limit > 0 else None
|
||||
|
||||
pipe = redis.pipeline(transaction=True)
|
||||
pipe.delete(d_key)
|
||||
# Daily and weekly keys hash to different cluster slots, so cross-key
|
||||
# MULTI/EXEC is not available. Issue the writes sequentially — the
|
||||
# failure mode (daily deleted, weekly not decremented) is a
|
||||
# best-effort refund budget that the read path already tolerates.
|
||||
await redis.delete(d_key)
|
||||
if w_key is not None:
|
||||
pipe.decrby(w_key, daily_cost_limit)
|
||||
results = await pipe.execute()
|
||||
|
||||
# Clamp negative weekly counter to 0 (best-effort; not critical).
|
||||
if w_key is not None:
|
||||
new_val = results[1] # DECRBY result
|
||||
if new_val < 0:
|
||||
await redis.set(w_key, 0, keepttl=True)
|
||||
await _decr_counter_floor_zero(redis, w_key, daily_cost_limit)
|
||||
|
||||
logger.info("Reset daily usage for user %s", user_id[:8])
|
||||
return True
|
||||
@@ -555,30 +554,18 @@ async def record_cost_usage(
|
||||
logger.info("Recording copilot spend: %d microdollars", cost_microdollars)
|
||||
|
||||
now = datetime.now(UTC)
|
||||
d_key = _daily_key(user_id, now=now)
|
||||
w_key = _weekly_key(user_id, now=now)
|
||||
daily_ttl = max(int((_daily_reset_time(now=now) - now).total_seconds()), 1)
|
||||
weekly_ttl = max(int((_weekly_reset_time(now=now) - now).total_seconds()), 1)
|
||||
try:
|
||||
redis = await get_redis_async()
|
||||
# Use MULTI/EXEC so each INCRBY/EXPIRE pair is atomic — guarantees
|
||||
# the TTL is set even if the connection drops mid-pipeline, so
|
||||
# counters can never survive past their date-based rotation window.
|
||||
pipe = redis.pipeline(transaction=True)
|
||||
|
||||
# Daily counter (expires at next midnight UTC)
|
||||
d_key = _daily_key(user_id, now=now)
|
||||
pipe.incrby(d_key, cost_microdollars)
|
||||
seconds_until_daily_reset = int(
|
||||
(_daily_reset_time(now=now) - now).total_seconds()
|
||||
)
|
||||
pipe.expire(d_key, max(seconds_until_daily_reset, 1))
|
||||
|
||||
# Weekly counter (expires end of week)
|
||||
w_key = _weekly_key(user_id, now=now)
|
||||
pipe.incrby(w_key, cost_microdollars)
|
||||
seconds_until_weekly_reset = int(
|
||||
(_weekly_reset_time(now=now) - now).total_seconds()
|
||||
)
|
||||
pipe.expire(w_key, max(seconds_until_weekly_reset, 1))
|
||||
|
||||
await pipe.execute()
|
||||
# Daily and weekly keys hash to different cluster slots — cross-slot
|
||||
# MULTI/EXEC is not supported, so each counter gets its own
|
||||
# single-key transaction. Per-counter INCRBY+EXPIRE atomicity is the
|
||||
# invariant that matters; the two counters are independent budgets.
|
||||
await _incr_counter_atomic(redis, d_key, cost_microdollars, daily_ttl)
|
||||
await _incr_counter_atomic(redis, w_key, cost_microdollars, weekly_ttl)
|
||||
except (RedisError, ConnectionError, OSError):
|
||||
logger.warning(
|
||||
"Redis unavailable for recording cost usage (microdollars=%d)",
|
||||
@@ -586,30 +573,56 @@ async def record_cost_usage(
|
||||
)
|
||||
|
||||
|
||||
async def _incr_counter_atomic(
|
||||
redis: AsyncRedisClient, key: str, delta: int, ttl_seconds: int
|
||||
) -> None:
|
||||
"""INCRBY + EXPIRE on a single key inside a MULTI/EXEC transaction."""
|
||||
pipe = redis.pipeline(transaction=True)
|
||||
pipe.incrby(key, delta)
|
||||
pipe.expire(key, ttl_seconds)
|
||||
await pipe.execute()
|
||||
|
||||
|
||||
# Atomic DECRBY + floor-to-zero so a concurrent INCRBY from record_cost_usage
|
||||
# cannot be lost. DELETE on underflow also avoids leaving a zero-valued key
|
||||
# with no TTL, which the non-atomic set-with-keepttl variant could do.
|
||||
_DECR_FLOOR_ZERO_SCRIPT = """
|
||||
local value = redis.call("DECRBY", KEYS[1], ARGV[1])
|
||||
if value < 0 then
|
||||
redis.call("DEL", KEYS[1])
|
||||
return 0
|
||||
end
|
||||
return value
|
||||
"""
|
||||
|
||||
|
||||
async def _decr_counter_floor_zero(
|
||||
redis: AsyncRedisClient, key: str, delta: int
|
||||
) -> None:
|
||||
"""Atomically DECRBY ``delta`` on ``key`` and DEL on underflow.
|
||||
|
||||
DEL on underflow avoids leaving a zero-valued key without a TTL, so the
|
||||
next INCRBY in ``record_cost_usage`` re-seeds both the value and the
|
||||
expiry in one shot.
|
||||
"""
|
||||
await redis.eval(_DECR_FLOOR_ZERO_SCRIPT, 1, key, delta)
|
||||
|
||||
|
||||
class _UserNotFoundError(Exception):
|
||||
"""Raised when a user record is missing or has no subscription tier.
|
||||
|
||||
Used internally by ``_fetch_user_tier`` to signal a cache-miss condition:
|
||||
by raising instead of returning ``DEFAULT_TIER``, we prevent the ``@cached``
|
||||
decorator from storing the fallback value. This avoids a race condition
|
||||
where a non-existent user's DEFAULT_TIER is cached, then the user is
|
||||
created with a higher tier but receives the stale cached FREE tier for
|
||||
up to 5 minutes.
|
||||
Raising (rather than returning ``DEFAULT_TIER``) prevents ``@cached``
|
||||
from persisting the fallback, which would otherwise keep serving FREE
|
||||
for up to the TTL after the user's real tier is set.
|
||||
"""
|
||||
|
||||
|
||||
@cached(maxsize=1000, ttl_seconds=300, shared_cache=True)
|
||||
async def _fetch_user_tier(user_id: str) -> SubscriptionTier:
|
||||
"""Fetch the user's rate-limit tier from the database (cached via Redis).
|
||||
"""Fetch the user's rate-limit tier, cached across pods.
|
||||
|
||||
Uses ``shared_cache=True`` so that tier changes propagate across all pods
|
||||
immediately when the cache entry is invalidated (via ``cache_delete``).
|
||||
|
||||
Only successful DB lookups of existing users with a valid tier are cached.
|
||||
Raises ``_UserNotFoundError`` when the user is missing or has no tier, so
|
||||
the ``@cached`` decorator does **not** store a fallback value. This
|
||||
prevents a race condition where a non-existent user's ``DEFAULT_TIER`` is
|
||||
cached and then persists after the user is created with a higher tier.
|
||||
Only successful lookups are cached. Missing users raise
|
||||
``_UserNotFoundError`` so ``@cached`` never stores the fallback.
|
||||
"""
|
||||
try:
|
||||
user = await user_db().get_user_by_id(user_id)
|
||||
@@ -651,20 +664,10 @@ get_user_tier.cache_delete = _fetch_user_tier.cache_delete # type: ignore[attr-
|
||||
async def set_user_tier(user_id: str, tier: SubscriptionTier) -> None:
|
||||
"""Persist the user's rate-limit tier to the database.
|
||||
|
||||
Invalidates every cache that keys off the user's subscription tier so the
|
||||
change is visible immediately: this function's own ``get_user_tier``, the
|
||||
shared ``get_user_by_id`` (which exposes ``user.subscription_tier``), and
|
||||
``get_pending_subscription_change`` (since an admin override can invalidate
|
||||
a cached ``cancel_at_period_end`` or schedule-based pending state).
|
||||
|
||||
If the user has an active Stripe subscription whose current price does not
|
||||
match ``tier``, Stripe will keep billing the old price and the next
|
||||
``customer.subscription.updated`` webhook will overwrite the DB tier back
|
||||
to whatever Stripe has. Proper reconciliation (cancelling or modifying the
|
||||
Stripe subscription when an admin overrides the tier) is out of scope for
|
||||
this PR — it changes the admin contract and needs its own test coverage.
|
||||
For now we emit a ``WARNING`` so drift surfaces via Sentry until that
|
||||
follow-up lands.
|
||||
Invalidates the caches that expose ``subscription_tier`` so the change
|
||||
takes effect immediately. If the user has an active Stripe subscription
|
||||
on a mismatched price, emits a WARNING; Stripe remains the billing
|
||||
source of truth and the next webhook will reconcile the DB tier.
|
||||
|
||||
Raises:
|
||||
prisma.errors.RecordNotFoundError: If the user does not exist.
|
||||
@@ -674,21 +677,13 @@ async def set_user_tier(user_id: str, tier: SubscriptionTier) -> None:
|
||||
data={"subscriptionTier": tier.value},
|
||||
)
|
||||
get_user_tier.cache_delete(user_id) # type: ignore[attr-defined]
|
||||
# Local import required: backend.data.credit imports backend.copilot.rate_limit
|
||||
# (via get_user_tier in credit.py's _invalidate_user_tier_caches), so a
|
||||
# top-level ``from backend.data.credit import ...`` here would create a
|
||||
# circular import at module-load time.
|
||||
# Local import: backend.data.credit imports from this module.
|
||||
from backend.data.credit import get_pending_subscription_change
|
||||
|
||||
get_user_by_id.cache_delete(user_id) # type: ignore[attr-defined]
|
||||
get_pending_subscription_change.cache_delete(user_id) # type: ignore[attr-defined]
|
||||
|
||||
# The DB write above is already committed; the drift check is best-effort
|
||||
# diagnostic logging. Fire-and-forget so admin bulk ops don't wait on a
|
||||
# Stripe roundtrip. The inner helper wraps its body in a timeout + broad
|
||||
# except so background task errors still surface via logs rather than as
|
||||
# "task exception never retrieved" warnings. Cancellation on request
|
||||
# shutdown is acceptable — the drift warning is non-load-bearing.
|
||||
# Fire-and-forget drift check so admin bulk ops don't wait on Stripe.
|
||||
asyncio.ensure_future(_drift_check_background(user_id, tier))
|
||||
|
||||
|
||||
@@ -711,8 +706,6 @@ async def _drift_check_background(user_id: str, tier: SubscriptionTier) -> None:
|
||||
tier.value,
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
# Request may have completed and the event loop is cancelling tasks —
|
||||
# the drift log is non-critical, so accept cancellation silently.
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception(
|
||||
@@ -726,19 +719,9 @@ async def _drift_check_background(user_id: str, tier: SubscriptionTier) -> None:
|
||||
async def _warn_if_stripe_subscription_drifts(
|
||||
user_id: str, new_tier: SubscriptionTier
|
||||
) -> None:
|
||||
"""Emit a WARNING when an admin tier override leaves an active Stripe sub on a
|
||||
mismatched price.
|
||||
|
||||
The warning is diagnostic only: Stripe remains the billing source of truth,
|
||||
so the next ``customer.subscription.updated`` webhook will reset the DB
|
||||
tier. Surfacing the drift here lets ops catch admin overrides that bypass
|
||||
the intended Checkout / Portal cancel flows before users notice surprise
|
||||
charges.
|
||||
"""
|
||||
# Local imports: see note in ``set_user_tier`` about the credit <-> rate_limit
|
||||
# circular. These helpers (``_get_active_subscription``,
|
||||
# ``get_subscription_price_id``) live in credit.py alongside the rest of
|
||||
# the Stripe billing code.
|
||||
"""Emit a WARNING when an admin tier override leaves an active Stripe
|
||||
subscription on a mismatched price."""
|
||||
# Local import: breaks a credit <-> rate_limit circular at module load.
|
||||
from backend.data.credit import _get_active_subscription, get_subscription_price_id
|
||||
|
||||
try:
|
||||
@@ -753,10 +736,8 @@ async def _warn_if_stripe_subscription_drifts(
|
||||
return
|
||||
price = items[0].price
|
||||
current_price_id = price if isinstance(price, str) else price.id
|
||||
# The LaunchDarkly-backed price lookup must live inside this try/except:
|
||||
# an LD SDK failure (network, token revoked) here would otherwise
|
||||
# propagate past set_user_tier's already-committed DB write and turn a
|
||||
# best-effort diagnostic into a 500 on admin tier writes.
|
||||
# Inside the try/except: an LD SDK failure here must not turn a
|
||||
# best-effort diagnostic into a 500 after the DB write committed.
|
||||
expected_price_id = await get_subscription_price_id(new_tier)
|
||||
except Exception:
|
||||
logger.debug(
|
||||
@@ -816,6 +797,16 @@ async def get_global_rate_limits(
|
||||
tier = await get_user_tier(user_id)
|
||||
multipliers = await get_tier_multipliers()
|
||||
multiplier = multipliers.get(tier.value, 1.0)
|
||||
# NO_TIER's 0.0 multiplier is the backend half of the paywall — it
|
||||
# collapses limits to zero so unsubscribed users can't run the chat.
|
||||
# Only enforce that gate when the platform-payment flag is on for this
|
||||
# user; in the beta cohort (flag off) NO_TIER falls back to BASIC's
|
||||
# baseline so the e2e suite and beta testers retain access.
|
||||
if tier == SubscriptionTier.NO_TIER:
|
||||
from backend.util.feature_flag import Flag, is_feature_enabled
|
||||
|
||||
if not await is_feature_enabled(Flag.ENABLE_PLATFORM_PAYMENT, user_id):
|
||||
multiplier = multipliers.get(SubscriptionTier.BASIC.value, 1.0)
|
||||
if multiplier != 1.0:
|
||||
# Cast back to int to preserve the microdollar integer contract
|
||||
# downstream — fractional LD multipliers (e.g. 8.5×) truncate at the
|
||||
@@ -838,12 +829,15 @@ async def reset_user_usage(user_id: str, *, reset_weekly: bool = False) -> None:
|
||||
the admin believing the counters were zeroed when they were not.
|
||||
"""
|
||||
now = datetime.now(UTC)
|
||||
keys_to_delete = [_daily_key(user_id, now=now)]
|
||||
if reset_weekly:
|
||||
keys_to_delete.append(_weekly_key(user_id, now=now))
|
||||
d_key = _daily_key(user_id, now=now)
|
||||
w_key = _weekly_key(user_id, now=now) if reset_weekly else None
|
||||
try:
|
||||
redis = await get_redis_async()
|
||||
await redis.delete(*keys_to_delete)
|
||||
# Daily and weekly keys hash to different cluster slots — multi-key
|
||||
# DELETE would raise CROSSSLOT, so issue separate calls.
|
||||
await redis.delete(d_key)
|
||||
if w_key is not None:
|
||||
await redis.delete(w_key)
|
||||
except (RedisError, ConnectionError, OSError):
|
||||
logger.warning("Redis unavailable for resetting user usage")
|
||||
raise
|
||||
|
||||
@@ -359,6 +359,9 @@ class TestSubscriptionTier:
|
||||
def test_tier_multipliers(self):
|
||||
# Float-typed so LD-provided fractional multipliers compose naturally;
|
||||
# equality against int literals still holds for the whole defaults.
|
||||
# NO_TIER is 0.0 — explicit "no active subscription" state;
|
||||
# rate-limited routes refuse with 429 (backend half of the paywall).
|
||||
assert TIER_MULTIPLIERS[SubscriptionTier.NO_TIER] == 0.0
|
||||
assert TIER_MULTIPLIERS[SubscriptionTier.BASIC] == 1.0
|
||||
assert TIER_MULTIPLIERS[SubscriptionTier.PRO] == 5.0
|
||||
assert TIER_MULTIPLIERS[SubscriptionTier.MAX] == 20.0
|
||||
@@ -366,8 +369,8 @@ class TestSubscriptionTier:
|
||||
assert TIER_MULTIPLIERS[SubscriptionTier.ENTERPRISE] == 60.0
|
||||
assert TIER_MULTIPLIERS is _DEFAULT_TIER_MULTIPLIERS
|
||||
|
||||
def test_default_tier_is_basic(self):
|
||||
assert DEFAULT_TIER == SubscriptionTier.BASIC
|
||||
def test_default_tier_is_no_tier(self):
|
||||
assert DEFAULT_TIER == SubscriptionTier.NO_TIER
|
||||
|
||||
def test_usage_status_includes_tier(self):
|
||||
now = datetime.now(UTC)
|
||||
@@ -375,7 +378,7 @@ class TestSubscriptionTier:
|
||||
daily=UsageWindow(used=0, limit=100, resets_at=now + timedelta(hours=1)),
|
||||
weekly=UsageWindow(used=0, limit=500, resets_at=now + timedelta(days=1)),
|
||||
)
|
||||
assert status.tier == SubscriptionTier.BASIC
|
||||
assert status.tier == SubscriptionTier.NO_TIER
|
||||
|
||||
def test_usage_status_with_custom_tier(self):
|
||||
now = datetime.now(UTC)
|
||||
@@ -1243,18 +1246,9 @@ class TestTierLimitsRespected:
|
||||
|
||||
|
||||
class TestResetDailyUsage:
|
||||
@staticmethod
|
||||
def _make_pipeline_mock(decrby_result: int = 0) -> MagicMock:
|
||||
"""Create a pipeline mock that returns [delete_result, decrby_result]."""
|
||||
pipe = MagicMock()
|
||||
pipe.execute = AsyncMock(return_value=[1, decrby_result])
|
||||
return pipe
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deletes_daily_key(self):
|
||||
mock_pipe = self._make_pipeline_mock(decrby_result=0)
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.pipeline = lambda **_kw: mock_pipe
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
@@ -1263,14 +1257,12 @@ class TestResetDailyUsage:
|
||||
result = await reset_daily_usage(_USER, daily_cost_limit=10000)
|
||||
|
||||
assert result is True
|
||||
mock_pipe.delete.assert_called_once()
|
||||
mock_redis.delete.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reduces_weekly_usage_via_decrby(self):
|
||||
"""Weekly counter should be reduced via DECRBY in the pipeline."""
|
||||
mock_pipe = self._make_pipeline_mock(decrby_result=35000)
|
||||
async def test_reduces_weekly_usage_via_eval(self):
|
||||
"""Weekly counter should be decremented via the atomic Lua script."""
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.pipeline = lambda **_kw: mock_pipe
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
@@ -1278,32 +1270,22 @@ class TestResetDailyUsage:
|
||||
):
|
||||
await reset_daily_usage(_USER, daily_cost_limit=10000)
|
||||
|
||||
mock_pipe.decrby.assert_called_once()
|
||||
mock_redis.set.assert_not_called() # 35000 > 0, no clamp needed
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_clamps_negative_weekly_to_zero(self):
|
||||
"""If DECRBY goes negative, SET to 0 (outside the pipeline)."""
|
||||
mock_pipe = self._make_pipeline_mock(decrby_result=-5000)
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.pipeline = lambda **_kw: mock_pipe
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
):
|
||||
await reset_daily_usage(_USER, daily_cost_limit=10000)
|
||||
|
||||
mock_pipe.decrby.assert_called_once()
|
||||
mock_redis.set.assert_called_once()
|
||||
# The Lua script handles both decrement and floor-to-zero in a single
|
||||
# call — no separate SET is expected for the clamp branch any more.
|
||||
# Pin the call shape so a regression that targets the wrong key or
|
||||
# delta (e.g. the daily key, or a sign-flip) fails loudly.
|
||||
mock_redis.eval.assert_called_once()
|
||||
eval_args = mock_redis.eval.call_args.args
|
||||
# eval(script, numkeys, KEYS[1], ARGV[1])
|
||||
assert eval_args[1] == 1
|
||||
assert eval_args[2] == _weekly_key(_USER)
|
||||
assert int(eval_args[3]) == 10000
|
||||
mock_redis.set.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_weekly_reduction_when_daily_limit_zero(self):
|
||||
"""When daily_cost_limit is 0, weekly counter should not be touched."""
|
||||
mock_pipe = self._make_pipeline_mock()
|
||||
mock_pipe.execute = AsyncMock(return_value=[1]) # only delete result
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.pipeline = lambda **_kw: mock_pipe
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
@@ -1311,8 +1293,8 @@ class TestResetDailyUsage:
|
||||
):
|
||||
await reset_daily_usage(_USER, daily_cost_limit=0)
|
||||
|
||||
mock_pipe.delete.assert_called_once()
|
||||
mock_pipe.decrby.assert_not_called()
|
||||
mock_redis.delete.assert_called_once()
|
||||
mock_redis.eval.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_false_when_redis_unavailable(self):
|
||||
@@ -1324,6 +1306,23 @@ class TestResetDailyUsage:
|
||||
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decr_counter_floor_zero_invokes_lua_script(self):
|
||||
"""The atomic DECRBY+floor helper routes through redis.eval with the
|
||||
expected single-key, single-arg call shape."""
|
||||
from backend.copilot.rate_limit import (
|
||||
_DECR_FLOOR_ZERO_SCRIPT,
|
||||
_decr_counter_floor_zero,
|
||||
)
|
||||
|
||||
mock_redis = AsyncMock()
|
||||
|
||||
await _decr_counter_floor_zero(mock_redis, "weekly:user1", 42)
|
||||
|
||||
mock_redis.eval.assert_called_once_with(
|
||||
_DECR_FLOOR_ZERO_SCRIPT, 1, "weekly:user1", 42
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tier-limit enforcement (integration-style)
|
||||
@@ -1781,8 +1780,9 @@ class TestResetUserUsage:
|
||||
"backend.copilot.rate_limit.get_redis_async", return_value=mock_redis
|
||||
):
|
||||
await reset_user_usage("user-1", reset_weekly=True)
|
||||
args = mock_redis.delete.call_args[0]
|
||||
assert len(args) == 2 # both daily and weekly keys
|
||||
# Daily and weekly keys hash to different cluster slots, so they are
|
||||
# deleted via two separate DELETE calls (not a single multi-key one).
|
||||
assert mock_redis.delete.call_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_raises_on_redis_failure(self):
|
||||
|
||||
@@ -52,7 +52,8 @@ class ResponseType(str, Enum):
|
||||
ERROR = "error"
|
||||
USAGE = "usage"
|
||||
HEARTBEAT = "heartbeat"
|
||||
STATUS = "status"
|
||||
STATUS = "data-status"
|
||||
CURSOR = "data-cursor"
|
||||
|
||||
|
||||
class StreamBaseResponse(BaseModel):
|
||||
@@ -275,10 +276,18 @@ class StreamError(StreamBaseResponse):
|
||||
|
||||
The AI SDK uses z.strictObject({type, errorText}) which rejects
|
||||
any extra fields like `code` or `details`.
|
||||
|
||||
When ``code`` is set we prefix ``errorText`` with ``[code:<id>]`` so
|
||||
the frontend can still parse a machine-readable code out of the
|
||||
otherwise opaque text. Idempotent: if the caller already embedded
|
||||
the prefix, we don't double it.
|
||||
"""
|
||||
text = self.errorText
|
||||
if self.code and not text.lstrip().startswith(f"[code:{self.code}]"):
|
||||
text = f"[code:{self.code}] {text}"
|
||||
data = {
|
||||
"type": self.type.value,
|
||||
"errorText": self.errorText,
|
||||
"errorText": text,
|
||||
}
|
||||
return f"data: {json_dumps(data)}\n\n"
|
||||
|
||||
@@ -300,17 +309,46 @@ class StreamHeartbeat(StreamBaseResponse):
|
||||
return ": heartbeat\n\n"
|
||||
|
||||
|
||||
class StreamCursor(StreamBaseResponse):
|
||||
"""Deprecated Redis-stream cursor data part.
|
||||
|
||||
Kept so older stored chunks or tests can still be reconstructed, but new
|
||||
stream subscriptions no longer emit it. AI SDK resume needs a full replay
|
||||
from ``0-0`` so every ``*-delta`` has its matching ``*-start`` event.
|
||||
"""
|
||||
|
||||
type: ResponseType = ResponseType.CURSOR
|
||||
chunkId: str = Field(..., description="Redis Stream message ID (XADD)")
|
||||
|
||||
def to_sse(self) -> str:
|
||||
"""Emit as an AI SDK v5 data part."""
|
||||
data = {
|
||||
"type": self.type.value,
|
||||
"data": {"chunkId": self.chunkId},
|
||||
}
|
||||
return f"data: {json.dumps(data)}\n\n"
|
||||
|
||||
|
||||
class StreamStatus(StreamBaseResponse):
|
||||
"""Transient status notification shown to the user during long operations.
|
||||
|
||||
Used to provide feedback when the backend performs behind-the-scenes work
|
||||
(e.g., compacting conversation context on a retry) that would otherwise
|
||||
leave the user staring at an unexplained pause.
|
||||
|
||||
Sent as a proper ``data:`` event so the frontend can display it to the
|
||||
user. The AI SDK stream parser gracefully skips unknown chunk types
|
||||
(logs a console warning), so this does not break the stream.
|
||||
Emitted when the backend is about to enter a phase that would otherwise
|
||||
leave the user staring at a silent "Thinking…" bubble — e.g. the first
|
||||
LLM call, the continuation after a tool result, compacting conversation
|
||||
context on retry, or activating a fallback model. The frontend reads
|
||||
the latest `data-status` part on the current assistant message and uses
|
||||
its `message` in place of the generic "Thinking…" copy.
|
||||
"""
|
||||
|
||||
type: ResponseType = ResponseType.STATUS
|
||||
message: str = Field(..., description="Human-readable status message")
|
||||
|
||||
def to_sse(self) -> str:
|
||||
"""Emit as an AI SDK v5 data part so the client surfaces it as
|
||||
`type="data-status"` on `message.parts` instead of dropping it as
|
||||
an unknown chunk type."""
|
||||
data = {
|
||||
"type": self.type.value,
|
||||
"data": {"message": self.message},
|
||||
}
|
||||
return f"data: {json.dumps(data)}\n\n"
|
||||
|
||||
@@ -11,6 +11,33 @@ import pytest_asyncio
|
||||
|
||||
from backend.util import json
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Env vars that ``ChatConfig`` validators read — must be cleared so explicit
|
||||
# constructor values are used. Centralised here so adding a new env-backed
|
||||
# field only needs one update across the SDK test suite.
|
||||
# ---------------------------------------------------------------------------
|
||||
_CONFIG_ENV_VARS = (
|
||||
"CHAT_USE_OPENROUTER",
|
||||
"CHAT_API_KEY",
|
||||
"OPEN_ROUTER_API_KEY",
|
||||
"OPENAI_API_KEY",
|
||||
"CHAT_BASE_URL",
|
||||
"OPENROUTER_BASE_URL",
|
||||
"OPENAI_BASE_URL",
|
||||
"CHAT_USE_CLAUDE_CODE_SUBSCRIPTION",
|
||||
"CHAT_USE_CLAUDE_AGENT_SDK",
|
||||
"CHAT_CLAUDE_AGENT_CROSS_USER_PROMPT_CACHE",
|
||||
"CHAT_CLAUDE_AGENT_CLI_PATH",
|
||||
"CLAUDE_AGENT_CLI_PATH",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def _clean_config_env(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Clear env-backed CHAT_* settings so ChatConfig uses constructor values."""
|
||||
for var in _CONFIG_ENV_VARS:
|
||||
monkeypatch.delenv(var, raising=False)
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session", loop_scope="session", name="server")
|
||||
async def _server_noop() -> None:
|
||||
|
||||
@@ -22,6 +22,7 @@ from backend.copilot.model import ChatMessage, ChatSession
|
||||
from backend.copilot.response_model import StreamToolOutputAvailable
|
||||
|
||||
from .service import (
|
||||
_RETRYABLE_STREAM_ERROR_CODES,
|
||||
_classify_final_failure,
|
||||
_FinalFailure,
|
||||
_flush_orphan_tool_uses_to_session,
|
||||
@@ -320,3 +321,22 @@ class TestRetryRollbackContract:
|
||||
"part-2",
|
||||
f"{COPILOT_ERROR_PREFIX} Boom",
|
||||
]
|
||||
|
||||
|
||||
class TestRetryableStreamErrorCodes:
|
||||
"""SECRT-2252: ``_dispatch_response`` consults this set to decide whether
|
||||
the StreamError flowing through it should append a retryable marker (UI
|
||||
shows a retry button) or a terminal one (UI shows ErrorCard only)."""
|
||||
|
||||
def test_transient_api_error_is_retryable(self):
|
||||
assert "transient_api_error" in _RETRYABLE_STREAM_ERROR_CODES
|
||||
|
||||
def test_empty_completion_is_retryable(self):
|
||||
# The adapter emits this for ghost-finished SDK turns. The user
|
||||
# message ("The model returned an empty response.") only makes sense
|
||||
# if the UI offers a retry — otherwise the user sees a dead error.
|
||||
assert "empty_completion" in _RETRYABLE_STREAM_ERROR_CODES
|
||||
|
||||
def test_unknown_codes_are_not_retryable(self):
|
||||
assert "sdk_error" not in _RETRYABLE_STREAM_ERROR_CODES
|
||||
assert "all_attempts_exhausted" not in _RETRYABLE_STREAM_ERROR_CODES
|
||||
|
||||
@@ -36,6 +36,7 @@ from backend.copilot.response_model import (
|
||||
StreamReasoningStart,
|
||||
StreamStart,
|
||||
StreamStartStep,
|
||||
StreamStatus,
|
||||
StreamTextDelta,
|
||||
StreamTextEnd,
|
||||
StreamTextStart,
|
||||
@@ -374,8 +375,41 @@ class SDKResponseAdapter:
|
||||
responses.append(StreamFinishStep())
|
||||
self.step_open = False
|
||||
|
||||
# Narrate the gap between "tool returned" and "model emits its
|
||||
# next chunk". Usually sub-second, but with large tool outputs
|
||||
# or complex continuations it can stretch long enough for the
|
||||
# generic "Thinking…" copy to feel dead. The frontend replaces
|
||||
# it with actual content as soon as the next chunk lands.
|
||||
if resolved_in_blocks:
|
||||
responses.append(StreamStatus(message="Analyzing result\u2026"))
|
||||
|
||||
elif isinstance(sdk_message, ResultMessage):
|
||||
self.flush_unresolved_tool_calls(responses)
|
||||
# SECRT-2252: surface ghost-finished sessions as errors instead of silent finishes.
|
||||
if sdk_message.subtype == "success" and self._is_empty_completion(
|
||||
sdk_message
|
||||
):
|
||||
if self.step_open:
|
||||
responses.append(StreamFinishStep())
|
||||
self.step_open = False
|
||||
responses.append(
|
||||
StreamError(
|
||||
errorText="The model returned an empty response.",
|
||||
code="empty_completion",
|
||||
)
|
||||
)
|
||||
# Pair with StreamFinish so ``acc.stream_completed`` flips True
|
||||
# in ``_dispatch_response`` — without it the service-layer
|
||||
# post-stream branch mis-classifies the turn as "stopped by
|
||||
# user" and appends a STOPPED_BY_USER_MARKER on top of the
|
||||
# error marker.
|
||||
responses.append(StreamFinish())
|
||||
logger.warning(
|
||||
"[SDK] [%s] Empty-success ResultMessage detected — "
|
||||
"emitting stream error instead of silent finish",
|
||||
(self.session_id or "?")[:12],
|
||||
)
|
||||
return responses
|
||||
# Thinking-only final turn guard: when the model's last LLM
|
||||
# call after a tool result produced only a ``ThinkingBlock``
|
||||
# (no ``TextBlock``, no ``ToolUseBlock``) the UI has nothing
|
||||
@@ -437,6 +471,25 @@ class SDKResponseAdapter:
|
||||
|
||||
return responses
|
||||
|
||||
def _is_empty_completion(self, msg: ResultMessage) -> bool:
|
||||
"""True when a success ResultMessage carries no content at all.
|
||||
|
||||
Detects the SDK's ghost-finished session: empty ``result``, zero
|
||||
``output_tokens``, and nothing emitted on the wire this turn (no
|
||||
text, no reasoning, no tool calls).
|
||||
"""
|
||||
if msg.result:
|
||||
return False
|
||||
if self.has_started_text or self.has_started_reasoning:
|
||||
return False
|
||||
if self.current_tool_calls:
|
||||
return False
|
||||
if self._any_tool_results_seen:
|
||||
return False
|
||||
usage = msg.usage or {}
|
||||
output_tokens = usage.get("output_tokens") or 0
|
||||
return output_tokens == 0
|
||||
|
||||
def _ensure_text_started(self, responses: list[StreamBaseResponse]) -> None:
|
||||
"""Start (or restart) a text block if needed."""
|
||||
if not self.has_started_text or self.has_ended_text:
|
||||
|
||||
@@ -25,6 +25,7 @@ from backend.copilot.response_model import (
|
||||
StreamReasoningEnd,
|
||||
StreamStart,
|
||||
StreamStartStep,
|
||||
StreamStatus,
|
||||
StreamTextDelta,
|
||||
StreamTextEnd,
|
||||
StreamTextStart,
|
||||
@@ -193,13 +194,15 @@ def test_tool_result_emits_output_and_finish_step():
|
||||
content=[ToolResultBlock(tool_use_id="t1", content="found 3 agents")]
|
||||
)
|
||||
results = adapter.convert_message(result_msg)
|
||||
assert len(results) == 2
|
||||
assert len(results) == 3
|
||||
assert isinstance(results[0], StreamToolOutputAvailable)
|
||||
assert results[0].toolCallId == "t1"
|
||||
assert results[0].toolName == "find_agent" # prefix stripped
|
||||
assert results[0].output == "found 3 agents"
|
||||
assert results[0].success is True
|
||||
assert isinstance(results[1], StreamFinishStep)
|
||||
assert isinstance(results[2], StreamStatus)
|
||||
assert results[2].message == "Analyzing result…"
|
||||
|
||||
|
||||
def test_tool_result_error():
|
||||
@@ -565,6 +568,105 @@ def test_result_success_does_not_synthesize_when_no_tools_ran():
|
||||
assert text_deltas == []
|
||||
|
||||
|
||||
def test_result_empty_success_emits_error_and_finish():
|
||||
"""SECRT-2252: a ``subtype="success"`` ResultMessage with empty ``result``,
|
||||
no produced content, and ``output_tokens == 0`` is the SDK's ghost-finish
|
||||
bug. The adapter surfaces it as a ``StreamError`` *paired with*
|
||||
``StreamFinish`` so the service-layer post-stream flow flips
|
||||
``acc.stream_completed`` and skips the ``STOPPED_BY_USER_MARKER``
|
||||
branch. ``SystemMessage(subtype="init")`` opened a step, so the
|
||||
empty-completion branch must close it before emitting the error."""
|
||||
adapter = _adapter()
|
||||
adapter.convert_message(SystemMessage(subtype="init", data={}))
|
||||
msg = ResultMessage(
|
||||
subtype="success",
|
||||
duration_ms=100,
|
||||
duration_api_ms=50,
|
||||
is_error=False,
|
||||
num_turns=1,
|
||||
session_id="s1",
|
||||
result=None,
|
||||
usage={"input_tokens": 5, "output_tokens": 0},
|
||||
)
|
||||
results = adapter.convert_message(msg)
|
||||
types = [type(r).__name__ for r in results]
|
||||
assert "StreamFinishStep" in types
|
||||
assert "StreamError" in types
|
||||
assert "StreamFinish" in types
|
||||
# Open step must be closed before the error, and the error must
|
||||
# precede StreamFinish on the wire.
|
||||
assert types.index("StreamFinishStep") < types.index("StreamError")
|
||||
assert types.index("StreamError") < types.index("StreamFinish")
|
||||
err = next(r for r in results if isinstance(r, StreamError))
|
||||
assert err.code == "empty_completion"
|
||||
|
||||
|
||||
def test_result_empty_success_with_empty_string_result_treated_as_empty():
|
||||
"""An empty string (not just None) for ``result`` is also empty."""
|
||||
adapter = _adapter()
|
||||
adapter.convert_message(SystemMessage(subtype="init", data={}))
|
||||
msg = ResultMessage(
|
||||
subtype="success",
|
||||
duration_ms=100,
|
||||
duration_api_ms=50,
|
||||
is_error=False,
|
||||
num_turns=1,
|
||||
session_id="s1",
|
||||
result="",
|
||||
usage={"output_tokens": 0},
|
||||
)
|
||||
results = adapter.convert_message(msg)
|
||||
err = next(r for r in results if isinstance(r, StreamError))
|
||||
assert err.code == "empty_completion"
|
||||
assert any(isinstance(r, StreamFinish) for r in results)
|
||||
|
||||
|
||||
def test_result_success_with_text_emits_finish_not_error():
|
||||
"""Non-empty success (text was produced) keeps the existing
|
||||
``StreamFinish`` behaviour — no spurious error."""
|
||||
adapter = _adapter()
|
||||
adapter.convert_message(
|
||||
AssistantMessage(content=[TextBlock(text="hello")], model="test")
|
||||
)
|
||||
msg = ResultMessage(
|
||||
subtype="success",
|
||||
duration_ms=100,
|
||||
duration_api_ms=50,
|
||||
is_error=False,
|
||||
num_turns=1,
|
||||
session_id="s1",
|
||||
result="hello",
|
||||
usage={"output_tokens": 5},
|
||||
)
|
||||
results = adapter.convert_message(msg)
|
||||
types = [type(r).__name__ for r in results]
|
||||
assert "StreamFinish" in types
|
||||
assert "StreamError" not in types
|
||||
|
||||
|
||||
def test_result_success_with_nonzero_output_tokens_not_empty():
|
||||
"""If ``output_tokens > 0`` but ``result`` is empty, don't classify as
|
||||
empty — fall through to the existing success path. No prior
|
||||
AssistantMessage so the `output_tokens` guard is the only thing
|
||||
keeping `_is_empty_completion()` from firing."""
|
||||
adapter = _adapter()
|
||||
adapter.convert_message(SystemMessage(subtype="init", data={}))
|
||||
msg = ResultMessage(
|
||||
subtype="success",
|
||||
duration_ms=100,
|
||||
duration_api_ms=50,
|
||||
is_error=False,
|
||||
num_turns=1,
|
||||
session_id="s1",
|
||||
result="",
|
||||
usage={"output_tokens": 50},
|
||||
)
|
||||
results = adapter.convert_message(msg)
|
||||
types = [type(r).__name__ for r in results]
|
||||
assert "StreamFinish" in types
|
||||
assert "StreamError" not in types
|
||||
|
||||
|
||||
def test_result_error_emits_error_and_finish():
|
||||
adapter = _adapter()
|
||||
msg = ResultMessage(
|
||||
@@ -686,6 +788,7 @@ def test_full_conversation_flow():
|
||||
"StreamToolInputAvailable",
|
||||
"StreamToolOutputAvailable", # tool result
|
||||
"StreamFinishStep", # step 1 closed after tool result
|
||||
"StreamStatus", # user-facing status while continuation is generated
|
||||
"StreamStartStep", # step 2: continuation text
|
||||
"StreamTextStart", # new block after tool
|
||||
"StreamTextDelta", # "I found 2"
|
||||
|
||||
@@ -51,7 +51,6 @@ from ..constants import (
|
||||
COPILOT_RETRYABLE_ERROR_PREFIX,
|
||||
FRIENDLY_TRANSIENT_MSG,
|
||||
STOPPED_BY_USER_MARKER,
|
||||
STREAM_IDLE_TIMEOUT_SECONDS,
|
||||
is_transient_api_error,
|
||||
)
|
||||
from ..session_cleanup import prune_orphan_tool_calls
|
||||
@@ -185,13 +184,32 @@ _CIRCUIT_BREAKER_ERROR_MSG = (
|
||||
"Try breaking your request into smaller parts."
|
||||
)
|
||||
|
||||
# Idle timeout: abort the stream if no meaningful SDK message (only heartbeats)
|
||||
# arrives for this many seconds. Derived from MAX_TOOL_WAIT_SECONDS so the
|
||||
# invariant "no single tool blocks close to this long" holds by construction —
|
||||
# long-running tools use the async "start + poll" pattern (initial tool returns
|
||||
# with a handle, polling tool waits in ≤MAX_TOOL_WAIT_SECONDS chunks), so an
|
||||
# idle of 2× that genuinely means the SDK itself is stuck.
|
||||
_IDLE_TIMEOUT_SECONDS = STREAM_IDLE_TIMEOUT_SECONDS
|
||||
# Two regimes: no tool pending → 30 min (SDK genuinely idle); tool pending →
|
||||
# 2 h hard cap (lets long sub-AutoPilots run, still backstops a hung tool).
|
||||
_IDLE_TIMEOUT_SECONDS = 30 * 60
|
||||
_HUNG_TOOL_CAP_SECONDS = 2 * 60 * 60
|
||||
|
||||
|
||||
def _idle_timeout_threshold(adapter: SDKResponseAdapter) -> int:
|
||||
"""Pick the idle-timeout threshold for the current heartbeat.
|
||||
|
||||
Returns ``_HUNG_TOOL_CAP_SECONDS`` (longer) whenever any tool call is
|
||||
still pending, so a legitimately long operation isn't killed. Returns
|
||||
``_IDLE_TIMEOUT_SECONDS`` (shorter) when nothing is pending — the SDK
|
||||
itself is idle with no work in flight.
|
||||
"""
|
||||
if adapter.has_unresolved_tool_calls:
|
||||
return _HUNG_TOOL_CAP_SECONDS
|
||||
return _IDLE_TIMEOUT_SECONDS
|
||||
|
||||
|
||||
# StreamError codes that should render as a retryable error in the UI (retry
|
||||
# button) rather than a terminal ErrorCard. Codes appended via
|
||||
# ``_append_error_marker`` directly already pass ``retryable=True``; this set
|
||||
# covers the codes that flow through the adapter -> ``_dispatch_response``.
|
||||
_RETRYABLE_STREAM_ERROR_CODES: frozenset[str] = frozenset(
|
||||
{"transient_api_error", "empty_completion"}
|
||||
)
|
||||
|
||||
|
||||
# Event types that are ephemeral / cosmetic and must NOT be counted toward
|
||||
@@ -535,6 +553,26 @@ async def _reduce_context(
|
||||
return ReducedContext(TranscriptBuilder(), False, None, True, True, retry_target)
|
||||
|
||||
|
||||
def _humanise_tool_list(names: list[str]) -> str:
|
||||
"""Format a list of tool names for user-facing messages.
|
||||
|
||||
``["WebSearch"]`` → ``"'WebSearch'"``
|
||||
``["WebSearch", "run_block"]`` → ``"'WebSearch' and 'run_block'"``
|
||||
Three or more items collapse to ``"'A', 'B', and 1 other"`` so the
|
||||
toast stays readable.
|
||||
"""
|
||||
if not names:
|
||||
return ""
|
||||
quoted = [f"'{n}'" for n in names]
|
||||
if len(quoted) == 1:
|
||||
return quoted[0]
|
||||
if len(quoted) == 2:
|
||||
return f"{quoted[0]} and {quoted[1]}"
|
||||
extras = len(quoted) - 2
|
||||
suffix = "others" if extras > 1 else "other"
|
||||
return f"{quoted[0]}, {quoted[1]}, and {extras} {suffix}"
|
||||
|
||||
|
||||
def _append_error_marker(
|
||||
session: ChatSession | None,
|
||||
display_msg: str,
|
||||
@@ -901,38 +939,46 @@ async def _iter_sdk_messages(
|
||||
|
||||
|
||||
def _normalize_model_name(raw_model: str) -> str:
|
||||
"""Normalize a model name for the current routing configuration.
|
||||
"""Normalize a model name for the **actual** SDK CLI transport.
|
||||
|
||||
Two routing modes:
|
||||
Three transports (see ``ChatConfig.effective_transport``):
|
||||
|
||||
1. **OpenRouter active** — the canonical OpenRouter slug is
|
||||
``"<vendor>/<model>"`` (e.g. ``"anthropic/claude-opus-4.6"``,
|
||||
``"moonshotai/kimi-k2.6"``). Pass the prefixed name through
|
||||
1. **OpenRouter** — the canonical OpenRouter slug is
|
||||
``"<vendor>/<model>"`` (e.g. ``"anthropic/claude-opus-4-6"``,
|
||||
``"moonshotai/kimi-k2-6"``). Pass the prefixed name through
|
||||
unchanged so OpenRouter can route to the correct provider. Anthropic
|
||||
names happen to also resolve when stripped, but non-Anthropic vendors
|
||||
(Moonshot, Google, etc.) do not — keeping the prefix is the only form
|
||||
that works for every model in the catalog.
|
||||
2. **Direct Anthropic** — strip the OpenRouter ``anthropic/`` prefix
|
||||
and convert dots to hyphens (``"claude-opus-4.6"`` →
|
||||
``"claude-opus-4-6"``) since the Anthropic Messages API rejects
|
||||
both the prefix and dot-separated versions. Raises ``ValueError``
|
||||
when a non-Anthropic vendor slug is paired with direct-Anthropic
|
||||
mode — silently stripping ``moonshotai/`` would send ``kimi-k2.6``
|
||||
to the Anthropic API and produce an opaque ``model_not_found``
|
||||
error far from the misconfiguration source.
|
||||
2. **Subscription / Direct Anthropic** — strip the OpenRouter
|
||||
``anthropic/`` prefix and convert dots to hyphens
|
||||
(``"claude-opus-4.6"`` → ``"claude-opus-4-6"``). The CLI subprocess
|
||||
(subscription mode) and the Anthropic Messages API both reject the
|
||||
prefix and dot-separated versions. Raises ``ValueError`` when a
|
||||
non-Anthropic vendor slug is paired with these transports — silently
|
||||
stripping ``moonshotai/`` would send ``kimi-k2-6`` to the Anthropic
|
||||
API / CLI and produce an opaque ``model_not_found`` error far from
|
||||
the misconfiguration source.
|
||||
|
||||
Gating on the **actual transport** (not just config shape) matters
|
||||
because subscription mode and OpenRouter config can coexist —
|
||||
``CHAT_USE_CLAUDE_CODE_SUBSCRIPTION=true`` paired with a populated
|
||||
``CHAT_BASE_URL`` / ``CHAT_API_KEY`` (left over from an earlier
|
||||
OpenRouter setup) used to incorrectly pass ``anthropic/claude-opus-4-7``
|
||||
to the CLI subprocess, which the CLI rejects.
|
||||
"""
|
||||
if config.openrouter_active:
|
||||
if config.effective_transport == "openrouter":
|
||||
return raw_model
|
||||
model = raw_model
|
||||
if "/" in model:
|
||||
vendor, model = model.split("/", 1)
|
||||
if vendor != "anthropic":
|
||||
raise ValueError(
|
||||
f"Direct-Anthropic mode (use_openrouter=False or missing "
|
||||
f"OpenRouter credentials) requires an Anthropic model, got "
|
||||
f"vendor={vendor!r} from model={raw_model!r}. Set "
|
||||
f"CHAT_THINKING_STANDARD_MODEL/CHAT_THINKING_ADVANCED_MODEL "
|
||||
f"to an anthropic/* slug, or enable OpenRouter."
|
||||
f"{config.effective_transport!r} transport requires an "
|
||||
f"Anthropic model, got vendor={vendor!r} from "
|
||||
f"model={raw_model!r}. Set CHAT_THINKING_STANDARD_MODEL/"
|
||||
f"CHAT_THINKING_ADVANCED_MODEL to an anthropic/* slug, or "
|
||||
f"enable OpenRouter."
|
||||
)
|
||||
return model.replace(".", "-")
|
||||
|
||||
@@ -1258,6 +1304,58 @@ def _write_cli_session_to_disk(
|
||||
return False
|
||||
|
||||
|
||||
def delete_stale_cli_session_file(
|
||||
sdk_cwd: str,
|
||||
session_id: str,
|
||||
log_prefix: str,
|
||||
) -> bool:
|
||||
"""Delete the local CLI session file at the predictable path.
|
||||
|
||||
Used so a subsequent CLI invocation with ``--session-id`` (no ``--resume``)
|
||||
doesn't trip ``"Session ID already in use"``. Path-traversal guard:
|
||||
rejects paths outside the CLI projects base.
|
||||
|
||||
Returns True if a file was deleted, False otherwise (missing, traversal,
|
||||
or unlink failure).
|
||||
"""
|
||||
real_path = os.path.realpath(cli_session_path(sdk_cwd, session_id))
|
||||
if not real_path.startswith(projects_base() + os.sep):
|
||||
# Mirror ``_write_cli_session_to_disk``'s defense-in-depth: log
|
||||
# rather than fail silently when the resolved path escapes the
|
||||
# projects base. In normal operation this is unreachable
|
||||
# (session_id is a server-generated UUID and ``cli_session_path``
|
||||
# is deterministic), so a hit indicates a config or tampering
|
||||
# issue that's worth surfacing.
|
||||
logger.warning(
|
||||
"%s CLI session delete path outside projects base: %s",
|
||||
log_prefix,
|
||||
os.path.basename(real_path),
|
||||
)
|
||||
return False
|
||||
# Direct unlink — no exists() check (avoids TOCTOU with the file being
|
||||
# deleted by another process between check and unlink).
|
||||
try:
|
||||
Path(real_path).unlink()
|
||||
logger.info(
|
||||
"%s Removed stale local CLI session file at %s",
|
||||
log_prefix,
|
||||
os.path.basename(real_path),
|
||||
)
|
||||
return True
|
||||
except FileNotFoundError:
|
||||
return False
|
||||
except OSError as unlink_err:
|
||||
# Sanitise log: basename + strerror only (no full path / no raw
|
||||
# exception which can echo absolute paths back in some libc errors).
|
||||
logger.warning(
|
||||
"%s Failed to remove stale local CLI session file %s: %s",
|
||||
log_prefix,
|
||||
os.path.basename(real_path),
|
||||
unlink_err.strerror or type(unlink_err).__name__,
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
def read_cli_session_from_disk(
|
||||
sdk_cwd: str,
|
||||
session_id: str,
|
||||
@@ -2026,7 +2124,7 @@ def _dispatch_response(
|
||||
_append_error_marker(
|
||||
ctx.session,
|
||||
response.errorText,
|
||||
retryable=(response.code == "transient_api_error"),
|
||||
retryable=response.code in _RETRYABLE_STREAM_ERROR_CODES,
|
||||
)
|
||||
|
||||
if isinstance(response, StreamReasoningStart):
|
||||
@@ -2354,6 +2452,13 @@ async def _run_stream_attempt(
|
||||
for ev in ctx.compaction.emit_pre_query(ctx.session):
|
||||
yield ev
|
||||
|
||||
# Narrate the silent gap between dispatching the query and the
|
||||
# SDK's first real chunk — usually <1s but can stretch to several
|
||||
# seconds on cold-starts or large contexts. The frontend prefers
|
||||
# this over the generic "Thinking…" copy; fast turns replace it
|
||||
# with content immediately.
|
||||
yield StreamStatus(message="Contacting the model\u2026")
|
||||
|
||||
if ctx.attachments.image_blocks:
|
||||
content_blocks: list[dict[str, Any]] = [
|
||||
*ctx.attachments.image_blocks,
|
||||
@@ -2388,21 +2493,41 @@ async def _run_stream_attempt(
|
||||
yield ev
|
||||
yield StreamHeartbeat()
|
||||
|
||||
# Idle timeout: abort if the SDK has been silent for too long.
|
||||
# Long-running tools use the async "start + poll" pattern so
|
||||
# the MCP handler never blocks longer than the poll cap (5 min)
|
||||
# — a 10-min gap here means the SDK itself is stuck.
|
||||
# Threshold flips to the long cap while a tool is pending; clock never resets.
|
||||
idle_seconds = time.monotonic() - _last_real_msg_time
|
||||
if idle_seconds >= _IDLE_TIMEOUT_SECONDS:
|
||||
threshold = _idle_timeout_threshold(state.adapter)
|
||||
if idle_seconds >= threshold:
|
||||
unresolved_tool_names = sorted(
|
||||
{
|
||||
info.get("name", "unknown")
|
||||
for tid, info in state.adapter.current_tool_calls.items()
|
||||
if tid not in state.adapter.resolved_tool_calls
|
||||
}
|
||||
)
|
||||
logger.error(
|
||||
"%s Idle timeout after %.0fs — aborting stream",
|
||||
"%s Idle timeout after %.0fs (threshold=%ds, "
|
||||
"unresolved tools: %s) — aborting stream",
|
||||
ctx.log_prefix,
|
||||
idle_seconds,
|
||||
threshold,
|
||||
", ".join(unresolved_tool_names) or "none",
|
||||
)
|
||||
# The retryable marker written to the session omits
|
||||
# the `[code:<id>]` prefix — the AI SDK serializer
|
||||
# (`StreamError.to_sse`) attaches that automatically
|
||||
# on the wire so the frontend can still parse a
|
||||
# machine-readable code out of the otherwise opaque
|
||||
# `{type, errorText}` schema.
|
||||
stream_error_code = "idle_timeout"
|
||||
tool_phrase = (
|
||||
f" while running {_humanise_tool_list(unresolved_tool_names)}"
|
||||
if unresolved_tool_names
|
||||
else ""
|
||||
)
|
||||
stream_error_msg = (
|
||||
"The session has been idle for too long. Please try again."
|
||||
f"AutoPilot stopped responding{tool_phrase}. "
|
||||
"This usually means a tool got stuck. Please try again."
|
||||
)
|
||||
stream_error_code = "idle_timeout"
|
||||
_append_error_marker(ctx.session, stream_error_msg, retryable=True)
|
||||
yield StreamError(
|
||||
errorText=stream_error_msg,
|
||||
@@ -3082,22 +3207,7 @@ async def _restore_cli_session_for_turn(
|
||||
# session_id with "Session ID already in use". T1 may have
|
||||
# left a valid file at this path; we clear it so the fallback
|
||||
# path (session_id= without --resume) can create a new session.
|
||||
_stale_path = os.path.realpath(cli_session_path(sdk_cwd, session_id))
|
||||
if Path(_stale_path).exists() and _stale_path.startswith(
|
||||
projects_base() + os.sep
|
||||
):
|
||||
try:
|
||||
Path(_stale_path).unlink()
|
||||
logger.debug(
|
||||
"%s Removed stale local CLI session file for clean fallback",
|
||||
log_prefix,
|
||||
)
|
||||
except OSError as _unlink_err:
|
||||
logger.debug(
|
||||
"%s Failed to remove stale local session file: %s",
|
||||
log_prefix,
|
||||
_unlink_err,
|
||||
)
|
||||
delete_stale_cli_session_file(sdk_cwd, session_id, log_prefix)
|
||||
|
||||
if cli_restore is not None:
|
||||
result.transcript_content = stripped
|
||||
@@ -3943,21 +4053,21 @@ async def stream_chat_completion_sdk( # pyright: ignore[reportGeneralTypeIssues
|
||||
if ctx.use_resume and ctx.resume_file:
|
||||
sdk_options_kwargs_retry["resume"] = ctx.resume_file
|
||||
sdk_options_kwargs_retry.pop("session_id", None)
|
||||
elif "session_id" in sdk_options_kwargs:
|
||||
# Initial invocation used session_id (T1 or mode-switch
|
||||
# T1): keep it so the CLI writes the session file to the
|
||||
# predictable path for upload_transcript(). Storage is
|
||||
# ephemeral per invocation, so no "Session ID already in
|
||||
# use" conflict occurs — no prior file was restored.
|
||||
else:
|
||||
# No --resume on this retry. Whether we entered with
|
||||
# ``session_id`` (T1, mode-switch) or with ``--resume`` (T2+),
|
||||
# we want the recovery turn's CLI write to land on the
|
||||
# predictable ``cli_session_path(.., session_id)`` so the
|
||||
# post-turn ``upload_transcript`` actually picks up the
|
||||
# rescued (compacted) content. Without this, a T2+ retry
|
||||
# would drop session_id to dodge "Session ID already in use",
|
||||
# write to a random path, and the upload would silently grab
|
||||
# the stale pre-failure file — leaving GCS bloated and
|
||||
# guaranteeing the next turn re-trips prompt-too-long.
|
||||
if sdk_cwd:
|
||||
delete_stale_cli_session_file(sdk_cwd, session_id, log_prefix)
|
||||
sdk_options_kwargs_retry.pop("resume", None)
|
||||
sdk_options_kwargs_retry["session_id"] = session_id
|
||||
else:
|
||||
# T2+ retry without --resume: initial invocation used
|
||||
# --resume, which restored the T1 session file to local
|
||||
# storage. Re-using session_id without --resume would
|
||||
# fail with "Session ID already in use".
|
||||
sdk_options_kwargs_retry.pop("resume", None)
|
||||
sdk_options_kwargs_retry.pop("session_id", None)
|
||||
# Recompute system_prompt for retry — the preset is safe on
|
||||
# every turn (requires CLI ≥ 2.1.98, installed in the Docker
|
||||
# image and configured via CHAT_CLAUDE_AGENT_CLI_PATH).
|
||||
|
||||
@@ -13,6 +13,9 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||
import pytest
|
||||
from claude_agent_sdk import AssistantMessage, TextBlock, ToolUseBlock
|
||||
|
||||
from backend.copilot import config as cfg_mod
|
||||
from backend.copilot.config import ChatConfig
|
||||
|
||||
from .conftest import build_test_transcript as _build_transcript
|
||||
from .service import (
|
||||
_RETRY_TARGET_TOKENS,
|
||||
@@ -23,6 +26,7 @@ from .service import (
|
||||
_iter_sdk_messages,
|
||||
_normalize_model_name,
|
||||
_reduce_context,
|
||||
_resolve_sdk_model_for_request,
|
||||
_restore_cli_session_for_turn,
|
||||
_TokenUsage,
|
||||
)
|
||||
@@ -373,15 +377,15 @@ class TestNormalizeModelName:
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def _direct_anthropic_config(self, monkeypatch: pytest.MonkeyPatch):
|
||||
def _direct_anthropic_config(
|
||||
self, monkeypatch: pytest.MonkeyPatch, _clean_config_env: None
|
||||
):
|
||||
"""Force ``config.openrouter_active = False`` for prefix-strip tests.
|
||||
|
||||
Pins the SDK model fields to anthropic/* so the new
|
||||
``_validate_sdk_model_vendor_compatibility`` model_validator
|
||||
permits ChatConfig construction.
|
||||
"""
|
||||
from backend.copilot import config as cfg_mod
|
||||
|
||||
cfg = cfg_mod.ChatConfig(
|
||||
use_openrouter=False,
|
||||
api_key=None,
|
||||
@@ -393,10 +397,10 @@ class TestNormalizeModelName:
|
||||
monkeypatch.setattr("backend.copilot.sdk.service.config", cfg)
|
||||
|
||||
@pytest.fixture
|
||||
def _openrouter_config(self, monkeypatch: pytest.MonkeyPatch):
|
||||
def _openrouter_config(
|
||||
self, monkeypatch: pytest.MonkeyPatch, _clean_config_env: None
|
||||
):
|
||||
"""Force ``config.openrouter_active = True`` for slug-preservation tests."""
|
||||
from backend.copilot import config as cfg_mod
|
||||
|
||||
cfg = cfg_mod.ChatConfig(
|
||||
use_openrouter=True,
|
||||
api_key="or-key",
|
||||
@@ -445,6 +449,172 @@ class TestNormalizeModelName:
|
||||
"""Non-Anthropic vendors (Moonshot) require the prefix to route."""
|
||||
assert _normalize_model_name("moonshotai/kimi-k2.6") == "moonshotai/kimi-k2.6"
|
||||
|
||||
@pytest.fixture
|
||||
def _subscription_with_openrouter_config(
|
||||
self, monkeypatch: pytest.MonkeyPatch, _clean_config_env: None
|
||||
):
|
||||
"""Subscription mode with leftover OpenRouter base_url + api_key.
|
||||
|
||||
Reproduces the bug: ``CHAT_USE_CLAUDE_CODE_SUBSCRIPTION=true`` plus
|
||||
a populated ``CHAT_BASE_URL`` (e.g. left over from an earlier
|
||||
OpenRouter setup) used to incorrectly preserve the OpenRouter slug
|
||||
because the gate checked config shape (``openrouter_active``) not
|
||||
actual transport. The CLI subprocess uses OAuth here and rejects
|
||||
the OpenRouter format.
|
||||
"""
|
||||
cfg = cfg_mod.ChatConfig(
|
||||
use_openrouter=True,
|
||||
api_key="or-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
use_claude_code_subscription=True,
|
||||
)
|
||||
monkeypatch.setattr("backend.copilot.sdk.service.config", cfg)
|
||||
|
||||
def test_subscription_strips_anthropic_prefix_despite_openrouter_config(
|
||||
self, _subscription_with_openrouter_config
|
||||
):
|
||||
"""Subscription transport must produce the CLI-friendly form even
|
||||
when OpenRouter base_url + api_key are set — the CLI uses OAuth
|
||||
and ignores those fields, so the OpenRouter slug would be rejected."""
|
||||
assert _normalize_model_name("anthropic/claude-opus-4.7") == "claude-opus-4-7"
|
||||
|
||||
def test_subscription_rejects_non_anthropic_vendor(
|
||||
self, _subscription_with_openrouter_config
|
||||
):
|
||||
"""The CLI subprocess can only talk to Anthropic models — Kimi via
|
||||
Moonshot must raise so the resolver falls back to a tier default
|
||||
instead of feeding an unroutable slug to the CLI."""
|
||||
with pytest.raises(ValueError, match="requires an Anthropic model"):
|
||||
_normalize_model_name("moonshotai/kimi-k2.6")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ChatConfig.effective_transport — single source of truth for "which
|
||||
# transport will the SDK CLI actually use?"
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEffectiveTransport:
|
||||
"""Subscription mode wins over OpenRouter even when OpenRouter
|
||||
base_url + api_key are set, because the CLI subprocess uses OAuth and
|
||||
ignores ``CHAT_BASE_URL`` / ``CHAT_API_KEY`` (see ``build_sdk_env``
|
||||
mode 1). Picking the right transport here is what lets
|
||||
``_normalize_model_name`` produce the correct model-name format.
|
||||
"""
|
||||
|
||||
def test_subscription_wins_over_openrouter_config(self, _clean_config_env):
|
||||
cfg = ChatConfig(
|
||||
use_openrouter=True,
|
||||
api_key="or-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
use_claude_code_subscription=True,
|
||||
)
|
||||
assert cfg.effective_transport == "subscription"
|
||||
# ``openrouter_active`` is still True (config-shape check) but
|
||||
# the actual transport is subscription.
|
||||
assert cfg.openrouter_active is True
|
||||
|
||||
def test_openrouter_when_subscription_disabled(self, _clean_config_env):
|
||||
cfg = ChatConfig(
|
||||
use_openrouter=True,
|
||||
api_key="or-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
use_claude_code_subscription=False,
|
||||
)
|
||||
assert cfg.effective_transport == "openrouter"
|
||||
|
||||
def test_direct_anthropic_when_no_openrouter_no_subscription(
|
||||
self, _clean_config_env
|
||||
):
|
||||
cfg = ChatConfig(
|
||||
use_openrouter=False,
|
||||
api_key=None,
|
||||
base_url=None,
|
||||
use_claude_code_subscription=False,
|
||||
thinking_standard_model="anthropic/claude-sonnet-4-6",
|
||||
thinking_advanced_model="anthropic/claude-opus-4-7",
|
||||
)
|
||||
assert cfg.effective_transport == "direct_anthropic"
|
||||
|
||||
def test_subscription_alone_is_subscription(self, _clean_config_env):
|
||||
cfg = ChatConfig(
|
||||
use_openrouter=False,
|
||||
api_key=None,
|
||||
base_url=None,
|
||||
use_claude_code_subscription=True,
|
||||
)
|
||||
assert cfg.effective_transport == "subscription"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _resolve_sdk_model_for_request — transport-aware LD-override normalisation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestResolveSdkModelForRequestTransportAware:
|
||||
"""When subscription mode is on but the deployment also has OpenRouter
|
||||
config populated (e.g. ``CHAT_BASE_URL`` left over from a previous
|
||||
setup), an LD-served override must be normalised for the **subscription
|
||||
CLI**, not passed through as the OpenRouter slug. The CLI subprocess
|
||||
uses OAuth and rejects ``anthropic/claude-opus-4.7`` with the model
|
||||
error reproduced in local debugging:
|
||||
|
||||
``There's an issue with the selected model
|
||||
(anthropic/claude-opus-4.7). It may not exist or you may not have
|
||||
access to it.``
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_subscription_advanced_override_normalised_for_cli(
|
||||
self, monkeypatch: pytest.MonkeyPatch, _clean_config_env: None
|
||||
):
|
||||
cfg = cfg_mod.ChatConfig(
|
||||
thinking_standard_model="anthropic/claude-sonnet-4-6",
|
||||
thinking_advanced_model="anthropic/claude-opus-4.7",
|
||||
claude_agent_model=None,
|
||||
use_openrouter=True,
|
||||
api_key="or-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
use_claude_code_subscription=True,
|
||||
)
|
||||
monkeypatch.setattr("backend.copilot.sdk.service.config", cfg)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.service._resolve_thinking_model_for_user",
|
||||
new=AsyncMock(return_value="anthropic/claude-opus-4.7"),
|
||||
):
|
||||
resolved = await _resolve_sdk_model_for_request(
|
||||
model="advanced", session_id="sess-adv", user_id="user-1"
|
||||
)
|
||||
# NOT the OpenRouter slug, NOT None — the CLI-friendly hyphenated form.
|
||||
assert resolved == "claude-opus-4-7"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_subscription_standard_no_override_returns_none(
|
||||
self, monkeypatch: pytest.MonkeyPatch, _clean_config_env: None
|
||||
):
|
||||
"""When LD agrees with the config default, subscription mode still
|
||||
wins on the standard tier — returns ``None`` so the CLI picks the
|
||||
subscription default model."""
|
||||
cfg = cfg_mod.ChatConfig(
|
||||
thinking_standard_model="anthropic/claude-sonnet-4-6",
|
||||
claude_agent_model=None,
|
||||
use_openrouter=True,
|
||||
api_key="or-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
use_claude_code_subscription=True,
|
||||
)
|
||||
monkeypatch.setattr("backend.copilot.sdk.service.config", cfg)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.service._resolve_thinking_model_for_user",
|
||||
new=AsyncMock(return_value="anthropic/claude-sonnet-4-6"),
|
||||
):
|
||||
resolved = await _resolve_sdk_model_for_request(
|
||||
model="standard", session_id="sess-std", user_id="user-1"
|
||||
)
|
||||
assert resolved is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _TokenUsage — null-safe accumulation (OpenRouter initial-stream-event bug)
|
||||
@@ -566,17 +736,20 @@ def _build_retry_sdk_options(
|
||||
ctx_resume_file: str | None,
|
||||
session_id: str,
|
||||
) -> dict:
|
||||
"""Mirror the retry branch in stream_chat_completion_sdk."""
|
||||
"""Mirror the retry branch in stream_chat_completion_sdk.
|
||||
|
||||
Production-side companion: ``delete_stale_cli_session_file`` is invoked
|
||||
on every non-resume retry path so the CLI doesn't trip "Session ID
|
||||
already in use" when we re-attach ``session_id``. This helper only
|
||||
mirrors the kwarg shape (file-system side effect is tested separately).
|
||||
"""
|
||||
retry: dict = dict(initial_kwargs)
|
||||
if ctx_use_resume and ctx_resume_file:
|
||||
retry["resume"] = ctx_resume_file
|
||||
retry.pop("session_id", None)
|
||||
elif "session_id" in initial_kwargs:
|
||||
retry.pop("resume", None)
|
||||
retry["session_id"] = session_id
|
||||
else:
|
||||
retry.pop("resume", None)
|
||||
retry.pop("session_id", None)
|
||||
retry["session_id"] = session_id
|
||||
return retry
|
||||
|
||||
|
||||
@@ -648,12 +821,21 @@ class TestSdkSessionIdSelection:
|
||||
assert retry.get("session_id") == self.SESSION_ID
|
||||
assert "resume" not in retry
|
||||
|
||||
def test_retry_removes_session_id_for_t2_plus(self):
|
||||
"""Retry for T2+ (initial used --resume) removes session_id to avoid conflict."""
|
||||
def test_retry_keeps_session_id_for_t2_plus(self):
|
||||
"""Retry for T2+ now keeps session_id so the recovery turn writes to
|
||||
the predictable ``cli_session_path`` and gets uploaded. Production
|
||||
clears the stale local file via ``delete_stale_cli_session_file``
|
||||
before this branch runs to dodge "Session ID already in use".
|
||||
|
||||
Regression guard for SENTRY-1207: previously this branch dropped
|
||||
session_id, the CLI wrote to a random path, and the post-turn
|
||||
upload silently grabbed the stale pre-failure file — so GCS stayed
|
||||
bloated and every subsequent turn re-tripped prompt-too-long.
|
||||
"""
|
||||
initial = _build_sdk_options(True, self.SESSION_ID, self.SESSION_ID)
|
||||
# T2+ retry where context reduction dropped --resume
|
||||
retry = _build_retry_sdk_options(initial, False, None, self.SESSION_ID)
|
||||
assert "session_id" not in retry
|
||||
assert retry.get("session_id") == self.SESSION_ID
|
||||
assert "resume" not in retry
|
||||
|
||||
def test_retry_t2_with_resume_sets_resume(self):
|
||||
@@ -1127,3 +1309,78 @@ class TestCompactionTargetTokens:
|
||||
|
||||
# Target derived from the RUNTIME model, not the compactor model.
|
||||
assert captured["target_tokens"] == 12345
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# delete_stale_cli_session_file — clears a leftover local session file so a
|
||||
# subsequent --session-id (no --resume) invocation doesn't trip "Session ID
|
||||
# already in use". Critical for the prompt-too-long retry path.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestDeleteStaleCliSessionFile:
|
||||
def test_deletes_file_when_present(self, tmp_path) -> None:
|
||||
from backend.copilot.sdk.service import delete_stale_cli_session_file
|
||||
|
||||
sdk_cwd = str(tmp_path / "cwd")
|
||||
session_id = "sess-deadbeef"
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.sdk.service.cli_session_path",
|
||||
return_value=str(tmp_path / "session.jsonl"),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.service.projects_base",
|
||||
return_value=str(tmp_path),
|
||||
),
|
||||
):
|
||||
target = tmp_path / "session.jsonl"
|
||||
target.write_text("{}\n")
|
||||
|
||||
removed = delete_stale_cli_session_file(sdk_cwd, session_id, "[t]")
|
||||
|
||||
assert removed is True
|
||||
assert not target.exists()
|
||||
|
||||
def test_returns_false_when_file_missing(self, tmp_path) -> None:
|
||||
from backend.copilot.sdk.service import delete_stale_cli_session_file
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.sdk.service.cli_session_path",
|
||||
return_value=str(tmp_path / "missing.jsonl"),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.service.projects_base",
|
||||
return_value=str(tmp_path),
|
||||
),
|
||||
):
|
||||
removed = delete_stale_cli_session_file("/cwd", "sess", "[t]")
|
||||
|
||||
assert removed is False
|
||||
|
||||
def test_path_traversal_guard_rejects_outside_projects_base(self, tmp_path) -> None:
|
||||
"""Refuse to delete files outside the projects base, even if they exist."""
|
||||
from backend.copilot.sdk.service import delete_stale_cli_session_file
|
||||
|
||||
outside = tmp_path / "outside.jsonl"
|
||||
outside.write_text("data")
|
||||
projects = tmp_path / "projects"
|
||||
projects.mkdir()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.sdk.service.cli_session_path",
|
||||
return_value=str(outside),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.service.projects_base",
|
||||
return_value=str(projects),
|
||||
),
|
||||
):
|
||||
removed = delete_stale_cli_session_file("/cwd", "sess", "[t]")
|
||||
|
||||
# File was outside projects base — guard rejected, file untouched.
|
||||
assert removed is False
|
||||
assert outside.exists()
|
||||
|
||||
@@ -12,8 +12,11 @@ import pytest
|
||||
from backend.copilot import config as cfg_mod
|
||||
|
||||
from .service import (
|
||||
_HUNG_TOOL_CAP_SECONDS,
|
||||
_IDLE_TIMEOUT_SECONDS,
|
||||
_build_system_prompt_value,
|
||||
_humanise_tool_list,
|
||||
_idle_timeout_threshold,
|
||||
_is_sdk_disconnect_error,
|
||||
_normalize_model_name,
|
||||
_prepare_file_attachments,
|
||||
@@ -323,27 +326,11 @@ class TestCleanupSdkToolResults:
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Env vars that ChatConfig validators read — must be cleared so explicit
|
||||
# constructor values are used.
|
||||
# Env-cleanup fixture is shared via ``conftest._clean_config_env``. This
|
||||
# file exposes a re-export for callers that don't rely on conftest discovery
|
||||
# (kept for backwards compatibility — pytest finds the conftest fixture
|
||||
# automatically without an explicit import).
|
||||
# ---------------------------------------------------------------------------
|
||||
_CONFIG_ENV_VARS = (
|
||||
"CHAT_USE_OPENROUTER",
|
||||
"CHAT_API_KEY",
|
||||
"OPEN_ROUTER_API_KEY",
|
||||
"OPENAI_API_KEY",
|
||||
"CHAT_BASE_URL",
|
||||
"OPENROUTER_BASE_URL",
|
||||
"OPENAI_BASE_URL",
|
||||
"CHAT_USE_CLAUDE_CODE_SUBSCRIPTION",
|
||||
"CHAT_USE_CLAUDE_AGENT_SDK",
|
||||
"CHAT_CLAUDE_AGENT_CROSS_USER_PROMPT_CACHE",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def _clean_config_env(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
for var in _CONFIG_ENV_VARS:
|
||||
monkeypatch.delenv(var, raising=False)
|
||||
|
||||
|
||||
class TestNormalizeModelName:
|
||||
@@ -617,7 +604,13 @@ class TestResolveSdkModelForRequestLdFallback:
|
||||
on ``copilot-model-routing[thinking][standard]`` returned
|
||||
``None`` (CLI picked subscription default Opus), silently
|
||||
ignoring the LD override. An LD value different from the
|
||||
config default is an explicit admin decision and must win."""
|
||||
config default is an explicit admin decision and must win.
|
||||
|
||||
Subscription transport rejects non-Anthropic vendors (the CLI
|
||||
subprocess can't talk to Moonshot), so the resolver fails soft
|
||||
to the tier default normalised for the subscription transport
|
||||
(``claude-sonnet-4-6``) — not ``None``, which would silently
|
||||
re-introduce the old subscription-default bypass."""
|
||||
cfg = cfg_mod.ChatConfig(
|
||||
thinking_standard_model="anthropic/claude-sonnet-4-6",
|
||||
claude_agent_model=None,
|
||||
@@ -635,8 +628,9 @@ class TestResolveSdkModelForRequestLdFallback:
|
||||
resolved = await _resolve_sdk_model_for_request(
|
||||
model="standard", session_id="sess-std-sub", user_id="user-1"
|
||||
)
|
||||
# Expect LD-served Kimi, NOT None (the old subscription-default bypass)
|
||||
assert resolved == "moonshotai/kimi-k2.6"
|
||||
# Kimi can't be served by the subscription CLI; fail-soft to
|
||||
# the tier default normalised for the active transport.
|
||||
assert resolved == "claude-sonnet-4-6"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_standard_subscription_survives_trailing_whitespace_in_env(
|
||||
@@ -703,7 +697,10 @@ class TestResolveSdkModelForRequestLdFallback:
|
||||
"""Subscription mode bypasses LD only on the standard tier —
|
||||
the advanced tier always consults LD because the user explicitly
|
||||
asked for the premium path. A subscription + advanced request
|
||||
with LD-served Opus must return Opus (not ``None``)."""
|
||||
with LD-served Opus must return Opus normalised for the
|
||||
subscription CLI (``claude-opus-4-7``), not the OpenRouter slug
|
||||
``anthropic/claude-opus-4.7`` which the CLI subprocess rejects
|
||||
even when ``CHAT_BASE_URL`` is set to the OpenRouter proxy."""
|
||||
cfg = cfg_mod.ChatConfig(
|
||||
thinking_standard_model="anthropic/claude-sonnet-4-6",
|
||||
thinking_advanced_model="anthropic/claude-opus-4.7",
|
||||
@@ -722,7 +719,7 @@ class TestResolveSdkModelForRequestLdFallback:
|
||||
resolved = await _resolve_sdk_model_for_request(
|
||||
model="advanced", session_id="sess-adv-sub", user_id="user-1"
|
||||
)
|
||||
assert resolved == "anthropic/claude-opus-4.7"
|
||||
assert resolved == "claude-opus-4-7"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -907,14 +904,51 @@ class TestSystemPromptPreset:
|
||||
assert cfg.claude_agent_cross_user_prompt_cache is False
|
||||
|
||||
|
||||
class TestIdleTimeoutConstant:
|
||||
"""SECRT-2247: long-running work now uses async start+poll pattern
|
||||
(run_sub_session / run_agent), so no single MCP tool call ever blocks
|
||||
the stream close to the idle limit. The plain 10-min cap from the
|
||||
original code is restored."""
|
||||
class TestStreamErrorCodePrefix:
|
||||
"""StreamError.to_sse auto-prefixes errorText with `[code:<id>]` when a
|
||||
code is set, so the frontend can parse a machine-readable code out of
|
||||
the AI-SDK's strict `{type, errorText}` schema."""
|
||||
|
||||
def test_idle_timeout_is_10_min(self):
|
||||
assert _IDLE_TIMEOUT_SECONDS == 10 * 60
|
||||
def test_auto_prefix_when_code_set(self):
|
||||
from backend.copilot.response_model import StreamError
|
||||
|
||||
sse = StreamError(errorText="Boom", code="idle_timeout").to_sse()
|
||||
assert '"errorText":"[code:idle_timeout] Boom"' in sse
|
||||
|
||||
def test_no_prefix_when_code_missing(self):
|
||||
from backend.copilot.response_model import StreamError
|
||||
|
||||
sse = StreamError(errorText="Boom").to_sse()
|
||||
assert '"errorText":"Boom"' in sse
|
||||
|
||||
def test_does_not_double_prefix(self):
|
||||
from backend.copilot.response_model import StreamError
|
||||
|
||||
sse = StreamError(errorText="[code:x] Boom", code="x").to_sse()
|
||||
assert "[code:x] [code:x]" not in sse
|
||||
assert '"errorText":"[code:x] Boom"' in sse
|
||||
|
||||
|
||||
class TestHumaniseToolList:
|
||||
"""Tool-name formatter used to build the idle-timeout error message."""
|
||||
|
||||
def test_empty_returns_empty_string(self):
|
||||
assert _humanise_tool_list([]) == ""
|
||||
|
||||
def test_single_tool_is_quoted(self):
|
||||
assert _humanise_tool_list(["WebSearch"]) == "'WebSearch'"
|
||||
|
||||
def test_two_tools_are_joined_with_and(self):
|
||||
assert (
|
||||
_humanise_tool_list(["WebSearch", "run_block"])
|
||||
== "'WebSearch' and 'run_block'"
|
||||
)
|
||||
|
||||
def test_three_uses_singular_other(self):
|
||||
assert _humanise_tool_list(["a", "b", "c"]) == "'a', 'b', and 1 other"
|
||||
|
||||
def test_four_plus_uses_plural_others(self):
|
||||
assert _humanise_tool_list(["a", "b", "c", "d"]) == "'a', 'b', and 2 others"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -1137,3 +1171,61 @@ class TestMoonshotHelperReexports:
|
||||
from .service import _override_cost_for_moonshot
|
||||
|
||||
assert _override_cost_for_moonshot is canonical
|
||||
|
||||
|
||||
class TestIdleTimeoutThreshold:
|
||||
"""SECRT-2247: stream uses two idle thresholds. The shorter 30-min threshold
|
||||
fires when the SDK is idle with no tool pending. The longer 2-hour cap
|
||||
applies while any tool call is pending so a 45-min sub-AutoPilot isn't
|
||||
killed, but a truly hung tool still eventually frees session resources."""
|
||||
|
||||
def _make_adapter(self, current: dict, resolved: set):
|
||||
from backend.copilot.sdk.response_adapter import SDKResponseAdapter
|
||||
|
||||
adapter = SDKResponseAdapter(session_id="test")
|
||||
adapter.current_tool_calls = current
|
||||
adapter.resolved_tool_calls = resolved
|
||||
return adapter
|
||||
|
||||
def test_threshold_uses_long_cap_with_unresolved_tool_call(self):
|
||||
adapter = self._make_adapter(
|
||||
current={"t1": {"name": "run_block"}},
|
||||
resolved=set(),
|
||||
)
|
||||
assert _idle_timeout_threshold(adapter) == _HUNG_TOOL_CAP_SECONDS
|
||||
|
||||
def test_threshold_uses_short_cap_when_all_tools_resolved(self):
|
||||
adapter = self._make_adapter(
|
||||
current={"t1": {"name": "find_agent"}},
|
||||
resolved={"t1"},
|
||||
)
|
||||
assert _idle_timeout_threshold(adapter) == _IDLE_TIMEOUT_SECONDS
|
||||
|
||||
def test_threshold_uses_short_cap_with_no_tool_calls(self):
|
||||
adapter = self._make_adapter(current={}, resolved=set())
|
||||
assert _idle_timeout_threshold(adapter) == _IDLE_TIMEOUT_SECONDS
|
||||
|
||||
def test_threshold_uses_long_cap_with_mixed_resolved_and_pending(self):
|
||||
adapter = self._make_adapter(
|
||||
current={
|
||||
"t1": {"name": "find_agent"},
|
||||
"t2": {"name": "run_block"},
|
||||
},
|
||||
resolved={"t1"},
|
||||
)
|
||||
assert _idle_timeout_threshold(adapter) == _HUNG_TOOL_CAP_SECONDS
|
||||
|
||||
def test_idle_timeout_is_30_min_not_the_old_10(self):
|
||||
# Regression guard: the old 10-min value killed long tool calls
|
||||
# (SECRT-2247). New idle-without-tools cap is 30 min.
|
||||
assert _IDLE_TIMEOUT_SECONDS == 30 * 60
|
||||
|
||||
def test_hung_tool_cap_is_2_hours(self):
|
||||
# Hard cap protects against a hung tool leaking resources forever.
|
||||
# 2 hours is plenty for any legitimate sub-AutoPilot or graph run.
|
||||
assert _HUNG_TOOL_CAP_SECONDS == 2 * 60 * 60
|
||||
|
||||
def test_long_cap_is_strictly_longer_than_short_cap(self):
|
||||
# The whole point of the two-regime design: pending tools get more
|
||||
# patience than pure idle.
|
||||
assert _HUNG_TOOL_CAP_SECONDS > _IDLE_TIMEOUT_SECONDS
|
||||
|
||||
@@ -48,6 +48,7 @@ from .response_model import (
|
||||
StreamReasoningStart,
|
||||
StreamStart,
|
||||
StreamStartStep,
|
||||
StreamStatus,
|
||||
StreamTextDelta,
|
||||
StreamTextEnd,
|
||||
StreamTextStart,
|
||||
@@ -89,8 +90,19 @@ class ActiveSession:
|
||||
|
||||
|
||||
def _get_session_meta_key(session_id: str) -> str:
|
||||
"""Get Redis key for session metadata (keyed by session_id).
|
||||
|
||||
Hash-tag braces colocate this key with ``pending_messages._buffer_key``
|
||||
on the same Redis Cluster slot — the gated-rpush Lua script touches both
|
||||
keys atomically and would CROSSSLOT-fail if they hashed to different
|
||||
shards.
|
||||
"""
|
||||
return f"{config.session_meta_prefix}{{{session_id}}}"
|
||||
|
||||
|
||||
def get_session_meta_key(session_id: str) -> str:
|
||||
"""Get Redis key for session metadata (keyed by session_id)."""
|
||||
return f"{config.session_meta_prefix}{session_id}"
|
||||
return _get_session_meta_key(session_id)
|
||||
|
||||
|
||||
def _get_turn_stream_key(turn_id: str) -> str:
|
||||
@@ -1093,6 +1105,7 @@ def _reconstruct_chunk(chunk_data: dict) -> StreamBaseResponse | None:
|
||||
ResponseType.ERROR.value: StreamError,
|
||||
ResponseType.USAGE.value: StreamUsage,
|
||||
ResponseType.HEARTBEAT.value: StreamHeartbeat,
|
||||
ResponseType.STATUS.value: StreamStatus,
|
||||
}
|
||||
|
||||
chunk_type = chunk_data.get("type")
|
||||
|
||||
@@ -343,3 +343,73 @@ async def test_mark_session_completed_survives_lock_release_redis_error():
|
||||
isinstance(call.args[1], stream_registry.StreamFinish)
|
||||
for call in publish_mock.call_args_list
|
||||
), "StreamFinish must still be published even if lock DELETE raises"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Replays must contain protocol chunks only. Redis cursor data parts are not
|
||||
# emitted because AI SDK resume needs the complete stream envelope from 0-0.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_subscribe_to_session_replays_chunks_without_cursor_parts():
|
||||
"""During replay, the subscriber queue contains chunks plus terminal finish."""
|
||||
import orjson
|
||||
|
||||
from backend.copilot.response_model import (
|
||||
StreamTextDelta,
|
||||
StreamTextEnd,
|
||||
StreamTextStart,
|
||||
)
|
||||
|
||||
# Three chunks recorded in Redis for a completed turn. Completed status
|
||||
# means the listener branch is skipped and only the replay path runs,
|
||||
# which keeps the test hermetic.
|
||||
stream_key_msgs = [
|
||||
(
|
||||
"9999-0",
|
||||
{"data": orjson.dumps(StreamTextStart(id="blk-1").model_dump()).decode()},
|
||||
),
|
||||
(
|
||||
"9999-1",
|
||||
{
|
||||
"data": orjson.dumps(
|
||||
StreamTextDelta(id="blk-1", delta="hi").model_dump()
|
||||
).decode()
|
||||
},
|
||||
),
|
||||
(
|
||||
"9999-2",
|
||||
{"data": orjson.dumps(StreamTextEnd(id="blk-1").model_dump()).decode()},
|
||||
),
|
||||
]
|
||||
|
||||
fake_redis = AsyncMock()
|
||||
fake_redis.hgetall = AsyncMock(
|
||||
return_value={
|
||||
"user_id": "u1",
|
||||
"session_id": "sess-1",
|
||||
"turn_id": "turn-1",
|
||||
"status": "completed", # finished → no listener task
|
||||
}
|
||||
)
|
||||
fake_redis.xread = AsyncMock(return_value=[("stream-key", stream_key_msgs)])
|
||||
|
||||
with patch.object(
|
||||
stream_registry, "get_redis_async", new=AsyncMock(return_value=fake_redis)
|
||||
):
|
||||
queue = await stream_registry.subscribe_to_session(
|
||||
session_id="sess-1", user_id="u1", last_message_id="0-0"
|
||||
)
|
||||
|
||||
assert queue is not None
|
||||
|
||||
delivered = []
|
||||
while not queue.empty():
|
||||
delivered.append(queue.get_nowait())
|
||||
|
||||
assert len(delivered) == 4
|
||||
assert isinstance(delivered[0], StreamTextStart)
|
||||
assert isinstance(delivered[1], StreamTextDelta)
|
||||
assert isinstance(delivered[2], StreamTextEnd)
|
||||
assert isinstance(delivered[3], stream_registry.StreamFinish)
|
||||
|
||||
@@ -10,6 +10,7 @@ from backend.data.execution import (
|
||||
ExecutionStatus,
|
||||
GraphExecution,
|
||||
GraphExecutionEvent,
|
||||
exec_channel,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -81,7 +82,7 @@ async def wait_for_execution(
|
||||
)
|
||||
|
||||
event_bus = AsyncRedisExecutionEventBus()
|
||||
channel_key = f"{user_id}/{graph_id}/{execution_id}"
|
||||
channel_key = exec_channel(user_id, graph_id, execution_id)
|
||||
|
||||
# Mutable container so _subscribe_and_wait can surface the task even if
|
||||
# asyncio.wait_for cancels the coroutine before it returns.
|
||||
|
||||
@@ -949,24 +949,45 @@ class UserCredit(UserCreditBase):
|
||||
f"Top up amount must be at least 500 credits and multiple of 100 but is {amount}"
|
||||
)
|
||||
|
||||
# Resolve the Stripe Product ID from LD; when unset (default), keep the
|
||||
# legacy inline product_data path (Stripe creates an ephemeral product
|
||||
# per Checkout). When set, reference the canonical Product so all
|
||||
# top-ups group under one entity in Stripe Dashboard reporting; the
|
||||
# amount stays dynamic via unit_amount.
|
||||
topup_product_id = await get_feature_flag_value(
|
||||
Flag.STRIPE_PRODUCT_ID_TOPUP.value, user_id, default=None
|
||||
)
|
||||
line_items: list[stripe.checkout.Session.CreateParamsLineItem] = (
|
||||
[
|
||||
{
|
||||
"price_data": {
|
||||
"currency": "usd",
|
||||
"product": topup_product_id,
|
||||
"unit_amount": amount,
|
||||
},
|
||||
"quantity": 1,
|
||||
}
|
||||
]
|
||||
if isinstance(topup_product_id, str) and topup_product_id
|
||||
else [
|
||||
{
|
||||
"price_data": {
|
||||
"currency": "usd",
|
||||
"product_data": {"name": "AutoGPT Platform Credits"},
|
||||
"unit_amount": amount,
|
||||
},
|
||||
"quantity": 1,
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
# Create checkout session
|
||||
# https://docs.stripe.com/checkout/quickstart?client=react
|
||||
# unit_amount param is always in the smallest currency unit (so cents for usd)
|
||||
# which is equal to amount of credits
|
||||
checkout_session = stripe.checkout.Session.create(
|
||||
customer=await get_stripe_customer_id(user_id),
|
||||
line_items=[
|
||||
{
|
||||
"price_data": {
|
||||
"currency": "usd",
|
||||
"product_data": {
|
||||
"name": "AutoGPT Platform Credits",
|
||||
},
|
||||
"unit_amount": amount,
|
||||
},
|
||||
"quantity": 1,
|
||||
}
|
||||
],
|
||||
line_items=line_items,
|
||||
mode="payment",
|
||||
ui_mode="hosted",
|
||||
payment_intent_data={"setup_future_usage": "off_session"},
|
||||
@@ -1442,6 +1463,7 @@ async def get_proration_credit_cents(user_id: str, monthly_cost_cents: int) -> i
|
||||
# (move right) from downgrades (move left); ENTERPRISE is admin-managed and
|
||||
# never reached via self-service flows.
|
||||
_TIER_ORDER: tuple[SubscriptionTier, ...] = (
|
||||
SubscriptionTier.NO_TIER,
|
||||
SubscriptionTier.BASIC,
|
||||
SubscriptionTier.PRO,
|
||||
SubscriptionTier.MAX,
|
||||
@@ -1479,6 +1501,30 @@ async def _get_active_subscription(customer_id: str) -> stripe.Subscription | No
|
||||
return None
|
||||
|
||||
|
||||
async def get_active_subscription_period_end(user_id: str) -> int | None:
|
||||
"""Return the Unix timestamp of the active sub's current_period_end, or None.
|
||||
|
||||
Used to surface "next invoice on {date}" in upgrade dialog UX. Returns None
|
||||
for users without a Stripe customer or active sub. Stripe failures swallow
|
||||
to None — UX falls back to generic copy if the lookup misfires.
|
||||
"""
|
||||
user = await get_user_by_id(user_id)
|
||||
if not user.stripe_customer_id:
|
||||
return None
|
||||
try:
|
||||
sub = await _get_active_subscription(user.stripe_customer_id)
|
||||
except stripe.StripeError:
|
||||
logger.warning(
|
||||
"get_active_subscription_period_end: Stripe lookup failed for user %s",
|
||||
user_id,
|
||||
)
|
||||
return None
|
||||
if sub is None:
|
||||
return None
|
||||
period_end = sub.current_period_end
|
||||
return int(period_end) if period_end else None
|
||||
|
||||
|
||||
# Substrings Stripe uses in InvalidRequestError messages when the schedule is
|
||||
# already in a terminal state (released / completed / canceled) and therefore
|
||||
# cannot be released again. We only swallow the error when one of these appears;
|
||||
@@ -1670,7 +1716,7 @@ async def modify_stripe_subscription_for_tier(
|
||||
user = await get_user_by_id(user_id)
|
||||
if not user.stripe_customer_id:
|
||||
return False
|
||||
current_tier = user.subscription_tier or SubscriptionTier.BASIC
|
||||
current_tier = user.subscription_tier or SubscriptionTier.NO_TIER
|
||||
|
||||
sub = await _get_active_subscription(user.stripe_customer_id)
|
||||
if sub is None:
|
||||
@@ -1891,7 +1937,7 @@ async def get_pending_subscription_change(
|
||||
return None
|
||||
effective_at = datetime.fromtimestamp(period_end, tz=timezone.utc)
|
||||
if sub.cancel_at_period_end:
|
||||
return SubscriptionTier.BASIC, effective_at
|
||||
return SubscriptionTier.NO_TIER, effective_at
|
||||
if not sub.schedule:
|
||||
return None
|
||||
schedule_id = sub.schedule if isinstance(sub.schedule, str) else sub.schedule.id
|
||||
@@ -1986,6 +2032,7 @@ async def create_subscription_checkout(
|
||||
success_url=success_url,
|
||||
cancel_url=cancel_url,
|
||||
subscription_data={"metadata": {"user_id": user_id, "tier": tier.value}},
|
||||
allow_promotion_codes=True,
|
||||
)
|
||||
if not session.url:
|
||||
# An empty checkout URL for a paid upgrade is always an error; surfacing it
|
||||
@@ -2071,7 +2118,7 @@ async def sync_subscription_from_stripe(stripe_subscription: dict) -> None:
|
||||
# ENTERPRISE user to a different tier — if a user on ENTERPRISE somehow has
|
||||
# a self-service Stripe sub, it's a data-consistency issue for an operator,
|
||||
# not something the webhook should automatically "fix".
|
||||
current_tier = user.subscriptionTier or SubscriptionTier.BASIC
|
||||
current_tier = user.subscriptionTier or SubscriptionTier.NO_TIER
|
||||
if current_tier == SubscriptionTier.ENTERPRISE:
|
||||
logger.warning(
|
||||
"sync_subscription_from_stripe: refusing to overwrite ENTERPRISE tier"
|
||||
@@ -2170,7 +2217,7 @@ async def sync_subscription_from_stripe(stripe_subscription: dict) -> None:
|
||||
current_tier.value,
|
||||
)
|
||||
return
|
||||
tier = SubscriptionTier.BASIC
|
||||
tier = SubscriptionTier.NO_TIER
|
||||
# Idempotency: Stripe retries webhooks on delivery failure, and several event
|
||||
# types map to the same final tier. Skip the DB write + cache invalidation
|
||||
# when the tier is already correct to avoid redundant writes on replay.
|
||||
@@ -2269,7 +2316,7 @@ async def handle_subscription_payment_failure(invoice: dict) -> None:
|
||||
)
|
||||
return
|
||||
|
||||
current_tier = user.subscriptionTier or SubscriptionTier.BASIC
|
||||
current_tier = user.subscriptionTier or SubscriptionTier.NO_TIER
|
||||
if current_tier == SubscriptionTier.ENTERPRISE:
|
||||
logger.warning(
|
||||
"handle_subscription_payment_failure: skipping ENTERPRISE user %s"
|
||||
@@ -2310,12 +2357,19 @@ async def handle_subscription_payment_failure(invoice: dict) -> None:
|
||||
}
|
||||
),
|
||||
)
|
||||
# Balance covered the invoice. Pay the Stripe invoice so Stripe's dunning
|
||||
# system stops retrying it — without this call Stripe would retry automatically
|
||||
# and re-trigger this webhook, causing double-deductions each retry cycle.
|
||||
# Balance covered the invoice. Pay the Stripe invoice with
|
||||
# ``paid_out_of_band=True`` so Stripe marks the invoice paid without
|
||||
# retrying the card charge — the card already failed and the user is
|
||||
# paying via their AutoGPT balance, so a card retry here would
|
||||
# double-bill the user (card charge + balance debit). Stripe still
|
||||
# fires ``invoice.payment_succeeded`` on the transition; the success
|
||||
# handler reads ``paid_out_of_band`` and skips the credit grant so
|
||||
# the balance debit isn't reversed.
|
||||
if invoice_id:
|
||||
try:
|
||||
await run_in_threadpool(stripe.Invoice.pay, invoice_id)
|
||||
await run_in_threadpool(
|
||||
stripe.Invoice.pay, invoice_id, paid_out_of_band=True
|
||||
)
|
||||
except stripe.StripeError:
|
||||
logger.warning(
|
||||
"handle_subscription_payment_failure: balance deducted for user"
|
||||
@@ -2355,7 +2409,95 @@ async def handle_subscription_payment_failure(invoice: dict) -> None:
|
||||
customer_id,
|
||||
)
|
||||
return
|
||||
await set_subscription_tier(user.id, SubscriptionTier.BASIC)
|
||||
await set_subscription_tier(user.id, SubscriptionTier.NO_TIER)
|
||||
|
||||
|
||||
async def handle_subscription_payment_success(invoice: dict) -> None:
|
||||
"""Grant AutoGPT credits equal to the paid Stripe invoice amount.
|
||||
|
||||
Fires on every paid subscription invoice (initial signup, monthly renewal,
|
||||
and prorated upgrade charges). Credits = ``invoice.amount_paid`` cents,
|
||||
keyed by ``invoice_id`` for idempotency so Stripe retries don't double-grant.
|
||||
|
||||
Skipped:
|
||||
- Non-subscription invoices (no ``subscription`` field).
|
||||
- Zero-amount invoices (e.g. card-validation checks, $0 trials).
|
||||
- ENTERPRISE users (admin-managed; they don't pay via self-service).
|
||||
"""
|
||||
customer_id = invoice.get("customer")
|
||||
if not customer_id:
|
||||
logger.warning(
|
||||
"handle_subscription_payment_success: missing customer in invoice; skipping"
|
||||
)
|
||||
return
|
||||
sub_id: str = invoice.get("subscription") or ""
|
||||
if not sub_id:
|
||||
# Non-subscription invoices (one-off invoices, etc.) — no credit grant.
|
||||
return
|
||||
user = await User.prisma().find_first(where={"stripeCustomerId": customer_id})
|
||||
if not user:
|
||||
logger.warning(
|
||||
"handle_subscription_payment_success: no user for customer %s",
|
||||
customer_id,
|
||||
)
|
||||
return
|
||||
if (
|
||||
user.subscriptionTier or SubscriptionTier.NO_TIER
|
||||
) == SubscriptionTier.ENTERPRISE:
|
||||
logger.warning(
|
||||
"handle_subscription_payment_success: skipping ENTERPRISE user %s"
|
||||
" (customer %s) — tier is admin-managed",
|
||||
user.id,
|
||||
customer_id,
|
||||
)
|
||||
return
|
||||
|
||||
amount_paid: int = invoice.get("amount_paid", 0)
|
||||
invoice_id: str = invoice.get("id", "")
|
||||
if amount_paid <= 0 or not invoice_id:
|
||||
return
|
||||
|
||||
# Skip when ``handle_subscription_payment_failure`` already covered this
|
||||
# invoice from the user's balance and marked it paid out of band — the
|
||||
# balance was debited there, granting matching credits here would reverse
|
||||
# the debit and give the user a free billing period.
|
||||
if invoice.get("paid_out_of_band"):
|
||||
logger.info(
|
||||
"handle_subscription_payment_success: skipping invoice %s for user %s"
|
||||
" (paid_out_of_band — covered by balance in failure handler)",
|
||||
invoice_id,
|
||||
user.id,
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
await UserCredit()._add_transaction(
|
||||
user_id=user.id,
|
||||
amount=amount_paid,
|
||||
transaction_type=CreditTransactionType.GRANT,
|
||||
transaction_key=f"INVOICE-{invoice_id}",
|
||||
metadata=SafeJson(
|
||||
{
|
||||
"stripe_customer_id": customer_id,
|
||||
"stripe_subscription_id": sub_id,
|
||||
"stripe_invoice_id": invoice_id,
|
||||
"billing_reason": invoice.get("billing_reason", ""),
|
||||
"reason": "subscription_invoice_paid",
|
||||
}
|
||||
),
|
||||
)
|
||||
logger.info(
|
||||
"handle_subscription_payment_success: granted %d credits to user %s"
|
||||
" for invoice %s (sub %s)",
|
||||
amount_paid,
|
||||
user.id,
|
||||
invoice_id,
|
||||
sub_id,
|
||||
)
|
||||
except UniqueViolationError:
|
||||
# Idempotency key collision — Stripe retried this invoice's webhook and
|
||||
# we already granted the credits. Safe to ignore.
|
||||
return
|
||||
|
||||
|
||||
async def admin_get_user_history(
|
||||
|
||||
@@ -7,6 +7,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||
import pytest
|
||||
import stripe
|
||||
from prisma.enums import SubscriptionTier
|
||||
from prisma.errors import UniqueViolationError
|
||||
from prisma.models import User
|
||||
|
||||
from backend.data.credit import (
|
||||
@@ -15,6 +16,7 @@ from backend.data.credit import (
|
||||
get_pending_subscription_change,
|
||||
get_proration_credit_cents,
|
||||
handle_subscription_payment_failure,
|
||||
handle_subscription_payment_success,
|
||||
is_tier_downgrade,
|
||||
is_tier_upgrade,
|
||||
modify_stripe_subscription_for_tier,
|
||||
@@ -174,7 +176,7 @@ async def test_sync_subscription_from_stripe_enterprise_not_overwritten():
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sync_subscription_from_stripe_cancelled():
|
||||
"""When the only active sub is cancelled, the user is downgraded to BASIC."""
|
||||
"""When the only active sub is cancelled, the user is downgraded to NO_TIER."""
|
||||
mock_user = _make_user(tier=SubscriptionTier.PRO)
|
||||
stripe_sub = {
|
||||
"id": "sub_old",
|
||||
@@ -199,7 +201,7 @@ async def test_sync_subscription_from_stripe_cancelled():
|
||||
) as mock_set,
|
||||
):
|
||||
await sync_subscription_from_stripe(stripe_sub)
|
||||
mock_set.assert_awaited_once_with("user-1", SubscriptionTier.BASIC)
|
||||
mock_set.assert_awaited_once_with("user-1", SubscriptionTier.NO_TIER)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -1284,7 +1286,10 @@ async def test_sync_subscription_from_stripe_no_metadata_user_id_skips_check():
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_subscription_payment_failure_balance_covers_pays_invoice():
|
||||
"""When balance covers the invoice, Stripe Invoice.pay is called to stop retries."""
|
||||
"""When balance covers the invoice, Stripe Invoice.pay is called with
|
||||
paid_out_of_band=True so the card isn't double-charged on top of the
|
||||
balance debit (the card already failed; retrying it would let the
|
||||
success-handler webhook reverse the debit via the credit grant)."""
|
||||
mock_user = _make_user(user_id="user-1", tier=SubscriptionTier.PRO)
|
||||
invoice = {
|
||||
"id": "in_abc123",
|
||||
@@ -1305,7 +1310,7 @@ async def test_handle_subscription_payment_failure_balance_covers_pays_invoice()
|
||||
patch("backend.data.credit.stripe.Invoice.pay") as mock_pay,
|
||||
):
|
||||
await handle_subscription_payment_failure(invoice)
|
||||
mock_pay.assert_called_once_with("in_abc123")
|
||||
mock_pay.assert_called_once_with("in_abc123", paid_out_of_band=True)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -1367,6 +1372,356 @@ async def test_handle_subscription_payment_failure_passes_invoice_id_as_transact
|
||||
assert kwargs.get("transaction_key") == "in_idempotency_test"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_subscription_payment_success_grants_credits():
|
||||
"""A paid subscription invoice grants credits = amount_paid, keyed by invoice_id."""
|
||||
mock_user = _make_user(user_id="user-1", tier=SubscriptionTier.PRO)
|
||||
invoice = {
|
||||
"id": "in_abc123",
|
||||
"customer": "cus_123",
|
||||
"subscription": "sub_abc123",
|
||||
"amount_paid": 5000,
|
||||
"billing_reason": "subscription_cycle",
|
||||
}
|
||||
|
||||
add_tx_mock = AsyncMock()
|
||||
with (
|
||||
patch(
|
||||
"backend.data.credit.User.prisma",
|
||||
return_value=MagicMock(find_first=AsyncMock(return_value=mock_user)),
|
||||
),
|
||||
patch(
|
||||
"backend.data.credit.UserCredit._add_transaction",
|
||||
new=add_tx_mock,
|
||||
),
|
||||
):
|
||||
await handle_subscription_payment_success(invoice)
|
||||
|
||||
add_tx_mock.assert_awaited_once()
|
||||
kwargs = add_tx_mock.await_args.kwargs
|
||||
assert kwargs["amount"] == 5000
|
||||
assert kwargs["transaction_key"] == "INVOICE-in_abc123"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_subscription_payment_success_skips_non_subscription_invoice():
|
||||
"""Invoices with no subscription field (one-off invoices) are no-ops."""
|
||||
invoice = {
|
||||
"id": "in_abc123",
|
||||
"customer": "cus_123",
|
||||
"amount_paid": 5000,
|
||||
# No 'subscription' field
|
||||
}
|
||||
|
||||
prisma_mock = MagicMock()
|
||||
with patch("backend.data.credit.User.prisma", return_value=prisma_mock):
|
||||
await handle_subscription_payment_success(invoice)
|
||||
prisma_mock.find_first.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_subscription_payment_success_skips_paid_out_of_band():
|
||||
"""When the failure handler covered the invoice from the user's balance and
|
||||
marked it ``paid_out_of_band=True``, the success-handler webhook that
|
||||
follows must NOT grant credits — doing so would reverse the balance debit
|
||||
and effectively give the user a free billing period."""
|
||||
mock_user = _make_user(user_id="user-1", tier=SubscriptionTier.PRO)
|
||||
invoice = {
|
||||
"id": "in_oob_123",
|
||||
"customer": "cus_123",
|
||||
"subscription": "sub_abc123",
|
||||
"amount_paid": 5000,
|
||||
"billing_reason": "subscription_cycle",
|
||||
"paid_out_of_band": True,
|
||||
}
|
||||
|
||||
add_tx_mock = AsyncMock()
|
||||
with (
|
||||
patch(
|
||||
"backend.data.credit.User.prisma",
|
||||
return_value=MagicMock(find_first=AsyncMock(return_value=mock_user)),
|
||||
),
|
||||
patch(
|
||||
"backend.data.credit.UserCredit._add_transaction",
|
||||
new=add_tx_mock,
|
||||
),
|
||||
):
|
||||
await handle_subscription_payment_success(invoice)
|
||||
add_tx_mock.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_subscription_payment_success_skips_zero_amount():
|
||||
"""Zero-amount invoices (card validation, $0 trials) are no-ops."""
|
||||
mock_user = _make_user(user_id="user-1", tier=SubscriptionTier.PRO)
|
||||
invoice = {
|
||||
"id": "in_abc123",
|
||||
"customer": "cus_123",
|
||||
"subscription": "sub_abc123",
|
||||
"amount_paid": 0,
|
||||
}
|
||||
|
||||
add_tx_mock = AsyncMock()
|
||||
with (
|
||||
patch(
|
||||
"backend.data.credit.User.prisma",
|
||||
return_value=MagicMock(find_first=AsyncMock(return_value=mock_user)),
|
||||
),
|
||||
patch(
|
||||
"backend.data.credit.UserCredit._add_transaction",
|
||||
new=add_tx_mock,
|
||||
),
|
||||
):
|
||||
await handle_subscription_payment_success(invoice)
|
||||
add_tx_mock.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_subscription_payment_success_skips_missing_customer():
|
||||
"""Invoices missing the customer field are dropped with a warning."""
|
||||
invoice = {
|
||||
"id": "in_abc",
|
||||
"subscription": "sub_abc",
|
||||
"amount_paid": 1000,
|
||||
}
|
||||
prisma_mock = MagicMock()
|
||||
with patch("backend.data.credit.User.prisma", return_value=prisma_mock):
|
||||
await handle_subscription_payment_success(invoice)
|
||||
prisma_mock.find_first.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_subscription_payment_success_skips_unknown_user():
|
||||
"""Invoices for an unknown stripeCustomerId are dropped with a warning."""
|
||||
invoice = {
|
||||
"id": "in_abc",
|
||||
"customer": "cus_unknown",
|
||||
"subscription": "sub_abc",
|
||||
"amount_paid": 1000,
|
||||
}
|
||||
add_tx_mock = AsyncMock()
|
||||
with (
|
||||
patch(
|
||||
"backend.data.credit.User.prisma",
|
||||
return_value=MagicMock(find_first=AsyncMock(return_value=None)),
|
||||
),
|
||||
patch(
|
||||
"backend.data.credit.UserCredit._add_transaction",
|
||||
new=add_tx_mock,
|
||||
),
|
||||
):
|
||||
await handle_subscription_payment_success(invoice)
|
||||
add_tx_mock.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_subscription_payment_success_skips_enterprise():
|
||||
"""ENTERPRISE users don't get credit grants from Stripe invoices."""
|
||||
mock_user = _make_user(user_id="user-1", tier=SubscriptionTier.ENTERPRISE)
|
||||
invoice = {
|
||||
"id": "in_abc",
|
||||
"customer": "cus_123",
|
||||
"subscription": "sub_abc",
|
||||
"amount_paid": 5000,
|
||||
}
|
||||
add_tx_mock = AsyncMock()
|
||||
with (
|
||||
patch(
|
||||
"backend.data.credit.User.prisma",
|
||||
return_value=MagicMock(find_first=AsyncMock(return_value=mock_user)),
|
||||
),
|
||||
patch(
|
||||
"backend.data.credit.UserCredit._add_transaction",
|
||||
new=add_tx_mock,
|
||||
),
|
||||
):
|
||||
await handle_subscription_payment_success(invoice)
|
||||
add_tx_mock.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_subscription_payment_success_idempotent_on_unique_violation():
|
||||
"""If the GRANT transaction key already exists (Stripe webhook retry),
|
||||
UniqueViolationError is swallowed so the webhook returns 200 and Stripe
|
||||
stops retrying."""
|
||||
mock_user = _make_user(user_id="user-1", tier=SubscriptionTier.PRO)
|
||||
invoice = {
|
||||
"id": "in_abc",
|
||||
"customer": "cus_123",
|
||||
"subscription": "sub_abc",
|
||||
"amount_paid": 5000,
|
||||
}
|
||||
add_tx_mock = AsyncMock(side_effect=UniqueViolationError({"error": "dup"}))
|
||||
with (
|
||||
patch(
|
||||
"backend.data.credit.User.prisma",
|
||||
return_value=MagicMock(find_first=AsyncMock(return_value=mock_user)),
|
||||
),
|
||||
patch(
|
||||
"backend.data.credit.UserCredit._add_transaction",
|
||||
new=add_tx_mock,
|
||||
),
|
||||
):
|
||||
await handle_subscription_payment_success(invoice)
|
||||
add_tx_mock.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_active_subscription_period_end_returns_unix_timestamp():
|
||||
"""Happy path: returns int(current_period_end) for an active sub."""
|
||||
mock_sub = stripe.Subscription.construct_from(
|
||||
{"id": "sub_abc", "current_period_end": 1779340148}, "k"
|
||||
)
|
||||
mock_list = MagicMock()
|
||||
mock_list.data = [mock_sub]
|
||||
user = MagicMock(spec=User)
|
||||
user.stripe_customer_id = "cus_abc"
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.data.credit.get_user_by_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=user,
|
||||
),
|
||||
patch(
|
||||
"backend.data.credit.stripe.Subscription.list_async",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_list,
|
||||
),
|
||||
):
|
||||
from backend.data.credit import get_active_subscription_period_end
|
||||
|
||||
result = await get_active_subscription_period_end("user-1")
|
||||
assert result == 1779340148
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_active_subscription_period_end_returns_none_without_customer():
|
||||
"""Users without a Stripe customer ID return None — no Stripe API call."""
|
||||
user = MagicMock(spec=User)
|
||||
user.stripe_customer_id = None
|
||||
list_mock = AsyncMock()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.data.credit.get_user_by_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=user,
|
||||
),
|
||||
patch(
|
||||
"backend.data.credit.stripe.Subscription.list_async",
|
||||
new=list_mock,
|
||||
),
|
||||
):
|
||||
from backend.data.credit import get_active_subscription_period_end
|
||||
|
||||
result = await get_active_subscription_period_end("user-1")
|
||||
assert result is None
|
||||
list_mock.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_active_subscription_period_end_swallows_stripe_errors():
|
||||
"""A Stripe error during the lookup returns None instead of raising."""
|
||||
user = MagicMock(spec=User)
|
||||
user.stripe_customer_id = "cus_abc"
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.data.credit.get_user_by_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=user,
|
||||
),
|
||||
patch(
|
||||
"backend.data.credit.stripe.Subscription.list_async",
|
||||
side_effect=stripe.StripeError("boom"),
|
||||
),
|
||||
):
|
||||
from backend.data.credit import get_active_subscription_period_end
|
||||
|
||||
result = await get_active_subscription_period_end("user-1")
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_top_up_intent_uses_inline_product_data_when_flag_unset():
|
||||
"""When STRIPE_PRODUCT_ID_TOPUP flag is undefined (default), top-up Checkout
|
||||
creates an ephemeral product per session via product_data."""
|
||||
from backend.data.credit import UserCredit
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session.id = "cs_test_topup"
|
||||
mock_session.url = "https://checkout.stripe.com/c/cs_test_topup"
|
||||
create_mock = MagicMock(return_value=mock_session)
|
||||
credit_system = UserCredit()
|
||||
with (
|
||||
patch(
|
||||
"backend.data.credit.get_stripe_customer_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value="cus_123",
|
||||
),
|
||||
patch(
|
||||
"backend.data.credit.get_feature_flag_value",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
),
|
||||
patch(
|
||||
"backend.data.credit.stripe.checkout.Session.create",
|
||||
new=create_mock,
|
||||
),
|
||||
patch.object(credit_system, "_add_transaction", new_callable=AsyncMock),
|
||||
):
|
||||
await credit_system.top_up_intent(user_id="user-1", amount=500)
|
||||
|
||||
price_data = create_mock.call_args.kwargs["line_items"][0]["price_data"]
|
||||
assert price_data == {
|
||||
"currency": "usd",
|
||||
"unit_amount": 500,
|
||||
"product_data": {"name": "AutoGPT Platform Credits"},
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_top_up_intent_references_product_id_when_flag_set():
|
||||
"""When STRIPE_PRODUCT_ID_TOPUP flag returns a string, top-up Checkout
|
||||
references the canonical Product ID and keeps the per-session amount via
|
||||
unit_amount."""
|
||||
from backend.data.credit import UserCredit
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session.id = "cs_test_topup"
|
||||
mock_session.url = "https://checkout.stripe.com/c/cs_test_topup"
|
||||
create_mock = MagicMock(return_value=mock_session)
|
||||
credit_system = UserCredit()
|
||||
with (
|
||||
patch(
|
||||
"backend.data.credit.get_stripe_customer_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value="cus_123",
|
||||
),
|
||||
patch(
|
||||
"backend.data.credit.get_feature_flag_value",
|
||||
new_callable=AsyncMock,
|
||||
return_value="prod_abc123",
|
||||
),
|
||||
patch(
|
||||
"backend.data.credit.stripe.checkout.Session.create",
|
||||
new=create_mock,
|
||||
),
|
||||
patch.object(credit_system, "_add_transaction", new_callable=AsyncMock),
|
||||
):
|
||||
await credit_system.top_up_intent(user_id="user-1", amount=2500)
|
||||
|
||||
price_data = create_mock.call_args.kwargs["line_items"][0]["price_data"]
|
||||
assert price_data == {
|
||||
"currency": "usd",
|
||||
"unit_amount": 2500,
|
||||
"product": "prod_abc123",
|
||||
}
|
||||
# No product_data — that path is mutually exclusive with product reference.
|
||||
assert "product_data" not in price_data
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_modify_stripe_subscription_for_tier_modifies_existing_sub():
|
||||
"""modify_stripe_subscription_for_tier calls Subscription.modify and returns True."""
|
||||
@@ -1845,7 +2200,7 @@ async def test_release_pending_subscription_schedule_no_stripe_customer_returns_
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_pending_subscription_change_cancel_at_period_end():
|
||||
"""cancel_at_period_end=True maps to pending BASIC at current_period_end."""
|
||||
"""cancel_at_period_end=True maps to pending NO_TIER at current_period_end."""
|
||||
import time as time_mod
|
||||
|
||||
get_pending_subscription_change.cache_clear() # type: ignore[attr-defined]
|
||||
@@ -1894,7 +2249,7 @@ async def test_get_pending_subscription_change_cancel_at_period_end():
|
||||
|
||||
assert result is not None
|
||||
pending_tier, effective_at = result
|
||||
assert pending_tier == SubscriptionTier.BASIC
|
||||
assert pending_tier == SubscriptionTier.NO_TIER
|
||||
assert int(effective_at.timestamp()) == period_end
|
||||
|
||||
|
||||
|
||||
@@ -98,6 +98,12 @@ from backend.data.notifications import (
|
||||
)
|
||||
from backend.data.onboarding import increment_onboarding_runs
|
||||
from backend.data.platform_cost import log_platform_cost
|
||||
from backend.data.push_subscription import (
|
||||
cleanup_failed_subscriptions,
|
||||
delete_push_subscription,
|
||||
get_user_push_subscriptions,
|
||||
increment_fail_count,
|
||||
)
|
||||
from backend.data.understanding import (
|
||||
get_business_understanding,
|
||||
upsert_business_understanding,
|
||||
@@ -339,6 +345,16 @@ class DatabaseManager(AppService):
|
||||
# ============ Platform Cost Tracking ============ #
|
||||
log_platform_cost = _(log_platform_cost)
|
||||
|
||||
# ============ Push Notifications ============ #
|
||||
get_user_push_subscriptions = _(get_user_push_subscriptions)
|
||||
delete_push_subscription = _(delete_push_subscription)
|
||||
increment_push_fail_count = _(
|
||||
increment_fail_count, name="increment_push_fail_count"
|
||||
)
|
||||
cleanup_failed_push_subscriptions = _(
|
||||
cleanup_failed_subscriptions, name="cleanup_failed_push_subscriptions"
|
||||
)
|
||||
|
||||
# ============ Platform Linking ============ #
|
||||
find_server_link_owner = _(platform_linking_db.find_server_link_owner)
|
||||
find_user_link_owner = _(platform_linking_db.find_user_link_owner)
|
||||
@@ -557,6 +573,12 @@ class DatabaseManagerAsyncClient(AppServiceClient):
|
||||
# ============ Platform Cost Tracking ============ #
|
||||
log_platform_cost = d.log_platform_cost
|
||||
|
||||
# ============ Push Notifications ============ #
|
||||
get_user_push_subscriptions = d.get_user_push_subscriptions
|
||||
delete_push_subscription = d.delete_push_subscription
|
||||
increment_push_fail_count = d.increment_push_fail_count
|
||||
cleanup_failed_push_subscriptions = d.cleanup_failed_push_subscriptions
|
||||
|
||||
# ============ Platform Linking ============ #
|
||||
find_server_link_owner = d.find_server_link_owner
|
||||
find_user_link_owner = d.find_user_link_owner
|
||||
|
||||
498
autogpt_platform/backend/backend/data/e2e_redis_rabbit_test.py
Normal file
498
autogpt_platform/backend/backend/data/e2e_redis_rabbit_test.py
Normal file
@@ -0,0 +1,498 @@
|
||||
"""End-to-end coverage of the data-layer APIs over the live 3-shard Redis
|
||||
cluster + RabbitMQ broker. Tests skip when their infra is unreachable.
|
||||
Container-restart scenarios live in `e2e_redis_restart_test.py`."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
import backend.data.redis_client as redis_client
|
||||
from backend.api.model import NotificationPayload
|
||||
from backend.data.execution import (
|
||||
AsyncRedisExecutionEventBus,
|
||||
ExecutionStatus,
|
||||
NodeExecutionEvent,
|
||||
exec_channel,
|
||||
graph_all_channel,
|
||||
)
|
||||
from backend.data.notification_bus import (
|
||||
AsyncRedisNotificationEventBus,
|
||||
NotificationEvent,
|
||||
)
|
||||
from backend.data.rabbitmq import AsyncRabbitMQ
|
||||
from backend.executor.utils import (
|
||||
GRAPH_EXECUTION_EXCHANGE,
|
||||
GRAPH_EXECUTION_QUEUE_NAME,
|
||||
create_execution_queue_config,
|
||||
)
|
||||
|
||||
|
||||
def _has_live_cluster() -> bool:
|
||||
try:
|
||||
c = redis_client.connect()
|
||||
except Exception: # noqa: BLE001 — any connect failure → skip
|
||||
return False
|
||||
try:
|
||||
c.close()
|
||||
except Exception:
|
||||
pass
|
||||
return True
|
||||
|
||||
|
||||
def _has_live_rabbit() -> bool:
|
||||
"""Probe the rabbitmq host:port from settings; skip if unreachable."""
|
||||
import socket
|
||||
|
||||
from backend.util.settings import Settings
|
||||
|
||||
s = Settings()
|
||||
try:
|
||||
with socket.create_connection(
|
||||
(s.config.rabbitmq_host, s.config.rabbitmq_port), timeout=1.0
|
||||
):
|
||||
return True
|
||||
except Exception: # noqa: BLE001 - any connect failure → skip the test
|
||||
return False
|
||||
|
||||
|
||||
cluster_only = pytest.mark.skipif(
|
||||
not _has_live_cluster(),
|
||||
reason="local redis cluster not reachable; skip e2e integration",
|
||||
)
|
||||
rabbit_only = pytest.mark.skipif(
|
||||
not _has_live_rabbit(),
|
||||
reason="local rabbitmq not reachable; skip e2e integration",
|
||||
)
|
||||
|
||||
|
||||
def _make_node_event(*, user_id: str, graph_id: str, gex_id: str, marker: str):
|
||||
return NodeExecutionEvent(
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
graph_version=1,
|
||||
graph_exec_id=gex_id,
|
||||
node_exec_id=f"node-exec-{marker}",
|
||||
node_id="node-1",
|
||||
block_id="block-1",
|
||||
status=ExecutionStatus.COMPLETED,
|
||||
input_data={"in": marker},
|
||||
output_data={"out": [marker]},
|
||||
add_time=datetime.now(tz=timezone.utc),
|
||||
queue_time=None,
|
||||
start_time=datetime.now(tz=timezone.utc),
|
||||
end_time=datetime.now(tz=timezone.utc),
|
||||
)
|
||||
|
||||
|
||||
# ---------- Scenario 1: cluster cache round-trip across slots ----------
|
||||
|
||||
|
||||
@cluster_only
|
||||
def test_cluster_cache_roundtrip_across_three_slots() -> None:
|
||||
"""A list-graphs-style cache flow: SET keys with hash tags that land on
|
||||
different shards, GET them back. Validates the basic cluster-routing
|
||||
contract end-to-end."""
|
||||
redis_client.get_redis.cache_clear()
|
||||
cluster = redis_client.get_redis()
|
||||
keys = []
|
||||
try:
|
||||
# Pick keys that hash to different slots — try until 3 distinct shards.
|
||||
seen: set[tuple[str, int]] = set()
|
||||
for i in range(2000):
|
||||
key = f"e2e:cache:{i}"
|
||||
node = cluster.get_node_from_key(key)
|
||||
owner = (node.host, node.port)
|
||||
if owner in seen:
|
||||
continue
|
||||
seen.add(owner)
|
||||
keys.append(key)
|
||||
if len(seen) >= 3:
|
||||
break
|
||||
assert len(keys) >= 3
|
||||
|
||||
for i, k in enumerate(keys):
|
||||
cluster.setex(k, 60, f"v-{i}")
|
||||
for i, k in enumerate(keys):
|
||||
assert cluster.get(k) == f"v-{i}"
|
||||
finally:
|
||||
for k in keys:
|
||||
try:
|
||||
cluster.delete(k)
|
||||
except Exception:
|
||||
pass
|
||||
redis_client.disconnect()
|
||||
|
||||
|
||||
# ---------- Scenarios 2 & 3: graph execution event streams ----------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@cluster_only
|
||||
async def test_graph_execution_events_complete_under_ten_seconds() -> None:
|
||||
"""A listener subscribes to the per-exec channel; the producer publishes
|
||||
one node event. The listener must observe it in under 10 seconds —
|
||||
pins the latency contract end-to-end through SPUBLISH/SSUBSCRIBE."""
|
||||
redis_client._async_clients.clear()
|
||||
user_id = f"u-e2e-{uuid4().hex[:8]}"
|
||||
graph_id = f"g-{uuid4().hex[:8]}"
|
||||
gex_id = f"x-{uuid4().hex[:8]}"
|
||||
|
||||
publisher = AsyncRedisExecutionEventBus()
|
||||
subscriber = AsyncRedisExecutionEventBus()
|
||||
received: list[str] = []
|
||||
|
||||
async def _consume() -> None:
|
||||
async for evt in subscriber.listen_events(
|
||||
exec_channel(user_id, graph_id, gex_id)
|
||||
):
|
||||
received.append(getattr(evt, "node_exec_id", "graph"))
|
||||
return
|
||||
|
||||
task = asyncio.create_task(_consume())
|
||||
await asyncio.sleep(0.3)
|
||||
|
||||
start = time.monotonic()
|
||||
try:
|
||||
await publisher.publish_event(
|
||||
_make_node_event(
|
||||
user_id=user_id, graph_id=graph_id, gex_id=gex_id, marker="m1"
|
||||
),
|
||||
exec_channel(user_id, graph_id, gex_id),
|
||||
)
|
||||
await asyncio.wait_for(task, timeout=10.0)
|
||||
finally:
|
||||
if not task.done():
|
||||
task.cancel()
|
||||
await subscriber.close()
|
||||
await redis_client.disconnect_async()
|
||||
|
||||
elapsed = time.monotonic() - start
|
||||
assert elapsed < 10.0, f"event roundtrip took {elapsed:.2f}s, expected < 10s"
|
||||
assert received == ["node-exec-m1"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@cluster_only
|
||||
async def test_two_concurrent_graphs_no_cross_talk() -> None:
|
||||
"""Two graphs execute in parallel; two listeners on different per-exec
|
||||
channels each receive only their own events."""
|
||||
redis_client._async_clients.clear()
|
||||
user_id = f"u-e2e-{uuid4().hex[:8]}"
|
||||
g1, g2 = f"g1-{uuid4().hex[:8]}", f"g2-{uuid4().hex[:8]}"
|
||||
e1, e2 = f"e1-{uuid4().hex[:8]}", f"e2-{uuid4().hex[:8]}"
|
||||
|
||||
publisher = AsyncRedisExecutionEventBus()
|
||||
sub_a = AsyncRedisExecutionEventBus()
|
||||
sub_b = AsyncRedisExecutionEventBus()
|
||||
|
||||
async def _listen_one(bus, channel_key: str, sink: list, want: int) -> None:
|
||||
async for evt in bus.listen_events(channel_key):
|
||||
sink.append(getattr(evt, "node_exec_id", "graph"))
|
||||
if len(sink) >= want:
|
||||
return
|
||||
|
||||
sink_a: list[str] = []
|
||||
sink_b: list[str] = []
|
||||
t_a = asyncio.create_task(
|
||||
_listen_one(sub_a, exec_channel(user_id, g1, e1), sink_a, want=3)
|
||||
)
|
||||
t_b = asyncio.create_task(
|
||||
_listen_one(sub_b, exec_channel(user_id, g2, e2), sink_b, want=3)
|
||||
)
|
||||
await asyncio.sleep(0.3)
|
||||
|
||||
try:
|
||||
for i in range(3):
|
||||
await publisher.publish_event(
|
||||
_make_node_event(
|
||||
user_id=user_id, graph_id=g1, gex_id=e1, marker=f"a{i}"
|
||||
),
|
||||
exec_channel(user_id, g1, e1),
|
||||
)
|
||||
await publisher.publish_event(
|
||||
_make_node_event(
|
||||
user_id=user_id, graph_id=g2, gex_id=e2, marker=f"b{i}"
|
||||
),
|
||||
exec_channel(user_id, g2, e2),
|
||||
)
|
||||
|
||||
await asyncio.wait_for(asyncio.gather(t_a, t_b), timeout=10.0)
|
||||
assert sink_a == ["node-exec-a0", "node-exec-a1", "node-exec-a2"]
|
||||
assert sink_b == ["node-exec-b0", "node-exec-b1", "node-exec-b2"]
|
||||
finally:
|
||||
await sub_a.close()
|
||||
await sub_b.close()
|
||||
await redis_client.disconnect_async()
|
||||
|
||||
|
||||
# ---------- Scenario 4: aggregate /all channel for graph executions ----------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@cluster_only
|
||||
async def test_three_executions_land_on_aggregate_channel() -> None:
|
||||
"""Subscribe to the aggregate ``/all`` channel; trigger 3 different
|
||||
executions of the same graph; assert all 3 land on the aggregate."""
|
||||
redis_client._async_clients.clear()
|
||||
user_id = f"u-e2e-{uuid4().hex[:8]}"
|
||||
graph_id = f"g-{uuid4().hex[:8]}"
|
||||
exec_ids = [f"x{i}-{uuid4().hex[:6]}" for i in range(3)]
|
||||
|
||||
publisher = AsyncRedisExecutionEventBus()
|
||||
subscriber = AsyncRedisExecutionEventBus()
|
||||
received: list[str] = []
|
||||
|
||||
async def _listen_all() -> None:
|
||||
async for evt in subscriber.listen_events(graph_all_channel(user_id, graph_id)):
|
||||
received.append(getattr(evt, "graph_exec_id", "?"))
|
||||
if len(received) >= 3:
|
||||
return
|
||||
|
||||
task = asyncio.create_task(_listen_all())
|
||||
await asyncio.sleep(0.3)
|
||||
|
||||
try:
|
||||
for ex in exec_ids:
|
||||
await publisher.publish_event(
|
||||
_make_node_event(
|
||||
user_id=user_id, graph_id=graph_id, gex_id=ex, marker=ex
|
||||
),
|
||||
graph_all_channel(user_id, graph_id),
|
||||
)
|
||||
|
||||
await asyncio.wait_for(task, timeout=10.0)
|
||||
# Order of receipt may vary slightly under load — check set membership.
|
||||
assert set(received) == set(exec_ids)
|
||||
finally:
|
||||
await subscriber.close()
|
||||
await redis_client.disconnect_async()
|
||||
|
||||
|
||||
# ---------- Scenarios 5 & 6: copilot/notification per-user channels ----------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@cluster_only
|
||||
async def test_copilot_cancel_signal_via_sharded_pubsub() -> None:
|
||||
"""A subscriber on a per-session channel receives an SPUBLISH cancel
|
||||
signal — the primitive the copilot executor uses for graceful cancel."""
|
||||
redis_client._async_clients.clear()
|
||||
session_id = f"sess-{uuid4().hex[:8]}"
|
||||
channel = "{copilot/" + session_id + "}/cancel"
|
||||
|
||||
client = await redis_client.connect_sharded_pubsub_async(channel)
|
||||
pubsub = client.pubsub()
|
||||
received: list[str] = []
|
||||
try:
|
||||
await pubsub.execute_command("SSUBSCRIBE", channel)
|
||||
# Prime the channels map so listen() doesn't early-exit (see _Subscription).
|
||||
pubsub.channels[channel] = None # type: ignore[index]
|
||||
|
||||
async def _pump() -> None:
|
||||
async for msg in pubsub.listen():
|
||||
if msg.get("type") == "smessage":
|
||||
received.append(msg["data"])
|
||||
return
|
||||
|
||||
listener = asyncio.create_task(_pump())
|
||||
await asyncio.sleep(0.2)
|
||||
|
||||
cluster = await redis_client.get_redis_async()
|
||||
await cluster.execute_command("SPUBLISH", channel, "cancel")
|
||||
|
||||
await asyncio.wait_for(listener, timeout=5.0)
|
||||
assert received == ["cancel"]
|
||||
finally:
|
||||
try:
|
||||
await pubsub.execute_command("SUNSUBSCRIBE", channel)
|
||||
except Exception:
|
||||
pass
|
||||
await pubsub.aclose()
|
||||
await client.aclose()
|
||||
await redis_client.disconnect_async()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@cluster_only
|
||||
async def test_notification_fan_out_per_user_channel() -> None:
|
||||
"""Per-user SSUBSCRIBE: a publish on the user's notification channel
|
||||
reaches the user's listener and only that listener."""
|
||||
redis_client._async_clients.clear()
|
||||
user_id = f"u-notif-{uuid4().hex[:8]}"
|
||||
other_user_id = f"u-other-{uuid4().hex[:8]}"
|
||||
|
||||
publisher = AsyncRedisNotificationEventBus()
|
||||
listener_user = AsyncRedisNotificationEventBus()
|
||||
listener_other = AsyncRedisNotificationEventBus()
|
||||
|
||||
user_received: list[str] = []
|
||||
other_received: list[str] = []
|
||||
notif_for_user = NotificationEvent(
|
||||
user_id=user_id,
|
||||
payload=NotificationPayload(type="info", event="balance-low"),
|
||||
)
|
||||
notif_for_other = NotificationEvent(
|
||||
user_id=other_user_id,
|
||||
payload=NotificationPayload(type="info", event="other"),
|
||||
)
|
||||
|
||||
async def _listen_one(bus: AsyncRedisNotificationEventBus, uid: str, sink: list):
|
||||
async for evt in bus.listen(uid):
|
||||
sink.append(evt.user_id)
|
||||
return
|
||||
|
||||
t_user = asyncio.create_task(_listen_one(listener_user, user_id, user_received))
|
||||
t_other = asyncio.create_task(
|
||||
_listen_one(listener_other, other_user_id, other_received)
|
||||
)
|
||||
await asyncio.sleep(0.3)
|
||||
|
||||
try:
|
||||
await publisher.publish(notif_for_user)
|
||||
await publisher.publish(notif_for_other)
|
||||
await asyncio.wait_for(asyncio.gather(t_user, t_other), timeout=10.0)
|
||||
assert user_received == [user_id]
|
||||
assert other_received == [other_user_id]
|
||||
finally:
|
||||
await listener_user.close()
|
||||
await listener_other.close()
|
||||
await publisher.close()
|
||||
await redis_client.disconnect_async()
|
||||
|
||||
|
||||
# ---------- Scenario 7: idle WS connection 60s ----------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@cluster_only
|
||||
async def test_idle_subscriber_60s_then_receives_publish() -> None:
|
||||
"""An SSUBSCRIBE that sits idle past one health-check interval must
|
||||
still deliver a subsequent SPUBLISH (uses HEALTH_CHECK_INTERVAL+5s)."""
|
||||
redis_client._async_clients.clear()
|
||||
channel = "{idle-e2e}/exec/" + uuid4().hex[:8]
|
||||
client = await redis_client.connect_sharded_pubsub_async(channel)
|
||||
pubsub = client.pubsub()
|
||||
try:
|
||||
await pubsub.execute_command("SSUBSCRIBE", channel)
|
||||
pubsub.channels[channel] = None # type: ignore[index]
|
||||
# Drain ssubscribe confirm.
|
||||
async for _msg in pubsub.listen():
|
||||
break
|
||||
|
||||
idle_seconds = redis_client.HEALTH_CHECK_INTERVAL + 5
|
||||
await asyncio.sleep(idle_seconds)
|
||||
|
||||
cluster = await redis_client.get_redis_async()
|
||||
await cluster.execute_command("SPUBLISH", channel, "hello-after-idle")
|
||||
|
||||
async for msg in pubsub.listen():
|
||||
if msg.get("type") == "smessage":
|
||||
assert msg["data"] == "hello-after-idle"
|
||||
return
|
||||
finally:
|
||||
try:
|
||||
await pubsub.execute_command("SUNSUBSCRIBE", channel)
|
||||
except Exception:
|
||||
pass
|
||||
await pubsub.aclose()
|
||||
await client.aclose()
|
||||
await redis_client.disconnect_async()
|
||||
|
||||
|
||||
# ---------- Scenario 8: graph_execution_queue_v2 publish + consume ----------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@rabbit_only
|
||||
async def test_graph_execution_queue_publish_and_consume() -> None:
|
||||
"""End-to-end on a test-scoped quorum queue: publish via AsyncRabbitMQ
|
||||
→ consume → payload round-trips intact. Uses a unique routing key so
|
||||
the live executor consumer (if any) doesn't race for the message."""
|
||||
from backend.data.rabbitmq import Exchange, ExchangeType, Queue, RabbitMQConfig
|
||||
|
||||
test_queue_name = f"e2e_test_{uuid4().hex[:8]}_v2"
|
||||
test_routing_key = f"e2e.test.{uuid4().hex[:8]}"
|
||||
test_exchange = Exchange(
|
||||
name=GRAPH_EXECUTION_EXCHANGE.name,
|
||||
type=ExchangeType.DIRECT,
|
||||
durable=True,
|
||||
)
|
||||
test_queue = Queue(
|
||||
name=test_queue_name,
|
||||
durable=True,
|
||||
# Quorum queues reject auto_delete; we delete the queue explicitly
|
||||
# in the finally block instead.
|
||||
auto_delete=False,
|
||||
exchange=test_exchange,
|
||||
routing_key=test_routing_key,
|
||||
arguments={"x-queue-type": "quorum"},
|
||||
)
|
||||
cfg = RabbitMQConfig(vhost="/", exchanges=[test_exchange], queues=[test_queue])
|
||||
|
||||
publisher = AsyncRabbitMQ(cfg)
|
||||
await publisher.connect()
|
||||
consumer = AsyncRabbitMQ(cfg)
|
||||
await consumer.connect()
|
||||
|
||||
payload = json.dumps(
|
||||
{"graph_exec_id": f"e2e-{uuid4().hex[:8]}", "marker": "round-trip"}
|
||||
)
|
||||
|
||||
try:
|
||||
channel = await consumer.get_channel()
|
||||
queue_obj = await channel.get_queue(test_queue_name)
|
||||
|
||||
await publisher.publish_message(
|
||||
routing_key=test_routing_key,
|
||||
message=payload,
|
||||
exchange=test_exchange,
|
||||
)
|
||||
|
||||
# Poll get() — quorum queue must surface the publish within 5s.
|
||||
deadline = time.monotonic() + 5.0
|
||||
msg = None
|
||||
while time.monotonic() < deadline:
|
||||
msg = await queue_obj.get(no_ack=True, fail=False)
|
||||
if msg is not None:
|
||||
break
|
||||
await asyncio.sleep(0.05)
|
||||
assert msg is not None, "publish never reached the quorum queue"
|
||||
assert msg.body.decode() == payload
|
||||
finally:
|
||||
# Best-effort delete in case auto_delete didn't trigger.
|
||||
try:
|
||||
channel = await consumer.get_channel()
|
||||
await channel.queue_delete(test_queue_name, if_unused=False, if_empty=False)
|
||||
except Exception:
|
||||
pass
|
||||
await publisher.disconnect()
|
||||
await consumer.disconnect()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@rabbit_only
|
||||
async def test_graph_execution_queue_uses_quorum_via_real_broker() -> None:
|
||||
"""Live-broker check that `graph_execution_queue_v2` is declared as
|
||||
quorum — passive re-declare with `x-queue-type=quorum` must not raise."""
|
||||
cfg = create_execution_queue_config()
|
||||
client = AsyncRabbitMQ(cfg)
|
||||
await client.connect() # declares everything in cfg
|
||||
try:
|
||||
channel = await client.get_channel()
|
||||
# Re-declare passively — must NOT raise PRECONDITION_FAILED if the
|
||||
# type matches, would raise if quorum was lost.
|
||||
q = await channel.declare_queue(
|
||||
name=GRAPH_EXECUTION_QUEUE_NAME,
|
||||
durable=True,
|
||||
arguments={"x-queue-type": "quorum"},
|
||||
passive=True,
|
||||
)
|
||||
assert q.name == GRAPH_EXECUTION_QUEUE_NAME
|
||||
finally:
|
||||
await client.disconnect()
|
||||
313
autogpt_platform/backend/backend/data/e2e_redis_restart_test.py
Normal file
313
autogpt_platform/backend/backend/data/e2e_redis_restart_test.py
Normal file
@@ -0,0 +1,313 @@
|
||||
"""Sharded pubsub reconnect across a real `docker restart` of a shard,
|
||||
against a private 3-shard cluster on isolated host ports. Gated on
|
||||
`E2E_REDIS_CLUSTER_RESTART=1` + `docker` on PATH, marked `pytest.mark.slow`."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import importlib
|
||||
import os
|
||||
import shutil
|
||||
import socket
|
||||
import subprocess
|
||||
import time
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
# Disjoint from the dev-compose ports (17000-17002) so both stacks coexist.
|
||||
ISOLATED_PROJECT = "redis-restart-test"
|
||||
ISOLATED_PORTS = (27110, 27111, 27112)
|
||||
ISOLATED_BUS_PORTS = (37110, 37111, 37112)
|
||||
|
||||
|
||||
def _docker_available() -> bool:
|
||||
return shutil.which("docker") is not None
|
||||
|
||||
|
||||
def _isolated_enabled() -> bool:
|
||||
return os.getenv("E2E_REDIS_CLUSTER_RESTART", "").lower() in ("1", "true", "yes")
|
||||
|
||||
|
||||
cluster_restart_only = pytest.mark.skipif(
|
||||
not (_docker_available() and _isolated_enabled()),
|
||||
reason=(
|
||||
"isolated docker cluster restart e2e: requires docker + E2E_REDIS_CLUSTER_RESTART=1"
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _run(cmd: list[str], *, timeout: float = 60.0) -> subprocess.CompletedProcess[str]:
|
||||
return subprocess.run(
|
||||
cmd,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=timeout,
|
||||
check=False,
|
||||
)
|
||||
|
||||
|
||||
def _wait_port(port: int, *, deadline_s: float = 60.0) -> None:
|
||||
deadline = time.monotonic() + deadline_s
|
||||
while time.monotonic() < deadline:
|
||||
try:
|
||||
with socket.create_connection(("127.0.0.1", port), timeout=1.0):
|
||||
return
|
||||
except OSError:
|
||||
time.sleep(0.5)
|
||||
raise TimeoutError(f"port {port} never opened within {deadline_s:.0f}s")
|
||||
|
||||
|
||||
def _start_isolated_cluster() -> None:
|
||||
"""Spin up a private 3-shard cluster via raw `docker run` + one-shot
|
||||
`redis-cli --cluster create`."""
|
||||
network = f"{ISOLATED_PROJECT}-net"
|
||||
_run(["docker", "network", "create", network]) # may exist; ignore exit
|
||||
for i, (port, bus) in enumerate(zip(ISOLATED_PORTS, ISOLATED_BUS_PORTS)):
|
||||
name = f"{ISOLATED_PROJECT}-redis-{i}"
|
||||
_run(["docker", "rm", "-f", name])
|
||||
rc = _run(
|
||||
[
|
||||
"docker",
|
||||
"run",
|
||||
"-d",
|
||||
"--name",
|
||||
name,
|
||||
"--network",
|
||||
network,
|
||||
"--network-alias",
|
||||
f"redis-{i}",
|
||||
"-p",
|
||||
f"{port}:{port}",
|
||||
"redis:7",
|
||||
"redis-server",
|
||||
"--port",
|
||||
str(port),
|
||||
"--cluster-enabled",
|
||||
"yes",
|
||||
"--cluster-config-file",
|
||||
"nodes.conf",
|
||||
"--cluster-node-timeout",
|
||||
"5000",
|
||||
"--cluster-require-full-coverage",
|
||||
"no",
|
||||
"--cluster-announce-hostname",
|
||||
f"redis-{i}",
|
||||
"--cluster-announce-port",
|
||||
str(port),
|
||||
"--cluster-announce-bus-port",
|
||||
str(bus),
|
||||
"--cluster-preferred-endpoint-type",
|
||||
"hostname",
|
||||
]
|
||||
)
|
||||
if rc.returncode != 0:
|
||||
raise RuntimeError(f"docker run redis-{i} failed: {rc.stderr}")
|
||||
for port in ISOLATED_PORTS:
|
||||
_wait_port(port)
|
||||
rc = _run(
|
||||
[
|
||||
"docker",
|
||||
"run",
|
||||
"--rm",
|
||||
"--network",
|
||||
network,
|
||||
"redis:7",
|
||||
"redis-cli",
|
||||
"--cluster",
|
||||
"create",
|
||||
f"redis-0:{ISOLATED_PORTS[0]}",
|
||||
f"redis-1:{ISOLATED_PORTS[1]}",
|
||||
f"redis-2:{ISOLATED_PORTS[2]}",
|
||||
"--cluster-replicas",
|
||||
"0",
|
||||
"--cluster-yes",
|
||||
]
|
||||
)
|
||||
if rc.returncode != 0:
|
||||
raise RuntimeError(f"cluster create failed: {rc.stderr}")
|
||||
deadline = time.monotonic() + 30
|
||||
while time.monotonic() < deadline:
|
||||
info = _run(
|
||||
[
|
||||
"docker",
|
||||
"exec",
|
||||
f"{ISOLATED_PROJECT}-redis-0",
|
||||
"redis-cli",
|
||||
"-p",
|
||||
str(ISOLATED_PORTS[0]),
|
||||
"cluster",
|
||||
"info",
|
||||
]
|
||||
)
|
||||
if "cluster_state:ok" in info.stdout:
|
||||
return
|
||||
time.sleep(0.5)
|
||||
raise TimeoutError("isolated cluster never reached cluster_state:ok")
|
||||
|
||||
|
||||
def _wait_cluster_ok(timeout_s: float = 30.0) -> bool:
|
||||
deadline = time.monotonic() + timeout_s
|
||||
while time.monotonic() < deadline:
|
||||
info = _run(
|
||||
[
|
||||
"docker",
|
||||
"exec",
|
||||
f"{ISOLATED_PROJECT}-redis-0",
|
||||
"redis-cli",
|
||||
"-p",
|
||||
str(ISOLATED_PORTS[0]),
|
||||
"cluster",
|
||||
"info",
|
||||
]
|
||||
)
|
||||
if "cluster_state:ok" in info.stdout:
|
||||
return True
|
||||
time.sleep(0.5)
|
||||
return False
|
||||
|
||||
|
||||
def _teardown_isolated_cluster() -> None:
|
||||
for i in range(3):
|
||||
_run(["docker", "rm", "-f", f"{ISOLATED_PROJECT}-redis-{i}"])
|
||||
_run(["docker", "network", "rm", f"{ISOLATED_PROJECT}-net"])
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def isolated_cluster():
|
||||
"""Module-scoped: tests share one cluster lifecycle."""
|
||||
_start_isolated_cluster()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
_teardown_isolated_cluster()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.slow
|
||||
@cluster_restart_only
|
||||
async def test_subscriber_survives_shard_restart(isolated_cluster, monkeypatch) -> None:
|
||||
"""Subscriber must receive a post-`docker restart` SPUBLISH after
|
||||
reopening the sharded-pubsub client (the broker drops the socket on
|
||||
restart; production's `with_pubsub` loop reconnects the same way)."""
|
||||
# Must override REDIS_CLUSTER_HOST/PORT too — those take precedence
|
||||
# over REDIS_HOST/PORT and a stray .env would point us at the dev cluster.
|
||||
monkeypatch.setenv("REDIS_HOST", "127.0.0.1")
|
||||
monkeypatch.setenv("REDIS_PORT", str(ISOLATED_PORTS[0]))
|
||||
monkeypatch.setenv("REDIS_CLUSTER_HOST", "127.0.0.1")
|
||||
monkeypatch.setenv("REDIS_CLUSTER_PORT", str(ISOLATED_PORTS[0]))
|
||||
monkeypatch.setenv("REDIS_USE_ANNOUNCED_ADDRESS", "false")
|
||||
monkeypatch.delenv("REDIS_PASSWORD", raising=False)
|
||||
|
||||
import backend.data.redis_client as rc
|
||||
|
||||
importlib.reload(rc)
|
||||
|
||||
# Restart whichever container owns the keyslot, not a guess.
|
||||
cluster = rc.get_redis()
|
||||
target_tag = f"restart-{uuid4().hex[:8]}"
|
||||
channel = "{" + target_tag + "}/restart-test"
|
||||
owner = cluster.get_node_from_key(channel)
|
||||
port_to_idx = {p: i for i, p in enumerate(ISOLATED_PORTS)}
|
||||
target_idx = port_to_idx.get(owner.port)
|
||||
assert (
|
||||
target_idx is not None
|
||||
), f"owner port {owner.port} not in known set {ISOLATED_PORTS}"
|
||||
target_container = f"{ISOLATED_PROJECT}-redis-{target_idx}"
|
||||
|
||||
client = await rc.connect_sharded_pubsub_async(channel)
|
||||
pubsub = client.pubsub()
|
||||
await pubsub.execute_command("SSUBSCRIBE", channel)
|
||||
pubsub.channels[channel] = None # type: ignore[index]
|
||||
|
||||
received: list[str] = []
|
||||
|
||||
async def _drain_one() -> str | None:
|
||||
try:
|
||||
async for msg in pubsub.listen():
|
||||
if msg.get("type") == "smessage":
|
||||
return msg["data"]
|
||||
except Exception:
|
||||
return None
|
||||
return None
|
||||
|
||||
try:
|
||||
async_cluster = await rc.get_redis_async()
|
||||
await async_cluster.execute_command("SPUBLISH", channel, "before-restart")
|
||||
|
||||
first = await asyncio.wait_for(_drain_one(), timeout=6.0)
|
||||
received.append(first or "")
|
||||
assert received == [
|
||||
"before-restart"
|
||||
], f"pre-restart publish did not arrive: {received}"
|
||||
|
||||
# Restart the shard that owns the slot.
|
||||
rc_restart = _run(["docker", "restart", "--time", "1", target_container])
|
||||
assert rc_restart.returncode == 0, rc_restart.stderr
|
||||
|
||||
assert _wait_cluster_ok(
|
||||
timeout_s=30
|
||||
), "isolated cluster never re-converged to state=ok after restart"
|
||||
# Hold a small grace window for shard's gossip to settle.
|
||||
await asyncio.sleep(1.0)
|
||||
|
||||
# Old socket is dead — open a fresh sharded-pubsub connection.
|
||||
try:
|
||||
await pubsub.aclose()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
await client.aclose()
|
||||
except Exception:
|
||||
pass
|
||||
rc._async_clients.clear()
|
||||
|
||||
client2 = await rc.connect_sharded_pubsub_async(channel)
|
||||
pubsub2 = client2.pubsub()
|
||||
try:
|
||||
await pubsub2.execute_command("SSUBSCRIBE", channel)
|
||||
pubsub2.channels[channel] = None # type: ignore[index]
|
||||
|
||||
# Drain the SSUBSCRIBE confirm.
|
||||
async for _msg in pubsub2.listen():
|
||||
break
|
||||
|
||||
async def _drain_after() -> str | None:
|
||||
async for msg in pubsub2.listen():
|
||||
if msg.get("type") == "smessage":
|
||||
return msg["data"]
|
||||
return None
|
||||
|
||||
async_cluster_2 = await rc.get_redis_async()
|
||||
await async_cluster_2.execute_command("SPUBLISH", channel, "after-restart")
|
||||
|
||||
data = await asyncio.wait_for(_drain_after(), timeout=15.0)
|
||||
assert (
|
||||
data == "after-restart"
|
||||
), f"subscriber did not receive post-restart event (got {data!r})"
|
||||
finally:
|
||||
try:
|
||||
await pubsub2.execute_command("SUNSUBSCRIBE", channel)
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
await pubsub2.aclose()
|
||||
except Exception:
|
||||
pass
|
||||
await client2.aclose()
|
||||
finally:
|
||||
try:
|
||||
await pubsub.aclose()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
await client.aclose()
|
||||
except Exception:
|
||||
pass
|
||||
await rc.disconnect_async()
|
||||
# Undo monkeypatched env BEFORE reloading so subsequent tests see the
|
||||
# original REDIS_HOST/PORT — otherwise the module captures the
|
||||
# isolated cluster's port (27110) which is torn down right after this
|
||||
# test, and any later test that touches redis hangs on conn_retry.
|
||||
monkeypatch.undo()
|
||||
importlib.reload(rc)
|
||||
@@ -1,7 +1,15 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, AsyncGenerator, Generator, Generic, Optional, TypeVar
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
Generator,
|
||||
Generic,
|
||||
Optional,
|
||||
TypeVar,
|
||||
)
|
||||
|
||||
from pydantic import BaseModel
|
||||
from redis.asyncio.client import PubSub as AsyncPubSub
|
||||
@@ -11,6 +19,9 @@ from backend.data import redis_client as redis
|
||||
from backend.util import json
|
||||
from backend.util.settings import Settings
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from redis.asyncio import Redis as AsyncRedis
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
config = Settings().config
|
||||
|
||||
@@ -18,6 +29,15 @@ config = Settings().config
|
||||
M = TypeVar("M", bound=BaseModel)
|
||||
|
||||
|
||||
def _assert_no_wildcard(channel_key: str) -> None:
|
||||
"""Sharded pub/sub has no pattern-subscribe; fail fast on wildcards."""
|
||||
if "*" in channel_key:
|
||||
raise ValueError(
|
||||
f"channel_key {channel_key!r} contains a wildcard; sharded pub/sub "
|
||||
"(SSUBSCRIBE) requires exact channel names."
|
||||
)
|
||||
|
||||
|
||||
class BaseRedisEventBus(Generic[M], ABC):
|
||||
Model: type[M]
|
||||
|
||||
@@ -71,8 +91,8 @@ class BaseRedisEventBus(Generic[M], ABC):
|
||||
return message, channel_name
|
||||
|
||||
def _deserialize_message(self, msg: Any, channel_key: str) -> M | None:
|
||||
message_type = "pmessage" if "*" in channel_key else "message"
|
||||
if msg["type"] != message_type:
|
||||
# Accept sharded (smessage) and classic (message/pmessage) deliveries.
|
||||
if msg["type"] not in ("smessage", "message", "pmessage"):
|
||||
return None
|
||||
try:
|
||||
logger.debug(f"[{channel_key}] Consuming an event from Redis {msg['data']}")
|
||||
@@ -80,12 +100,8 @@ class BaseRedisEventBus(Generic[M], ABC):
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to parse event result from Redis {msg} {e}")
|
||||
|
||||
def _get_pubsub_channel(
|
||||
self, connection: redis.Redis | redis.AsyncRedis, channel_key: str
|
||||
) -> tuple[PubSub | AsyncPubSub, str]:
|
||||
full_channel_name = f"{self.event_bus_name}/{channel_key}"
|
||||
pubsub = connection.pubsub()
|
||||
return pubsub, full_channel_name
|
||||
def _build_channel_name(self, channel_key: str) -> str:
|
||||
return f"{self.event_bus_name}/{channel_key}"
|
||||
|
||||
|
||||
class _EventPayloadWrapper(BaseModel, Generic[M]):
|
||||
@@ -98,88 +114,97 @@ class _EventPayloadWrapper(BaseModel, Generic[M]):
|
||||
|
||||
|
||||
class RedisEventBus(BaseRedisEventBus[M], ABC):
|
||||
@property
|
||||
def connection(self) -> redis.Redis:
|
||||
return redis.get_redis()
|
||||
|
||||
def publish_event(self, event: M, channel_key: str):
|
||||
"""
|
||||
Publish an event to Redis. Gracefully handles connection failures
|
||||
by logging the error instead of raising exceptions.
|
||||
"""
|
||||
"""Publish via SPUBLISH; swallow failures so Redis blips don't crash callers."""
|
||||
_assert_no_wildcard(channel_key)
|
||||
try:
|
||||
message, full_channel_name = self._serialize_message(event, channel_key)
|
||||
self.connection.publish(full_channel_name, message)
|
||||
cluster = redis.get_redis()
|
||||
cluster.execute_command("SPUBLISH", full_channel_name, message)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"Failed to publish event to Redis channel {channel_key}. "
|
||||
"Event bus operation will continue without Redis connectivity."
|
||||
)
|
||||
logger.exception(f"Failed to publish event to Redis channel {channel_key}")
|
||||
|
||||
def listen_events(self, channel_key: str) -> Generator[M, None, None]:
|
||||
pubsub, full_channel_name = self._get_pubsub_channel(
|
||||
self.connection, channel_key
|
||||
)
|
||||
assert isinstance(pubsub, PubSub)
|
||||
_assert_no_wildcard(channel_key)
|
||||
full_channel_name = self._build_channel_name(channel_key)
|
||||
|
||||
if "*" in channel_key:
|
||||
pubsub.psubscribe(full_channel_name)
|
||||
else:
|
||||
pubsub.subscribe(full_channel_name)
|
||||
|
||||
for message in pubsub.listen():
|
||||
if event := self._deserialize_message(message, channel_key):
|
||||
yield event
|
||||
cluster = redis.get_redis()
|
||||
pubsub: PubSub = cluster.pubsub()
|
||||
try:
|
||||
pubsub.ssubscribe(full_channel_name)
|
||||
for message in pubsub.listen():
|
||||
if event := self._deserialize_message(message, channel_key):
|
||||
yield event
|
||||
finally:
|
||||
try:
|
||||
pubsub.sunsubscribe(full_channel_name)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to SUNSUBSCRIBE from %s", full_channel_name, exc_info=True
|
||||
)
|
||||
try:
|
||||
pubsub.close()
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to close sharded pubsub for %s",
|
||||
full_channel_name,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
|
||||
class AsyncRedisEventBus(BaseRedisEventBus[M], ABC):
|
||||
def __init__(self):
|
||||
self._pubsub: AsyncPubSub | None = None
|
||||
|
||||
@property
|
||||
async def connection(self) -> redis.AsyncRedis:
|
||||
return await redis.get_redis_async()
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the PubSub connection if it exists."""
|
||||
if self._pubsub is not None:
|
||||
try:
|
||||
await self._pubsub.close()
|
||||
except Exception:
|
||||
logger.warning("Failed to close PubSub connection", exc_info=True)
|
||||
finally:
|
||||
self._pubsub = None
|
||||
"""No-op kept for backward compatibility.
|
||||
|
||||
Earlier revisions of this class stored the per-listen pubsub on the
|
||||
instance, requiring an external close. ``listen_events`` now owns its
|
||||
own client/pubsub locally so concurrent calls on a singleton (e.g.
|
||||
``_webhook_event_bus``) cannot clobber each other's connection.
|
||||
"""
|
||||
return None
|
||||
|
||||
async def publish_event(self, event: M, channel_key: str):
|
||||
"""
|
||||
Publish an event to Redis. Gracefully handles connection failures
|
||||
by logging the error instead of raising exceptions.
|
||||
"""
|
||||
"""Publish via SPUBLISH; swallow failures so Redis blips don't crash callers."""
|
||||
_assert_no_wildcard(channel_key)
|
||||
try:
|
||||
message, full_channel_name = self._serialize_message(event, channel_key)
|
||||
connection = await self.connection
|
||||
await connection.publish(full_channel_name, message)
|
||||
cluster = await redis.get_redis_async()
|
||||
# redis-py 6.x async cluster has no spublish(); execute_command handles MOVED.
|
||||
await cluster.execute_command("SPUBLISH", full_channel_name, message)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"Failed to publish event to Redis channel {channel_key}. "
|
||||
"Event bus operation will continue without Redis connectivity."
|
||||
)
|
||||
logger.exception(f"Failed to publish event to Redis channel {channel_key}")
|
||||
|
||||
async def listen_events(self, channel_key: str) -> AsyncGenerator[M, None]:
|
||||
pubsub, full_channel_name = self._get_pubsub_channel(
|
||||
await self.connection, channel_key
|
||||
_assert_no_wildcard(channel_key)
|
||||
full_channel_name = self._build_channel_name(channel_key)
|
||||
|
||||
# Sharded pub/sub only delivers on the keyslot-owning shard, so pin
|
||||
# a plain AsyncRedis to that node. Both client and pubsub stay
|
||||
# generator-local — concurrent listen_events on the same instance
|
||||
# (e.g. the singleton _webhook_event_bus) must not share state.
|
||||
client: "AsyncRedis" = await redis.connect_sharded_pubsub_async(
|
||||
full_channel_name
|
||||
)
|
||||
assert isinstance(pubsub, AsyncPubSub)
|
||||
self._pubsub = pubsub
|
||||
|
||||
if "*" in channel_key:
|
||||
await pubsub.psubscribe(full_channel_name)
|
||||
else:
|
||||
await pubsub.subscribe(full_channel_name)
|
||||
|
||||
async for message in pubsub.listen():
|
||||
if event := self._deserialize_message(message, channel_key):
|
||||
yield event
|
||||
pubsub: AsyncPubSub = client.pubsub()
|
||||
try:
|
||||
await pubsub.execute_command("SSUBSCRIBE", full_channel_name)
|
||||
# redis-py 6.x async PubSub.listen() exits when ``channels`` is
|
||||
# empty; raw SSUBSCRIBE doesn't populate it, so do it ourselves.
|
||||
pubsub.channels[full_channel_name] = None # type: ignore[index]
|
||||
async for message in pubsub.listen():
|
||||
if event := self._deserialize_message(message, channel_key):
|
||||
yield event
|
||||
finally:
|
||||
try:
|
||||
await pubsub.aclose()
|
||||
except Exception:
|
||||
logger.warning("Failed to close PubSub connection", exc_info=True)
|
||||
try:
|
||||
await client.aclose()
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to close shard-pinned Redis connection", exc_info=True
|
||||
)
|
||||
|
||||
async def wait_for_event(
|
||||
self, channel_key: str, timeout: Optional[float] = None
|
||||
|
||||
@@ -1,25 +1,26 @@
|
||||
"""
|
||||
Tests for event_bus graceful degradation when Redis is unavailable.
|
||||
"""
|
||||
"""Tests for event_bus publish/listen paths."""
|
||||
|
||||
from unittest.mock import AsyncMock, patch
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.event_bus import AsyncRedisEventBus
|
||||
from backend.data.event_bus import (
|
||||
AsyncRedisEventBus,
|
||||
RedisEventBus,
|
||||
_assert_no_wildcard,
|
||||
)
|
||||
|
||||
|
||||
class TestEvent(BaseModel):
|
||||
"""Test event model."""
|
||||
class SampleEvent(BaseModel):
|
||||
"""Minimal event model used by the tests below."""
|
||||
|
||||
message: str
|
||||
|
||||
|
||||
class TestNotificationBus(AsyncRedisEventBus[TestEvent]):
|
||||
"""Test implementation of AsyncRedisEventBus."""
|
||||
|
||||
Model = TestEvent
|
||||
class _BusUnderTest(AsyncRedisEventBus[SampleEvent]):
|
||||
Model = SampleEvent
|
||||
|
||||
@property
|
||||
def event_bus_name(self) -> str:
|
||||
@@ -28,11 +29,10 @@ class TestNotificationBus(AsyncRedisEventBus[TestEvent]):
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_publish_event_handles_connection_failure_gracefully():
|
||||
"""Test that publish_event logs exception instead of raising when Redis is unavailable."""
|
||||
bus = TestNotificationBus()
|
||||
event = TestEvent(message="test message")
|
||||
"""publish_event must log and swallow when the cluster client is down."""
|
||||
bus = _BusUnderTest()
|
||||
event = SampleEvent(message="test message")
|
||||
|
||||
# Mock get_redis_async to raise connection error
|
||||
with patch(
|
||||
"backend.data.event_bus.redis.get_redis_async",
|
||||
side_effect=ConnectionError("Authentication required."),
|
||||
@@ -42,15 +42,487 @@ async def test_publish_event_handles_connection_failure_gracefully():
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_publish_event_works_with_redis_available():
|
||||
"""Test that publish_event works normally when Redis is available."""
|
||||
bus = TestNotificationBus()
|
||||
event = TestEvent(message="test message")
|
||||
async def test_publish_event_spublishes_via_cluster_client():
|
||||
"""publish_event routes a single SPUBLISH through the cluster client."""
|
||||
bus = _BusUnderTest()
|
||||
event = SampleEvent(message="test message")
|
||||
|
||||
# Mock successful Redis connection
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.publish = AsyncMock()
|
||||
mock_cluster = MagicMock()
|
||||
mock_cluster.execute_command = AsyncMock()
|
||||
|
||||
with patch("backend.data.event_bus.redis.get_redis_async", return_value=mock_redis):
|
||||
with patch(
|
||||
"backend.data.event_bus.redis.get_redis_async", return_value=mock_cluster
|
||||
):
|
||||
await bus.publish_event(event, "test_channel")
|
||||
mock_redis.publish.assert_called_once()
|
||||
|
||||
mock_cluster.execute_command.assert_awaited_once()
|
||||
assert mock_cluster.execute_command.await_args[0][0] == "SPUBLISH"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_publish_event_rejects_wildcard_channel():
|
||||
"""A channel_key containing ``*`` must raise — no silent no-op."""
|
||||
bus = _BusUnderTest()
|
||||
with patch("backend.data.event_bus.redis.get_redis_async") as get_cluster:
|
||||
with pytest.raises(ValueError):
|
||||
await bus.publish_event(SampleEvent(message="m"), "user/*/exec")
|
||||
# The cluster client must never be reached for a wildcard channel.
|
||||
get_cluster.assert_not_called()
|
||||
|
||||
|
||||
def test_assert_no_wildcard_guard():
|
||||
"""The standalone guard must reject any ``*``-containing channel."""
|
||||
with pytest.raises(ValueError):
|
||||
_assert_no_wildcard("user/*/exec")
|
||||
# Concrete channels must pass.
|
||||
_assert_no_wildcard("execution_event/user-1/graph-1/exec-1")
|
||||
|
||||
|
||||
# Live SSUBSCRIBE round-trip; skipped when no cluster is reachable.
|
||||
|
||||
|
||||
def _has_live_cluster() -> bool:
|
||||
from backend.data import redis_client
|
||||
|
||||
try:
|
||||
c = redis_client.connect()
|
||||
except Exception: # noqa: BLE001 - any connect failure → skip the test
|
||||
return False
|
||||
try:
|
||||
c.close()
|
||||
except Exception:
|
||||
pass
|
||||
return True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(
|
||||
not _has_live_cluster(),
|
||||
reason="local redis cluster not reachable; skip SSUBSCRIBE integration",
|
||||
)
|
||||
async def test_ssubscribe_end_to_end_async():
|
||||
"""SPUBLISH on one AsyncRedisEventBus reaches SSUBSCRIBE on another."""
|
||||
import asyncio
|
||||
|
||||
from backend.data import redis_client
|
||||
|
||||
redis_client.get_redis.cache_clear()
|
||||
redis_client._async_clients.clear()
|
||||
|
||||
publisher = _BusUnderTest()
|
||||
subscriber = _BusUnderTest()
|
||||
channel_key = "pr12900:event_bus:integration"
|
||||
|
||||
received: list[SampleEvent] = []
|
||||
|
||||
async def consume() -> None:
|
||||
async for ev in subscriber.listen_events(channel_key):
|
||||
received.append(ev)
|
||||
return
|
||||
|
||||
task = asyncio.create_task(consume())
|
||||
# Let SSUBSCRIBE settle; races drop the publish otherwise.
|
||||
await asyncio.sleep(0.3)
|
||||
try:
|
||||
await publisher.publish_event(SampleEvent(message="hello-ssub"), channel_key)
|
||||
await asyncio.wait_for(task, timeout=5.0)
|
||||
finally:
|
||||
if not task.done():
|
||||
task.cancel()
|
||||
await subscriber.close()
|
||||
await redis_client.disconnect_async()
|
||||
|
||||
assert received and received[0].message == "hello-ssub"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(
|
||||
not _has_live_cluster(),
|
||||
reason="local redis cluster not reachable; skip execution-bus integration",
|
||||
)
|
||||
async def test_execution_bus_listen_and_listen_graph_both_deliver():
|
||||
"""Per-exec and per-graph channels both receive every execution event."""
|
||||
import asyncio
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from backend.data import redis_client
|
||||
from backend.data.execution import (
|
||||
AsyncRedisExecutionEventBus,
|
||||
ExecutionStatus,
|
||||
GraphExecutionEvent,
|
||||
)
|
||||
|
||||
redis_client.get_redis.cache_clear()
|
||||
redis_client._async_clients.clear()
|
||||
|
||||
user_id = "user-it"
|
||||
graph_id = "graph-it"
|
||||
exec_id = "exec-it"
|
||||
now = datetime.now(tz=timezone.utc)
|
||||
|
||||
event = GraphExecutionEvent(
|
||||
id=exec_id,
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
graph_version=1,
|
||||
preset_id=None,
|
||||
status=ExecutionStatus.COMPLETED,
|
||||
started_at=now,
|
||||
ended_at=now,
|
||||
stats=GraphExecutionEvent.Stats(
|
||||
cost=0, duration=0.1, node_exec_time=0.1, node_exec_count=1
|
||||
),
|
||||
inputs={},
|
||||
credential_inputs=None,
|
||||
nodes_input_masks=None,
|
||||
outputs={},
|
||||
)
|
||||
|
||||
single = AsyncRedisExecutionEventBus()
|
||||
all_execs = AsyncRedisExecutionEventBus()
|
||||
publisher = AsyncRedisExecutionEventBus()
|
||||
|
||||
received_single: list = []
|
||||
received_all: list = []
|
||||
|
||||
async def _listen_single() -> None:
|
||||
async for ev in single.listen(user_id, graph_id, exec_id):
|
||||
received_single.append(ev)
|
||||
return
|
||||
|
||||
async def _listen_all() -> None:
|
||||
async for ev in all_execs.listen_graph(user_id, graph_id):
|
||||
received_all.append(ev)
|
||||
return
|
||||
|
||||
t1 = asyncio.create_task(_listen_single())
|
||||
t2 = asyncio.create_task(_listen_all())
|
||||
await asyncio.sleep(0.3)
|
||||
|
||||
try:
|
||||
await publisher.publish(event)
|
||||
await asyncio.wait_for(asyncio.gather(t1, t2), timeout=5.0)
|
||||
finally:
|
||||
for t in (t1, t2):
|
||||
if not t.done():
|
||||
t.cancel()
|
||||
await single.close()
|
||||
await all_execs.close()
|
||||
await publisher.close()
|
||||
await redis_client.disconnect_async()
|
||||
|
||||
assert received_single and received_single[0].id == exec_id
|
||||
assert received_all and received_all[0].id == exec_id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_listen_events_rejects_wildcard_channel():
|
||||
"""listen_events on a wildcard channel must raise before touching Redis."""
|
||||
bus = _BusUnderTest()
|
||||
with pytest.raises(ValueError):
|
||||
async for _ in bus.listen_events("user/*/exec"):
|
||||
break
|
||||
|
||||
|
||||
# ---------- Serialization + size guard ----------
|
||||
|
||||
|
||||
def test_serialize_message_tags_full_channel_name():
|
||||
"""_serialize_message returns the ``<bus>/<key>`` full channel name."""
|
||||
bus = _BusUnderTest()
|
||||
_, full = bus._serialize_message(SampleEvent(message="x"), "chan")
|
||||
assert full == "test_event_bus/chan"
|
||||
|
||||
|
||||
def test_serialize_message_truncates_oversized_payload(monkeypatch):
|
||||
"""If the payload exceeds max_message_size_limit, it's replaced with an
|
||||
``error_comms_update`` payload rather than crashing the cluster."""
|
||||
import backend.data.event_bus as event_bus
|
||||
|
||||
bus = _BusUnderTest()
|
||||
# Cap tiny to force truncation.
|
||||
monkeypatch.setattr(event_bus.config, "max_message_size_limit", 50)
|
||||
message, _ = bus._serialize_message(SampleEvent(message="x" * 1000), "chan")
|
||||
assert "error_comms_update" in message
|
||||
assert "Payload too large" in message
|
||||
|
||||
|
||||
def test_deserialize_message_rejects_non_pubsub_types():
|
||||
"""Non ``smessage|message|pmessage`` deliveries deserialize to None."""
|
||||
bus = _BusUnderTest()
|
||||
assert bus._deserialize_message({"type": "ssubscribe", "data": 1}, "c") is None
|
||||
assert bus._deserialize_message({"type": "subscribe", "data": 1}, "c") is None
|
||||
|
||||
|
||||
def test_deserialize_message_swallows_bad_json():
|
||||
"""Corrupted payloads must not raise — they return None (logged elsewhere)."""
|
||||
bus = _BusUnderTest()
|
||||
assert (
|
||||
bus._deserialize_message({"type": "smessage", "data": "not-json"}, "c") is None
|
||||
)
|
||||
|
||||
|
||||
def test_deserialize_message_parses_smessage():
|
||||
"""Happy-path ``smessage`` yields the inner event model."""
|
||||
bus = _BusUnderTest()
|
||||
wrapped = '{"payload":{"message":"hi"}}'
|
||||
parsed = bus._deserialize_message({"type": "smessage", "data": wrapped}, "chan")
|
||||
assert parsed is not None and parsed.message == "hi"
|
||||
|
||||
|
||||
# ---------- Sync RedisEventBus ----------
|
||||
|
||||
|
||||
class _SyncBusUnderTest(RedisEventBus[SampleEvent]):
|
||||
Model = SampleEvent
|
||||
|
||||
@property
|
||||
def event_bus_name(self) -> str:
|
||||
return "test_event_bus"
|
||||
|
||||
|
||||
def test_sync_publish_event_spublish_only():
|
||||
"""Sync publish_event must issue a single SPUBLISH (no classic fallback)."""
|
||||
bus = _SyncBusUnderTest()
|
||||
cluster = MagicMock()
|
||||
cluster.execute_command = MagicMock()
|
||||
|
||||
with patch("backend.data.event_bus.redis.get_redis", return_value=cluster):
|
||||
bus.publish_event(SampleEvent(message="m"), "chan")
|
||||
|
||||
cluster.execute_command.assert_called_once()
|
||||
assert cluster.execute_command.call_args.args[0] == "SPUBLISH"
|
||||
|
||||
|
||||
def test_sync_publish_event_rejects_wildcard():
|
||||
bus = _SyncBusUnderTest()
|
||||
with patch("backend.data.event_bus.redis.get_redis") as mock_get:
|
||||
with pytest.raises(ValueError):
|
||||
bus.publish_event(SampleEvent(message="m"), "user/*/exec")
|
||||
mock_get.assert_not_called()
|
||||
|
||||
|
||||
def test_sync_publish_event_swallows_connection_errors():
|
||||
"""publish_event must never raise to callers — logs + drops on failure."""
|
||||
bus = _SyncBusUnderTest()
|
||||
with patch(
|
||||
"backend.data.event_bus.redis.get_redis",
|
||||
side_effect=ConnectionError("no redis"),
|
||||
):
|
||||
# Should NOT raise.
|
||||
bus.publish_event(SampleEvent(message="m"), "chan")
|
||||
|
||||
|
||||
def test_sync_listen_events_rejects_wildcard():
|
||||
bus = _SyncBusUnderTest()
|
||||
with pytest.raises(ValueError):
|
||||
next(iter(bus.listen_events("user/*/exec")))
|
||||
|
||||
|
||||
def test_sync_listen_events_ssubscribes_and_yields_decoded_events():
|
||||
"""Sync listen_events: SSUBSCRIBE on the full channel, decode smessage payloads."""
|
||||
bus = _SyncBusUnderTest()
|
||||
|
||||
fake_pubsub = MagicMock()
|
||||
fake_pubsub.ssubscribe = MagicMock()
|
||||
fake_pubsub.sunsubscribe = MagicMock()
|
||||
fake_pubsub.close = MagicMock()
|
||||
fake_pubsub.listen = MagicMock(
|
||||
return_value=iter(
|
||||
[
|
||||
{"type": "ssubscribe", "data": 1},
|
||||
{"type": "smessage", "data": '{"payload":{"message":"one"}}'},
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
cluster = MagicMock()
|
||||
cluster.pubsub = MagicMock(return_value=fake_pubsub)
|
||||
|
||||
with patch("backend.data.event_bus.redis.get_redis", return_value=cluster):
|
||||
gen = bus.listen_events("chan")
|
||||
first = next(iter(gen))
|
||||
|
||||
assert first.message == "one"
|
||||
fake_pubsub.ssubscribe.assert_called_once_with("test_event_bus/chan")
|
||||
|
||||
|
||||
def test_sync_listen_events_teardown_swallows_sunsubscribe_errors():
|
||||
"""Teardown must not propagate SUNSUBSCRIBE/close failures."""
|
||||
bus = _SyncBusUnderTest()
|
||||
|
||||
fake_pubsub = MagicMock()
|
||||
fake_pubsub.ssubscribe = MagicMock()
|
||||
fake_pubsub.sunsubscribe = MagicMock(side_effect=RuntimeError("SUNSUB broke"))
|
||||
fake_pubsub.close = MagicMock(side_effect=RuntimeError("close broke"))
|
||||
fake_pubsub.listen = MagicMock(return_value=iter([]))
|
||||
cluster = MagicMock()
|
||||
cluster.pubsub = MagicMock(return_value=fake_pubsub)
|
||||
|
||||
with patch("backend.data.event_bus.redis.get_redis", return_value=cluster):
|
||||
# Exhausting the generator runs the ``finally`` teardown.
|
||||
list(bus.listen_events("chan"))
|
||||
|
||||
fake_pubsub.sunsubscribe.assert_called_once()
|
||||
fake_pubsub.close.assert_called_once()
|
||||
|
||||
|
||||
# ---------- Async close() teardown ----------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_close_is_noop():
|
||||
"""close() is a backward-compat no-op now that listen_events owns its own state."""
|
||||
bus = _BusUnderTest()
|
||||
# Repeated calls must not crash; pubsub/client are generator-locals.
|
||||
await bus.close()
|
||||
await bus.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_listen_events_swallows_aclose_errors():
|
||||
"""Broken pubsub.aclose / client.aclose must not propagate to the caller."""
|
||||
bus = _BusUnderTest()
|
||||
|
||||
fake_pubsub = MagicMock()
|
||||
fake_pubsub.execute_command = AsyncMock()
|
||||
fake_pubsub.channels = {}
|
||||
fake_pubsub.aclose = AsyncMock(side_effect=RuntimeError("pubsub broke"))
|
||||
|
||||
async def _listen():
|
||||
return
|
||||
yield # pragma: no cover — unreachable
|
||||
|
||||
fake_pubsub.listen = _listen
|
||||
|
||||
fake_client = MagicMock()
|
||||
fake_client.pubsub = MagicMock(return_value=fake_pubsub)
|
||||
fake_client.aclose = AsyncMock(side_effect=RuntimeError("client broke"))
|
||||
|
||||
with patch(
|
||||
"backend.data.event_bus.redis.connect_sharded_pubsub_async",
|
||||
AsyncMock(return_value=fake_client),
|
||||
):
|
||||
async for _ in bus.listen_events("chan"):
|
||||
pass # pragma: no cover — never yields
|
||||
|
||||
# Both aclose attempts must have run despite raising.
|
||||
fake_pubsub.aclose.assert_awaited_once()
|
||||
fake_client.aclose.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_listen_events_concurrent_does_not_share_state():
|
||||
"""Two concurrent listens on the same bus must keep their pubsub/client local."""
|
||||
bus = _BusUnderTest()
|
||||
|
||||
pubsubs: list[MagicMock] = []
|
||||
clients: list[MagicMock] = []
|
||||
started = asyncio.Event()
|
||||
proceed = asyncio.Event()
|
||||
|
||||
def _make_pair() -> tuple[MagicMock, MagicMock]:
|
||||
pubsub = MagicMock()
|
||||
pubsub.execute_command = AsyncMock()
|
||||
pubsub.channels = {}
|
||||
pubsub.aclose = AsyncMock()
|
||||
|
||||
async def _listen():
|
||||
started.set()
|
||||
await proceed.wait()
|
||||
return
|
||||
yield # pragma: no cover — unreachable
|
||||
|
||||
pubsub.listen = _listen
|
||||
|
||||
client = MagicMock()
|
||||
client.pubsub = MagicMock(return_value=pubsub)
|
||||
client.aclose = AsyncMock()
|
||||
pubsubs.append(pubsub)
|
||||
clients.append(client)
|
||||
return pubsub, client
|
||||
|
||||
async def _factory(_chan: str):
|
||||
_, client = _make_pair()
|
||||
return client
|
||||
|
||||
with patch(
|
||||
"backend.data.event_bus.redis.connect_sharded_pubsub_async",
|
||||
AsyncMock(side_effect=_factory),
|
||||
):
|
||||
|
||||
async def _run():
|
||||
async for _ in bus.listen_events("chan"):
|
||||
pass # pragma: no cover — never yields
|
||||
|
||||
task_a = asyncio.create_task(_run())
|
||||
task_b = asyncio.create_task(_run())
|
||||
# Wait for both pumps to be parked inside listen() before unblocking.
|
||||
await started.wait()
|
||||
# Yield once more so the second task also enters listen().
|
||||
await asyncio.sleep(0)
|
||||
proceed.set()
|
||||
await asyncio.gather(task_a, task_b)
|
||||
|
||||
# Each listen must have closed its OWN pubsub/client exactly once. If
|
||||
# either was closed twice or zero times, the singleton race is back.
|
||||
assert len(pubsubs) == 2
|
||||
for pubsub in pubsubs:
|
||||
pubsub.aclose.assert_awaited_once()
|
||||
for client in clients:
|
||||
client.aclose.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_wait_for_event_returns_none_on_timeout():
|
||||
"""wait_for_event must coerce asyncio.TimeoutError → None."""
|
||||
import asyncio as _asyncio
|
||||
|
||||
bus = _BusUnderTest()
|
||||
|
||||
async def _never(self, channel_key):
|
||||
await _asyncio.sleep(10)
|
||||
yield # pragma: no cover — unreachable
|
||||
|
||||
with patch.object(_BusUnderTest, "listen_events", _never):
|
||||
result = await bus.wait_for_event("chan", timeout=0.01)
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
# The listen_events async happy path is covered by the live-cluster integration
|
||||
# test above; this one exercises the close-on-exception fallback.
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_listen_events_closes_on_exception():
|
||||
"""If the pump raises, close() must still run to release the shard-pinned client."""
|
||||
bus = _BusUnderTest()
|
||||
|
||||
fake_pubsub = MagicMock()
|
||||
fake_pubsub.execute_command = AsyncMock()
|
||||
fake_pubsub.channels = {}
|
||||
fake_pubsub.aclose = AsyncMock()
|
||||
|
||||
class _Boom(Exception):
|
||||
pass
|
||||
|
||||
async def _listen():
|
||||
raise _Boom()
|
||||
yield # pragma: no cover — unreachable
|
||||
|
||||
fake_pubsub.listen = _listen
|
||||
|
||||
fake_client = MagicMock()
|
||||
fake_client.pubsub = MagicMock(return_value=fake_pubsub)
|
||||
fake_client.aclose = AsyncMock()
|
||||
|
||||
with patch(
|
||||
"backend.data.event_bus.redis.connect_sharded_pubsub_async",
|
||||
AsyncMock(return_value=fake_client),
|
||||
):
|
||||
with pytest.raises(_Boom):
|
||||
async for _ in bus.listen_events("chan"):
|
||||
pass
|
||||
|
||||
# close() must have fired (both aclose calls).
|
||||
fake_pubsub.aclose.assert_awaited_once()
|
||||
fake_client.aclose.assert_awaited_once()
|
||||
|
||||
@@ -570,7 +570,7 @@ async def get_graph_executions(
|
||||
# Build properly typed order clause
|
||||
# Prisma wants specific typed dicts for each field, so we construct them explicitly
|
||||
order_clause: AgentGraphExecutionOrderByInput
|
||||
match (order_by):
|
||||
match order_by:
|
||||
case "startedAt":
|
||||
order_clause = {
|
||||
"startedAt": order_direction,
|
||||
@@ -1337,6 +1337,22 @@ ExecutionEvent = Annotated[
|
||||
]
|
||||
|
||||
|
||||
# Hash-tagged channels keep per-exec and per-graph keys on the same shard,
|
||||
# so one SSUBSCRIBE connection can watch both.
|
||||
|
||||
|
||||
def _graph_scope_tag(user_id: str, graph_id: str) -> str:
|
||||
return "{" + f"{user_id}/{graph_id}" + "}"
|
||||
|
||||
|
||||
def exec_channel(user_id: str, graph_id: str, graph_exec_id: str) -> str:
|
||||
return f"{_graph_scope_tag(user_id, graph_id)}/exec/{graph_exec_id}"
|
||||
|
||||
|
||||
def graph_all_channel(user_id: str, graph_id: str) -> str:
|
||||
return f"{_graph_scope_tag(user_id, graph_id)}/all"
|
||||
|
||||
|
||||
class RedisExecutionEventBus(RedisEventBus[ExecutionEvent]):
|
||||
Model = ExecutionEvent # type: ignore
|
||||
|
||||
@@ -1352,16 +1368,20 @@ class RedisExecutionEventBus(RedisEventBus[ExecutionEvent]):
|
||||
|
||||
def _publish_node_exec_update(self, res: NodeExecutionResult):
|
||||
event = NodeExecutionEvent.model_validate(res.model_dump())
|
||||
self._publish(event, f"{res.user_id}/{res.graph_id}/{res.graph_exec_id}")
|
||||
self._publish(event, res.user_id, res.graph_id, res.graph_exec_id)
|
||||
|
||||
def _publish_graph_exec_update(self, res: GraphExecution):
|
||||
event = GraphExecutionEvent.model_validate(res.model_dump())
|
||||
self._publish(event, f"{res.user_id}/{res.graph_id}/{res.id}")
|
||||
self._publish(event, res.user_id, res.graph_id, res.id)
|
||||
|
||||
def _publish(self, event: ExecutionEvent, channel: str):
|
||||
"""
|
||||
truncate inputs and outputs to avoid large payloads
|
||||
"""
|
||||
def _publish(
|
||||
self,
|
||||
event: ExecutionEvent,
|
||||
user_id: str,
|
||||
graph_id: str,
|
||||
graph_exec_id: str,
|
||||
):
|
||||
"""Truncate oversized payloads, then publish to per-exec + per-graph channels."""
|
||||
limit = config.max_message_size_limit // 2
|
||||
if isinstance(event, GraphExecutionEvent):
|
||||
event.inputs = truncate(event.inputs, limit)
|
||||
@@ -1370,12 +1390,22 @@ class RedisExecutionEventBus(RedisEventBus[ExecutionEvent]):
|
||||
event.input_data = truncate(event.input_data, limit)
|
||||
event.output_data = truncate(event.output_data, limit)
|
||||
|
||||
super().publish_event(event, channel)
|
||||
# Publisher fans out: per-exec and per-graph watchers.
|
||||
super().publish_event(event, exec_channel(user_id, graph_id, graph_exec_id))
|
||||
super().publish_event(event, graph_all_channel(user_id, graph_id))
|
||||
|
||||
def listen(
|
||||
self, user_id: str, graph_id: str = "*", graph_exec_id: str = "*"
|
||||
self, user_id: str, graph_id: str, graph_exec_id: str
|
||||
) -> Generator[ExecutionEvent, None, None]:
|
||||
for event in self.listen_events(f"{user_id}/{graph_id}/{graph_exec_id}"):
|
||||
"""Stream events for a specific graph execution."""
|
||||
for event in self.listen_events(exec_channel(user_id, graph_id, graph_exec_id)):
|
||||
yield event
|
||||
|
||||
def listen_graph(
|
||||
self, user_id: str, graph_id: str
|
||||
) -> Generator[ExecutionEvent, None, None]:
|
||||
"""Stream every event for every execution of ``graph_id``."""
|
||||
for event in self.listen_events(graph_all_channel(user_id, graph_id)):
|
||||
yield event
|
||||
|
||||
|
||||
@@ -1395,7 +1425,7 @@ class AsyncRedisExecutionEventBus(AsyncRedisEventBus[ExecutionEvent]):
|
||||
|
||||
async def _publish_node_exec_update(self, res: NodeExecutionResult):
|
||||
event = NodeExecutionEvent.model_validate(res.model_dump())
|
||||
await self._publish(event, f"{res.user_id}/{res.graph_id}/{res.graph_exec_id}")
|
||||
await self._publish(event, res.user_id, res.graph_id, res.graph_exec_id)
|
||||
|
||||
async def _publish_graph_exec_update(self, res: GraphExecutionMeta):
|
||||
# GraphExecutionEvent requires inputs and outputs fields that GraphExecutionMeta doesn't have
|
||||
@@ -1404,12 +1434,16 @@ class AsyncRedisExecutionEventBus(AsyncRedisEventBus[ExecutionEvent]):
|
||||
event_data.setdefault("inputs", {})
|
||||
event_data.setdefault("outputs", {})
|
||||
event = GraphExecutionEvent.model_validate(event_data)
|
||||
await self._publish(event, f"{res.user_id}/{res.graph_id}/{res.id}")
|
||||
await self._publish(event, res.user_id, res.graph_id, res.id)
|
||||
|
||||
async def _publish(self, event: ExecutionEvent, channel: str):
|
||||
"""
|
||||
truncate inputs and outputs to avoid large payloads
|
||||
"""
|
||||
async def _publish(
|
||||
self,
|
||||
event: ExecutionEvent,
|
||||
user_id: str,
|
||||
graph_id: str,
|
||||
graph_exec_id: str,
|
||||
):
|
||||
"""Truncate oversized payloads, then publish to per-exec + per-graph channels."""
|
||||
limit = config.max_message_size_limit // 2
|
||||
if isinstance(event, GraphExecutionEvent):
|
||||
event.inputs = truncate(event.inputs, limit)
|
||||
@@ -1418,12 +1452,25 @@ class AsyncRedisExecutionEventBus(AsyncRedisEventBus[ExecutionEvent]):
|
||||
event.input_data = truncate(event.input_data, limit)
|
||||
event.output_data = truncate(event.output_data, limit)
|
||||
|
||||
await super().publish_event(event, channel)
|
||||
await super().publish_event(
|
||||
event, exec_channel(user_id, graph_id, graph_exec_id)
|
||||
)
|
||||
await super().publish_event(event, graph_all_channel(user_id, graph_id))
|
||||
|
||||
async def listen(
|
||||
self, user_id: str, graph_id: str = "*", graph_exec_id: str = "*"
|
||||
self, user_id: str, graph_id: str, graph_exec_id: str
|
||||
) -> AsyncGenerator[ExecutionEvent, None]:
|
||||
async for event in self.listen_events(f"{user_id}/{graph_id}/{graph_exec_id}"):
|
||||
"""Stream events for a specific graph execution."""
|
||||
async for event in self.listen_events(
|
||||
exec_channel(user_id, graph_id, graph_exec_id)
|
||||
):
|
||||
yield event
|
||||
|
||||
async def listen_graph(
|
||||
self, user_id: str, graph_id: str
|
||||
) -> AsyncGenerator[ExecutionEvent, None]:
|
||||
"""Stream every event for every execution of ``graph_id``."""
|
||||
async for event in self.listen_events(graph_all_channel(user_id, graph_id)):
|
||||
yield event
|
||||
|
||||
|
||||
@@ -1682,11 +1729,11 @@ async def create_shared_execution_files(
|
||||
created += 1
|
||||
except UniqueViolationError:
|
||||
logger.debug(
|
||||
f"Skipping shared file record for {file_id}: " f"record already exists"
|
||||
f"Skipping shared file record for {file_id}: record already exists"
|
||||
)
|
||||
except ForeignKeyViolationError:
|
||||
logger.debug(
|
||||
f"Skipping shared file record for {file_id}: " f"file does not exist"
|
||||
f"Skipping shared file record for {file_id}: file does not exist"
|
||||
)
|
||||
return created
|
||||
|
||||
|
||||
@@ -0,0 +1,387 @@
|
||||
"""Tests for the sharded channel builders + publish/listen paths on
|
||||
``AsyncRedisExecutionEventBus`` / ``RedisExecutionEventBus``.
|
||||
|
||||
These tests are intentionally Prisma-free: they exercise only the in-process
|
||||
event-routing layer, using mocks for the Redis cluster client. The live
|
||||
SSUBSCRIBE round-trip is covered by the integration test in
|
||||
``event_bus_test.py``.
|
||||
"""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.data.execution import (
|
||||
AsyncRedisExecutionEventBus,
|
||||
ExecutionEventType,
|
||||
ExecutionStatus,
|
||||
GraphExecutionEvent,
|
||||
NodeExecutionEvent,
|
||||
RedisExecutionEventBus,
|
||||
_graph_scope_tag,
|
||||
exec_channel,
|
||||
graph_all_channel,
|
||||
)
|
||||
|
||||
# ---------- Hash-tagged channel builders ----------
|
||||
|
||||
|
||||
def test_graph_scope_tag_uses_hash_tag_syntax():
|
||||
"""Hash-tagged tag must look like ``{user/graph}`` so per-exec + per-graph
|
||||
channels hash to the same Redis Cluster keyslot."""
|
||||
assert _graph_scope_tag("u", "g") == "{u/g}"
|
||||
|
||||
|
||||
def test_exec_channel_nests_scope_tag():
|
||||
"""Per-exec channel: ``{user/graph}/exec/<exec_id>``."""
|
||||
assert exec_channel("u", "g", "e") == "{u/g}/exec/e"
|
||||
|
||||
|
||||
def test_graph_all_channel_nests_scope_tag():
|
||||
"""Aggregate channel: ``{user/graph}/all`` — keyslot-aligned with per-exec."""
|
||||
assert graph_all_channel("u", "g") == "{u/g}/all"
|
||||
|
||||
|
||||
def test_exec_and_graph_channels_share_hash_tag():
|
||||
"""Invariant: both channels *must* share the ``{user/graph}`` prefix.
|
||||
If this breaks, SSUBSCRIBE for per-exec and aggregate routes to different
|
||||
shards and the per-graph listener loses some events."""
|
||||
exec_ch = exec_channel("u", "g", "e")
|
||||
graph_ch = graph_all_channel("u", "g")
|
||||
assert exec_ch.startswith("{u/g}")
|
||||
assert graph_ch.startswith("{u/g}")
|
||||
|
||||
|
||||
# ---------- NodeExecutionEvent publish → exec channel only ----------
|
||||
|
||||
|
||||
def _sample_node_event() -> NodeExecutionEvent:
|
||||
now = datetime.now(tz=timezone.utc)
|
||||
return NodeExecutionEvent(
|
||||
user_id="u",
|
||||
graph_id="g",
|
||||
graph_version=1,
|
||||
graph_exec_id="e",
|
||||
node_exec_id="ne",
|
||||
node_id="nid",
|
||||
block_id="bid",
|
||||
status=ExecutionStatus.COMPLETED,
|
||||
input_data={"a": 1},
|
||||
output_data={"o": [1]},
|
||||
add_time=now,
|
||||
queue_time=None,
|
||||
start_time=now,
|
||||
end_time=now,
|
||||
)
|
||||
|
||||
|
||||
def _sample_graph_event() -> GraphExecutionEvent:
|
||||
now = datetime.now(tz=timezone.utc)
|
||||
return GraphExecutionEvent(
|
||||
id="e",
|
||||
user_id="u",
|
||||
graph_id="g",
|
||||
graph_version=1,
|
||||
preset_id=None,
|
||||
status=ExecutionStatus.COMPLETED,
|
||||
started_at=now,
|
||||
ended_at=now,
|
||||
stats=GraphExecutionEvent.Stats(
|
||||
cost=0, duration=0.1, node_exec_time=0.1, node_exec_count=1
|
||||
),
|
||||
inputs={},
|
||||
credential_inputs=None,
|
||||
nodes_input_masks=None,
|
||||
outputs={},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_publish_node_sends_to_both_channels():
|
||||
"""Node events fan out to BOTH per-exec and aggregate channels so the
|
||||
per-graph WS subscriber sees every node update, not just graph-level ones.
|
||||
"""
|
||||
bus = AsyncRedisExecutionEventBus()
|
||||
sent_channels: list[str] = []
|
||||
|
||||
async def _capture(self, event, channel_key):
|
||||
sent_channels.append(channel_key)
|
||||
|
||||
with patch.object(
|
||||
AsyncRedisExecutionEventBus.__mro__[1], "publish_event", _capture
|
||||
):
|
||||
await bus._publish_node_exec_update(_sample_node_event())
|
||||
|
||||
assert sent_channels == [
|
||||
exec_channel("u", "g", "e"),
|
||||
graph_all_channel("u", "g"),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_publish_graph_sends_to_both_channels():
|
||||
bus = AsyncRedisExecutionEventBus()
|
||||
sent_channels: list[str] = []
|
||||
|
||||
async def _capture(self, event, channel_key):
|
||||
sent_channels.append(channel_key)
|
||||
|
||||
with patch.object(
|
||||
AsyncRedisExecutionEventBus.__mro__[1], "publish_event", _capture
|
||||
):
|
||||
await bus._publish_graph_exec_update(_sample_graph_event())
|
||||
|
||||
assert sent_channels == [
|
||||
exec_channel("u", "g", "e"),
|
||||
graph_all_channel("u", "g"),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_publish_routes_via_type_dispatch():
|
||||
"""publish() dispatches on the model type — not on status or event_type."""
|
||||
bus = AsyncRedisExecutionEventBus()
|
||||
|
||||
with (
|
||||
patch.object(bus, "_publish_graph_exec_update", AsyncMock()) as graph_pub,
|
||||
patch.object(bus, "_publish_node_exec_update", AsyncMock()) as node_pub,
|
||||
):
|
||||
await bus.publish(_sample_graph_event())
|
||||
await bus.publish(_sample_node_event())
|
||||
|
||||
graph_pub.assert_awaited_once()
|
||||
node_pub.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_publish_truncates_oversized_payload(monkeypatch):
|
||||
"""Payload truncation applies before sending — size exceeded → replacement."""
|
||||
import backend.data.execution as execution
|
||||
|
||||
bus = AsyncRedisExecutionEventBus()
|
||||
# Force tiny limit so ``truncate`` rewrites the payload.
|
||||
monkeypatch.setattr(execution.config, "max_message_size_limit", 10)
|
||||
|
||||
cluster = MagicMock()
|
||||
cluster.execute_command = AsyncMock()
|
||||
with patch("backend.data.event_bus.redis.get_redis_async", return_value=cluster):
|
||||
await bus.publish(_sample_node_event())
|
||||
|
||||
# Called twice: per-exec and per-graph channel.
|
||||
assert cluster.execute_command.await_count == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_listen_uses_exec_channel():
|
||||
"""listen() must subscribe to the per-exec hash-tagged channel."""
|
||||
bus = AsyncRedisExecutionEventBus()
|
||||
|
||||
captured: list[str] = []
|
||||
|
||||
async def _listen_events(self, channel_key):
|
||||
captured.append(channel_key)
|
||||
# Return an empty async-generator so the ``async for`` exits cleanly.
|
||||
if False:
|
||||
yield # pragma: no cover
|
||||
|
||||
with patch.object(AsyncRedisExecutionEventBus, "listen_events", _listen_events):
|
||||
async for _ in bus.listen("u", "g", "e"):
|
||||
break # pragma: no cover — generator is empty
|
||||
|
||||
assert captured == [exec_channel("u", "g", "e")]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_listen_graph_uses_all_channel():
|
||||
"""listen_graph() must subscribe to the aggregate hash-tagged channel."""
|
||||
bus = AsyncRedisExecutionEventBus()
|
||||
|
||||
captured: list[str] = []
|
||||
|
||||
async def _listen_events(self, channel_key):
|
||||
captured.append(channel_key)
|
||||
if False:
|
||||
yield # pragma: no cover
|
||||
|
||||
with patch.object(AsyncRedisExecutionEventBus, "listen_events", _listen_events):
|
||||
async for _ in bus.listen_graph("u", "g"):
|
||||
break # pragma: no cover — generator is empty
|
||||
|
||||
assert captured == [graph_all_channel("u", "g")]
|
||||
|
||||
|
||||
# ---------- Sync RedisExecutionEventBus (smaller surface; covers branching) ----------
|
||||
|
||||
|
||||
def test_sync_listen_uses_exec_channel():
|
||||
bus = RedisExecutionEventBus()
|
||||
|
||||
captured: list[str] = []
|
||||
|
||||
def _listen_events(self, channel_key):
|
||||
captured.append(channel_key)
|
||||
return iter([])
|
||||
|
||||
with patch.object(RedisExecutionEventBus, "listen_events", _listen_events):
|
||||
list(bus.listen("u", "g", "e"))
|
||||
|
||||
assert captured == [exec_channel("u", "g", "e")]
|
||||
|
||||
|
||||
def test_sync_listen_graph_uses_all_channel():
|
||||
bus = RedisExecutionEventBus()
|
||||
|
||||
captured: list[str] = []
|
||||
|
||||
def _listen_events(self, channel_key):
|
||||
captured.append(channel_key)
|
||||
return iter([])
|
||||
|
||||
with patch.object(RedisExecutionEventBus, "listen_events", _listen_events):
|
||||
list(bus.listen_graph("u", "g"))
|
||||
|
||||
assert captured == [graph_all_channel("u", "g")]
|
||||
|
||||
|
||||
def test_sync_publish_node_sends_to_both_channels():
|
||||
"""Sync publish path also fans out to per-exec + per-graph."""
|
||||
bus = RedisExecutionEventBus()
|
||||
sent: list[str] = []
|
||||
|
||||
def _capture(self, event, channel_key):
|
||||
sent.append(channel_key)
|
||||
|
||||
with patch.object(RedisExecutionEventBus.__mro__[1], "publish_event", _capture):
|
||||
bus._publish_node_exec_update(_sample_node_event().model_copy())
|
||||
|
||||
assert sent == [
|
||||
exec_channel("u", "g", "e"),
|
||||
graph_all_channel("u", "g"),
|
||||
]
|
||||
|
||||
|
||||
def test_event_type_is_literal_on_events():
|
||||
"""event_type is a discriminator literal, not dynamic — the WS fan-out
|
||||
relies on ``ExecutionEventType(event_type)`` being stable."""
|
||||
node = _sample_node_event()
|
||||
graph = _sample_graph_event()
|
||||
assert node.event_type == ExecutionEventType.NODE_EXEC_UPDATE
|
||||
assert graph.event_type == ExecutionEventType.GRAPH_EXEC_UPDATE
|
||||
|
||||
|
||||
# ---------- Sync publish dispatch + listen yields ----------
|
||||
|
||||
|
||||
def test_sync_publish_dispatches_on_model_type():
|
||||
"""Sync ``publish()`` routes GraphExecution and NodeExecutionResult to
|
||||
their respective helpers — regression guard on the type-dispatch branch."""
|
||||
from backend.data.execution import GraphExecution, NodeExecutionResult
|
||||
|
||||
bus = RedisExecutionEventBus()
|
||||
|
||||
graph_like = MagicMock(spec=GraphExecution)
|
||||
node_like = MagicMock(spec=NodeExecutionResult)
|
||||
|
||||
with (
|
||||
patch.object(bus, "_publish_graph_exec_update") as graph_pub,
|
||||
patch.object(bus, "_publish_node_exec_update") as node_pub,
|
||||
):
|
||||
bus.publish(graph_like)
|
||||
bus.publish(node_like)
|
||||
|
||||
graph_pub.assert_called_once_with(graph_like)
|
||||
node_pub.assert_called_once_with(node_like)
|
||||
|
||||
|
||||
def test_sync_publish_graph_exec_update_rebuilds_event():
|
||||
"""Sync ``_publish_graph_exec_update`` validates the input into a
|
||||
GraphExecutionEvent before delegating to ``_publish`` — don't let a raw
|
||||
GraphExecution slip through the type-discriminated listener."""
|
||||
bus = RedisExecutionEventBus()
|
||||
graph_event = _sample_graph_event()
|
||||
|
||||
with patch.object(bus, "_publish") as mock_publish:
|
||||
# Feed back the event itself (it's a GraphExecution subclass) to avoid
|
||||
# needing a full Graph fixture.
|
||||
bus._publish_graph_exec_update(graph_event)
|
||||
|
||||
mock_publish.assert_called_once()
|
||||
args = mock_publish.call_args.args
|
||||
# The first arg is a GraphExecutionEvent (validated copy).
|
||||
assert args[0].event_type == ExecutionEventType.GRAPH_EXEC_UPDATE
|
||||
# Channel-routing args match the input.
|
||||
assert args[1:] == ("u", "g", "e")
|
||||
|
||||
|
||||
def test_sync_publish_node_exec_update_rebuilds_event():
|
||||
"""Sync ``_publish_node_exec_update`` validates to NodeExecutionEvent."""
|
||||
bus = RedisExecutionEventBus()
|
||||
node_event = _sample_node_event()
|
||||
|
||||
with patch.object(bus, "_publish") as mock_publish:
|
||||
bus._publish_node_exec_update(node_event)
|
||||
|
||||
mock_publish.assert_called_once()
|
||||
args = mock_publish.call_args.args
|
||||
assert args[0].event_type == ExecutionEventType.NODE_EXEC_UPDATE
|
||||
assert args[1:] == ("u", "g", "e")
|
||||
|
||||
|
||||
def test_sync_publish_graph_truncates_inputs_and_outputs(monkeypatch):
|
||||
"""Sync ``_publish`` must truncate GraphExecutionEvent.inputs/outputs when
|
||||
the payload exceeds the cap — protects Redis from oversized frames."""
|
||||
import backend.data.execution as execution
|
||||
|
||||
bus = RedisExecutionEventBus()
|
||||
monkeypatch.setattr(execution.config, "max_message_size_limit", 4)
|
||||
|
||||
event = _sample_graph_event()
|
||||
event.inputs = {"long": "x" * 10_000}
|
||||
event.outputs = {"long": ["y" * 10_000]}
|
||||
|
||||
with patch("backend.data.event_bus.redis.get_redis", return_value=MagicMock()):
|
||||
bus._publish(event, "u", "g", "e")
|
||||
|
||||
# After _publish runs, inputs/outputs have been truncated in-place.
|
||||
import json as _json
|
||||
|
||||
assert len(_json.dumps(event.inputs)) < 1000
|
||||
assert len(_json.dumps(event.outputs)) < 1000
|
||||
|
||||
|
||||
def test_sync_listen_yields_events_from_generator():
|
||||
"""Sync ``listen()`` must yield through every event produced by the
|
||||
underlying ``listen_events`` generator."""
|
||||
bus = RedisExecutionEventBus()
|
||||
node_ev = _sample_node_event()
|
||||
|
||||
def _listen_events(self, channel_key):
|
||||
yield node_ev
|
||||
|
||||
with patch.object(RedisExecutionEventBus, "listen_events", _listen_events):
|
||||
got = list(bus.listen("u", "g", "e"))
|
||||
|
||||
assert got == [node_ev]
|
||||
|
||||
|
||||
def test_sync_listen_graph_yields_events_from_generator():
|
||||
bus = RedisExecutionEventBus()
|
||||
graph_ev = _sample_graph_event()
|
||||
|
||||
def _listen_events(self, channel_key):
|
||||
yield graph_ev
|
||||
|
||||
with patch.object(RedisExecutionEventBus, "listen_events", _listen_events):
|
||||
got = list(bus.listen_graph("u", "g"))
|
||||
|
||||
assert got == [graph_ev]
|
||||
|
||||
|
||||
def test_execution_bus_name_matches_settings():
|
||||
"""Both sync and async buses must read the same configured bus name — the
|
||||
WS subscriber depends on this for channel naming."""
|
||||
assert (
|
||||
RedisExecutionEventBus().event_bus_name
|
||||
== AsyncRedisExecutionEventBus().event_bus_name
|
||||
)
|
||||
@@ -1,5 +1,5 @@
|
||||
import logging
|
||||
from typing import AsyncGenerator, Literal, Optional, overload
|
||||
from typing import Literal, Optional, overload
|
||||
|
||||
from prisma.models import AgentNode, AgentPreset, IntegrationWebhook
|
||||
from prisma.types import (
|
||||
@@ -354,18 +354,10 @@ async def publish_webhook_event(event: WebhookEvent):
|
||||
)
|
||||
|
||||
|
||||
async def listen_for_webhook_events(
|
||||
webhook_id: str, event_type: Optional[str] = None
|
||||
) -> AsyncGenerator[WebhookEvent, None]:
|
||||
async for event in _webhook_event_bus.listen_events(
|
||||
f"{webhook_id}/{event_type or '*'}"
|
||||
):
|
||||
yield event
|
||||
|
||||
|
||||
async def wait_for_webhook_event(
|
||||
webhook_id: str, event_type: Optional[str] = None, timeout: Optional[float] = None
|
||||
webhook_id: str, event_type: str, timeout: Optional[float] = None
|
||||
) -> WebhookEvent | None:
|
||||
# Concrete event_type required: sharded pub/sub has no pattern support.
|
||||
return await _webhook_event_bus.wait_for_event(
|
||||
f"{webhook_id}/{event_type or '*'}", timeout
|
||||
f"{webhook_id}/{event_type}", timeout
|
||||
)
|
||||
|
||||
@@ -1,15 +1,23 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from pydantic import BaseModel, field_serializer
|
||||
|
||||
from backend.api.model import NotificationPayload
|
||||
from backend.data.event_bus import AsyncRedisEventBus
|
||||
from backend.data.push_sender import send_push_for_user
|
||||
from backend.util.settings import Settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
_settings = Settings()
|
||||
|
||||
# Strong refs for in-flight push fanout tasks. asyncio only keeps weak refs
|
||||
# to tasks, so a fire-and-forget create_task can be GC'd mid-run.
|
||||
_push_tasks: set[asyncio.Task] = set()
|
||||
|
||||
|
||||
class NotificationEvent(BaseModel):
|
||||
"""Generic notification event destined for websocket delivery."""
|
||||
@@ -23,6 +31,14 @@ class NotificationEvent(BaseModel):
|
||||
return payload.model_dump()
|
||||
|
||||
|
||||
async def _safe_send_push(user_id: str, payload: NotificationPayload) -> None:
|
||||
"""Deliver web push for a notification, swallowing errors."""
|
||||
try:
|
||||
await send_push_for_user(user_id, payload)
|
||||
except Exception:
|
||||
logger.exception("Failed to send web push for user %s", user_id)
|
||||
|
||||
|
||||
class AsyncRedisNotificationEventBus(AsyncRedisEventBus[NotificationEvent]):
|
||||
Model = NotificationEvent # type: ignore
|
||||
|
||||
@@ -32,9 +48,19 @@ class AsyncRedisNotificationEventBus(AsyncRedisEventBus[NotificationEvent]):
|
||||
|
||||
async def publish(self, event: NotificationEvent) -> None:
|
||||
await self.publish_event(event, event.user_id)
|
||||
# Skip OS push for onboarding step toasts — those are in-page only.
|
||||
# TODO: remove once the onboarding/wallet rework lands and decides
|
||||
# per-event whether a system notification is desired.
|
||||
if event.payload.model_dump().get("type") == "onboarding":
|
||||
return
|
||||
# Fan out to web push subscriptions in parallel. Fire-and-forget so
|
||||
# publishers never wait on the push service; held in _push_tasks so
|
||||
# the task survives until completion.
|
||||
task = asyncio.create_task(_safe_send_push(event.user_id, event.payload))
|
||||
_push_tasks.add(task)
|
||||
task.add_done_callback(_push_tasks.discard)
|
||||
|
||||
async def listen(
|
||||
self, user_id: str = "*"
|
||||
) -> AsyncGenerator[NotificationEvent, None]:
|
||||
async def listen(self, user_id: str) -> AsyncGenerator[NotificationEvent, None]:
|
||||
"""Stream notifications for a specific user."""
|
||||
async for event in self.listen_events(user_id):
|
||||
yield event
|
||||
|
||||
145
autogpt_platform/backend/backend/data/notification_bus_test.py
Normal file
145
autogpt_platform/backend/backend/data/notification_bus_test.py
Normal file
@@ -0,0 +1,145 @@
|
||||
"""Tests for AsyncRedisNotificationEventBus.
|
||||
|
||||
Covers the tiny delegation surface: publish → publish_event(user_id), listen
|
||||
→ listen_events(user_id), and the payload serializer that ensures extra
|
||||
fields survive the Redis round-trip.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.api.model import NotificationPayload
|
||||
from backend.data.notification_bus import (
|
||||
AsyncRedisNotificationEventBus,
|
||||
NotificationEvent,
|
||||
)
|
||||
|
||||
|
||||
def test_notification_event_serializes_payload_including_extras():
|
||||
"""``NotificationPayload`` allows extra fields; the bus serializer must
|
||||
preserve them. Dropping extras breaks feature payloads like CopilotCompletion."""
|
||||
payload = NotificationPayload(type="info", event="hey", extra_field="survive me")
|
||||
event = NotificationEvent(user_id="u", payload=payload)
|
||||
dumped = event.model_dump()
|
||||
assert dumped["payload"]["type"] == "info"
|
||||
assert dumped["payload"]["event"] == "hey"
|
||||
assert dumped["payload"]["extra_field"] == "survive me"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_publish_calls_publish_event_with_user_id_channel():
|
||||
"""publish(event) → publish_event(event, channel_key=event.user_id)."""
|
||||
bus = AsyncRedisNotificationEventBus()
|
||||
event = NotificationEvent(
|
||||
user_id="user-42", payload=NotificationPayload(type="info", event="hi")
|
||||
)
|
||||
|
||||
with patch.object(
|
||||
AsyncRedisNotificationEventBus, "publish_event", AsyncMock()
|
||||
) as mock_pub:
|
||||
await bus.publish(event)
|
||||
|
||||
mock_pub.assert_awaited_once()
|
||||
args = mock_pub.await_args.args
|
||||
# Pydantic may pass the event as a positional; regardless, user_id is the channel.
|
||||
assert args[-1] == "user-42"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_listen_delegates_to_listen_events_for_user():
|
||||
"""listen(user_id) must subscribe on the per-user channel."""
|
||||
bus = AsyncRedisNotificationEventBus()
|
||||
|
||||
captured: list[str] = []
|
||||
|
||||
async def _listen_events(self, channel_key):
|
||||
captured.append(channel_key)
|
||||
if False:
|
||||
yield # pragma: no cover
|
||||
|
||||
with patch.object(AsyncRedisNotificationEventBus, "listen_events", _listen_events):
|
||||
async for _ in bus.listen("user-42"):
|
||||
break # pragma: no cover — generator empty
|
||||
|
||||
assert captured == ["user-42"]
|
||||
|
||||
|
||||
def test_event_bus_name_is_configured() -> None:
|
||||
"""The notification bus uses a distinct namespace from the execution bus,
|
||||
so WS exec channels and notification channels never collide."""
|
||||
bus = AsyncRedisNotificationEventBus()
|
||||
assert bus.event_bus_name # non-empty, configured via Settings
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_publish_fans_out_to_web_push():
|
||||
"""publish() must also kick off web-push fanout for the user."""
|
||||
bus = AsyncRedisNotificationEventBus()
|
||||
event = NotificationEvent(
|
||||
user_id="user-42", payload=NotificationPayload(type="info", event="hi")
|
||||
)
|
||||
|
||||
with (
|
||||
patch.object(AsyncRedisNotificationEventBus, "publish_event", AsyncMock()),
|
||||
patch(
|
||||
"backend.data.notification_bus.send_push_for_user",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_push,
|
||||
):
|
||||
await bus.publish(event)
|
||||
# create_task is fire-and-forget — let the event loop drain the task.
|
||||
import asyncio
|
||||
|
||||
for _ in range(3):
|
||||
await asyncio.sleep(0)
|
||||
|
||||
mock_push.assert_awaited_once_with("user-42", event.payload)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_publish_skips_web_push_for_onboarding():
|
||||
"""Onboarding step toasts are in-page only and must NOT trigger OS push."""
|
||||
bus = AsyncRedisNotificationEventBus()
|
||||
event = NotificationEvent(
|
||||
user_id="user-42",
|
||||
payload=NotificationPayload(type="onboarding", event="step_completed"),
|
||||
)
|
||||
|
||||
with (
|
||||
patch.object(AsyncRedisNotificationEventBus, "publish_event", AsyncMock()),
|
||||
patch(
|
||||
"backend.data.notification_bus.send_push_for_user",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_push,
|
||||
):
|
||||
await bus.publish(event)
|
||||
import asyncio
|
||||
|
||||
for _ in range(3):
|
||||
await asyncio.sleep(0)
|
||||
|
||||
mock_push.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_publish_swallows_push_errors():
|
||||
"""A failing push must not propagate or fail the publish."""
|
||||
bus = AsyncRedisNotificationEventBus()
|
||||
event = NotificationEvent(
|
||||
user_id="user-42", payload=NotificationPayload(type="info", event="hi")
|
||||
)
|
||||
|
||||
with (
|
||||
patch.object(AsyncRedisNotificationEventBus, "publish_event", AsyncMock()),
|
||||
patch(
|
||||
"backend.data.notification_bus.send_push_for_user",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=RuntimeError("push backend down"),
|
||||
),
|
||||
):
|
||||
await bus.publish(event) # must not raise
|
||||
import asyncio
|
||||
|
||||
for _ in range(3):
|
||||
await asyncio.sleep(0)
|
||||
139
autogpt_platform/backend/backend/data/push_sender.py
Normal file
139
autogpt_platform/backend/backend/data/push_sender.py
Normal file
@@ -0,0 +1,139 @@
|
||||
"""Fire-and-forget Web Push delivery for notification events."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
import uuid
|
||||
|
||||
from cachetools import TTLCache
|
||||
from pywebpush import WebPushException, webpush
|
||||
|
||||
from backend.api.model import NotificationPayload
|
||||
from backend.data.push_subscription import PushSubscriptionDTO, validate_push_endpoint
|
||||
from backend.util.clients import get_database_manager_async_client
|
||||
from backend.util.settings import Settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_settings = Settings()
|
||||
|
||||
DEBOUNCE_SECONDS = 5.0
|
||||
# Per-user debounce timestamps, bounded + auto-evicted so the process doesn't
|
||||
# accumulate one entry per user forever. Process-local — ineffective across
|
||||
# multiple WS replicas; acceptable since debounce is a best-effort UX nicety.
|
||||
_user_last_push: TTLCache[str, float] = TTLCache(maxsize=10_000, ttl=DEBOUNCE_SECONDS)
|
||||
|
||||
# Fields to forward from the notification payload to the push message
|
||||
_FORWARDED_FIELDS = ("session_id", "step", "status", "graph_id", "execution_id")
|
||||
|
||||
|
||||
def _extract_status_code(e: WebPushException) -> int | None:
|
||||
"""Extract HTTP status code from a pywebpush exception."""
|
||||
if e.response is not None:
|
||||
return e.response.status_code
|
||||
# Fallback: parse "Push failed: <code> <reason>" out of the message in
|
||||
# case a future pywebpush version raises without attaching the Response.
|
||||
match = re.search(r"Push failed:\s*(\d{3})\b", str(e))
|
||||
return int(match.group(1)) if match else None
|
||||
|
||||
|
||||
def _build_push_payload(payload: NotificationPayload) -> str:
|
||||
"""Build a compact JSON payload (<4KB) for the push message.
|
||||
|
||||
``id`` is a per-push UUID used by the service worker to build a unique
|
||||
notification tag, so repeat pushes don't get coalesced by the OS.
|
||||
"""
|
||||
data = payload.model_dump(mode="json")
|
||||
compact: dict[str, object] = {
|
||||
"id": uuid.uuid4().hex,
|
||||
"type": data.get("type", ""),
|
||||
"event": data.get("event", ""),
|
||||
}
|
||||
for key in _FORWARDED_FIELDS:
|
||||
if key in data:
|
||||
compact[key] = data[key]
|
||||
return json.dumps(compact)
|
||||
|
||||
|
||||
async def send_push_for_user(user_id: str, payload: NotificationPayload) -> None:
|
||||
"""Send push notifications to all of a user's subscriptions.
|
||||
|
||||
- Skips silently if VAPID keys are not configured.
|
||||
- Debounces per-user (collapses pushes within DEBOUNCE_SECONDS).
|
||||
- Cleans up stale subscriptions on 410/404 responses.
|
||||
"""
|
||||
vapid_private = _settings.secrets.vapid_private_key
|
||||
vapid_public = _settings.secrets.vapid_public_key
|
||||
vapid_claim_email = _settings.secrets.vapid_claim_email
|
||||
if not vapid_private or not vapid_public or not vapid_claim_email:
|
||||
logger.debug("VAPID keys not fully configured, skipping push")
|
||||
return
|
||||
# py_vapid rejects unprefixed strings deep in webpush(), where they'd
|
||||
# surface once per subscription as an "Unexpected error". Catch the
|
||||
# misconfiguration here and skip cleanly.
|
||||
if not vapid_claim_email.startswith(("mailto:", "https://")):
|
||||
logger.warning(
|
||||
"VAPID_CLAIM_EMAIL must start with 'mailto:' or 'https://', got %r — "
|
||||
"skipping push",
|
||||
vapid_claim_email[:40],
|
||||
)
|
||||
return
|
||||
|
||||
if user_id in _user_last_push:
|
||||
logger.debug("Debouncing push for user %s", user_id)
|
||||
return
|
||||
_user_last_push[user_id] = time.monotonic()
|
||||
|
||||
db_client = get_database_manager_async_client()
|
||||
subscriptions = await db_client.get_user_push_subscriptions(user_id)
|
||||
if not subscriptions:
|
||||
return
|
||||
|
||||
push_data = _build_push_payload(payload)
|
||||
vapid_claims: dict[str, str | int] = {"sub": vapid_claim_email}
|
||||
|
||||
async def _send_one(sub: PushSubscriptionDTO) -> None:
|
||||
try:
|
||||
# Defense-in-depth: reject endpoints that somehow bypassed the
|
||||
# subscribe-time check (rows written before the validator existed,
|
||||
# direct DB writes, or DNS changes that shifted a trusted host to
|
||||
# a blocked IP).
|
||||
await validate_push_endpoint(sub.endpoint)
|
||||
await asyncio.to_thread(
|
||||
webpush,
|
||||
subscription_info={
|
||||
"endpoint": sub.endpoint,
|
||||
"keys": {"p256dh": sub.p256dh, "auth": sub.auth},
|
||||
},
|
||||
data=push_data,
|
||||
vapid_private_key=vapid_private,
|
||||
vapid_claims=vapid_claims,
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.warning(
|
||||
"Refusing push to untrusted endpoint %s: %s",
|
||||
sub.endpoint[:60],
|
||||
e,
|
||||
)
|
||||
await db_client.delete_push_subscription(sub.user_id, sub.endpoint)
|
||||
return
|
||||
except WebPushException as e:
|
||||
status = _extract_status_code(e)
|
||||
if status in (410, 404):
|
||||
logger.info(
|
||||
"Push subscription gone (%s), removing: %s",
|
||||
status,
|
||||
sub.endpoint[:60],
|
||||
)
|
||||
await db_client.delete_push_subscription(sub.user_id, sub.endpoint)
|
||||
else:
|
||||
logger.warning("Push failed for %s: %s", sub.endpoint[:60], e)
|
||||
await db_client.increment_push_fail_count(sub.user_id, sub.endpoint)
|
||||
except Exception:
|
||||
logger.exception("Unexpected error sending push to %s", sub.endpoint[:60])
|
||||
|
||||
await asyncio.gather(
|
||||
*[_send_one(sub) for sub in subscriptions], return_exceptions=True
|
||||
)
|
||||
348
autogpt_platform/backend/backend/data/push_sender_test.py
Normal file
348
autogpt_platform/backend/backend/data/push_sender_test.py
Normal file
@@ -0,0 +1,348 @@
|
||||
"""Tests for fire-and-forget Web Push delivery."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.api.model import NotificationPayload
|
||||
from backend.data import push_sender
|
||||
from backend.data.push_subscription import PushSubscriptionDTO
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clear_debounce():
|
||||
"""Reset the per-user debounce state between tests."""
|
||||
push_sender._user_last_push.clear()
|
||||
yield
|
||||
push_sender._user_last_push.clear()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_client(mocker):
|
||||
"""Provides a mocked DatabaseManagerAsyncClient with stub async methods."""
|
||||
client = MagicMock()
|
||||
client.get_user_push_subscriptions = AsyncMock(return_value=[])
|
||||
client.delete_push_subscription = AsyncMock()
|
||||
client.increment_push_fail_count = AsyncMock()
|
||||
mocker.patch(
|
||||
"backend.data.push_sender.get_database_manager_async_client",
|
||||
return_value=client,
|
||||
)
|
||||
return client
|
||||
|
||||
|
||||
def _make_settings(
|
||||
private: str = "vapid-private",
|
||||
public: str = "vapid-public",
|
||||
email: str = "mailto:push@agpt.co",
|
||||
) -> MagicMock:
|
||||
settings = MagicMock()
|
||||
settings.secrets.vapid_private_key = private
|
||||
settings.secrets.vapid_public_key = public
|
||||
settings.secrets.vapid_claim_email = email
|
||||
return settings
|
||||
|
||||
|
||||
def _make_subscription(
|
||||
user_id: str = "user-1",
|
||||
endpoint: str = "https://fcm.googleapis.com/fcm/send/sub/1",
|
||||
p256dh: str = "test-p256dh",
|
||||
auth: str = "test-auth",
|
||||
) -> PushSubscriptionDTO:
|
||||
return PushSubscriptionDTO(
|
||||
user_id=user_id, endpoint=endpoint, p256dh=p256dh, auth=auth
|
||||
)
|
||||
|
||||
|
||||
def _make_payload(**kwargs) -> NotificationPayload:
|
||||
defaults = {"type": "agent_run", "event": "completed"}
|
||||
defaults.update(kwargs)
|
||||
return NotificationPayload(**defaults)
|
||||
|
||||
|
||||
class TestBuildPushPayload:
|
||||
def test_includes_type_and_event(self):
|
||||
payload = _make_payload(type="agent_run", event="completed")
|
||||
|
||||
result = push_sender._build_push_payload(payload)
|
||||
|
||||
import json
|
||||
|
||||
parsed = json.loads(result)
|
||||
assert parsed["type"] == "agent_run"
|
||||
assert parsed["event"] == "completed"
|
||||
|
||||
def test_forwards_known_fields(self):
|
||||
payload = _make_payload(
|
||||
execution_id="exec-1",
|
||||
graph_id="graph-1",
|
||||
status="completed",
|
||||
)
|
||||
|
||||
result = push_sender._build_push_payload(payload)
|
||||
|
||||
import json
|
||||
|
||||
parsed = json.loads(result)
|
||||
assert parsed["execution_id"] == "exec-1"
|
||||
assert parsed["graph_id"] == "graph-1"
|
||||
assert parsed["status"] == "completed"
|
||||
|
||||
def test_excludes_unknown_fields(self):
|
||||
payload = _make_payload(
|
||||
custom_field="should-not-appear",
|
||||
)
|
||||
|
||||
result = push_sender._build_push_payload(payload)
|
||||
|
||||
import json
|
||||
|
||||
parsed = json.loads(result)
|
||||
assert "custom_field" not in parsed
|
||||
|
||||
def test_uses_model_dump_json_mode(self):
|
||||
"""Ensure model_dump(mode='json') serializes enums to strings."""
|
||||
payload = _make_payload(type="agent_run", event="completed")
|
||||
|
||||
result = push_sender._build_push_payload(payload)
|
||||
|
||||
import json
|
||||
|
||||
parsed = json.loads(result)
|
||||
assert isinstance(parsed["type"], str)
|
||||
assert isinstance(parsed["event"], str)
|
||||
|
||||
def test_includes_unique_id_per_call(self):
|
||||
"""Each push gets a fresh UUID so repeats don't collapse under the same SW tag."""
|
||||
import json
|
||||
|
||||
payload = _make_payload(type="agent_run", event="completed")
|
||||
|
||||
first = json.loads(push_sender._build_push_payload(payload))
|
||||
second = json.loads(push_sender._build_push_payload(payload))
|
||||
|
||||
assert "id" in first and "id" in second
|
||||
assert first["id"] != second["id"]
|
||||
|
||||
|
||||
class TestSendPushForUser:
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_when_vapid_private_key_missing(self, mocker, mock_db_client):
|
||||
mocker.patch.object(push_sender, "_settings", _make_settings(private=""))
|
||||
|
||||
await push_sender.send_push_for_user("user-1", _make_payload())
|
||||
|
||||
mock_db_client.get_user_push_subscriptions.assert_not_awaited()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_when_vapid_public_key_missing(self, mocker, mock_db_client):
|
||||
mocker.patch.object(push_sender, "_settings", _make_settings(public=""))
|
||||
|
||||
await push_sender.send_push_for_user("user-1", _make_payload())
|
||||
|
||||
mock_db_client.get_user_push_subscriptions.assert_not_awaited()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_when_vapid_email_missing(self, mocker, mock_db_client):
|
||||
mocker.patch.object(push_sender, "_settings", _make_settings(email=""))
|
||||
|
||||
await push_sender.send_push_for_user("user-1", _make_payload())
|
||||
|
||||
mock_db_client.get_user_push_subscriptions.assert_not_awaited()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_debounces_rapid_calls(self, mocker, mock_db_client):
|
||||
mocker.patch.object(push_sender, "_settings", _make_settings())
|
||||
|
||||
await push_sender.send_push_for_user("user-1", _make_payload())
|
||||
assert mock_db_client.get_user_push_subscriptions.await_count == 1
|
||||
|
||||
await push_sender.send_push_for_user("user-1", _make_payload())
|
||||
assert mock_db_client.get_user_push_subscriptions.await_count == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_different_users_not_debounced(self, mocker, mock_db_client):
|
||||
mocker.patch.object(push_sender, "_settings", _make_settings())
|
||||
|
||||
await push_sender.send_push_for_user("user-1", _make_payload())
|
||||
await push_sender.send_push_for_user("user-2", _make_payload())
|
||||
|
||||
assert mock_db_client.get_user_push_subscriptions.await_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_early_when_no_subscriptions(self, mocker, mock_db_client):
|
||||
mocker.patch.object(push_sender, "_settings", _make_settings())
|
||||
mock_webpush = mocker.patch("backend.data.push_sender.webpush")
|
||||
|
||||
await push_sender.send_push_for_user("user-1", _make_payload())
|
||||
|
||||
mock_webpush.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_calls_webpush_for_each_subscription(self, mocker, mock_db_client):
|
||||
mocker.patch.object(push_sender, "_settings", _make_settings())
|
||||
sub1 = _make_subscription(endpoint="https://fcm.googleapis.com/fcm/send/sub/1")
|
||||
sub2 = _make_subscription(endpoint="https://fcm.googleapis.com/fcm/send/sub/2")
|
||||
mock_db_client.get_user_push_subscriptions.return_value = [sub1, sub2]
|
||||
mock_webpush = mocker.patch("backend.data.push_sender.webpush")
|
||||
|
||||
await push_sender.send_push_for_user("user-1", _make_payload())
|
||||
|
||||
assert mock_webpush.call_count == 2
|
||||
|
||||
calls = mock_webpush.call_args_list
|
||||
endpoints_called = [c.kwargs["subscription_info"]["endpoint"] for c in calls]
|
||||
assert "https://fcm.googleapis.com/fcm/send/sub/1" in endpoints_called
|
||||
assert "https://fcm.googleapis.com/fcm/send/sub/2" in endpoints_called
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_webpush_called_with_correct_args(self, mocker, mock_db_client):
|
||||
mocker.patch.object(push_sender, "_settings", _make_settings())
|
||||
sub = _make_subscription(
|
||||
user_id="user-1",
|
||||
endpoint="https://fcm.googleapis.com/fcm/send/sub/1",
|
||||
p256dh="key-p256dh",
|
||||
auth="key-auth",
|
||||
)
|
||||
mock_db_client.get_user_push_subscriptions.return_value = [sub]
|
||||
mock_webpush = mocker.patch("backend.data.push_sender.webpush")
|
||||
|
||||
await push_sender.send_push_for_user("user-1", _make_payload())
|
||||
|
||||
mock_webpush.assert_called_once()
|
||||
call_kwargs = mock_webpush.call_args.kwargs
|
||||
assert call_kwargs["subscription_info"] == {
|
||||
"endpoint": "https://fcm.googleapis.com/fcm/send/sub/1",
|
||||
"keys": {"p256dh": "key-p256dh", "auth": "key-auth"},
|
||||
}
|
||||
assert call_kwargs["vapid_private_key"] == "vapid-private"
|
||||
assert call_kwargs["vapid_claims"] == {"sub": "mailto:push@agpt.co"}
|
||||
assert isinstance(call_kwargs["data"], str)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_removes_subscription_on_410_gone(self, mocker, mock_db_client):
|
||||
from pywebpush import WebPushException
|
||||
|
||||
mocker.patch.object(push_sender, "_settings", _make_settings())
|
||||
sub = _make_subscription()
|
||||
mock_db_client.get_user_push_subscriptions.return_value = [sub]
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 410
|
||||
exc = WebPushException("Gone", response=mock_response)
|
||||
mocker.patch("backend.data.push_sender.webpush", side_effect=exc)
|
||||
|
||||
await push_sender.send_push_for_user("user-1", _make_payload())
|
||||
|
||||
mock_db_client.delete_push_subscription.assert_awaited_once_with(
|
||||
sub.user_id, sub.endpoint
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_removes_subscription_on_404(self, mocker, mock_db_client):
|
||||
from pywebpush import WebPushException
|
||||
|
||||
mocker.patch.object(push_sender, "_settings", _make_settings())
|
||||
sub = _make_subscription()
|
||||
mock_db_client.get_user_push_subscriptions.return_value = [sub]
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 404
|
||||
exc = WebPushException("Not Found", response=mock_response)
|
||||
mocker.patch("backend.data.push_sender.webpush", side_effect=exc)
|
||||
|
||||
await push_sender.send_push_for_user("user-1", _make_payload())
|
||||
|
||||
mock_db_client.delete_push_subscription.assert_awaited_once_with(
|
||||
sub.user_id, sub.endpoint
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_removes_subscription_when_status_only_in_message(
|
||||
self, mocker, mock_db_client
|
||||
):
|
||||
"""Some pywebpush versions don't expose .response.status_code; the
|
||||
sender must still detect 410 from the exception message and clean up."""
|
||||
from pywebpush import WebPushException
|
||||
|
||||
mocker.patch.object(push_sender, "_settings", _make_settings())
|
||||
sub = _make_subscription()
|
||||
mock_db_client.get_user_push_subscriptions.return_value = [sub]
|
||||
|
||||
# No usable response object — only the message carries the status.
|
||||
exc = WebPushException("Push failed: 410 Gone\nResponse body:gone")
|
||||
exc.response = None # type: ignore[assignment]
|
||||
mocker.patch("backend.data.push_sender.webpush", side_effect=exc)
|
||||
|
||||
await push_sender.send_push_for_user("user-1", _make_payload())
|
||||
|
||||
mock_db_client.delete_push_subscription.assert_awaited_once_with(
|
||||
sub.user_id, sub.endpoint
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_increments_fail_count_on_other_webpush_error(
|
||||
self, mocker, mock_db_client
|
||||
):
|
||||
from pywebpush import WebPushException
|
||||
|
||||
mocker.patch.object(push_sender, "_settings", _make_settings())
|
||||
sub = _make_subscription()
|
||||
mock_db_client.get_user_push_subscriptions.return_value = [sub]
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 429
|
||||
exc = WebPushException("Too Many Requests", response=mock_response)
|
||||
mocker.patch("backend.data.push_sender.webpush", side_effect=exc)
|
||||
|
||||
await push_sender.send_push_for_user("user-1", _make_payload())
|
||||
|
||||
mock_db_client.increment_push_fail_count.assert_awaited_once_with(
|
||||
sub.user_id, sub.endpoint
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_increments_fail_count_when_no_response_object(
|
||||
self, mocker, mock_db_client
|
||||
):
|
||||
from pywebpush import WebPushException
|
||||
|
||||
mocker.patch.object(push_sender, "_settings", _make_settings())
|
||||
sub = _make_subscription()
|
||||
mock_db_client.get_user_push_subscriptions.return_value = [sub]
|
||||
|
||||
exc = WebPushException("Connection error")
|
||||
mocker.patch("backend.data.push_sender.webpush", side_effect=exc)
|
||||
|
||||
await push_sender.send_push_for_user("user-1", _make_payload())
|
||||
|
||||
mock_db_client.increment_push_fail_count.assert_awaited_once_with(
|
||||
sub.user_id, sub.endpoint
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handles_unexpected_exception_gracefully(
|
||||
self, mocker, mock_db_client
|
||||
):
|
||||
mocker.patch.object(push_sender, "_settings", _make_settings())
|
||||
sub = _make_subscription()
|
||||
mock_db_client.get_user_push_subscriptions.return_value = [sub]
|
||||
mocker.patch(
|
||||
"backend.data.push_sender.webpush",
|
||||
side_effect=RuntimeError("network down"),
|
||||
)
|
||||
|
||||
await push_sender.send_push_for_user("user-1", _make_payload())
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_debounce_expires_after_threshold(self, mocker, mock_db_client):
|
||||
mocker.patch.object(push_sender, "_settings", _make_settings())
|
||||
|
||||
await push_sender.send_push_for_user("user-1", _make_payload())
|
||||
assert mock_db_client.get_user_push_subscriptions.await_count == 1
|
||||
|
||||
# Simulate TTL expiry (cachetools evicts on access after TTL elapses).
|
||||
push_sender._user_last_push.pop("user-1", None)
|
||||
|
||||
await push_sender.send_push_for_user("user-1", _make_payload())
|
||||
assert mock_db_client.get_user_push_subscriptions.await_count == 2
|
||||
142
autogpt_platform/backend/backend/data/push_subscription.py
Normal file
142
autogpt_platform/backend/backend/data/push_subscription.py
Normal file
@@ -0,0 +1,142 @@
|
||||
"""CRUD operations for Web Push subscriptions (PushSubscription model)."""
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from prisma.models import PushSubscription
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.util.request import validate_url_host
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Hostnames of legitimate Web Push services. Endpoints submitted by
|
||||
# clients must match one of these; everything else is rejected to prevent
|
||||
# the backend (which POSTs to the stored URL via pywebpush) from being
|
||||
# used as an SSRF primitive against internal infrastructure. Covers Chrome/
|
||||
# Edge/Brave (FCM), Firefox (Autopush), and Safari/macOS (Apple Web Push).
|
||||
_PUSH_SERVICE_HOSTNAMES: list[str] = [
|
||||
"fcm.googleapis.com",
|
||||
"updates.push.services.mozilla.com",
|
||||
"web.push.apple.com",
|
||||
]
|
||||
|
||||
# Cap on concurrent push subscriptions per user — one entry per device/browser
|
||||
# is typical, so this comfortably covers real usage while preventing an
|
||||
# authenticated user from registering unbounded endpoints to amplify outbound
|
||||
# traffic from the backend.
|
||||
MAX_SUBSCRIPTIONS_PER_USER = 20
|
||||
|
||||
# Delete subscriptions with this many failed push attempts during periodic
|
||||
# cleanup. Web Push sends occasionally fail transiently; beyond this threshold
|
||||
# the endpoint is effectively dead and should be removed.
|
||||
MAX_PUSH_FAILURES = 5
|
||||
|
||||
|
||||
async def validate_push_endpoint(endpoint: str) -> None:
|
||||
"""Ensure a push-subscription endpoint is an HTTPS URL hosted on a known
|
||||
Web Push provider. Raises ``ValueError`` otherwise.
|
||||
|
||||
Called at subscribe time and again before dispatch (defense-in-depth against
|
||||
rows written before this check existed or via future codepaths).
|
||||
"""
|
||||
parsed, is_trusted, _ = await validate_url_host(
|
||||
endpoint, trusted_hostnames=_PUSH_SERVICE_HOSTNAMES
|
||||
)
|
||||
if parsed.scheme != "https":
|
||||
raise ValueError("Push endpoint must use https://")
|
||||
if not is_trusted:
|
||||
raise ValueError(
|
||||
f"Push endpoint host '{parsed.hostname}' is not a recognised "
|
||||
"Web Push service"
|
||||
)
|
||||
|
||||
|
||||
class PushSubscriptionDTO(BaseModel):
|
||||
"""RPC-serializable projection of PushSubscription."""
|
||||
|
||||
user_id: str
|
||||
endpoint: str
|
||||
p256dh: str
|
||||
auth: str
|
||||
|
||||
@staticmethod
|
||||
def from_db(model: PushSubscription) -> "PushSubscriptionDTO":
|
||||
return PushSubscriptionDTO(
|
||||
user_id=model.userId,
|
||||
endpoint=model.endpoint,
|
||||
p256dh=model.p256dh,
|
||||
auth=model.auth,
|
||||
)
|
||||
|
||||
|
||||
async def upsert_push_subscription(
|
||||
user_id: str,
|
||||
endpoint: str,
|
||||
p256dh: str,
|
||||
auth: str,
|
||||
user_agent: str | None = None,
|
||||
) -> PushSubscription:
|
||||
existing = await PushSubscription.prisma().find_many(
|
||||
where={"userId": user_id},
|
||||
)
|
||||
# Allow updates to an existing endpoint; only block when adding a *new* one
|
||||
# past the cap.
|
||||
has_this_endpoint = any(row.endpoint == endpoint for row in existing)
|
||||
if len(existing) >= MAX_SUBSCRIPTIONS_PER_USER and not has_this_endpoint:
|
||||
raise ValueError(
|
||||
f"Subscription limit of {MAX_SUBSCRIPTIONS_PER_USER} per user reached"
|
||||
)
|
||||
return await PushSubscription.prisma().upsert(
|
||||
where={"userId_endpoint": {"userId": user_id, "endpoint": endpoint}},
|
||||
data={
|
||||
"create": {
|
||||
"userId": user_id,
|
||||
"endpoint": endpoint,
|
||||
"p256dh": p256dh,
|
||||
"auth": auth,
|
||||
"userAgent": user_agent,
|
||||
},
|
||||
"update": {
|
||||
"p256dh": p256dh,
|
||||
"auth": auth,
|
||||
"userAgent": user_agent,
|
||||
"failCount": 0,
|
||||
"lastFailedAt": None,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def get_user_push_subscriptions(user_id: str) -> list[PushSubscriptionDTO]:
|
||||
rows = await PushSubscription.prisma().find_many(where={"userId": user_id})
|
||||
return [PushSubscriptionDTO.from_db(row) for row in rows]
|
||||
|
||||
|
||||
async def delete_push_subscription(user_id: str, endpoint: str) -> None:
|
||||
await PushSubscription.prisma().delete_many(
|
||||
where={"userId": user_id, "endpoint": endpoint}
|
||||
)
|
||||
|
||||
|
||||
async def increment_fail_count(user_id: str, endpoint: str) -> None:
|
||||
await PushSubscription.prisma().update_many(
|
||||
where={"userId": user_id, "endpoint": endpoint},
|
||||
data={
|
||||
"failCount": {"increment": 1},
|
||||
"lastFailedAt": datetime.now(timezone.utc),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def cleanup_failed_subscriptions(
|
||||
max_failures: int = MAX_PUSH_FAILURES,
|
||||
) -> int:
|
||||
"""Delete subscriptions that have exceeded the failure threshold."""
|
||||
result = await PushSubscription.prisma().delete_many(
|
||||
where={"failCount": {"gte": max_failures}}
|
||||
)
|
||||
if result:
|
||||
logger.info(f"Cleaned up {result} failed push subscriptions")
|
||||
return result or 0
|
||||
325
autogpt_platform/backend/backend/data/push_subscription_test.py
Normal file
325
autogpt_platform/backend/backend/data/push_subscription_test.py
Normal file
@@ -0,0 +1,325 @@
|
||||
"""Tests for Web Push subscription CRUD operations."""
|
||||
|
||||
from datetime import datetime
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.data import push_subscription
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_prisma(mocker):
|
||||
"""Mock PushSubscription.prisma() and return the mock client."""
|
||||
mock_client = MagicMock()
|
||||
mock_client.upsert = AsyncMock()
|
||||
mock_client.find_many = AsyncMock(return_value=[])
|
||||
mock_client.delete_many = AsyncMock()
|
||||
mock_client.update_many = AsyncMock()
|
||||
mocker.patch(
|
||||
"backend.data.push_subscription.PushSubscription.prisma",
|
||||
return_value=mock_client,
|
||||
)
|
||||
return mock_client
|
||||
|
||||
|
||||
class TestUpsertPushSubscription:
|
||||
@pytest.mark.asyncio
|
||||
async def test_calls_prisma_upsert_with_correct_params(self, mock_prisma):
|
||||
mock_prisma.upsert.return_value = MagicMock()
|
||||
|
||||
await push_subscription.upsert_push_subscription(
|
||||
user_id="user-1",
|
||||
endpoint="https://fcm.googleapis.com/fcm/send/sub/1",
|
||||
p256dh="test-p256dh",
|
||||
auth="test-auth",
|
||||
user_agent="Mozilla/5.0",
|
||||
)
|
||||
|
||||
mock_prisma.upsert.assert_awaited_once()
|
||||
call_kwargs = mock_prisma.upsert.call_args.kwargs
|
||||
assert call_kwargs["where"] == {
|
||||
"userId_endpoint": {
|
||||
"userId": "user-1",
|
||||
"endpoint": "https://fcm.googleapis.com/fcm/send/sub/1",
|
||||
}
|
||||
}
|
||||
assert call_kwargs["data"]["create"] == {
|
||||
"userId": "user-1",
|
||||
"endpoint": "https://fcm.googleapis.com/fcm/send/sub/1",
|
||||
"p256dh": "test-p256dh",
|
||||
"auth": "test-auth",
|
||||
"userAgent": "Mozilla/5.0",
|
||||
}
|
||||
assert call_kwargs["data"]["update"] == {
|
||||
"p256dh": "test-p256dh",
|
||||
"auth": "test-auth",
|
||||
"userAgent": "Mozilla/5.0",
|
||||
"failCount": 0,
|
||||
"lastFailedAt": None,
|
||||
}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upsert_without_user_agent(self, mock_prisma):
|
||||
mock_prisma.upsert.return_value = MagicMock()
|
||||
|
||||
await push_subscription.upsert_push_subscription(
|
||||
user_id="user-1",
|
||||
endpoint="https://fcm.googleapis.com/fcm/send/sub/1",
|
||||
p256dh="test-p256dh",
|
||||
auth="test-auth",
|
||||
)
|
||||
|
||||
call_kwargs = mock_prisma.upsert.call_args.kwargs
|
||||
assert call_kwargs["data"]["create"]["userAgent"] is None
|
||||
assert call_kwargs["data"]["update"]["userAgent"] is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upsert_returns_prisma_result(self, mock_prisma):
|
||||
expected = MagicMock()
|
||||
mock_prisma.upsert.return_value = expected
|
||||
|
||||
result = await push_subscription.upsert_push_subscription(
|
||||
user_id="user-1",
|
||||
endpoint="https://fcm.googleapis.com/fcm/send/sub/1",
|
||||
p256dh="test-p256dh",
|
||||
auth="test-auth",
|
||||
)
|
||||
|
||||
assert result is expected
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upsert_resets_fail_count_on_update(self, mock_prisma):
|
||||
mock_prisma.upsert.return_value = MagicMock()
|
||||
|
||||
await push_subscription.upsert_push_subscription(
|
||||
user_id="user-1",
|
||||
endpoint="https://fcm.googleapis.com/fcm/send/sub/1",
|
||||
p256dh="test-p256dh",
|
||||
auth="test-auth",
|
||||
)
|
||||
|
||||
call_kwargs = mock_prisma.upsert.call_args.kwargs
|
||||
assert call_kwargs["data"]["update"]["failCount"] == 0
|
||||
assert call_kwargs["data"]["update"]["lastFailedAt"] is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rejects_new_endpoint_past_cap(self, mock_prisma):
|
||||
existing = [
|
||||
MagicMock(endpoint=f"https://fcm.googleapis.com/fcm/send/sub/{i}")
|
||||
for i in range(push_subscription.MAX_SUBSCRIPTIONS_PER_USER)
|
||||
]
|
||||
mock_prisma.find_many.return_value = existing
|
||||
|
||||
with pytest.raises(ValueError, match="Subscription limit"):
|
||||
await push_subscription.upsert_push_subscription(
|
||||
user_id="user-1",
|
||||
endpoint="https://fcm.googleapis.com/fcm/send/sub/NEW",
|
||||
p256dh="test-p256dh",
|
||||
auth="test-auth",
|
||||
)
|
||||
|
||||
mock_prisma.upsert.assert_not_awaited()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_allows_update_of_existing_endpoint_at_cap(self, mock_prisma):
|
||||
existing = [
|
||||
MagicMock(endpoint=f"https://fcm.googleapis.com/fcm/send/sub/{i}")
|
||||
for i in range(push_subscription.MAX_SUBSCRIPTIONS_PER_USER)
|
||||
]
|
||||
mock_prisma.find_many.return_value = existing
|
||||
mock_prisma.upsert.return_value = MagicMock()
|
||||
|
||||
await push_subscription.upsert_push_subscription(
|
||||
user_id="user-1",
|
||||
endpoint="https://fcm.googleapis.com/fcm/send/sub/0",
|
||||
p256dh="rotated-p256dh",
|
||||
auth="rotated-auth",
|
||||
)
|
||||
|
||||
mock_prisma.upsert.assert_awaited_once()
|
||||
|
||||
|
||||
class TestGetUserPushSubscriptions:
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_list_of_subscription_dtos(self, mock_prisma):
|
||||
sub1 = MagicMock(
|
||||
userId="user-1",
|
||||
endpoint="https://fcm.googleapis.com/fcm/send/sub/1",
|
||||
p256dh="key1",
|
||||
auth="auth1",
|
||||
)
|
||||
sub2 = MagicMock(
|
||||
userId="user-1",
|
||||
endpoint="https://fcm.googleapis.com/fcm/send/sub/2",
|
||||
p256dh="key2",
|
||||
auth="auth2",
|
||||
)
|
||||
mock_prisma.find_many.return_value = [sub1, sub2]
|
||||
|
||||
result = await push_subscription.get_user_push_subscriptions("user-1")
|
||||
|
||||
assert [r.endpoint for r in result] == [
|
||||
"https://fcm.googleapis.com/fcm/send/sub/1",
|
||||
"https://fcm.googleapis.com/fcm/send/sub/2",
|
||||
]
|
||||
assert all(r.user_id == "user-1" for r in result)
|
||||
mock_prisma.find_many.assert_awaited_once_with(where={"userId": "user-1"})
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_empty_list_when_no_subscriptions(self, mock_prisma):
|
||||
mock_prisma.find_many.return_value = []
|
||||
|
||||
result = await push_subscription.get_user_push_subscriptions("user-1")
|
||||
|
||||
assert result == []
|
||||
|
||||
|
||||
class TestDeletePushSubscription:
|
||||
@pytest.mark.asyncio
|
||||
async def test_deletes_by_user_id_and_endpoint(self, mock_prisma):
|
||||
await push_subscription.delete_push_subscription(
|
||||
"user-1",
|
||||
"https://fcm.googleapis.com/fcm/send/sub/1",
|
||||
)
|
||||
|
||||
mock_prisma.delete_many.assert_awaited_once_with(
|
||||
where={
|
||||
"userId": "user-1",
|
||||
"endpoint": "https://fcm.googleapis.com/fcm/send/sub/1",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class TestIncrementFailCount:
|
||||
@pytest.mark.asyncio
|
||||
async def test_includes_user_id_in_where(self, mock_prisma):
|
||||
await push_subscription.increment_fail_count(
|
||||
"user-1",
|
||||
"https://fcm.googleapis.com/fcm/send/sub/1",
|
||||
)
|
||||
|
||||
mock_prisma.update_many.assert_awaited_once()
|
||||
call_kwargs = mock_prisma.update_many.call_args.kwargs
|
||||
assert call_kwargs["where"] == {
|
||||
"userId": "user-1",
|
||||
"endpoint": "https://fcm.googleapis.com/fcm/send/sub/1",
|
||||
}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_increments_fail_count_by_one(self, mock_prisma):
|
||||
await push_subscription.increment_fail_count(
|
||||
"user-1",
|
||||
"https://fcm.googleapis.com/fcm/send/sub/1",
|
||||
)
|
||||
|
||||
call_kwargs = mock_prisma.update_many.call_args.kwargs
|
||||
assert call_kwargs["data"]["failCount"] == {"increment": 1}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sets_last_failed_at_to_utc_now(self, mock_prisma):
|
||||
await push_subscription.increment_fail_count(
|
||||
"user-1",
|
||||
"https://fcm.googleapis.com/fcm/send/sub/1",
|
||||
)
|
||||
|
||||
call_kwargs = mock_prisma.update_many.call_args.kwargs
|
||||
last_failed = call_kwargs["data"]["lastFailedAt"]
|
||||
assert isinstance(last_failed, datetime)
|
||||
assert last_failed.tzinfo is not None
|
||||
|
||||
|
||||
class TestCleanupFailedSubscriptions:
|
||||
@pytest.mark.asyncio
|
||||
async def test_deletes_subscriptions_exceeding_threshold(self, mock_prisma):
|
||||
mock_prisma.delete_many.return_value = 3
|
||||
|
||||
result = await push_subscription.cleanup_failed_subscriptions(
|
||||
max_failures=5,
|
||||
)
|
||||
|
||||
assert result == 3
|
||||
mock_prisma.delete_many.assert_awaited_once_with(
|
||||
where={"failCount": {"gte": 5}}
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_uses_default_max_failures(self, mock_prisma):
|
||||
mock_prisma.delete_many.return_value = 0
|
||||
|
||||
await push_subscription.cleanup_failed_subscriptions()
|
||||
|
||||
call_kwargs = mock_prisma.delete_many.call_args.kwargs
|
||||
assert call_kwargs["where"]["failCount"]["gte"] == 5
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_zero_when_none_deleted(self, mock_prisma):
|
||||
mock_prisma.delete_many.return_value = 0
|
||||
|
||||
result = await push_subscription.cleanup_failed_subscriptions()
|
||||
|
||||
assert result == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_zero_when_result_is_none(self, mock_prisma):
|
||||
mock_prisma.delete_many.return_value = None
|
||||
|
||||
result = await push_subscription.cleanup_failed_subscriptions()
|
||||
|
||||
assert result == 0
|
||||
|
||||
|
||||
class TestValidatePushEndpoint:
|
||||
"""Endpoints from clients must land on a known Web Push service — otherwise
|
||||
the backend can be coerced into POSTing to internal hosts (SSRF)."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"endpoint",
|
||||
[
|
||||
"https://fcm.googleapis.com/fcm/send/abc",
|
||||
"https://updates.push.services.mozilla.com/wpush/v2/xyz",
|
||||
"https://web.push.apple.com/some-token",
|
||||
],
|
||||
)
|
||||
async def test_allows_known_push_services(self, endpoint):
|
||||
await push_subscription.validate_push_endpoint(endpoint)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rejects_http_scheme(self):
|
||||
with pytest.raises(ValueError):
|
||||
await push_subscription.validate_push_endpoint(
|
||||
"http://fcm.googleapis.com/fcm/send/abc"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"endpoint",
|
||||
[
|
||||
"https://localhost/evil",
|
||||
"https://127.0.0.1/evil",
|
||||
"https://169.254.169.254/latest/meta-data/",
|
||||
"https://internal-service.local/api",
|
||||
"https://attacker.example.com/push",
|
||||
],
|
||||
)
|
||||
async def test_rejects_untrusted_hosts(self, endpoint):
|
||||
with pytest.raises(ValueError):
|
||||
await push_subscription.validate_push_endpoint(endpoint)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rejects_non_http_scheme(self):
|
||||
with pytest.raises(ValueError):
|
||||
await push_subscription.validate_push_endpoint("file:///etc/passwd")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_custom_max_failures_threshold(self, mock_prisma):
|
||||
mock_prisma.delete_many.return_value = 1
|
||||
|
||||
result = await push_subscription.cleanup_failed_subscriptions(
|
||||
max_failures=10,
|
||||
)
|
||||
|
||||
assert result == 1
|
||||
call_kwargs = mock_prisma.delete_many.call_args.kwargs
|
||||
assert call_kwargs["where"]["failCount"]["gte"] == 10
|
||||
278
autogpt_platform/backend/backend/data/rabbitmq_test.py
Normal file
278
autogpt_platform/backend/backend/data/rabbitmq_test.py
Normal file
@@ -0,0 +1,278 @@
|
||||
"""Quorum-queue config assertions + mock-driven publish behaviour for
|
||||
`AsyncRabbitMQ`. Live-broker scenarios live in `e2e_redis_rabbit_test.py`."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import aio_pika
|
||||
import pytest
|
||||
|
||||
from backend.copilot.executor.utils import (
|
||||
COPILOT_EXECUTION_EXCHANGE,
|
||||
COPILOT_EXECUTION_QUEUE_NAME,
|
||||
COPILOT_EXECUTION_ROUTING_KEY,
|
||||
create_copilot_queue_config,
|
||||
)
|
||||
from backend.data.rabbitmq import (
|
||||
AsyncRabbitMQ,
|
||||
Exchange,
|
||||
ExchangeType,
|
||||
Queue,
|
||||
RabbitMQConfig,
|
||||
)
|
||||
from backend.executor.utils import (
|
||||
GRAPH_EXECUTION_EXCHANGE,
|
||||
GRAPH_EXECUTION_QUEUE_NAME,
|
||||
GRAPH_EXECUTION_ROUTING_KEY,
|
||||
create_execution_queue_config,
|
||||
)
|
||||
|
||||
# ---------- Quorum queue config: classic→quorum rollover guard ----------
|
||||
|
||||
|
||||
def test_graph_execution_queue_is_quorum() -> None:
|
||||
"""Run queue must declare `x-queue-type=quorum` to survive a single
|
||||
broker-node outage (AUTOGPT-SERVER-8ST/SV/SW)."""
|
||||
cfg = create_execution_queue_config()
|
||||
run = next(q for q in cfg.queues if q.name == GRAPH_EXECUTION_QUEUE_NAME)
|
||||
assert run.arguments is not None
|
||||
assert run.arguments.get("x-queue-type") == "quorum"
|
||||
# _v2 suffix marks the rollover so the old-image consumer keeps draining
|
||||
# the unsuffixed classic queue during a rolling deploy.
|
||||
assert run.name.endswith("_v2")
|
||||
assert run.durable is True
|
||||
assert run.exchange is GRAPH_EXECUTION_EXCHANGE
|
||||
|
||||
|
||||
def test_graph_execution_cancel_queue_is_quorum() -> None:
|
||||
"""Cancel queue must also be quorum — losing cancellations on a node
|
||||
flap is just as bad as losing runs."""
|
||||
cfg = create_execution_queue_config()
|
||||
cancel = next(q for q in cfg.queues if q.name.endswith("cancel_queue_v2"))
|
||||
assert cancel.arguments == {"x-queue-type": "quorum"}
|
||||
|
||||
|
||||
def test_copilot_execution_queue_is_quorum_with_consumer_timeout() -> None:
|
||||
"""Copilot run queue must be quorum + carry a long consumer timeout
|
||||
matching the pod's graceful-shutdown window."""
|
||||
cfg = create_copilot_queue_config()
|
||||
run = next(q for q in cfg.queues if q.name == COPILOT_EXECUTION_QUEUE_NAME)
|
||||
assert run.arguments is not None
|
||||
assert run.arguments.get("x-queue-type") == "quorum"
|
||||
# Timeout must be in milliseconds and substantially larger than the
|
||||
# default 30-minute timeout so a 6-hour copilot turn doesn't get
|
||||
# cancelled by RabbitMQ mid-execution.
|
||||
timeout_ms = run.arguments.get("x-consumer-timeout")
|
||||
assert isinstance(timeout_ms, int)
|
||||
assert timeout_ms >= 60 * 60 * 1000 # at least 1 hour
|
||||
|
||||
|
||||
def test_copilot_cancel_queue_is_quorum() -> None:
|
||||
cfg = create_copilot_queue_config()
|
||||
cancel = next(q for q in cfg.queues if q.name.endswith("cancel_queue_v2"))
|
||||
assert cancel.arguments == {"x-queue-type": "quorum"}
|
||||
|
||||
|
||||
# ---------- AsyncRabbitMQ.publish_message: mock-driven behaviour ----------
|
||||
|
||||
|
||||
def _make_async_client(
|
||||
*, exchange_publish: AsyncMock | None = None
|
||||
) -> tuple[AsyncRabbitMQ, MagicMock, MagicMock]:
|
||||
"""Build an AsyncRabbitMQ wired to mock connection/channel/exchange.
|
||||
|
||||
Returns the client, the mock channel, and the mock exchange so tests can
|
||||
assert on per-call arguments and tweak side_effects mid-flight.
|
||||
"""
|
||||
cfg = RabbitMQConfig(
|
||||
vhost="/",
|
||||
exchanges=[
|
||||
Exchange(name="test_exchange", type=ExchangeType.DIRECT, durable=True)
|
||||
],
|
||||
queues=[
|
||||
Queue(
|
||||
name="test_queue",
|
||||
exchange=Exchange(
|
||||
name="test_exchange", type=ExchangeType.DIRECT, durable=True
|
||||
),
|
||||
routing_key="rk",
|
||||
arguments={"x-queue-type": "quorum"},
|
||||
)
|
||||
],
|
||||
)
|
||||
client = AsyncRabbitMQ(cfg)
|
||||
|
||||
fake_exchange = MagicMock()
|
||||
fake_exchange.publish = exchange_publish or AsyncMock()
|
||||
|
||||
fake_channel = MagicMock()
|
||||
fake_channel.is_closed = False
|
||||
fake_channel.get_exchange = AsyncMock(return_value=fake_exchange)
|
||||
fake_channel.default_exchange = fake_exchange
|
||||
|
||||
fake_connection = MagicMock()
|
||||
fake_connection.is_closed = False
|
||||
|
||||
client._connection = fake_connection
|
||||
client._channel = fake_channel
|
||||
return client, fake_channel, fake_exchange
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_publish_100_messages_to_quorum_queue_all_confirmed() -> None:
|
||||
"""A healthy quorum queue publish path must confirm 100/100 publishes
|
||||
with no NACKs."""
|
||||
client, _, fake_exchange = _make_async_client()
|
||||
exchange = Exchange(name="test_exchange", type=ExchangeType.DIRECT)
|
||||
for i in range(100):
|
||||
await client.publish_message(
|
||||
routing_key="rk", message=f"msg-{i}", exchange=exchange
|
||||
)
|
||||
assert fake_exchange.publish.await_count == 100
|
||||
# Every call carried a persistent message — durable on the broker side.
|
||||
for call in fake_exchange.publish.await_args_list:
|
||||
msg = call.args[0]
|
||||
assert isinstance(msg, aio_pika.Message)
|
||||
assert msg.delivery_mode == aio_pika.DeliveryMode.PERSISTENT
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_publish_retries_on_delivery_error_then_raises() -> None:
|
||||
"""Broker-side NACK (DeliveryError) must trigger ``func_retry`` and then
|
||||
raise gracefully if every retry fails — never crash the publisher loop."""
|
||||
publish = AsyncMock(
|
||||
side_effect=aio_pika.exceptions.DeliveryError(message=None, frame=None)
|
||||
)
|
||||
client, _, fake_exchange = _make_async_client(exchange_publish=publish)
|
||||
exchange = Exchange(name="test_exchange", type=ExchangeType.DIRECT)
|
||||
|
||||
with pytest.raises(aio_pika.exceptions.DeliveryError):
|
||||
await client.publish_message(
|
||||
routing_key="rk", message="will-nack", exchange=exchange
|
||||
)
|
||||
# ``func_retry`` is configured for 5 attempts in retry.py — assert the
|
||||
# publisher attempted at least once but bounded retries.
|
||||
assert fake_exchange.publish.await_count >= 1
|
||||
assert fake_exchange.publish.await_count <= 10 # generous upper bound
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_publish_retries_after_one_transient_failure() -> None:
|
||||
"""A single transient DeliveryError must NOT propagate — ``func_retry``
|
||||
retries and the second call succeeds."""
|
||||
publish = AsyncMock(
|
||||
side_effect=[
|
||||
aio_pika.exceptions.DeliveryError(message=None, frame=None),
|
||||
None, # second attempt succeeds
|
||||
]
|
||||
)
|
||||
client, _, fake_exchange = _make_async_client(exchange_publish=publish)
|
||||
exchange = Exchange(name="test_exchange", type=ExchangeType.DIRECT)
|
||||
await client.publish_message(
|
||||
routing_key="rk", message="recovers", exchange=exchange
|
||||
)
|
||||
assert fake_exchange.publish.await_count == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_publish_reconnects_on_channel_invalid_state() -> None:
|
||||
"""ChannelInvalidStateError must clear the channel and trigger a
|
||||
reconnect-and-retry — the publish_message wrapper handles this
|
||||
explicitly (see the except-clause in rabbitmq.py)."""
|
||||
publish = AsyncMock(
|
||||
side_effect=[
|
||||
aio_pika.exceptions.ChannelInvalidStateError("channel dead"),
|
||||
None,
|
||||
]
|
||||
)
|
||||
client, fake_channel, fake_exchange = _make_async_client(exchange_publish=publish)
|
||||
exchange = Exchange(name="test_exchange", type=ExchangeType.DIRECT)
|
||||
|
||||
# Patch connect() so the reconnect path doesn't try to hit a real broker.
|
||||
async def _fake_connect():
|
||||
# After reconnect the channel must be valid again.
|
||||
client._channel = fake_channel
|
||||
return None
|
||||
|
||||
with patch.object(client, "connect", side_effect=_fake_connect):
|
||||
await client.publish_message(
|
||||
routing_key="rk", message="reconnects", exchange=exchange
|
||||
)
|
||||
# Two publish attempts: the failing one + the post-reconnect retry.
|
||||
assert fake_exchange.publish.await_count == 2
|
||||
|
||||
|
||||
# ---------- Dual-deploy: legacy classic + new quorum publisher in parallel ----------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dual_deploy_publishes_to_legacy_and_new_queues_in_parallel() -> None:
|
||||
"""Rolling-deploy window: old-image producer publishes to classic queue,
|
||||
new-image to `_v2` quorum queue — both must succeed independently."""
|
||||
legacy_client, _, legacy_exchange = _make_async_client()
|
||||
new_client, _, new_exchange = _make_async_client()
|
||||
|
||||
legacy_routing = "copilot.run" # legacy producers used the same routing key
|
||||
new_routing = COPILOT_EXECUTION_ROUTING_KEY
|
||||
|
||||
legacy_exch = Exchange(name="copilot_execution", type=ExchangeType.DIRECT)
|
||||
new_exch = Exchange(name=COPILOT_EXECUTION_EXCHANGE.name, type=ExchangeType.DIRECT)
|
||||
|
||||
# Interleave 10 publishes from each producer — order doesn't matter.
|
||||
for i in range(10):
|
||||
await legacy_client.publish_message(
|
||||
routing_key=legacy_routing, message=f"legacy-{i}", exchange=legacy_exch
|
||||
)
|
||||
await new_client.publish_message(
|
||||
routing_key=new_routing, message=f"new-{i}", exchange=new_exch
|
||||
)
|
||||
|
||||
assert legacy_exchange.publish.await_count == 10
|
||||
assert new_exchange.publish.await_count == 10
|
||||
|
||||
# Each publisher's routing key landed on its own exchange — no crosstalk.
|
||||
for call in legacy_exchange.publish.await_args_list:
|
||||
assert call.kwargs.get("routing_key") == legacy_routing
|
||||
for call in new_exchange.publish.await_args_list:
|
||||
assert call.kwargs.get("routing_key") == new_routing
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dual_deploy_legacy_failure_does_not_affect_new_queue() -> None:
|
||||
"""Legacy classic queue NACKing (AUTOGPT-SERVER-8ST) must not break
|
||||
publishes on the new `_v2` quorum queue."""
|
||||
legacy_publish = AsyncMock(
|
||||
side_effect=aio_pika.exceptions.DeliveryError(message=None, frame=None)
|
||||
)
|
||||
legacy_client, _, _ = _make_async_client(exchange_publish=legacy_publish)
|
||||
new_client, _, new_exchange = _make_async_client()
|
||||
|
||||
legacy_exch = Exchange(name="copilot_execution", type=ExchangeType.DIRECT)
|
||||
new_exch = Exchange(name=COPILOT_EXECUTION_EXCHANGE.name, type=ExchangeType.DIRECT)
|
||||
|
||||
# Legacy raises after retries — caller must catch it.
|
||||
with pytest.raises(aio_pika.exceptions.DeliveryError):
|
||||
await legacy_client.publish_message(
|
||||
routing_key="copilot.run", message="legacy-fail", exchange=legacy_exch
|
||||
)
|
||||
# New publisher continues to work — 5 successful publishes.
|
||||
for i in range(5):
|
||||
await new_client.publish_message(
|
||||
routing_key=COPILOT_EXECUTION_ROUTING_KEY,
|
||||
message=f"new-ok-{i}",
|
||||
exchange=new_exch,
|
||||
)
|
||||
assert new_exchange.publish.await_count == 5
|
||||
|
||||
|
||||
# ---------- Configuration sanity for downstream queues ----------
|
||||
|
||||
|
||||
def test_graph_execution_routing_key_constants() -> None:
|
||||
"""Routing key + exchange wiring must stay aligned — guards against the
|
||||
classic→quorum migration accidentally also changing the routing key."""
|
||||
cfg = create_execution_queue_config()
|
||||
run = next(q for q in cfg.queues if q.name == GRAPH_EXECUTION_QUEUE_NAME)
|
||||
assert run.routing_key == GRAPH_EXECUTION_ROUTING_KEY
|
||||
assert GRAPH_EXECUTION_EXCHANGE in cfg.exchanges
|
||||
@@ -1,85 +1,205 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from redis import Redis
|
||||
from redis.asyncio import Redis as AsyncRedis
|
||||
from redis.asyncio.cluster import ClusterNode as AsyncClusterNode
|
||||
from redis.asyncio.cluster import RedisCluster as AsyncRedisCluster
|
||||
from redis.cluster import ClusterNode, RedisCluster
|
||||
|
||||
from backend.util.cache import cached, thread_cached
|
||||
from backend.util.cache import cached
|
||||
from backend.util.retry import conn_retry
|
||||
|
||||
load_dotenv()
|
||||
|
||||
HOST = os.getenv("REDIS_HOST", "localhost")
|
||||
PORT = int(os.getenv("REDIS_PORT", "6379"))
|
||||
# Prefer the cluster env vars so the cluster-only image can co-exist with
|
||||
# old-image pods still reading REDIS_HOST during a rollout.
|
||||
HOST = os.getenv("REDIS_CLUSTER_HOST") or os.getenv("REDIS_HOST", "localhost")
|
||||
PORT = int(os.getenv("REDIS_CLUSTER_PORT") or os.getenv("REDIS_PORT", "6379"))
|
||||
PASSWORD = os.getenv("REDIS_PASSWORD", None)
|
||||
|
||||
# Default socket timeouts so a wedged Redis endpoint can't hang callers
|
||||
# indefinitely — long-running code paths (cluster_lock refresh in particular)
|
||||
# rely on these to fail-fast instead of blocking on no-response TCP. Override
|
||||
# via env if a specific deployment needs a different budget.
|
||||
#
|
||||
# 30s matches the convention in ``backend.data.rabbitmq`` and leaves ~6x
|
||||
# headroom over the largest ``xread(block=5000)`` wait in stream_registry.
|
||||
# The connect timeout is shorter (5s) because initial connects should be
|
||||
# fast; a slow connect usually means the endpoint is genuinely unreachable.
|
||||
# Fail-fast on a wedged endpoint instead of blocking on no-response TCP.
|
||||
SOCKET_TIMEOUT = float(os.getenv("REDIS_SOCKET_TIMEOUT", "30"))
|
||||
SOCKET_CONNECT_TIMEOUT = float(os.getenv("REDIS_SOCKET_CONNECT_TIMEOUT", "5"))
|
||||
# How often redis-py sends a PING on idle connections to detect half-open
|
||||
# sockets; cheap and avoids waiting for the OS TCP keepalive (~2h default).
|
||||
# PING on idle sockets to detect half-open connections without waiting for
|
||||
# the OS TCP keepalive (~2h default).
|
||||
HEALTH_CHECK_INTERVAL = int(os.getenv("REDIS_HEALTH_CHECK_INTERVAL", "30"))
|
||||
|
||||
# Skip the HOST-pinning remap when each shard's announced hostname resolves
|
||||
# directly (e.g. compose DNS names redis-0/redis-1/redis-2).
|
||||
USE_ANNOUNCED_ADDRESS = os.getenv("REDIS_USE_ANNOUNCED_ADDRESS", "").lower() in (
|
||||
"1",
|
||||
"true",
|
||||
"yes",
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Aliases so call-sites don't care which class this is.
|
||||
RedisClient = RedisCluster
|
||||
AsyncRedisClient = AsyncRedisCluster
|
||||
|
||||
|
||||
def _address_remap(addr: tuple[str, int]) -> tuple[str, int]:
|
||||
"""Pin each shard to the seed `HOST`, keep its announced port.
|
||||
|
||||
Set `REDIS_USE_ANNOUNCED_ADDRESS=true` when the announced shard FQDNs
|
||||
resolve directly (e.g. each pod has its own DNS).
|
||||
"""
|
||||
if USE_ANNOUNCED_ADDRESS:
|
||||
return addr
|
||||
_, port = addr
|
||||
return HOST, port
|
||||
|
||||
|
||||
@conn_retry("Redis", "Acquiring connection")
|
||||
def connect() -> Redis:
|
||||
c = Redis(
|
||||
host=HOST,
|
||||
port=PORT,
|
||||
def connect() -> RedisClient:
|
||||
c = RedisCluster(
|
||||
startup_nodes=[ClusterNode(HOST, PORT)],
|
||||
password=PASSWORD,
|
||||
decode_responses=True,
|
||||
socket_timeout=SOCKET_TIMEOUT,
|
||||
socket_connect_timeout=SOCKET_CONNECT_TIMEOUT,
|
||||
socket_keepalive=True,
|
||||
health_check_interval=HEALTH_CHECK_INTERVAL,
|
||||
address_remap=_address_remap,
|
||||
)
|
||||
c.ping()
|
||||
# Close on PING failure so retries don't leak ClusterNodes (AUTOGPT-SERVER-8T1).
|
||||
try:
|
||||
c.ping()
|
||||
except Exception:
|
||||
try:
|
||||
c.close()
|
||||
except Exception:
|
||||
pass
|
||||
raise
|
||||
return c
|
||||
|
||||
|
||||
@conn_retry("Redis", "Releasing connection")
|
||||
def disconnect():
|
||||
get_redis().close()
|
||||
get_redis.cache_clear()
|
||||
|
||||
|
||||
@cached(ttl_seconds=3600)
|
||||
def get_redis() -> Redis:
|
||||
def get_redis() -> RedisClient:
|
||||
return connect()
|
||||
|
||||
|
||||
@conn_retry("AsyncRedis", "Acquiring connection")
|
||||
async def connect_async() -> AsyncRedis:
|
||||
c = AsyncRedis(
|
||||
host=HOST,
|
||||
port=PORT,
|
||||
async def connect_async() -> AsyncRedisClient:
|
||||
c = AsyncRedisCluster(
|
||||
startup_nodes=[AsyncClusterNode(HOST, PORT)],
|
||||
password=PASSWORD,
|
||||
decode_responses=True,
|
||||
socket_timeout=SOCKET_TIMEOUT,
|
||||
socket_connect_timeout=SOCKET_CONNECT_TIMEOUT,
|
||||
socket_keepalive=True,
|
||||
health_check_interval=HEALTH_CHECK_INTERVAL,
|
||||
address_remap=_address_remap,
|
||||
)
|
||||
await c.ping()
|
||||
# Close on PING failure so retries don't leak ClusterNodes (AUTOGPT-SERVER-8V6/8V4/8V3).
|
||||
try:
|
||||
await c.ping()
|
||||
except Exception:
|
||||
try:
|
||||
await c.close()
|
||||
except Exception:
|
||||
pass
|
||||
raise
|
||||
return c
|
||||
|
||||
|
||||
# One AsyncRedisCluster per event loop: the client binds to the loop it was
|
||||
# first awaited on, so a module-level singleton breaks across test loops.
|
||||
_async_clients: dict[int, AsyncRedisCluster] = {}
|
||||
|
||||
|
||||
@conn_retry("AsyncRedis", "Releasing connection")
|
||||
async def disconnect_async():
|
||||
c = await get_redis_async()
|
||||
await c.close()
|
||||
loop = asyncio.get_running_loop()
|
||||
c = _async_clients.pop(id(loop), None)
|
||||
if c is not None:
|
||||
await c.close()
|
||||
|
||||
|
||||
@thread_cached
|
||||
async def get_redis_async() -> AsyncRedis:
|
||||
return await connect_async()
|
||||
async def get_redis_async() -> AsyncRedisClient:
|
||||
loop = asyncio.get_running_loop()
|
||||
client = _async_clients.get(id(loop))
|
||||
if client is None:
|
||||
client = await connect_async()
|
||||
_async_clients[id(loop)] = client
|
||||
return client
|
||||
|
||||
|
||||
# Sharded pub/sub only delivers on the keyslot-owning shard; subscribers
|
||||
# need a plain (Async)Redis connection pinned to that node.
|
||||
|
||||
|
||||
def resolve_shard_for_channel(channel: str) -> tuple[str, int]:
|
||||
"""Return the ``(host, port)`` of the shard that owns the channel's keyslot.
|
||||
|
||||
Applies the configured ``_address_remap`` so callers connect through the
|
||||
same address the cluster client uses.
|
||||
"""
|
||||
cluster = get_redis()
|
||||
node = cluster.get_node_from_key(channel)
|
||||
if node is None:
|
||||
raise RuntimeError(f"No cluster node owns the keyslot for channel {channel!r}")
|
||||
return _address_remap((node.host, node.port))
|
||||
|
||||
|
||||
@conn_retry("RedisShardedPubSub", "Acquiring connection")
|
||||
def connect_sharded_pubsub(channel: str) -> Redis:
|
||||
"""Open a plain ``Redis`` connection pinned to the channel's owning shard."""
|
||||
host, port = resolve_shard_for_channel(channel)
|
||||
# socket_timeout=None: pubsub reads block indefinitely; a spurious
|
||||
# read timeout forces a reconnect whose PING races with subscribe-mode.
|
||||
c = Redis(
|
||||
host=host,
|
||||
port=port,
|
||||
password=PASSWORD,
|
||||
decode_responses=True,
|
||||
socket_timeout=None,
|
||||
socket_connect_timeout=SOCKET_CONNECT_TIMEOUT,
|
||||
socket_keepalive=True,
|
||||
health_check_interval=HEALTH_CHECK_INTERVAL,
|
||||
)
|
||||
try:
|
||||
c.ping()
|
||||
except Exception:
|
||||
try:
|
||||
c.close()
|
||||
except Exception:
|
||||
pass
|
||||
raise
|
||||
return c
|
||||
|
||||
|
||||
@conn_retry("AsyncRedisShardedPubSub", "Acquiring connection")
|
||||
async def connect_sharded_pubsub_async(channel: str) -> AsyncRedis:
|
||||
"""Async variant of :func:`connect_sharded_pubsub`."""
|
||||
host, port = resolve_shard_for_channel(channel)
|
||||
# socket_timeout=None: see ``connect_sharded_pubsub``.
|
||||
c = AsyncRedis(
|
||||
host=host,
|
||||
port=port,
|
||||
password=PASSWORD,
|
||||
decode_responses=True,
|
||||
socket_timeout=None,
|
||||
socket_connect_timeout=SOCKET_CONNECT_TIMEOUT,
|
||||
socket_keepalive=True,
|
||||
health_check_interval=HEALTH_CHECK_INTERVAL,
|
||||
)
|
||||
try:
|
||||
await c.ping()
|
||||
except Exception:
|
||||
try:
|
||||
await c.close()
|
||||
except Exception:
|
||||
pass
|
||||
raise
|
||||
return c
|
||||
|
||||
599
autogpt_platform/backend/backend/data/redis_client_test.py
Normal file
599
autogpt_platform/backend/backend/data/redis_client_test.py
Normal file
@@ -0,0 +1,599 @@
|
||||
"""Unit tests for the cluster-only Redis client in ``redis_client``.
|
||||
|
||||
Patches the redis-py constructors + ``ping()`` so no real Redis is needed.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from redis.asyncio.cluster import RedisCluster as AsyncRedisCluster
|
||||
from redis.cluster import RedisCluster
|
||||
|
||||
import backend.data.redis_client as redis_client
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _reset_module_caches() -> None:
|
||||
"""Flush cached singletons between tests so each test sees a fresh connect."""
|
||||
redis_client.get_redis.cache_clear()
|
||||
redis_client._async_clients.clear()
|
||||
|
||||
|
||||
def test_connect_builds_redis_cluster() -> None:
|
||||
with patch.object(redis_client, "RedisCluster", autospec=True) as mock_cluster:
|
||||
mock_cluster.return_value = MagicMock(spec=RedisCluster)
|
||||
client = redis_client.connect()
|
||||
|
||||
mock_cluster.assert_called_once()
|
||||
kwargs = mock_cluster.call_args.kwargs
|
||||
assert kwargs["password"] == redis_client.PASSWORD
|
||||
assert kwargs["decode_responses"] is True
|
||||
assert kwargs["socket_timeout"] == redis_client.SOCKET_TIMEOUT
|
||||
assert kwargs["socket_connect_timeout"] == redis_client.SOCKET_CONNECT_TIMEOUT
|
||||
assert kwargs["socket_keepalive"] is True
|
||||
assert kwargs["health_check_interval"] == redis_client.HEALTH_CHECK_INTERVAL
|
||||
assert kwargs["address_remap"] is redis_client._address_remap
|
||||
startup = kwargs["startup_nodes"]
|
||||
assert len(startup) == 1
|
||||
# ClusterNode resolves "localhost" → "127.0.0.1" internally; both are
|
||||
# valid representations of the configured host.
|
||||
assert startup[0].host in {redis_client.HOST, "127.0.0.1"}
|
||||
assert startup[0].port == redis_client.PORT
|
||||
client.ping.assert_called_once()
|
||||
|
||||
|
||||
def test_address_remap_pins_host_and_preserves_port() -> None:
|
||||
"""Default remap rewrites announced shard host to the configured seed."""
|
||||
with patch.object(redis_client, "USE_ANNOUNCED_ADDRESS", False):
|
||||
assert redis_client._address_remap(("any-other-host", 6380)) == (
|
||||
redis_client.HOST,
|
||||
6380,
|
||||
)
|
||||
|
||||
|
||||
def test_address_remap_passthrough_when_use_announced_address() -> None:
|
||||
"""When announced addresses resolve directly, remap leaves them alone."""
|
||||
with patch.object(redis_client, "USE_ANNOUNCED_ADDRESS", True):
|
||||
assert redis_client._address_remap(("redis-1", 17001)) == ("redis-1", 17001)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect_async_builds_async_redis_cluster() -> None:
|
||||
with patch.object(redis_client, "AsyncRedisCluster", autospec=True) as mock_cluster:
|
||||
fake = MagicMock(spec=AsyncRedisCluster)
|
||||
fake.ping = AsyncMock()
|
||||
mock_cluster.return_value = fake
|
||||
client = await redis_client.connect_async()
|
||||
|
||||
mock_cluster.assert_called_once()
|
||||
kwargs = mock_cluster.call_args.kwargs
|
||||
assert kwargs["password"] == redis_client.PASSWORD
|
||||
assert kwargs["decode_responses"] is True
|
||||
assert kwargs["socket_timeout"] == redis_client.SOCKET_TIMEOUT
|
||||
assert kwargs["socket_connect_timeout"] == redis_client.SOCKET_CONNECT_TIMEOUT
|
||||
assert kwargs["socket_keepalive"] is True
|
||||
assert kwargs["health_check_interval"] == redis_client.HEALTH_CHECK_INTERVAL
|
||||
assert kwargs["address_remap"] is redis_client._address_remap
|
||||
startup = kwargs["startup_nodes"]
|
||||
assert len(startup) == 1
|
||||
assert startup[0].host in {redis_client.HOST, "127.0.0.1"}
|
||||
assert startup[0].port == redis_client.PORT
|
||||
client.ping.assert_awaited_once()
|
||||
|
||||
|
||||
def test_get_redis_caches_connect() -> None:
|
||||
with patch.object(redis_client, "connect", autospec=True) as mock_connect:
|
||||
mock_connect.return_value = MagicMock(spec=RedisCluster)
|
||||
client_a = redis_client.get_redis()
|
||||
client_b = redis_client.get_redis()
|
||||
|
||||
assert client_a is client_b
|
||||
mock_connect.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_redis_async_caches_connect() -> None:
|
||||
with patch.object(redis_client, "connect_async", autospec=True) as mock_conn:
|
||||
fake = MagicMock(spec=AsyncRedisCluster)
|
||||
mock_conn.return_value = fake
|
||||
a = await redis_client.get_redis_async()
|
||||
b = await redis_client.get_redis_async()
|
||||
|
||||
assert a is b
|
||||
mock_conn.assert_called_once()
|
||||
|
||||
|
||||
def test_disconnect_closes_cached_client() -> None:
|
||||
with patch.object(redis_client, "connect", autospec=True) as mock_connect:
|
||||
fake = MagicMock(spec=RedisCluster)
|
||||
mock_connect.return_value = fake
|
||||
redis_client.get_redis()
|
||||
redis_client.disconnect()
|
||||
|
||||
fake.close.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_disconnect_async_closes_cached_client() -> None:
|
||||
with patch.object(redis_client, "connect_async", autospec=True) as mock_connect:
|
||||
fake = MagicMock(spec=AsyncRedisCluster)
|
||||
fake.close = AsyncMock()
|
||||
mock_connect.return_value = fake
|
||||
await redis_client.get_redis_async()
|
||||
await redis_client.disconnect_async()
|
||||
|
||||
fake.close.assert_awaited_once()
|
||||
assert redis_client._async_clients == {}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_disconnect_async_no_cached_client_is_noop() -> None:
|
||||
with patch.object(redis_client, "connect_async", autospec=True) as mock_connect:
|
||||
await redis_client.disconnect_async()
|
||||
mock_connect.assert_not_called()
|
||||
|
||||
|
||||
# Sharded pub/sub end-to-end against the local 3-shard compose cluster.
|
||||
# Skipped when no cluster is reachable so CI without docker doesn't flap.
|
||||
|
||||
|
||||
def _has_live_cluster() -> bool:
|
||||
try:
|
||||
c = redis_client.connect()
|
||||
except Exception: # noqa: BLE001 — any connect failure → skip the test
|
||||
return False
|
||||
try:
|
||||
c.close()
|
||||
except Exception:
|
||||
pass
|
||||
return True
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not _has_live_cluster(),
|
||||
reason="local redis cluster not reachable; skip sharded pub/sub integration",
|
||||
)
|
||||
def test_sharded_pubsub_end_to_end_sync() -> None:
|
||||
"""SPUBLISH → SSUBSCRIBE round-trip via the sync cluster client. Uses
|
||||
per-node `get_message` because redis-py 6.x's
|
||||
`ClusterPubSub.get_sharded_message(ignore_subscribe_messages=True)`
|
||||
drops every message, not just the subscribe confirmation."""
|
||||
redis_client.get_redis.cache_clear()
|
||||
cluster = redis_client.get_redis()
|
||||
channel = "pr12900:sharded-pubsub:integration"
|
||||
ps = cluster.pubsub()
|
||||
try:
|
||||
ps.ssubscribe(channel)
|
||||
assert cluster.spublish(channel, "hello") >= 1
|
||||
|
||||
# Exactly one node is subscribed (the keyslot owner); read from it.
|
||||
assert len(ps.node_pubsub_mapping) == 1
|
||||
node_ps = next(iter(ps.node_pubsub_mapping.values()))
|
||||
# First message is the ssubscribe confirmation, second is our payload.
|
||||
confirm = node_ps.get_message(timeout=2.0)
|
||||
assert confirm is not None and confirm["type"] == "ssubscribe"
|
||||
received = node_ps.get_message(timeout=5.0)
|
||||
assert received is not None and received["type"] == "smessage"
|
||||
assert received["data"] == "hello"
|
||||
finally:
|
||||
try:
|
||||
ps.sunsubscribe(channel)
|
||||
except Exception:
|
||||
pass
|
||||
ps.close()
|
||||
redis_client.disconnect()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(
|
||||
not _has_live_cluster(),
|
||||
reason="local redis cluster not reachable; skip sharded pub/sub integration",
|
||||
)
|
||||
async def test_sharded_spublish_end_to_end_async() -> None:
|
||||
"""Async cluster client routes SPUBLISH via ``execute_command``
|
||||
because redis-py 6.x has no async ``spublish()`` wrapper."""
|
||||
redis_client._async_clients.clear()
|
||||
cluster = await redis_client.get_redis_async()
|
||||
try:
|
||||
res = await cluster.execute_command(
|
||||
"SPUBLISH", "pr12900:sharded-pubsub:async", "ping"
|
||||
)
|
||||
# No subscribers — delivered count is 0, but the command must succeed
|
||||
# (i.e. not raise MOVED/ASK or routing errors).
|
||||
assert isinstance(res, int)
|
||||
finally:
|
||||
await redis_client.disconnect_async()
|
||||
|
||||
|
||||
# ---------- Sharded pub/sub: unit tests with mocks ----------
|
||||
|
||||
|
||||
def test_connect_sharded_pubsub_pins_host_and_disables_socket_timeout() -> None:
|
||||
"""`socket_timeout=None` on the pubsub socket: a spurious read timeout
|
||||
forces a reconnect whose PING races with subscribe-mode."""
|
||||
with (
|
||||
patch.object(
|
||||
redis_client,
|
||||
"resolve_shard_for_channel",
|
||||
return_value=("shard-host", 7001),
|
||||
),
|
||||
patch.object(redis_client, "Redis", autospec=True) as mock_redis,
|
||||
):
|
||||
fake_client = MagicMock()
|
||||
mock_redis.return_value = fake_client
|
||||
client = redis_client.connect_sharded_pubsub("chan")
|
||||
|
||||
mock_redis.assert_called_once()
|
||||
kwargs = mock_redis.call_args.kwargs
|
||||
# Pinned to the shard's remapped address.
|
||||
assert kwargs["host"] == "shard-host"
|
||||
assert kwargs["port"] == 7001
|
||||
# socket_timeout MUST be None for pubsub — see docstring in redis_client.py.
|
||||
assert kwargs["socket_timeout"] is None
|
||||
# Idle keepalive + health-check still intact.
|
||||
assert kwargs["socket_keepalive"] is True
|
||||
assert kwargs["health_check_interval"] == redis_client.HEALTH_CHECK_INTERVAL
|
||||
# connect() must PING before returning.
|
||||
client.ping.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect_sharded_pubsub_async_disables_socket_timeout() -> None:
|
||||
"""Async sibling of ``test_connect_sharded_pubsub_pins_host...``. Same
|
||||
invariant: socket_timeout=None."""
|
||||
with (
|
||||
patch.object(
|
||||
redis_client,
|
||||
"resolve_shard_for_channel",
|
||||
return_value=("shard-host", 7001),
|
||||
),
|
||||
patch.object(redis_client, "AsyncRedis", autospec=True) as mock_redis,
|
||||
):
|
||||
fake_client = MagicMock()
|
||||
fake_client.ping = AsyncMock()
|
||||
mock_redis.return_value = fake_client
|
||||
client = await redis_client.connect_sharded_pubsub_async("chan")
|
||||
|
||||
kwargs = mock_redis.call_args.kwargs
|
||||
assert kwargs["host"] == "shard-host"
|
||||
assert kwargs["port"] == 7001
|
||||
assert kwargs["socket_timeout"] is None
|
||||
assert kwargs["socket_keepalive"] is True
|
||||
assert kwargs["health_check_interval"] == redis_client.HEALTH_CHECK_INTERVAL
|
||||
client.ping.assert_awaited_once()
|
||||
|
||||
|
||||
def test_resolve_shard_for_channel_applies_address_remap() -> None:
|
||||
"""The resolver must run ``_address_remap`` on the announced address so
|
||||
callers connect through the same address the cluster client uses."""
|
||||
cluster = MagicMock()
|
||||
node = MagicMock()
|
||||
node.host = "announced-host"
|
||||
node.port = 17001
|
||||
cluster.get_node_from_key.return_value = node
|
||||
|
||||
with (
|
||||
patch.object(redis_client, "get_redis", return_value=cluster),
|
||||
patch.object(redis_client, "USE_ANNOUNCED_ADDRESS", False),
|
||||
):
|
||||
host, port = redis_client.resolve_shard_for_channel("chan")
|
||||
|
||||
# Remap pins the host to the seed, keeps the announced port.
|
||||
assert host == redis_client.HOST
|
||||
assert port == 17001
|
||||
|
||||
|
||||
def test_resolve_shard_for_channel_raises_when_no_node_owns_keyslot() -> None:
|
||||
"""Missing cluster node → explicit RuntimeError, not a silent None deref."""
|
||||
cluster = MagicMock()
|
||||
cluster.get_node_from_key.return_value = None
|
||||
|
||||
with patch.object(redis_client, "get_redis", return_value=cluster):
|
||||
with pytest.raises(RuntimeError, match="No cluster node"):
|
||||
redis_client.resolve_shard_for_channel("chan")
|
||||
|
||||
|
||||
def test_resolve_shard_for_channel_passthrough_with_announced_flag() -> None:
|
||||
"""When REDIS_USE_ANNOUNCED_ADDRESS is on, resolver returns the announced
|
||||
address verbatim — no HOST override."""
|
||||
cluster = MagicMock()
|
||||
node = MagicMock()
|
||||
node.host = "redis-2"
|
||||
node.port = 17002
|
||||
cluster.get_node_from_key.return_value = node
|
||||
|
||||
with (
|
||||
patch.object(redis_client, "get_redis", return_value=cluster),
|
||||
patch.object(redis_client, "USE_ANNOUNCED_ADDRESS", True),
|
||||
):
|
||||
host, port = redis_client.resolve_shard_for_channel("chan")
|
||||
|
||||
assert (host, port) == ("redis-2", 17002)
|
||||
|
||||
|
||||
def test_health_check_interval_is_30s_default() -> None:
|
||||
"""Idle PING interval must be <=30s so half-open pubsub sockets don't
|
||||
wait for the OS TCP keepalive (~2h)."""
|
||||
assert redis_client.HEALTH_CHECK_INTERVAL <= 30
|
||||
|
||||
|
||||
def test_connect_sets_health_check_interval() -> None:
|
||||
"""The cluster client must propagate health_check_interval to each node
|
||||
pool — otherwise idle cluster sockets go stale."""
|
||||
with patch.object(redis_client, "RedisCluster", autospec=True) as mock_cluster:
|
||||
mock_cluster.return_value = MagicMock(spec=RedisCluster)
|
||||
redis_client.connect()
|
||||
kwargs = mock_cluster.call_args.kwargs
|
||||
assert kwargs["health_check_interval"] == redis_client.HEALTH_CHECK_INTERVAL
|
||||
assert kwargs["health_check_interval"] > 0
|
||||
|
||||
|
||||
# ---------- K8s same-port shard collapse regression (AUTOGPT-SERVER-8SX) ----------
|
||||
|
||||
|
||||
def test_k8s_shard_collapse_with_announced_address_off_routes_all_to_seed() -> None:
|
||||
"""In K8s every shard serves on port 6379 behind the seed service, so the
|
||||
default `_address_remap` collapses all shards to `(HOST, 6379)` — the
|
||||
AUTOGPT-SERVER-8SX bug. Fix: `REDIS_USE_ANNOUNCED_ADDRESS=true`."""
|
||||
cluster = MagicMock()
|
||||
# 3 shards, each owning a distinct hash slot, but every pod serves on
|
||||
# 6379 in K8s — exactly the production topology.
|
||||
nodes_by_channel = {
|
||||
"{ch-a}/x": MagicMock(host="redis-cluster-redis-0", port=6379),
|
||||
"{ch-b}/y": MagicMock(host="redis-cluster-redis-1", port=6379),
|
||||
"{ch-c}/z": MagicMock(host="redis-cluster-redis-2", port=6379),
|
||||
}
|
||||
cluster.get_node_from_key.side_effect = lambda c: nodes_by_channel[c]
|
||||
|
||||
with (
|
||||
patch.object(redis_client, "get_redis", return_value=cluster),
|
||||
patch.object(redis_client, "USE_ANNOUNCED_ADDRESS", False),
|
||||
patch.object(redis_client, "HOST", "redis-dev-seed"),
|
||||
):
|
||||
endpoints = {
|
||||
channel: redis_client.resolve_shard_for_channel(channel)
|
||||
for channel in nodes_by_channel
|
||||
}
|
||||
|
||||
# The bug: every shard resolves to the same seed:port endpoint.
|
||||
assert len(set(endpoints.values())) == 1, (
|
||||
f"Expected the K8s shard-collapse bug, got {endpoints!r}. "
|
||||
"If this test is failing it means _address_remap behaviour changed "
|
||||
"and the AUTOGPT-SERVER-8SX regression note in this file needs review."
|
||||
)
|
||||
assert all(ep == ("redis-dev-seed", 6379) for ep in endpoints.values())
|
||||
|
||||
|
||||
def test_k8s_shard_collapse_fixed_with_announced_address_on() -> None:
|
||||
"""With `REDIS_USE_ANNOUNCED_ADDRESS=true`, each shard's announced FQDN
|
||||
passes through, so distinct slots resolve to distinct endpoints."""
|
||||
cluster = MagicMock()
|
||||
nodes_by_channel = {
|
||||
"{ch-a}/x": MagicMock(host="redis-cluster-redis-0", port=6379),
|
||||
"{ch-b}/y": MagicMock(host="redis-cluster-redis-1", port=6379),
|
||||
"{ch-c}/z": MagicMock(host="redis-cluster-redis-2", port=6379),
|
||||
}
|
||||
cluster.get_node_from_key.side_effect = lambda c: nodes_by_channel[c]
|
||||
|
||||
with (
|
||||
patch.object(redis_client, "get_redis", return_value=cluster),
|
||||
patch.object(redis_client, "USE_ANNOUNCED_ADDRESS", True),
|
||||
patch.object(redis_client, "HOST", "redis-dev-seed"),
|
||||
):
|
||||
endpoints = {
|
||||
channel: redis_client.resolve_shard_for_channel(channel)
|
||||
for channel in nodes_by_channel
|
||||
}
|
||||
|
||||
# Each shard maps to a distinct endpoint — sharded pubsub can route
|
||||
# SSUBSCRIBE to the slot owner.
|
||||
assert len(set(endpoints.values())) == 3
|
||||
assert endpoints["{ch-a}/x"] == ("redis-cluster-redis-0", 6379)
|
||||
assert endpoints["{ch-b}/y"] == ("redis-cluster-redis-1", 6379)
|
||||
assert endpoints["{ch-c}/z"] == ("redis-cluster-redis-2", 6379)
|
||||
|
||||
|
||||
def test_local_compose_remap_keeps_distinct_ports_per_shard() -> None:
|
||||
"""Local docker-compose announces distinct ports per shard, so the
|
||||
`(host, port)` tuple stays distinct even with `HOST` pinned to seed."""
|
||||
cluster = MagicMock()
|
||||
nodes_by_channel = {
|
||||
"{ch-a}/x": MagicMock(host="redis-0", port=17000),
|
||||
"{ch-b}/y": MagicMock(host="redis-1", port=17001),
|
||||
"{ch-c}/z": MagicMock(host="redis-2", port=17002),
|
||||
}
|
||||
cluster.get_node_from_key.side_effect = lambda c: nodes_by_channel[c]
|
||||
|
||||
with (
|
||||
patch.object(redis_client, "get_redis", return_value=cluster),
|
||||
patch.object(redis_client, "USE_ANNOUNCED_ADDRESS", False),
|
||||
patch.object(redis_client, "HOST", "localhost"),
|
||||
):
|
||||
endpoints = {
|
||||
channel: redis_client.resolve_shard_for_channel(channel)
|
||||
for channel in nodes_by_channel
|
||||
}
|
||||
|
||||
# Distinct ports → distinct endpoints even after remap pins the host.
|
||||
assert len(set(endpoints.values())) == 3
|
||||
assert endpoints["{ch-a}/x"] == ("localhost", 17000)
|
||||
assert endpoints["{ch-b}/y"] == ("localhost", 17001)
|
||||
assert endpoints["{ch-c}/z"] == ("localhost", 17002)
|
||||
|
||||
|
||||
# ---------- Sharded pub/sub: multi-shard integration on the live cluster ----------
|
||||
|
||||
|
||||
def _channel_owner(channel: str) -> tuple[str, int]:
|
||||
"""Resolve the slot owner for ``channel`` via the live cluster client."""
|
||||
cluster = redis_client.get_redis()
|
||||
node = cluster.get_node_from_key(channel)
|
||||
assert node is not None, f"no slot owner for {channel!r}"
|
||||
return node.host, node.port
|
||||
|
||||
|
||||
def _channels_on_distinct_shards(n: int = 3) -> list[str]:
|
||||
"""Build N hash-tagged channels each mapping to a distinct shard."""
|
||||
seen: dict[tuple[str, int], str] = {}
|
||||
for tag_id in range(2000):
|
||||
chan = "{u" + str(tag_id) + "/g}/exec/e"
|
||||
owner = _channel_owner(chan)
|
||||
seen.setdefault(owner, chan)
|
||||
if len(seen) >= n:
|
||||
break
|
||||
assert len(seen) >= n, f"could only cover {len(seen)} shards"
|
||||
return list(seen.values())[:n]
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not _has_live_cluster(),
|
||||
reason="local redis cluster not reachable; skip multi-shard integration",
|
||||
)
|
||||
def test_resolve_shard_for_channel_lands_on_distinct_shards() -> None:
|
||||
"""3 hash-tagged channels resolve to 3 different shards (slot-distribution)."""
|
||||
redis_client.get_redis.cache_clear()
|
||||
try:
|
||||
channels = _channels_on_distinct_shards(3)
|
||||
endpoints = {ch: redis_client.resolve_shard_for_channel(ch) for ch in channels}
|
||||
# Three channels → three distinct (host, port) endpoints.
|
||||
assert len(set(endpoints.values())) == 3
|
||||
finally:
|
||||
redis_client.disconnect()
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not _has_live_cluster(),
|
||||
reason="local redis cluster not reachable; skip multi-shard integration",
|
||||
)
|
||||
def test_sharded_pubsub_concurrent_subscribers_on_three_shards() -> None:
|
||||
"""SSUBSCRIBE on three channels owned by three different shards, then
|
||||
SPUBLISH to each — every payload must land on its subscriber."""
|
||||
redis_client.get_redis.cache_clear()
|
||||
cluster = redis_client.get_redis()
|
||||
try:
|
||||
channels = _channels_on_distinct_shards(3)
|
||||
# Subscribe via the cluster client so redis-py's per-node pubsub
|
||||
# mapping handles the sharded routing for us.
|
||||
ps = cluster.pubsub()
|
||||
try:
|
||||
for ch in channels:
|
||||
ps.ssubscribe(ch)
|
||||
# The cluster client opens one node-pubsub per shard owner — three
|
||||
# channels on three shards must produce three distinct node clients.
|
||||
assert len(ps.node_pubsub_mapping) == 3, (
|
||||
"Expected SSUBSCRIBE on 3 channels owned by 3 distinct shards "
|
||||
f"to open 3 node-pubsubs, got {len(ps.node_pubsub_mapping)}"
|
||||
)
|
||||
# Publish to each channel and verify each reaches the right node.
|
||||
for i, ch in enumerate(channels):
|
||||
assert cluster.spublish(ch, f"payload-{i}") >= 1
|
||||
# Drain ssubscribe confirmations + smessages from every node.
|
||||
received: dict[str, str] = {}
|
||||
for node_ps in ps.node_pubsub_mapping.values():
|
||||
# First message per node is the ssubscribe confirm; subsequent
|
||||
# smessages carry the test payloads.
|
||||
for _ in range(4): # confirm + at most 1 payload per shard
|
||||
msg = node_ps.get_message(timeout=2.0)
|
||||
if msg is None:
|
||||
break
|
||||
if msg["type"] == "smessage":
|
||||
received[msg["channel"]] = msg["data"]
|
||||
for i, ch in enumerate(channels):
|
||||
assert ch in received, f"channel {ch!r} got no message"
|
||||
assert received[ch] == f"payload-{i}"
|
||||
finally:
|
||||
for ch in channels:
|
||||
try:
|
||||
ps.sunsubscribe(ch)
|
||||
except Exception:
|
||||
pass
|
||||
ps.close()
|
||||
finally:
|
||||
redis_client.disconnect()
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not _has_live_cluster(),
|
||||
reason="local redis cluster not reachable; skip multi-shard integration",
|
||||
)
|
||||
def test_sharded_pubsub_idle_subscriber_survives_health_check_window() -> None:
|
||||
"""An SSUBSCRIBE connection must survive an idle window longer than
|
||||
`HEALTH_CHECK_INTERVAL` — uses `+5s` to provoke at least one health check."""
|
||||
import time as _time
|
||||
|
||||
redis_client.get_redis.cache_clear()
|
||||
cluster = redis_client.get_redis()
|
||||
channel = "{idle-test}/exec/e"
|
||||
client = redis_client.connect_sharded_pubsub(channel)
|
||||
ps = client.pubsub()
|
||||
try:
|
||||
ps.ssubscribe(channel)
|
||||
confirm = ps.get_message(timeout=2.0)
|
||||
assert confirm is not None and confirm["type"] == "ssubscribe"
|
||||
|
||||
# Idle window — must exceed health_check_interval at least once.
|
||||
idle_seconds = redis_client.HEALTH_CHECK_INTERVAL + 5
|
||||
_time.sleep(idle_seconds)
|
||||
|
||||
# After idling, publish + receive should still work.
|
||||
assert cluster.spublish(channel, "post-idle") >= 1
|
||||
msg = ps.get_message(timeout=5.0)
|
||||
assert msg is not None and msg["type"] == "smessage"
|
||||
assert msg["data"] == "post-idle"
|
||||
finally:
|
||||
try:
|
||||
ps.sunsubscribe(channel)
|
||||
except Exception:
|
||||
pass
|
||||
ps.close()
|
||||
client.close()
|
||||
redis_client.disconnect()
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not _has_live_cluster(),
|
||||
reason="local redis cluster not reachable; skip multi-shard integration",
|
||||
)
|
||||
def test_sharded_pubsub_reconnect_after_forced_disconnect() -> None:
|
||||
"""Subscriber reconnect after a forced disconnect — close socket, open
|
||||
a fresh one, and verify new SPUBLISH events still arrive."""
|
||||
redis_client.get_redis.cache_clear()
|
||||
cluster = redis_client.get_redis()
|
||||
channel = "{reconnect-test}/exec/e"
|
||||
|
||||
# Round 1: subscribe, receive one payload, then close everything.
|
||||
client = redis_client.connect_sharded_pubsub(channel)
|
||||
ps = client.pubsub()
|
||||
try:
|
||||
ps.ssubscribe(channel)
|
||||
ps.get_message(timeout=2.0) # ssubscribe confirmation
|
||||
assert cluster.spublish(channel, "before-restart") >= 1
|
||||
msg = ps.get_message(timeout=5.0)
|
||||
assert msg is not None and msg["data"] == "before-restart"
|
||||
finally:
|
||||
try:
|
||||
ps.sunsubscribe(channel)
|
||||
except Exception:
|
||||
pass
|
||||
ps.close()
|
||||
client.close()
|
||||
|
||||
# Round 2: a fresh subscriber on the same channel — same routing,
|
||||
# different socket. This exercises the reconnect-and-resubscribe path
|
||||
# the conn_manager runs after a network blip.
|
||||
client2 = redis_client.connect_sharded_pubsub(channel)
|
||||
ps2 = client2.pubsub()
|
||||
try:
|
||||
ps2.ssubscribe(channel)
|
||||
ps2.get_message(timeout=2.0)
|
||||
assert cluster.spublish(channel, "after-restart") >= 1
|
||||
msg = ps2.get_message(timeout=5.0)
|
||||
assert msg is not None and msg["data"] == "after-restart"
|
||||
finally:
|
||||
try:
|
||||
ps2.sunsubscribe(channel)
|
||||
except Exception:
|
||||
pass
|
||||
ps2.close()
|
||||
client2.close()
|
||||
redis_client.disconnect()
|
||||
@@ -22,8 +22,7 @@ this module can cover.
|
||||
|
||||
from typing import Any, cast
|
||||
|
||||
from redis import Redis
|
||||
from redis.asyncio import Redis as AsyncRedis
|
||||
from backend.data.redis_client import AsyncRedisClient, RedisClient
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Lua scripts — registered centrally so there is exactly ONE authoritative
|
||||
@@ -47,9 +46,30 @@ end
|
||||
return 0
|
||||
"""
|
||||
|
||||
# Push to a capped list only when a hash field currently matches the expected
|
||||
# value. Returns the new list length, or -1 when the guard fails.
|
||||
#
|
||||
# KEYS[1] hash key
|
||||
# KEYS[2] list key
|
||||
# ARGV[1] hash field
|
||||
# ARGV[2] expected current value
|
||||
# ARGV[3] list value
|
||||
# ARGV[4] max list length
|
||||
# ARGV[5] list TTL seconds
|
||||
_GATED_CAPPED_RPUSH_LUA = """
|
||||
local current = redis.call('HGET', KEYS[1], ARGV[1])
|
||||
if current ~= ARGV[2] then
|
||||
return -1
|
||||
end
|
||||
redis.call('RPUSH', KEYS[2], ARGV[3])
|
||||
redis.call('LTRIM', KEYS[2], -tonumber(ARGV[4]), -1)
|
||||
redis.call('EXPIRE', KEYS[2], tonumber(ARGV[5]))
|
||||
return redis.call('LLEN', KEYS[2])
|
||||
"""
|
||||
|
||||
|
||||
async def incr_with_ttl(
|
||||
redis: AsyncRedis,
|
||||
redis: AsyncRedisClient,
|
||||
key: str,
|
||||
ttl_seconds: int,
|
||||
*,
|
||||
@@ -85,7 +105,7 @@ async def incr_with_ttl(
|
||||
|
||||
|
||||
def incr_with_ttl_sync(
|
||||
redis: Redis,
|
||||
redis: RedisClient,
|
||||
key: str,
|
||||
ttl_seconds: int,
|
||||
*,
|
||||
@@ -103,7 +123,7 @@ def incr_with_ttl_sync(
|
||||
|
||||
|
||||
async def capped_rpush(
|
||||
redis: AsyncRedis,
|
||||
redis: AsyncRedisClient,
|
||||
key: str,
|
||||
value: str,
|
||||
*,
|
||||
@@ -129,8 +149,42 @@ async def capped_rpush(
|
||||
return int(results[-1])
|
||||
|
||||
|
||||
async def capped_rpush_if_hash_field(
|
||||
redis: AsyncRedisClient,
|
||||
*,
|
||||
hash_key: str,
|
||||
hash_field: str,
|
||||
expected: str,
|
||||
list_key: str,
|
||||
value: str,
|
||||
max_len: int,
|
||||
ttl_seconds: int,
|
||||
) -> int | None:
|
||||
"""Atomically RPUSH to a bounded list iff a hash field matches.
|
||||
|
||||
Returns the new list length when the push happens, or ``None`` when the
|
||||
hash field does not currently match ``expected``.
|
||||
"""
|
||||
result = await cast(
|
||||
"Any",
|
||||
redis.eval(
|
||||
_GATED_CAPPED_RPUSH_LUA,
|
||||
2,
|
||||
hash_key,
|
||||
list_key,
|
||||
hash_field,
|
||||
expected,
|
||||
value,
|
||||
str(max_len),
|
||||
str(ttl_seconds),
|
||||
),
|
||||
)
|
||||
length = int(result)
|
||||
return None if length < 0 else length
|
||||
|
||||
|
||||
async def hash_compare_and_set(
|
||||
redis: AsyncRedis,
|
||||
redis: AsyncRedisClient,
|
||||
key: str,
|
||||
field: str,
|
||||
*,
|
||||
|
||||
@@ -11,6 +11,7 @@ import pytest
|
||||
|
||||
from backend.data.redis_helpers import (
|
||||
capped_rpush,
|
||||
capped_rpush_if_hash_field,
|
||||
hash_compare_and_set,
|
||||
incr_with_ttl,
|
||||
incr_with_ttl_sync,
|
||||
@@ -56,7 +57,17 @@ class _Fake:
|
||||
return len(self.lists.get(key, []))
|
||||
|
||||
async def eval(self, script: str, numkeys: int, *args: Any) -> int:
|
||||
# Shim for hash-CAS only.
|
||||
if numkeys == 2:
|
||||
hash_key, list_key = args[0], args[1]
|
||||
field, expected, value, max_len, ttl_seconds = args[2:7]
|
||||
h = self.hashes.setdefault(hash_key, {})
|
||||
if h.get(field) != expected:
|
||||
return -1
|
||||
await self.rpush(list_key, value)
|
||||
await self.ltrim(list_key, -int(max_len), -1)
|
||||
await self.expire(list_key, int(ttl_seconds))
|
||||
return await self.llen(list_key)
|
||||
|
||||
key, field, expected, new = args[0], args[1], args[2], args[3]
|
||||
h = self.hashes.setdefault(key, {})
|
||||
if h.get(field) == expected:
|
||||
@@ -198,6 +209,50 @@ async def test_capped_rpush_first_push_returns_one() -> None:
|
||||
assert r.lists["buf"] == ["only"]
|
||||
|
||||
|
||||
# ── capped_rpush_if_hash_field ────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_capped_rpush_if_hash_field_pushes_when_expected_matches() -> None:
|
||||
r = _Fake()
|
||||
r.hashes["meta"] = {"status": "running"}
|
||||
|
||||
length = await capped_rpush_if_hash_field(
|
||||
r, # type: ignore[arg-type]
|
||||
hash_key="meta",
|
||||
hash_field="status",
|
||||
expected="running",
|
||||
list_key="buf",
|
||||
value="only",
|
||||
max_len=10,
|
||||
ttl_seconds=60,
|
||||
)
|
||||
|
||||
assert length == 1
|
||||
assert r.lists["buf"] == ["only"]
|
||||
assert r.ttls["buf"] == 60
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_capped_rpush_if_hash_field_skips_when_expected_differs() -> None:
|
||||
r = _Fake()
|
||||
r.hashes["meta"] = {"status": "completed"}
|
||||
|
||||
length = await capped_rpush_if_hash_field(
|
||||
r, # type: ignore[arg-type]
|
||||
hash_key="meta",
|
||||
hash_field="status",
|
||||
expected="running",
|
||||
list_key="buf",
|
||||
value="lost",
|
||||
max_len=10,
|
||||
ttl_seconds=60,
|
||||
)
|
||||
|
||||
assert length is None
|
||||
assert "buf" not in r.lists
|
||||
|
||||
|
||||
# ── hash_compare_and_set ───────────────────────────────────────────────
|
||||
|
||||
|
||||
|
||||
@@ -28,7 +28,7 @@ logger = logging.getLogger(__name__)
|
||||
settings = Settings()
|
||||
|
||||
# Cache decorator alias for consistent user lookup caching
|
||||
cache_user_lookup = cached(maxsize=1000, ttl_seconds=300)
|
||||
cache_user_lookup = cached(maxsize=1000, ttl_seconds=300, shared_cache=True)
|
||||
|
||||
|
||||
@cache_user_lookup
|
||||
@@ -509,8 +509,15 @@ async def update_user_timezone(user_id: str, timezone: str) -> User:
|
||||
if not user:
|
||||
raise ValueError(f"User not found with ID: {user_id}")
|
||||
|
||||
# Invalidate cache for this user
|
||||
# Invalidate user caches so subsequent reads see the new timezone.
|
||||
# get_user_by_id and get_user_by_email are keyed by a single value
|
||||
# and can be deleted surgically; get_or_create_user is keyed by the
|
||||
# JWT-payload dict so we can't delete a single entry — clear it
|
||||
# entirely.
|
||||
get_user_by_id.cache_delete(user_id)
|
||||
if user.email:
|
||||
get_user_by_email.cache_delete(user.email)
|
||||
get_or_create_user.cache_clear()
|
||||
|
||||
return User.from_db(user)
|
||||
except Exception as e:
|
||||
|
||||
66
autogpt_platform/backend/backend/data/user_test.py
Normal file
66
autogpt_platform/backend/backend/data/user_test.py
Normal file
@@ -0,0 +1,66 @@
|
||||
"""Unit tests for helpers in backend.data.user."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.data import user as user_module
|
||||
from backend.data.user import update_user_timezone
|
||||
from backend.util.exceptions import DatabaseError
|
||||
|
||||
|
||||
class TestUpdateUserTimezone:
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalidates_all_three_user_caches(self):
|
||||
prisma_user = MagicMock(id="user-1", email="user@example.com")
|
||||
sentinel_user = MagicMock()
|
||||
|
||||
with (
|
||||
patch.object(user_module, "PrismaUser") as mock_prisma_user,
|
||||
patch.object(user_module.User, "from_db", return_value=sentinel_user),
|
||||
patch.object(user_module.get_user_by_id, "cache_delete") as by_id_del,
|
||||
patch.object(user_module.get_user_by_email, "cache_delete") as by_email_del,
|
||||
patch.object(user_module.get_or_create_user, "cache_clear") as goc_clear,
|
||||
):
|
||||
mock_prisma_user.prisma.return_value.update = AsyncMock(
|
||||
return_value=prisma_user
|
||||
)
|
||||
result = await update_user_timezone("user-1", "Europe/London")
|
||||
|
||||
assert result is sentinel_user
|
||||
by_id_del.assert_called_once_with("user-1")
|
||||
by_email_del.assert_called_once_with("user@example.com")
|
||||
goc_clear.assert_called_once_with()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_email_cache_invalidation_when_email_missing(self):
|
||||
prisma_user = MagicMock(id="user-1", email=None)
|
||||
sentinel_user = MagicMock()
|
||||
|
||||
with (
|
||||
patch.object(user_module, "PrismaUser") as mock_prisma_user,
|
||||
patch.object(user_module.User, "from_db", return_value=sentinel_user),
|
||||
patch.object(user_module.get_user_by_id, "cache_delete") as by_id_del,
|
||||
patch.object(user_module.get_user_by_email, "cache_delete") as by_email_del,
|
||||
patch.object(user_module.get_or_create_user, "cache_clear") as goc_clear,
|
||||
):
|
||||
mock_prisma_user.prisma.return_value.update = AsyncMock(
|
||||
return_value=prisma_user
|
||||
)
|
||||
await update_user_timezone("user-1", "Europe/London")
|
||||
|
||||
by_id_del.assert_called_once_with("user-1")
|
||||
by_email_del.assert_not_called()
|
||||
goc_clear.assert_called_once_with()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_wraps_prisma_errors_in_database_error(self):
|
||||
with patch.object(user_module, "PrismaUser") as mock_prisma_user:
|
||||
mock_prisma_user.prisma.return_value.update = AsyncMock(
|
||||
side_effect=RuntimeError("connection lost")
|
||||
)
|
||||
with pytest.raises(DatabaseError) as exc:
|
||||
await update_user_timezone("user-1", "Europe/London")
|
||||
|
||||
assert "user-1" in str(exc.value)
|
||||
assert "connection lost" in str(exc.value)
|
||||
@@ -64,9 +64,12 @@ async def clear_insufficient_funds_notifications(user_id: str) -> int:
|
||||
redis_client = await redis.get_redis_async()
|
||||
pattern = f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:*"
|
||||
keys = [key async for key in redis_client.scan_iter(match=pattern)]
|
||||
if keys:
|
||||
return await redis_client.delete(*keys)
|
||||
return 0
|
||||
# Keys here span multiple graph IDs and therefore multiple cluster
|
||||
# slots — a bulk DELETE would raise CROSSSLOT, so delete per key.
|
||||
deleted = 0
|
||||
for key in keys:
|
||||
deleted += await redis_client.delete(key)
|
||||
return deleted
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to clear insufficient funds notification flags for user "
|
||||
|
||||
@@ -7,16 +7,12 @@ import time
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from redis import Redis
|
||||
from redis.asyncio import Redis as AsyncRedis
|
||||
from backend.data.redis_client import AsyncRedisClient, RedisClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Lua CAS release: only delete the key if the stored value still matches our
|
||||
# owner_id. Returns 1 on delete, 0 on no-op. This makes release() safe against
|
||||
# the race where an external caller (e.g. mark_session_completed's force-release)
|
||||
# deletes our key and a new owner acquires it before our release() fires — without
|
||||
# the CAS guard, release() would wipe the successor's valid lock.
|
||||
# CAS release: DEL only when the stored owner still matches — guards against
|
||||
# wiping a successor's lock after an external force-release.
|
||||
_RELEASE_LUA = (
|
||||
"if redis.call('get', KEYS[1]) == ARGV[1] then "
|
||||
"return redis.call('del', KEYS[1]) "
|
||||
@@ -27,7 +23,9 @@ _RELEASE_LUA = (
|
||||
class ClusterLock:
|
||||
"""Simple Redis-based distributed lock for preventing duplicate execution."""
|
||||
|
||||
def __init__(self, redis: "Redis", key: str, owner_id: str, timeout: int = 300):
|
||||
def __init__(
|
||||
self, redis: "RedisClient", key: str, owner_id: str, timeout: int = 300
|
||||
):
|
||||
self.redis = redis
|
||||
self.key = key
|
||||
self.owner_id = owner_id
|
||||
@@ -150,7 +148,7 @@ class AsyncClusterLock:
|
||||
"""Async Redis-based distributed lock for preventing duplicate execution."""
|
||||
|
||||
def __init__(
|
||||
self, redis: "AsyncRedis", key: str, owner_id: str, timeout: int = 300
|
||||
self, redis: "AsyncRedisClient", key: str, owner_id: str, timeout: int = 300
|
||||
):
|
||||
self.redis = redis
|
||||
self.key = key
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user