mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
fix: use Model.prisma() pattern for pyright compatibility
Switched from backend.data.db.get_prisma() to the standard PlatformLink.prisma() / PlatformLinkToken.prisma() pattern used throughout the codebase.
This commit is contained in:
@@ -21,16 +21,25 @@ from typing import Annotated
|
||||
|
||||
from autogpt_libs import auth
|
||||
from fastapi import APIRouter, HTTPException, Security
|
||||
from prisma.models import PlatformLink, PlatformLinkToken
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
import backend.data.db
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
LINK_TOKEN_EXPIRY_MINUTES = 30
|
||||
|
||||
VALID_PLATFORMS = {
|
||||
"DISCORD",
|
||||
"TELEGRAM",
|
||||
"SLACK",
|
||||
"TEAMS",
|
||||
"WHATSAPP",
|
||||
"GITHUB",
|
||||
"LINEAR",
|
||||
}
|
||||
|
||||
|
||||
# ── Request / Response Models ──────────────────────────────────────────
|
||||
|
||||
@@ -39,7 +48,9 @@ class CreateLinkTokenRequest(BaseModel):
|
||||
"""Request from the bot service to create a linking token."""
|
||||
|
||||
platform: str = Field(
|
||||
description="Platform name: DISCORD, TELEGRAM, SLACK, TEAMS, WHATSAPP, GITHUB, LINEAR"
|
||||
description=(
|
||||
"Platform name: DISCORD, TELEGRAM, SLACK, TEAMS, WHATSAPP, GITHUB, LINEAR"
|
||||
)
|
||||
)
|
||||
platform_user_id: str = Field(description="The user's ID on the platform")
|
||||
platform_username: str | None = Field(
|
||||
@@ -109,22 +120,20 @@ async def create_link_token(
|
||||
platform = request.platform.upper()
|
||||
_validate_platform(platform)
|
||||
|
||||
prisma = backend.data.db.get_prisma()
|
||||
|
||||
# Check if already linked
|
||||
existing = await prisma.platformlink.find_unique(
|
||||
existing = await PlatformLink.prisma().find_first(
|
||||
where={
|
||||
"platform_platformUserId": {
|
||||
"platform": platform,
|
||||
"platformUserId": request.platform_user_id,
|
||||
}
|
||||
"platform": platform,
|
||||
"platformUserId": request.platform_user_id,
|
||||
}
|
||||
)
|
||||
if existing:
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail=f"Platform user {request.platform_user_id} on {platform} "
|
||||
f"is already linked to an account.",
|
||||
detail=(
|
||||
f"Platform user {request.platform_user_id} on {platform} "
|
||||
f"is already linked to an account."
|
||||
),
|
||||
)
|
||||
|
||||
# Generate token
|
||||
@@ -133,7 +142,7 @@ async def create_link_token(
|
||||
minutes=LINK_TOKEN_EXPIRY_MINUTES
|
||||
)
|
||||
|
||||
await prisma.platformlinktoken.create(
|
||||
await PlatformLinkToken.prisma().create(
|
||||
data={
|
||||
"token": token,
|
||||
"platform": platform,
|
||||
@@ -168,21 +177,17 @@ async def get_link_token_status(token: str) -> LinkTokenStatusResponse:
|
||||
"""
|
||||
Called by the bot service to check if a user has completed linking.
|
||||
"""
|
||||
prisma = backend.data.db.get_prisma()
|
||||
|
||||
link_token = await prisma.platformlinktoken.find_unique(where={"token": token})
|
||||
link_token = await PlatformLinkToken.prisma().find_unique(where={"token": token})
|
||||
|
||||
if not link_token:
|
||||
raise HTTPException(status_code=404, detail="Token not found")
|
||||
|
||||
if link_token.usedAt is not None:
|
||||
# Token was used — find the linked account
|
||||
link = await prisma.platformlink.find_unique(
|
||||
link = await PlatformLink.prisma().find_first(
|
||||
where={
|
||||
"platform_platformUserId": {
|
||||
"platform": link_token.platform,
|
||||
"platformUserId": link_token.platformUserId,
|
||||
}
|
||||
"platform": link_token.platform,
|
||||
"platformUserId": link_token.platformUserId,
|
||||
}
|
||||
)
|
||||
return LinkTokenStatusResponse(
|
||||
@@ -190,7 +195,7 @@ async def get_link_token_status(token: str) -> LinkTokenStatusResponse:
|
||||
user_id=link.userId if link else None,
|
||||
)
|
||||
|
||||
if link_token.expiresAt < datetime.now(timezone.utc):
|
||||
if link_token.expiresAt.replace(tzinfo=timezone.utc) < datetime.now(timezone.utc):
|
||||
return LinkTokenStatusResponse(status="expired")
|
||||
|
||||
return LinkTokenStatusResponse(status="pending")
|
||||
@@ -211,14 +216,10 @@ async def resolve_platform_user(
|
||||
platform = request.platform.upper()
|
||||
_validate_platform(platform)
|
||||
|
||||
prisma = backend.data.db.get_prisma()
|
||||
|
||||
link = await prisma.platformlink.find_unique(
|
||||
link = await PlatformLink.prisma().find_first(
|
||||
where={
|
||||
"platform_platformUserId": {
|
||||
"platform": platform,
|
||||
"platformUserId": request.platform_user_id,
|
||||
}
|
||||
"platform": platform,
|
||||
"platformUserId": request.platform_user_id,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -249,9 +250,7 @@ async def confirm_link_token(
|
||||
Called by the frontend when the user clicks the link and is logged in.
|
||||
Consumes the token and creates the platform link.
|
||||
"""
|
||||
prisma = backend.data.db.get_prisma()
|
||||
|
||||
link_token = await prisma.platformlinktoken.find_unique(where={"token": token})
|
||||
link_token = await PlatformLinkToken.prisma().find_unique(where={"token": token})
|
||||
|
||||
if not link_token:
|
||||
raise HTTPException(status_code=404, detail="Token not found")
|
||||
@@ -259,16 +258,14 @@ async def confirm_link_token(
|
||||
if link_token.usedAt is not None:
|
||||
raise HTTPException(status_code=410, detail="Token already used")
|
||||
|
||||
if link_token.expiresAt < datetime.now(timezone.utc):
|
||||
if link_token.expiresAt.replace(tzinfo=timezone.utc) < datetime.now(timezone.utc):
|
||||
raise HTTPException(status_code=410, detail="Token expired")
|
||||
|
||||
# Check if this platform identity is already linked to someone else
|
||||
existing = await prisma.platformlink.find_unique(
|
||||
existing = await PlatformLink.prisma().find_first(
|
||||
where={
|
||||
"platform_platformUserId": {
|
||||
"platform": link_token.platform,
|
||||
"platformUserId": link_token.platformUserId,
|
||||
}
|
||||
"platform": link_token.platform,
|
||||
"platformUserId": link_token.platformUserId,
|
||||
}
|
||||
)
|
||||
if existing:
|
||||
@@ -283,7 +280,7 @@ async def confirm_link_token(
|
||||
)
|
||||
|
||||
# Create the link
|
||||
await prisma.platformlink.create(
|
||||
await PlatformLink.prisma().create(
|
||||
data={
|
||||
"userId": user_id,
|
||||
"platform": link_token.platform,
|
||||
@@ -293,7 +290,7 @@ async def confirm_link_token(
|
||||
)
|
||||
|
||||
# Mark token as used
|
||||
await prisma.platformlinktoken.update(
|
||||
await PlatformLinkToken.prisma().update(
|
||||
where={"token": token},
|
||||
data={"usedAt": datetime.now(timezone.utc)},
|
||||
)
|
||||
@@ -323,9 +320,7 @@ async def list_my_links(
|
||||
"""
|
||||
Returns all platform identities linked to the current user's account.
|
||||
"""
|
||||
prisma = backend.data.db.get_prisma()
|
||||
|
||||
links = await prisma.platformlink.find_many(
|
||||
links = await PlatformLink.prisma().find_many(
|
||||
where={"userId": user_id},
|
||||
order={"linkedAt": "desc"},
|
||||
)
|
||||
@@ -355,9 +350,7 @@ async def delete_link(
|
||||
Removes a platform link. The user will need to re-link if they
|
||||
want to use the bot on that platform again.
|
||||
"""
|
||||
prisma = backend.data.db.get_prisma()
|
||||
|
||||
link = await prisma.platformlink.find_unique(where={"id": link_id})
|
||||
link = await PlatformLink.prisma().find_unique(where={"id": link_id})
|
||||
|
||||
if not link:
|
||||
raise HTTPException(status_code=404, detail="Link not found")
|
||||
@@ -365,7 +358,7 @@ async def delete_link(
|
||||
if link.userId != user_id:
|
||||
raise HTTPException(status_code=403, detail="Not your link")
|
||||
|
||||
await prisma.platformlink.delete(where={"id": link_id})
|
||||
await PlatformLink.prisma().delete(where={"id": link_id})
|
||||
|
||||
logger.info(
|
||||
f"Unlinked {link.platform}:{link.platformUserId} from user {user_id[-8:]}"
|
||||
@@ -376,20 +369,13 @@ async def delete_link(
|
||||
|
||||
# ── Helpers ────────────────────────────────────────────────────────────
|
||||
|
||||
VALID_PLATFORMS = {
|
||||
"DISCORD",
|
||||
"TELEGRAM",
|
||||
"SLACK",
|
||||
"TEAMS",
|
||||
"WHATSAPP",
|
||||
"GITHUB",
|
||||
"LINEAR",
|
||||
}
|
||||
|
||||
|
||||
def _validate_platform(platform: str) -> None:
|
||||
if platform not in VALID_PLATFORMS:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid platform '{platform}'. Must be one of: {', '.join(sorted(VALID_PLATFORMS))}",
|
||||
detail=(
|
||||
f"Invalid platform '{platform}'. "
|
||||
f"Must be one of: {', '.join(sorted(VALID_PLATFORMS))}"
|
||||
),
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user