mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
fix(backend): remove per-worker Prisma connect, route DB via DatabaseManagerAsyncClient
- Remove local import + db.connect()/disconnect() from CoPilotProcessor.on_executor_start DB calls already route through db_accessors (chat_db, user_db) which fall back to DatabaseManagerAsyncClient RPC when db.is_connected() is False - Fix rate_limit._fetch_user_tier to use user_db().get_user_by_id() instead of PrismaUser.prisma() directly — avoids requiring Prisma connected on worker event loop - Add subscription_tier field to User Pydantic model, mapped in User.from_db() so the RPC path returns the tier value without a direct Prisma connection
This commit is contained in:
@@ -151,9 +151,8 @@ class CoPilotProcessor:
|
||||
This method is called once per worker thread to set up the async event
|
||||
loop and initialize any required resources.
|
||||
|
||||
Prisma is connected here because copilot/db.py and rate_limit.py use
|
||||
the Prisma singleton directly for ChatSession, ChatMessage, and User
|
||||
queries on this worker's event loop.
|
||||
DB operations route through DatabaseManagerAsyncClient (RPC) via the
|
||||
db_accessors pattern — no direct Prisma connection is needed here.
|
||||
"""
|
||||
configure_logging()
|
||||
set_service_name("CoPilotExecutor")
|
||||
@@ -164,14 +163,6 @@ class CoPilotProcessor:
|
||||
)
|
||||
self.execution_thread.start()
|
||||
|
||||
# Connect Prisma for copilot/db.py and rate_limit.py which use
|
||||
# the Prisma singleton directly on this worker's event loop.
|
||||
from backend.data import db as db_module
|
||||
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
db_module.connect(), self.execution_loop
|
||||
).result(timeout=30)
|
||||
|
||||
# Skip the SDK's per-request CLI version check — the bundled CLI is
|
||||
# already version-matched to the SDK package.
|
||||
os.environ.setdefault("CLAUDE_AGENT_SDK_SKIP_VERSION_CHECK", "1")
|
||||
|
||||
@@ -15,6 +15,7 @@ from prisma.models import User as PrismaUser
|
||||
from pydantic import BaseModel, Field
|
||||
from redis.exceptions import RedisError
|
||||
|
||||
from backend.data.db_accessors import user_db
|
||||
from backend.data.redis_client import get_redis_async
|
||||
from backend.util.cache import cached
|
||||
|
||||
@@ -409,9 +410,12 @@ async def _fetch_user_tier(user_id: str) -> SubscriptionTier:
|
||||
prevents a race condition where a non-existent user's ``DEFAULT_TIER`` is
|
||||
cached and then persists after the user is created with a higher tier.
|
||||
"""
|
||||
user = await PrismaUser.prisma().find_unique(where={"id": user_id})
|
||||
if user and user.subscriptionTier: # type: ignore[reportAttributeAccessIssue]
|
||||
return SubscriptionTier(user.subscriptionTier) # type: ignore[reportAttributeAccessIssue]
|
||||
try:
|
||||
user = await user_db().get_user_by_id(user_id)
|
||||
except Exception:
|
||||
raise _UserNotFoundError(user_id)
|
||||
if user.subscription_tier:
|
||||
return SubscriptionTier(user.subscription_tier)
|
||||
raise _UserNotFoundError(user_id)
|
||||
|
||||
|
||||
|
||||
@@ -104,6 +104,11 @@ class User(BaseModel):
|
||||
description="User timezone (IANA timezone identifier or 'not-set')",
|
||||
)
|
||||
|
||||
# Subscription / rate-limit tier
|
||||
subscription_tier: str | None = Field(
|
||||
default=None, description="Subscription tier (FREE, PRO, BUSINESS, ENTERPRISE)"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, prisma_user: "PrismaUser") -> "User":
|
||||
"""Convert a database User object to application User model."""
|
||||
@@ -158,6 +163,7 @@ class User(BaseModel):
|
||||
notify_on_weekly_summary=prisma_user.notifyOnWeeklySummary or True,
|
||||
notify_on_monthly_summary=prisma_user.notifyOnMonthlySummary or True,
|
||||
timezone=prisma_user.timezone or USER_TIMEZONE_NOT_SET,
|
||||
subscription_tier=prisma_user.subscriptionTier,
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user