mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-30 03:00:41 -04:00
Compare commits
6 Commits
ntindle/nt
...
feat/reddi
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3bbb9891af | ||
|
|
6849c68bd5 | ||
|
|
750abdf49c | ||
|
|
da120dd6de | ||
|
|
34608a247b | ||
|
|
a065f699ba |
@@ -55,12 +55,7 @@ from backend.integrations.managed_providers.ayrshare import AyrshareManagedProvi
|
||||
from backend.integrations.managed_providers.ayrshare import (
|
||||
settings_available as ayrshare_settings_available,
|
||||
)
|
||||
from backend.integrations.oauth import (
|
||||
CREDENTIALS_BY_PROVIDER,
|
||||
DEVICE_HANDLERS_BY_NAME,
|
||||
HANDLERS_BY_NAME,
|
||||
)
|
||||
from backend.integrations.oauth.device_base import BaseDeviceAuthHandler
|
||||
from backend.integrations.oauth import CREDENTIALS_BY_PROVIDER, HANDLERS_BY_NAME
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.integrations.webhooks import get_webhook_manager
|
||||
from backend.util.exceptions import (
|
||||
@@ -255,161 +250,6 @@ async def callback(
|
||||
return to_meta_response(credentials)
|
||||
|
||||
|
||||
# ================================================================== #
|
||||
# Device Code Grant endpoints (RFC 8628)
|
||||
# ================================================================== #
|
||||
|
||||
|
||||
class DeviceAuthInitiateResponse(BaseModel):
|
||||
state_token: str
|
||||
device_code: str
|
||||
user_code: str
|
||||
verification_url: str
|
||||
verification_url_complete: str | None = None
|
||||
expires_in: int
|
||||
interval: int
|
||||
|
||||
|
||||
class DeviceAuthPollRequest(BaseModel):
|
||||
state_token: str
|
||||
|
||||
|
||||
class DeviceAuthPollResponse(BaseModel):
|
||||
status: str
|
||||
credentials: CredentialsMetaResponse | None = None
|
||||
|
||||
|
||||
def _get_device_auth_handler(provider: ProviderName) -> BaseDeviceAuthHandler:
|
||||
provider_key = provider.value if hasattr(provider, "value") else str(provider)
|
||||
if provider_key not in DEVICE_HANDLERS_BY_NAME:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"No device-auth handler for provider '{provider_key}'",
|
||||
)
|
||||
handler_class = DEVICE_HANDLERS_BY_NAME[provider_key]
|
||||
return handler_class()
|
||||
|
||||
|
||||
@router.post(
|
||||
"/{provider}/device-auth/initiate",
|
||||
summary="Initiate device code OAuth flow",
|
||||
)
|
||||
async def device_auth_initiate(
|
||||
provider: Annotated[
|
||||
ProviderName,
|
||||
Path(title="The provider to initiate device auth for"),
|
||||
],
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
scopes: Annotated[
|
||||
str, Query(title="Comma-separated list of authorization scopes")
|
||||
] = "",
|
||||
) -> DeviceAuthInitiateResponse:
|
||||
handler = _get_device_auth_handler(provider)
|
||||
requested_scopes = scopes.split(",") if scopes else []
|
||||
requested_scopes = handler.handle_default_scopes(requested_scopes)
|
||||
|
||||
try:
|
||||
initiation = await handler.initiate_device_auth(requested_scopes)
|
||||
except Exception as e:
|
||||
logger.error(f"Device auth initiation failed for {provider}: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||
detail=f"Failed to initiate device auth: {str(e)}",
|
||||
)
|
||||
|
||||
# Store state with the provider's expiry (not hardcoded 10 min)
|
||||
state_token, _ = await creds_manager.store.store_state_token(
|
||||
user_id=user_id,
|
||||
provider=provider.value if hasattr(provider, "value") else str(provider),
|
||||
scopes=requested_scopes,
|
||||
state_metadata={
|
||||
"flow_type": "device_code",
|
||||
"device_code": initiation.device_code,
|
||||
"interval": initiation.interval,
|
||||
"user_code": initiation.user_code,
|
||||
},
|
||||
)
|
||||
|
||||
return DeviceAuthInitiateResponse(
|
||||
state_token=state_token,
|
||||
device_code=initiation.device_code,
|
||||
user_code=initiation.user_code,
|
||||
verification_url=initiation.verification_url,
|
||||
verification_url_complete=initiation.verification_url_complete,
|
||||
expires_in=initiation.expires_in,
|
||||
interval=initiation.interval,
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/{provider}/device-auth/poll",
|
||||
summary="Poll device code OAuth flow for completion",
|
||||
)
|
||||
async def device_auth_poll(
|
||||
provider: Annotated[
|
||||
ProviderName,
|
||||
Path(title="The provider to poll device auth for"),
|
||||
],
|
||||
body: DeviceAuthPollRequest,
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> DeviceAuthPollResponse:
|
||||
handler = _get_device_auth_handler(provider)
|
||||
|
||||
# Non-consuming read — state survives across many polls
|
||||
valid_state = await creds_manager.store.peek_state_token(
|
||||
user_id, body.state_token, provider
|
||||
)
|
||||
if not valid_state:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid or expired state token",
|
||||
)
|
||||
|
||||
device_code = valid_state.state_metadata.get("device_code")
|
||||
if not device_code:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="State token is not for a device code flow",
|
||||
)
|
||||
|
||||
try:
|
||||
result = await handler.poll_for_tokens(device_code)
|
||||
except Exception as e:
|
||||
logger.error(f"Device auth poll failed for {provider}: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||
detail=f"Device auth poll failed: {str(e)}",
|
||||
)
|
||||
|
||||
if result.status in ("pending", "slow_down"):
|
||||
return DeviceAuthPollResponse(status=result.status)
|
||||
|
||||
# Terminal state — consume the token so it can't be reused
|
||||
await creds_manager.store.consume_state_token(user_id, body.state_token, provider)
|
||||
|
||||
if result.status == "approved" and result.credentials:
|
||||
credentials = result.credentials
|
||||
credentials.scopes = handler.handle_default_scopes(credentials.scopes)
|
||||
|
||||
if len(credentials.scopes) == 1 and " " in credentials.scopes[0]:
|
||||
credentials.scopes = credentials.scopes[0].split(" ")
|
||||
|
||||
credentials = await _merge_or_create_credential(
|
||||
user_id, provider, credentials, valid_state.credential_id
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Device auth approved for user {user_id} " f"and provider {provider.value}"
|
||||
)
|
||||
return DeviceAuthPollResponse(
|
||||
status="approved",
|
||||
credentials=to_meta_response(credentials),
|
||||
)
|
||||
|
||||
# denied / expired
|
||||
return DeviceAuthPollResponse(status=result.status)
|
||||
|
||||
|
||||
# Bound the first-time sweep so a slow upstream (e.g. Ayrshare) can't hang
|
||||
# the credential-list endpoint. On timeout we still kick off a fire-and-
|
||||
# forget sweep so provisioning eventually completes; the user just won't
|
||||
|
||||
@@ -43,7 +43,6 @@ _STATIC_PROVIDER_CONFIGS: dict[str, tuple[str, tuple[CredentialsType, ...]]] = {
|
||||
"revid": ("AI-generated short-form video", ("api_key",)),
|
||||
"screenshotone": ("Automated website screenshots", ("api_key",)),
|
||||
"smtp": ("Send email via SMTP", ("user_password",)),
|
||||
"stripe_link": ("Stripe Link wallet for agent payments", ("device_code",)),
|
||||
"unreal_speech": ("Low-cost text-to-speech", ("api_key",)),
|
||||
"webshare_proxy": ("Rotating proxies for scraping", ("api_key",)),
|
||||
}
|
||||
|
||||
@@ -37,9 +37,12 @@ RedditCredentialsInput = CredentialsMetaInput[
|
||||
]
|
||||
|
||||
|
||||
def RedditCredentialsField() -> RedditCredentialsInput:
|
||||
def RedditCredentialsField(
|
||||
required_scopes: set[str] | None = None,
|
||||
) -> RedditCredentialsInput:
|
||||
"""Creates a Reddit credentials input on a block."""
|
||||
return CredentialsField(
|
||||
required_scopes=required_scopes or set(),
|
||||
description="Connect your Reddit account to access Reddit features.",
|
||||
)
|
||||
|
||||
@@ -58,6 +61,10 @@ TEST_CREDENTIALS = OAuth2Credentials(
|
||||
"history",
|
||||
"privatemessages",
|
||||
"flair",
|
||||
"modposts",
|
||||
"modcontributors",
|
||||
"modmail",
|
||||
"modlog",
|
||||
],
|
||||
title="Mock Reddit credentials",
|
||||
username="mock-reddit-username",
|
||||
|
||||
627
autogpt_platform/backend/backend/blocks/reddit_moderation.py
Normal file
627
autogpt_platform/backend/backend/blocks/reddit_moderation.py
Normal file
@@ -0,0 +1,627 @@
|
||||
from typing import Literal
|
||||
|
||||
from praw.models import Comment, Submission
|
||||
|
||||
from backend.blocks._base import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.blocks.reddit import (
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
RedditCredentials,
|
||||
RedditCredentialsField,
|
||||
RedditCredentialsInput,
|
||||
get_praw,
|
||||
settings,
|
||||
strip_reddit_prefix,
|
||||
)
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
MOD_POSTS_SCOPE = {"modposts"}
|
||||
MOD_CONTRIBUTORS_SCOPE = {"modcontributors"}
|
||||
MODMAIL_SCOPE = {"modmail"}
|
||||
REMOVE_MOD_NOTE_MAX_LENGTH = 250
|
||||
BAN_REASON_MAX_LENGTH = 100
|
||||
BAN_MOD_NOTE_MAX_LENGTH = 300
|
||||
|
||||
|
||||
def _get_moderated_thing(
|
||||
creds: RedditCredentials, thing_id: str
|
||||
) -> Comment | Submission:
|
||||
client = get_praw(creds)
|
||||
normalized_id = strip_reddit_prefix(thing_id)
|
||||
if thing_id.startswith("t1_"):
|
||||
return client.comment(id=normalized_id)
|
||||
return client.submission(id=normalized_id)
|
||||
|
||||
|
||||
def _get_thing_id(item: Comment | Submission) -> str:
|
||||
fullname = getattr(item, "fullname", None)
|
||||
if fullname:
|
||||
return fullname
|
||||
if isinstance(item, Comment):
|
||||
return f"t1_{item.id}"
|
||||
return f"t3_{item.id}"
|
||||
|
||||
|
||||
def _get_thing_type(item: Comment | Submission) -> Literal["comment", "submission"]:
|
||||
if _get_thing_id(item).startswith("t1_"):
|
||||
return "comment"
|
||||
return "submission"
|
||||
|
||||
|
||||
class ModQueueBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: RedditCredentialsInput = RedditCredentialsField(
|
||||
required_scopes=MOD_POSTS_SCOPE
|
||||
)
|
||||
subreddit: str = SchemaField(
|
||||
description="Subreddit name, excluding the /r/ prefix",
|
||||
)
|
||||
limit: int = SchemaField(
|
||||
description="Maximum number of items to fetch from the mod queue",
|
||||
default=25,
|
||||
)
|
||||
only: Literal["submissions", "comments"] | None = SchemaField(
|
||||
description="Filter to only submissions or only comments. Leave blank for both.",
|
||||
default=None,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
post_id: str = SchemaField(
|
||||
description="Full Reddit thing ID of a queued item, such as 't3_abc123' or 't1_xyz789'"
|
||||
)
|
||||
item_type: Literal["comment", "submission"] = SchemaField(
|
||||
description="Whether the queued item is a comment or submission"
|
||||
)
|
||||
post_title: str = SchemaField(description="Title of the queued item")
|
||||
author: str = SchemaField(description="Username of the author")
|
||||
permalink: str = SchemaField(description="Full Reddit permalink")
|
||||
reason: str = SchemaField(description="Mod queue reason (if any)")
|
||||
items: list[dict] = SchemaField(description="All queued items as a list")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="166f3083-51da-4cfc-9f7a-57f47b1ba590",
|
||||
description="Fetches the mod queue for a subreddit. Requires moderator access.",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
disabled=(
|
||||
not settings.secrets.reddit_client_id
|
||||
or not settings.secrets.reddit_client_secret
|
||||
),
|
||||
input_schema=ModQueueBlock.Input,
|
||||
output_schema=ModQueueBlock.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"subreddit": "testsubreddit",
|
||||
"limit": 5,
|
||||
},
|
||||
test_output=[
|
||||
("post_id", "t3_abc123"),
|
||||
("item_type", "submission"),
|
||||
("post_title", "Test queued post"),
|
||||
("author", "testuser"),
|
||||
("permalink", "/r/testsubreddit/comments/abc123/test_queued_post/"),
|
||||
("reason", ""),
|
||||
(
|
||||
"items",
|
||||
[
|
||||
{
|
||||
"id": "t3_abc123",
|
||||
"type": "submission",
|
||||
"title": "Test queued post",
|
||||
"author": "testuser",
|
||||
"permalink": "/r/testsubreddit/comments/abc123/test_queued_post/",
|
||||
"reason": "",
|
||||
}
|
||||
],
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"get_mod_queue": lambda creds, subreddit, limit, only: [
|
||||
{
|
||||
"id": "t3_abc123",
|
||||
"type": "submission",
|
||||
"title": "Test queued post",
|
||||
"author": "testuser",
|
||||
"permalink": "/r/testsubreddit/comments/abc123/test_queued_post/",
|
||||
"reason": "",
|
||||
}
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_mod_queue(
|
||||
creds: RedditCredentials,
|
||||
subreddit: str,
|
||||
limit: int,
|
||||
only: Literal["submissions", "comments"] | None,
|
||||
) -> list[dict]:
|
||||
client = get_praw(creds)
|
||||
sub = client.subreddit(subreddit)
|
||||
kwargs: dict = {"limit": limit}
|
||||
if only:
|
||||
kwargs["only"] = only
|
||||
items = []
|
||||
for item in sub.mod.modqueue(**kwargs):
|
||||
items.append(
|
||||
{
|
||||
"id": _get_thing_id(item),
|
||||
"type": _get_thing_type(item),
|
||||
"title": getattr(item, "title", "[comment]"),
|
||||
"author": str(item.author) if item.author else "[deleted]",
|
||||
"permalink": item.permalink,
|
||||
"reason": getattr(item, "mod_reason_title", "") or "",
|
||||
}
|
||||
)
|
||||
return items
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: RedditCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
items = self.get_mod_queue(
|
||||
credentials,
|
||||
subreddit=input_data.subreddit,
|
||||
limit=input_data.limit,
|
||||
only=input_data.only,
|
||||
)
|
||||
for item in items:
|
||||
yield "post_id", item["id"]
|
||||
yield "item_type", item["type"]
|
||||
yield "post_title", item["title"]
|
||||
yield "author", item["author"]
|
||||
yield "permalink", item["permalink"]
|
||||
yield "reason", item["reason"]
|
||||
yield "items", items
|
||||
|
||||
|
||||
class RemoveRedditPostBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: RedditCredentialsInput = RedditCredentialsField(
|
||||
required_scopes=MOD_POSTS_SCOPE
|
||||
)
|
||||
post_id: str = SchemaField(
|
||||
description="ID or fullname of the post/comment to remove, such as 't3_abc123', 't1_xyz789', or bare submission ID 'abc123'",
|
||||
)
|
||||
spam: bool = SchemaField(
|
||||
description="Mark as spam (True) or just remove (False). Spam trains the filter.",
|
||||
default=False,
|
||||
)
|
||||
mod_note: str | None = SchemaField(
|
||||
description="Optional internal moderator note visible only to mods",
|
||||
default=None,
|
||||
max_length=REMOVE_MOD_NOTE_MAX_LENGTH,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
post_id: str = SchemaField(description="ID of the removed post (pass-through)")
|
||||
success: bool = SchemaField(description="Whether the removal succeeded")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="f75643df-0a1a-4240-aa5b-9b2a1b20dcdd",
|
||||
description="Removes a Reddit post or comment as a moderator. Requires 'modposts' scope.",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
disabled=(
|
||||
not settings.secrets.reddit_client_id
|
||||
or not settings.secrets.reddit_client_secret
|
||||
),
|
||||
input_schema=RemoveRedditPostBlock.Input,
|
||||
output_schema=RemoveRedditPostBlock.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"post_id": "abc123",
|
||||
"spam": False,
|
||||
},
|
||||
test_output=[
|
||||
("post_id", "abc123"),
|
||||
("success", True),
|
||||
],
|
||||
test_mock={"remove_post": lambda creds, post_id, spam, mod_note: True},
|
||||
is_sensitive_action=True,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def remove_post(
|
||||
creds: RedditCredentials,
|
||||
post_id: str,
|
||||
spam: bool,
|
||||
mod_note: str | None,
|
||||
) -> bool:
|
||||
thing = _get_moderated_thing(creds, post_id)
|
||||
remove_kwargs: dict[str, bool | str] = {"spam": spam}
|
||||
if mod_note:
|
||||
remove_kwargs["mod_note"] = mod_note[:REMOVE_MOD_NOTE_MAX_LENGTH]
|
||||
thing.mod.remove(**remove_kwargs)
|
||||
return True
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: RedditCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
success = self.remove_post(
|
||||
credentials,
|
||||
post_id=input_data.post_id,
|
||||
spam=input_data.spam,
|
||||
mod_note=input_data.mod_note,
|
||||
)
|
||||
yield "post_id", input_data.post_id
|
||||
yield "success", success
|
||||
|
||||
|
||||
class ApproveRedditPostBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: RedditCredentialsInput = RedditCredentialsField(
|
||||
required_scopes=MOD_POSTS_SCOPE
|
||||
)
|
||||
post_id: str = SchemaField(
|
||||
description="ID or fullname of the post/comment to approve, such as 't3_abc123', 't1_xyz789', or bare submission ID 'abc123'",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
post_id: str = SchemaField(description="ID of the approved post (pass-through)")
|
||||
success: bool = SchemaField(description="Whether the approval succeeded")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="ae695fcf-e1bf-4900-b06c-3ae21d6edf70",
|
||||
description="Approves a Reddit post or comment from the mod queue. Requires 'modposts' scope.",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
disabled=(
|
||||
not settings.secrets.reddit_client_id
|
||||
or not settings.secrets.reddit_client_secret
|
||||
),
|
||||
input_schema=ApproveRedditPostBlock.Input,
|
||||
output_schema=ApproveRedditPostBlock.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"post_id": "abc123",
|
||||
},
|
||||
test_output=[
|
||||
("post_id", "abc123"),
|
||||
("success", True),
|
||||
],
|
||||
test_mock={"approve_post": lambda creds, post_id: True},
|
||||
is_sensitive_action=True,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def approve_post(creds: RedditCredentials, post_id: str) -> bool:
|
||||
thing = _get_moderated_thing(creds, post_id)
|
||||
thing.mod.approve()
|
||||
return True
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: RedditCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
success = self.approve_post(credentials, post_id=input_data.post_id)
|
||||
yield "post_id", input_data.post_id
|
||||
yield "success", success
|
||||
|
||||
|
||||
class LockRedditPostBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: RedditCredentialsInput = RedditCredentialsField(
|
||||
required_scopes=MOD_POSTS_SCOPE
|
||||
)
|
||||
post_id: str = SchemaField(
|
||||
description="ID or fullname of the post/comment to lock or unlock",
|
||||
)
|
||||
lock: bool = SchemaField(
|
||||
description="True to lock (disable comments/replies), False to unlock",
|
||||
default=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
post_id: str = SchemaField(description="ID of the post (pass-through)")
|
||||
locked: bool = SchemaField(description="Current lock state after the action")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="1deaf67c-0407-457f-989d-323198073f74",
|
||||
description="Locks or unlocks a Reddit post or comment to prevent or allow replies. Requires 'modposts' scope.",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
disabled=(
|
||||
not settings.secrets.reddit_client_id
|
||||
or not settings.secrets.reddit_client_secret
|
||||
),
|
||||
input_schema=LockRedditPostBlock.Input,
|
||||
output_schema=LockRedditPostBlock.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"post_id": "abc123",
|
||||
"lock": True,
|
||||
},
|
||||
test_output=[
|
||||
("post_id", "abc123"),
|
||||
("locked", True),
|
||||
],
|
||||
test_mock={"set_lock": lambda creds, post_id, lock: lock},
|
||||
is_sensitive_action=True,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def set_lock(creds: RedditCredentials, post_id: str, lock: bool) -> bool:
|
||||
thing = _get_moderated_thing(creds, post_id)
|
||||
if lock:
|
||||
thing.mod.lock()
|
||||
else:
|
||||
thing.mod.unlock()
|
||||
return lock
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: RedditCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
locked = self.set_lock(
|
||||
credentials,
|
||||
post_id=input_data.post_id,
|
||||
lock=input_data.lock,
|
||||
)
|
||||
yield "post_id", input_data.post_id
|
||||
yield "locked", locked
|
||||
|
||||
|
||||
class BanSubredditUserBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: RedditCredentialsInput = RedditCredentialsField(
|
||||
required_scopes=MOD_CONTRIBUTORS_SCOPE
|
||||
)
|
||||
subreddit: str = SchemaField(
|
||||
description="Subreddit to ban the user from, excluding the /r/ prefix",
|
||||
)
|
||||
username: str = SchemaField(
|
||||
description="Reddit username to ban (without the u/ prefix)",
|
||||
)
|
||||
duration: int | None = SchemaField(
|
||||
description="Ban duration in days. Leave blank for a permanent ban.",
|
||||
default=None,
|
||||
ge=1,
|
||||
)
|
||||
reason: str = SchemaField(
|
||||
description="Internal moderator-only ban reason (max 100 chars). Use ban_message to explain the ban to the user.",
|
||||
default="Violation of subreddit rules",
|
||||
max_length=BAN_REASON_MAX_LENGTH,
|
||||
)
|
||||
mod_note: str | None = SchemaField(
|
||||
description="Internal moderator note (not shown to the user)",
|
||||
default=None,
|
||||
max_length=BAN_MOD_NOTE_MAX_LENGTH,
|
||||
)
|
||||
ban_message: str | None = SchemaField(
|
||||
description="Optional custom message sent to the user explaining the ban",
|
||||
default=None,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
username: str = SchemaField(description="Banned username (pass-through)")
|
||||
subreddit: str = SchemaField(description="Subreddit (pass-through)")
|
||||
success: bool = SchemaField(description="Whether the ban was applied")
|
||||
permanent: bool = SchemaField(description="True if the ban is permanent")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="428d56d4-52d0-47d9-8544-836d13d196c0",
|
||||
description="Bans a user from a subreddit. Requires 'modcontributors' scope.",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
disabled=(
|
||||
not settings.secrets.reddit_client_id
|
||||
or not settings.secrets.reddit_client_secret
|
||||
),
|
||||
input_schema=BanSubredditUserBlock.Input,
|
||||
output_schema=BanSubredditUserBlock.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"subreddit": "testsubreddit",
|
||||
"username": "spamuser123",
|
||||
"duration": 7,
|
||||
"reason": "Spam",
|
||||
},
|
||||
test_output=[
|
||||
("username", "spamuser123"),
|
||||
("subreddit", "testsubreddit"),
|
||||
("success", True),
|
||||
("permanent", False),
|
||||
],
|
||||
test_mock={
|
||||
"ban_user": lambda creds, subreddit, username, duration, reason, mod_note, ban_message: True
|
||||
},
|
||||
is_sensitive_action=True,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def ban_user(
|
||||
creds: RedditCredentials,
|
||||
subreddit: str,
|
||||
username: str,
|
||||
duration: int | None,
|
||||
reason: str,
|
||||
mod_note: str | None,
|
||||
ban_message: str | None,
|
||||
) -> bool:
|
||||
if duration is not None and duration <= 0:
|
||||
raise ValueError("Ban duration must be a positive number of days.")
|
||||
|
||||
client = get_praw(creds)
|
||||
sub = client.subreddit(subreddit)
|
||||
ban_kwargs: dict = {"ban_reason": reason[:BAN_REASON_MAX_LENGTH]}
|
||||
if duration is not None:
|
||||
ban_kwargs["duration"] = duration
|
||||
if mod_note:
|
||||
ban_kwargs["note"] = mod_note[:BAN_MOD_NOTE_MAX_LENGTH]
|
||||
if ban_message:
|
||||
ban_kwargs["ban_message"] = ban_message
|
||||
sub.banned.add(username, **ban_kwargs)
|
||||
return True
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: RedditCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
success = self.ban_user(
|
||||
credentials,
|
||||
subreddit=input_data.subreddit,
|
||||
username=input_data.username,
|
||||
duration=input_data.duration,
|
||||
reason=input_data.reason,
|
||||
mod_note=input_data.mod_note,
|
||||
ban_message=input_data.ban_message,
|
||||
)
|
||||
yield "username", input_data.username
|
||||
yield "subreddit", input_data.subreddit
|
||||
yield "success", success
|
||||
yield "permanent", input_data.duration is None
|
||||
|
||||
|
||||
class UnbanSubredditUserBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: RedditCredentialsInput = RedditCredentialsField(
|
||||
required_scopes=MOD_CONTRIBUTORS_SCOPE
|
||||
)
|
||||
subreddit: str = SchemaField(
|
||||
description="Subreddit to unban the user from, excluding the /r/ prefix",
|
||||
)
|
||||
username: str = SchemaField(
|
||||
description="Reddit username to unban (without the u/ prefix)",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
username: str = SchemaField(description="Unbanned username (pass-through)")
|
||||
subreddit: str = SchemaField(description="Subreddit (pass-through)")
|
||||
success: bool = SchemaField(description="Whether the unban succeeded")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="90979f47-605e-4478-a417-39da3d7184ef",
|
||||
description="Unbans a user from a subreddit. Requires 'modcontributors' scope.",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
disabled=(
|
||||
not settings.secrets.reddit_client_id
|
||||
or not settings.secrets.reddit_client_secret
|
||||
),
|
||||
input_schema=UnbanSubredditUserBlock.Input,
|
||||
output_schema=UnbanSubredditUserBlock.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"subreddit": "testsubreddit",
|
||||
"username": "rehabilitateduser",
|
||||
},
|
||||
test_output=[
|
||||
("username", "rehabilitateduser"),
|
||||
("subreddit", "testsubreddit"),
|
||||
("success", True),
|
||||
],
|
||||
test_mock={"unban_user": lambda creds, subreddit, username: True},
|
||||
is_sensitive_action=True,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def unban_user(creds: RedditCredentials, subreddit: str, username: str) -> bool:
|
||||
client = get_praw(creds)
|
||||
sub = client.subreddit(subreddit)
|
||||
sub.banned.remove(username)
|
||||
return True
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: RedditCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
success = self.unban_user(
|
||||
credentials,
|
||||
subreddit=input_data.subreddit,
|
||||
username=input_data.username,
|
||||
)
|
||||
yield "username", input_data.username
|
||||
yield "subreddit", input_data.subreddit
|
||||
yield "success", success
|
||||
|
||||
|
||||
class SendModMailBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: RedditCredentialsInput = RedditCredentialsField(
|
||||
required_scopes=MODMAIL_SCOPE
|
||||
)
|
||||
subreddit: str = SchemaField(
|
||||
description="Subreddit to send modmail from, excluding the /r/ prefix",
|
||||
)
|
||||
to_username: str = SchemaField(
|
||||
description="Username to send the modmail to (without u/ prefix)",
|
||||
)
|
||||
subject: str = SchemaField(
|
||||
description="Subject line of the modmail message",
|
||||
)
|
||||
body: str = SchemaField(
|
||||
description="Body of the modmail message",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
conversation_id: str = SchemaField(
|
||||
description="ID of the created modmail conversation"
|
||||
)
|
||||
success: bool = SchemaField(description="Whether the modmail was sent")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="168b919c-0e06-471d-bd46-eb354ed3d278",
|
||||
description="Sends a modmail message from a subreddit to a user. Requires 'modmail' scope.",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
disabled=(
|
||||
not settings.secrets.reddit_client_id
|
||||
or not settings.secrets.reddit_client_secret
|
||||
),
|
||||
input_schema=SendModMailBlock.Input,
|
||||
output_schema=SendModMailBlock.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"subreddit": "testsubreddit",
|
||||
"to_username": "someuser",
|
||||
"subject": "Warning: Spam",
|
||||
"body": "Please stop posting promotional content.",
|
||||
},
|
||||
test_output=[
|
||||
("conversation_id", "mock_conv_id"),
|
||||
("success", True),
|
||||
],
|
||||
test_mock={
|
||||
"send_modmail": lambda creds, subreddit, to_username, subject, body: "mock_conv_id"
|
||||
},
|
||||
is_sensitive_action=True,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def send_modmail(
|
||||
creds: RedditCredentials,
|
||||
subreddit: str,
|
||||
to_username: str,
|
||||
subject: str,
|
||||
body: str,
|
||||
) -> str:
|
||||
client = get_praw(creds)
|
||||
sub = client.subreddit(subreddit)
|
||||
conversation = sub.modmail.create(
|
||||
subject=subject,
|
||||
body=body,
|
||||
recipient=to_username,
|
||||
)
|
||||
return conversation.id
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: RedditCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
conversation_id = self.send_modmail(
|
||||
credentials,
|
||||
subreddit=input_data.subreddit,
|
||||
to_username=input_data.to_username,
|
||||
subject=input_data.subject,
|
||||
body=input_data.body,
|
||||
)
|
||||
yield "conversation_id", conversation_id
|
||||
yield "success", True
|
||||
@@ -1,210 +0,0 @@
|
||||
# Stripe Link CLI Block — Auth Exploration
|
||||
|
||||
## What is Stripe Link CLI?
|
||||
|
||||
[`@stripe/link-cli`](https://github.com/stripe/link-cli) lets AI agents get secure,
|
||||
one-time-use payment credentials from a user's **Link wallet** (Stripe's consumer
|
||||
payment product). The core operations are:
|
||||
|
||||
| Operation | Description |
|
||||
|-----------|-------------|
|
||||
| `auth login` | Authenticate the agent with a Link account |
|
||||
| `payment-methods list` | List cards/bank accounts in the wallet |
|
||||
| `spend-request create` | Request a one-time virtual card credential |
|
||||
| `spend-request retrieve` | Get card details once user approves |
|
||||
| `mpp pay` | Execute payment via Machine Payments Protocol (SPT) |
|
||||
|
||||
## Link CLI's Auth Model
|
||||
|
||||
Link CLI uses **OAuth 2.0 Device Code Grant** ([RFC 8628](https://tools.ietf.org/html/rfc8628)):
|
||||
|
||||
```
|
||||
┌─────────┐ ┌──────────────────┐ ┌──────────┐
|
||||
│ AutoGPT │ │ login.link.com │ │ User's │
|
||||
│ Backend │ │ (Auth Server) │ │ Browser │
|
||||
└────┬────┘ └────────┬─────────┘ └────┬─────┘
|
||||
│ │ │
|
||||
│ POST /device/code │ │
|
||||
│ client_id=lwlpk_U7Qy7ThG69STZk │ │
|
||||
│ scope=userinfo:read │ │
|
||||
│ payment_methods.agentic │ │
|
||||
│─────────────────────────────────────>│ │
|
||||
│ │ │
|
||||
│ { device_code, user_code, │ │
|
||||
│ verification_uri, │ │
|
||||
│ verification_uri_complete } │ │
|
||||
│<─────────────────────────────────────│ │
|
||||
│ │ │
|
||||
│ Show verification_uri to user ──────┼──────────────────────────────────>│
|
||||
│ │ │
|
||||
│ │ User visits URL, logs in, │
|
||||
│ │ enters user_code phrase │
|
||||
│ │<─────────────────────────────────│
|
||||
│ │ │
|
||||
│ POST /device/token (poll) │ │
|
||||
│ grant_type=device_code │ │
|
||||
│ device_code=... │ │
|
||||
│─────────────────────────────────────>│ │
|
||||
│ │ │
|
||||
│ { access_token, refresh_token, │ │
|
||||
│ expires_in, token_type } │ │
|
||||
│<─────────────────────────────────────│ │
|
||||
│ │ │
|
||||
```
|
||||
|
||||
**Key details:**
|
||||
- **Client ID**: Hardcoded public `lwlpk_U7Qy7ThG69STZk` (no client_secret needed)
|
||||
- **Scopes**: `userinfo:read payment_methods.agentic`
|
||||
- **Token refresh**: Standard `refresh_token` grant at same `/device/token` endpoint
|
||||
- **API calls**: Bearer token in `Authorization` header to `api.link.com`
|
||||
- **User approval**: Push notification or email via Link app, with code-phrase confirmation
|
||||
|
||||
## AutoGPT's Current OAuth Model
|
||||
|
||||
AutoGPT uses the **Authorization Code Grant** flow:
|
||||
|
||||
1. Backend generates `login_url` → frontend redirects user to provider
|
||||
2. User authorizes → provider redirects back with `code`
|
||||
3. Backend exchanges `code` for tokens via `POST /{provider}/callback`
|
||||
4. Tokens stored as `OAuth2Credentials` (access_token + refresh_token)
|
||||
|
||||
The `BaseOAuthHandler` interface:
|
||||
```python
|
||||
class BaseOAuthHandler(ABC):
|
||||
def __init__(self, client_id, client_secret, redirect_uri): ...
|
||||
def get_login_url(self, scopes, state, code_challenge) -> str: ...
|
||||
async def exchange_code_for_tokens(self, code, scopes, code_verifier) -> OAuth2Credentials: ...
|
||||
async def _refresh_tokens(self, credentials) -> OAuth2Credentials: ...
|
||||
async def revoke_tokens(self, credentials) -> bool: ...
|
||||
```
|
||||
|
||||
## The Mismatch
|
||||
|
||||
| Aspect | Authorization Code (current) | Device Code (Link CLI) |
|
||||
|--------|------------------------------|----------------------|
|
||||
| **Initiation** | Redirect user to login URL | Show verification URL + code phrase |
|
||||
| **Token acquisition** | One-shot callback with `code` | Poll `/device/token` until approved |
|
||||
| **Client secret** | Required | Not used (public client) |
|
||||
| **Redirect URI** | Required | Not used |
|
||||
| **User interaction** | Same browser (redirect) | Separate device (phone/other browser) |
|
||||
|
||||
## Implementation Options
|
||||
|
||||
### Option A: Adapt `BaseOAuthHandler` (Minimal Backend Changes)
|
||||
|
||||
Map the device code flow onto the existing handler interface:
|
||||
|
||||
- `get_login_url()` → call `/device/code`, return `verification_uri_complete`
|
||||
- Store `device_code` in the OAuth state token
|
||||
- `exchange_code_for_tokens()` → poll `/device/token` with the stored `device_code`
|
||||
- The `code` parameter is repurposed as the device_code
|
||||
- `_refresh_tokens()` → standard refresh_token grant ✅
|
||||
- `revoke_tokens()` → call `/device/revoke` ✅
|
||||
|
||||
**Pros:** Minimal backend changes, reuses existing OAuth infrastructure
|
||||
**Cons:**
|
||||
- Frontend redirect UX doesn't match — need to show "visit URL" instead of redirect
|
||||
- Polling doesn't fit the one-shot callback model — `exchange_code_for_tokens` would
|
||||
need to block/poll (potentially for minutes)
|
||||
- The frontend currently opens a popup/redirect; it would need a new "device auth" UI mode
|
||||
|
||||
### Option B: API Key Credential (Simplest)
|
||||
|
||||
Let users paste a pre-obtained access token as an API key:
|
||||
|
||||
```python
|
||||
StripeLinkCredentials = APIKeyCredentials
|
||||
StripeLinkCredentialsInput = CredentialsMetaInput[
|
||||
Literal[ProviderName.STRIPE_LINK], Literal["api_key"]
|
||||
]
|
||||
```
|
||||
|
||||
**Pros:** Zero infrastructure changes, works today
|
||||
**Cons:**
|
||||
- Terrible UX — user must run `link-cli auth login` externally, copy token
|
||||
- No auto-refresh — tokens expire, user must re-authenticate manually
|
||||
- Defeats the purpose of integrated credential management
|
||||
|
||||
### Option C: New Device Code OAuth Flow (Recommended)
|
||||
|
||||
Add a **device code flow variant** to the integrations system:
|
||||
|
||||
1. **New backend endpoint**: `POST /integrations/{provider}/device-auth`
|
||||
- Calls Link's `/device/code`
|
||||
- Returns `{ verification_url, user_code, poll_token }` to frontend
|
||||
2. **New backend endpoint**: `GET /integrations/{provider}/device-auth/poll`
|
||||
- Polls Link's `/device/token` on demand
|
||||
- Returns `{ status: "pending" | "approved" | "denied" }`
|
||||
- On "approved", stores `OAuth2Credentials` and returns credential metadata
|
||||
3. **Frontend UI**: Show verification URL + code phrase, poll status
|
||||
4. **OAuth handler**: New `BaseDeviceAuthHandler` base class
|
||||
|
||||
```python
|
||||
class BaseDeviceAuthHandler(ABC):
|
||||
"""Handler for OAuth 2.0 Device Code Grant flows"""
|
||||
PROVIDER_NAME: ClassVar[ProviderName | str]
|
||||
|
||||
@abstractmethod
|
||||
async def initiate_device_auth(self, scopes: list[str]) -> DeviceAuthState: ...
|
||||
|
||||
@abstractmethod
|
||||
async def poll_device_auth(self, device_code: str) -> OAuth2Credentials | None: ...
|
||||
|
||||
@abstractmethod
|
||||
async def _refresh_tokens(self, credentials: OAuth2Credentials) -> OAuth2Credentials: ...
|
||||
|
||||
@abstractmethod
|
||||
async def revoke_tokens(self, credentials: OAuth2Credentials) -> bool: ...
|
||||
```
|
||||
|
||||
**Pros:**
|
||||
- Clean separation of concerns
|
||||
- Reusable for other device-code providers (smart TVs, CLI tools, IoT)
|
||||
- Good UX — user sees clear instructions, approval via Link app
|
||||
- Auto-refresh works via standard OAuth2Credentials
|
||||
|
||||
**Cons:**
|
||||
- Requires new API endpoints and frontend UI components
|
||||
- More implementation effort upfront
|
||||
|
||||
### Option D: Backend Polling with SSE/WebSocket (Best UX, Most Complex)
|
||||
|
||||
Like Option C but the backend polls automatically and pushes status to frontend via SSE:
|
||||
|
||||
- Backend initiates device auth and starts polling in a background task
|
||||
- Frontend connects via SSE or WebSocket for real-time status updates
|
||||
- When approved, credentials are stored and frontend is notified instantly
|
||||
|
||||
**Pros:** Best UX (no frontend polling), instant notification
|
||||
**Cons:** Significant complexity, SSE/WebSocket infrastructure needed
|
||||
|
||||
## Recommendation
|
||||
|
||||
**Start with Option C** (New Device Code OAuth Flow). It's the cleanest architecture
|
||||
that properly handles the device code flow without hacking it into the authorization
|
||||
code flow. The endpoints and handler abstraction are also reusable for future providers.
|
||||
|
||||
**Fallback to Option A** if we want a quick proof of concept — the main challenge is
|
||||
frontend UX, but the backend adaptation is relatively straightforward.
|
||||
|
||||
## What Blocks Would Look Like
|
||||
|
||||
Regardless of auth approach, the block surface area would be:
|
||||
|
||||
| Block | Description | Auth Scope |
|
||||
|-------|-------------|------------|
|
||||
| `StripeLinkListPaymentMethodsBlock` | List cards/bank accounts | `payment_methods.agentic` |
|
||||
| `StripeLinkCreateSpendRequestBlock` | Create a spend request | `payment_methods.agentic` |
|
||||
| `StripeLinkRetrieveSpendRequestBlock` | Get spend request + card details | `payment_methods.agentic` |
|
||||
| `StripeLinkRequestApprovalBlock` | Request user approval for spend | `payment_methods.agentic` |
|
||||
| `StripeLinkMPPPayBlock` | Execute MPP payment | `payment_methods.agentic` |
|
||||
|
||||
All blocks would use the same credential type:
|
||||
```python
|
||||
StripeLinkCredentials = OAuth2Credentials
|
||||
StripeLinkCredentialsInput = CredentialsMetaInput[
|
||||
Literal[ProviderName.STRIPE_LINK], Literal["oauth2"]
|
||||
]
|
||||
```
|
||||
|
||||
See `_auth.py` and `spend_request.py` in this directory for skeleton implementations.
|
||||
@@ -1,63 +0,0 @@
|
||||
"""
|
||||
Stripe Link CLI — Credential definitions for AutoGPT blocks.
|
||||
|
||||
Link CLI uses OAuth 2.0 Device Code Grant (RFC 8628), which produces standard
|
||||
access_token + refresh_token pairs stored as OAuth2Credentials. The device-code
|
||||
acquisition flow is handled by ``StripeLinkDeviceAuthHandler`` in
|
||||
``backend/integrations/oauth/stripe_link.py``.
|
||||
"""
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.model import CredentialsField, CredentialsMetaInput, OAuth2Credentials
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
LINK_API_BASE_URL = "https://api.link.com"
|
||||
LINK_DEFAULT_SCOPES = ["userinfo:read", "payment_methods.agentic"]
|
||||
|
||||
StripeLinkCredentials = OAuth2Credentials
|
||||
|
||||
StripeLinkCredentialsInput = CredentialsMetaInput[
|
||||
Literal[ProviderName.STRIPE_LINK], # type: ignore[index]
|
||||
Literal["oauth2"],
|
||||
]
|
||||
|
||||
|
||||
def StripeLinkCredentialsField() -> StripeLinkCredentialsInput:
|
||||
"""
|
||||
Creates a Stripe Link credentials input on a block.
|
||||
|
||||
All Link blocks require the same `payment_methods.agentic` scope.
|
||||
"""
|
||||
return CredentialsField(
|
||||
required_scopes=set(LINK_DEFAULT_SCOPES),
|
||||
description=(
|
||||
"Connect your Stripe Link account to enable the agent to request "
|
||||
"secure, one-time-use payment credentials from your Link wallet. "
|
||||
"You'll approve each spend request via the Link app."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test credentials for block testing
|
||||
# ---------------------------------------------------------------------------
|
||||
TEST_CREDENTIALS = OAuth2Credentials(
|
||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||
provider="stripe_link",
|
||||
access_token=SecretStr("mock-link-access-token"),
|
||||
refresh_token=SecretStr("mock-link-refresh-token"),
|
||||
access_token_expires_at=None,
|
||||
scopes=LINK_DEFAULT_SCOPES,
|
||||
title="Mock Stripe Link credentials",
|
||||
username="test@example.com",
|
||||
)
|
||||
|
||||
TEST_CREDENTIALS_INPUT = {
|
||||
"provider": TEST_CREDENTIALS.provider,
|
||||
"id": TEST_CREDENTIALS.id,
|
||||
"type": TEST_CREDENTIALS.type,
|
||||
"title": TEST_CREDENTIALS.title,
|
||||
}
|
||||
@@ -1,370 +0,0 @@
|
||||
"""
|
||||
Stripe Link — Spend Request blocks.
|
||||
|
||||
These blocks interact with the Link API (api.link.com) to create, retrieve,
|
||||
and approve spend requests. A spend request provisions a one-time-use virtual
|
||||
card or shared payment token from the user's Link wallet.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from backend.blocks.stripe_link._auth import (
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
StripeLinkCredentials,
|
||||
StripeLinkCredentialsField,
|
||||
StripeLinkCredentialsInput,
|
||||
)
|
||||
from backend.data.block import Block, BlockOutput, BlockSchemaInput
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Shared helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
LINK_API_BASE = "https://api.link.com"
|
||||
|
||||
|
||||
async def _link_api_request(
|
||||
credentials: StripeLinkCredentials,
|
||||
method: str,
|
||||
path: str,
|
||||
body: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Make an authenticated request to the Link API.
|
||||
|
||||
Uses the access_token from OAuth2Credentials as a Bearer token.
|
||||
In a real implementation, this should handle 401 → token refresh.
|
||||
"""
|
||||
import httpx
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {credentials.access_token.get_secret_value()}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.request(
|
||||
method=method,
|
||||
url=f"{LINK_API_BASE}{path}",
|
||||
headers=headers,
|
||||
json=body,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Block: List Payment Methods
|
||||
# ---------------------------------------------------------------------------
|
||||
class StripeLinkListPaymentMethodsBlock(Block):
|
||||
"""List payment methods (cards and bank accounts) from the user's Link wallet."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: StripeLinkCredentialsInput = StripeLinkCredentialsField()
|
||||
|
||||
class Output(BlockSchemaInput):
|
||||
payment_methods: list[dict[str, Any]] = SchemaField(
|
||||
description="List of payment methods in the Link wallet"
|
||||
)
|
||||
error: str = SchemaField(
|
||||
description="Error message if the request failed",
|
||||
default="",
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="6eacc954-2218-4dc7-a485-5bf21549ecbe",
|
||||
description="List payment methods from a Stripe Link wallet",
|
||||
categories=set(),
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
(
|
||||
"payment_methods",
|
||||
[
|
||||
{
|
||||
"id": "csmrpd_test",
|
||||
"type": "card",
|
||||
"is_default": True,
|
||||
"card_details": {
|
||||
"brand": "visa",
|
||||
"last4": "4242",
|
||||
"exp_month": 12,
|
||||
"exp_year": 2030,
|
||||
},
|
||||
}
|
||||
],
|
||||
)
|
||||
],
|
||||
test_mock={
|
||||
"_link_api_request": lambda *args, **kwargs: [
|
||||
{
|
||||
"id": "csmrpd_test",
|
||||
"type": "card",
|
||||
"is_default": True,
|
||||
"card_details": {
|
||||
"brand": "visa",
|
||||
"last4": "4242",
|
||||
"exp_month": 12,
|
||||
"exp_year": 2030,
|
||||
},
|
||||
}
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: StripeLinkCredentials,
|
||||
**kwargs: Any,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
methods = await _link_api_request(credentials, "GET", "/payment_methods")
|
||||
yield "payment_methods", methods
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Block: Create Spend Request
|
||||
# ---------------------------------------------------------------------------
|
||||
class StripeLinkCreateSpendRequestBlock(Block):
|
||||
"""
|
||||
Create a spend request to get a one-time-use payment credential.
|
||||
|
||||
The user must approve the request via the Link app before card details
|
||||
are available. Use StripeLinkRetrieveSpendRequestBlock to check status
|
||||
and get the credential once approved.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: StripeLinkCredentialsInput = StripeLinkCredentialsField()
|
||||
payment_method_id: str = SchemaField(
|
||||
description="ID of the payment method to use (from list payment methods)"
|
||||
)
|
||||
merchant_name: str = SchemaField(
|
||||
description="Name of the merchant for this purchase"
|
||||
)
|
||||
merchant_url: str = SchemaField(description="URL of the merchant website")
|
||||
context: str = SchemaField(
|
||||
description=(
|
||||
"Description of the purchase context (min 100 characters). "
|
||||
"Shown to the user when they approve the request."
|
||||
)
|
||||
)
|
||||
amount: int = SchemaField(
|
||||
description="Amount in cents (max 50000)", ge=1, le=50000
|
||||
)
|
||||
currency: str = SchemaField(
|
||||
description="3-letter ISO currency code", default="usd"
|
||||
)
|
||||
request_approval: bool = SchemaField(
|
||||
description=(
|
||||
"If true, immediately sends a push notification to the user "
|
||||
"for approval. Otherwise, call request-approval separately."
|
||||
),
|
||||
default=True,
|
||||
)
|
||||
test_mode: bool = SchemaField(
|
||||
description="Use test mode (fake card 4242424242424242)",
|
||||
default=False,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaInput):
|
||||
spend_request_id: str = SchemaField(
|
||||
description="ID of the created spend request"
|
||||
)
|
||||
status: str = SchemaField(
|
||||
description="Status: created, pending_approval, approved, denied, etc."
|
||||
)
|
||||
approval_url: str = SchemaField(
|
||||
description="URL the user can visit to approve (if not using push)",
|
||||
default="",
|
||||
)
|
||||
error: str = SchemaField(
|
||||
description="Error message if the request failed",
|
||||
default="",
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="932c3c12-1e80-4392-8fb3-37824eb8a427",
|
||||
description="Create a Stripe Link spend request for a one-time payment credential",
|
||||
categories=set(),
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"payment_method_id": "csmrpd_test",
|
||||
"merchant_name": "Test Store",
|
||||
"merchant_url": "https://example.com",
|
||||
"context": "x" * 100,
|
||||
"amount": 1000,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("spend_request_id", "lsrq_test123"),
|
||||
("status", "pending_approval"),
|
||||
],
|
||||
test_mock={
|
||||
"_link_api_request": lambda *args, **kwargs: {
|
||||
"id": "lsrq_test123",
|
||||
"status": "pending_approval",
|
||||
"approval_url": "",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: StripeLinkCredentials,
|
||||
**kwargs: Any,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
result = await _link_api_request(
|
||||
credentials,
|
||||
"POST",
|
||||
"/spend_requests",
|
||||
body={
|
||||
"payment_details": input_data.payment_method_id,
|
||||
"merchant_name": input_data.merchant_name,
|
||||
"merchant_url": input_data.merchant_url,
|
||||
"context": input_data.context,
|
||||
"amount": input_data.amount,
|
||||
"currency": input_data.currency,
|
||||
"request_approval": input_data.request_approval,
|
||||
"test": input_data.test_mode,
|
||||
},
|
||||
)
|
||||
yield "spend_request_id", result["id"]
|
||||
yield "status", result["status"]
|
||||
if result.get("approval_url"):
|
||||
yield "approval_url", result["approval_url"]
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Block: Retrieve Spend Request
|
||||
# ---------------------------------------------------------------------------
|
||||
class StripeLinkRetrieveSpendRequestBlock(Block):
|
||||
"""
|
||||
Retrieve a spend request and its credentials (once approved).
|
||||
|
||||
After the user approves a spend request, this block returns the
|
||||
virtual card details (number, CVC, expiry, billing address) that
|
||||
can be used for a one-time purchase.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: StripeLinkCredentialsInput = StripeLinkCredentialsField()
|
||||
spend_request_id: str = SchemaField(
|
||||
description="ID of the spend request to retrieve (e.g., lsrq_...)"
|
||||
)
|
||||
include_card: bool = SchemaField(
|
||||
description="Include unmasked card details in the response",
|
||||
default=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaInput):
|
||||
status: str = SchemaField(description="Current status of the spend request")
|
||||
card_number: str = SchemaField(
|
||||
description="Virtual card number (only if approved and include_card=True)",
|
||||
default="",
|
||||
)
|
||||
card_cvc: str = SchemaField(
|
||||
description="Virtual card CVC",
|
||||
default="",
|
||||
)
|
||||
card_exp_month: int = SchemaField(
|
||||
description="Card expiry month",
|
||||
default=0,
|
||||
)
|
||||
card_exp_year: int = SchemaField(
|
||||
description="Card expiry year",
|
||||
default=0,
|
||||
)
|
||||
card_brand: str = SchemaField(
|
||||
description="Card brand (visa, mastercard, etc.)",
|
||||
default="",
|
||||
)
|
||||
valid_until: str = SchemaField(
|
||||
description="ISO timestamp when the virtual card expires",
|
||||
default="",
|
||||
)
|
||||
error: str = SchemaField(
|
||||
description="Error message if the request failed",
|
||||
default="",
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="1aff59ef-e8a2-413e-9410-4ce7e4849337",
|
||||
description="Retrieve a Stripe Link spend request and card credentials",
|
||||
categories=set(),
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"spend_request_id": "lsrq_test123",
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("status", "approved"),
|
||||
("card_number", "4242424242424242"),
|
||||
],
|
||||
test_mock={
|
||||
"_link_api_request": lambda *args, **kwargs: {
|
||||
"status": "approved",
|
||||
"card": {
|
||||
"number": "4242424242424242",
|
||||
"cvc": "123",
|
||||
"exp_month": 12,
|
||||
"exp_year": 2030,
|
||||
"brand": "visa",
|
||||
"valid_until": "2025-12-31T23:59:59Z",
|
||||
},
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: StripeLinkCredentials,
|
||||
**kwargs: Any,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
include = ["card"] if input_data.include_card else []
|
||||
path = f"/spend_requests/{input_data.spend_request_id}"
|
||||
if include:
|
||||
path += f"?include={','.join(include)}"
|
||||
|
||||
result = await _link_api_request(credentials, "GET", path)
|
||||
|
||||
yield "status", result["status"]
|
||||
|
||||
card = result.get("card")
|
||||
if card:
|
||||
yield "card_number", card.get("number", "")
|
||||
yield "card_cvc", card.get("cvc", "")
|
||||
yield "card_exp_month", card.get("exp_month", 0)
|
||||
yield "card_exp_year", card.get("exp_year", 0)
|
||||
yield "card_brand", card.get("brand", "")
|
||||
yield "valid_until", card.get("valid_until", "")
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
@@ -0,0 +1,140 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.blocks.reddit import TEST_CREDENTIALS
|
||||
from backend.blocks.reddit_moderation import (
|
||||
ApproveRedditPostBlock,
|
||||
BanSubredditUserBlock,
|
||||
LockRedditPostBlock,
|
||||
ModQueueBlock,
|
||||
RemoveRedditPostBlock,
|
||||
SendModMailBlock,
|
||||
UnbanSubredditUserBlock,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("block_cls", "expected_scopes"),
|
||||
[
|
||||
(ModQueueBlock, ["modposts"]),
|
||||
(RemoveRedditPostBlock, ["modposts"]),
|
||||
(ApproveRedditPostBlock, ["modposts"]),
|
||||
(LockRedditPostBlock, ["modposts"]),
|
||||
(BanSubredditUserBlock, ["modcontributors"]),
|
||||
(UnbanSubredditUserBlock, ["modcontributors"]),
|
||||
(SendModMailBlock, ["modmail"]),
|
||||
],
|
||||
)
|
||||
def test_moderation_blocks_require_expected_scopes(block_cls, expected_scopes):
|
||||
block = block_cls()
|
||||
|
||||
field = block.input_schema.model_fields["credentials"]
|
||||
scopes = sorted((field.json_schema_extra or {}).get("credentials_scopes") or [])
|
||||
|
||||
assert scopes == expected_scopes
|
||||
|
||||
|
||||
def test_get_mod_queue_uses_modqueue_and_submission_fullnames(mocker):
|
||||
queued_item = SimpleNamespace(
|
||||
id="abc123",
|
||||
fullname="t3_abc123",
|
||||
title="Queued title",
|
||||
author="queued-user",
|
||||
permalink="/r/test/comments/abc123/queued_title/",
|
||||
mod_reason_title="",
|
||||
)
|
||||
sub = MagicMock()
|
||||
sub.mod.modqueue.return_value = [queued_item]
|
||||
client = MagicMock()
|
||||
client.subreddit.return_value = sub
|
||||
mocker.patch("backend.blocks.reddit_moderation.get_praw", return_value=client)
|
||||
|
||||
items = ModQueueBlock.get_mod_queue(
|
||||
TEST_CREDENTIALS,
|
||||
subreddit="test",
|
||||
limit=5,
|
||||
only="submissions",
|
||||
)
|
||||
|
||||
sub.mod.modqueue.assert_called_once_with(limit=5, only="submissions")
|
||||
assert items == [
|
||||
{
|
||||
"id": "t3_abc123",
|
||||
"type": "submission",
|
||||
"title": "Queued title",
|
||||
"author": "queued-user",
|
||||
"permalink": "/r/test/comments/abc123/queued_title/",
|
||||
"reason": "",
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
def test_get_mod_queue_preserves_comment_fullnames(mocker):
|
||||
queued_item = SimpleNamespace(
|
||||
id="xyz789",
|
||||
fullname="t1_xyz789",
|
||||
author=None,
|
||||
permalink="/r/test/comments/abc123/comment/",
|
||||
mod_reason_title=None,
|
||||
)
|
||||
sub = MagicMock()
|
||||
sub.mod.modqueue.return_value = [queued_item]
|
||||
client = MagicMock()
|
||||
client.subreddit.return_value = sub
|
||||
mocker.patch("backend.blocks.reddit_moderation.get_praw", return_value=client)
|
||||
|
||||
items = ModQueueBlock.get_mod_queue(
|
||||
TEST_CREDENTIALS,
|
||||
subreddit="test",
|
||||
limit=5,
|
||||
only="comments",
|
||||
)
|
||||
|
||||
assert items[0]["id"] == "t1_xyz789"
|
||||
assert items[0]["type"] == "comment"
|
||||
assert items[0]["title"] == "[comment]"
|
||||
assert items[0]["author"] == "[deleted]"
|
||||
|
||||
|
||||
def test_remove_post_accepts_comment_fullname_and_truncates_mod_note(mocker):
|
||||
moderated_comment = MagicMock()
|
||||
moderated_comment.mod = MagicMock()
|
||||
client = MagicMock()
|
||||
client.comment.return_value = moderated_comment
|
||||
mocker.patch("backend.blocks.reddit_moderation.get_praw", return_value=client)
|
||||
|
||||
result = RemoveRedditPostBlock.remove_post(
|
||||
TEST_CREDENTIALS,
|
||||
post_id="t1_xyz789",
|
||||
spam=False,
|
||||
mod_note="x" * 300,
|
||||
)
|
||||
|
||||
assert result is True
|
||||
client.comment.assert_called_once_with(id="xyz789")
|
||||
moderated_comment.mod.remove.assert_called_once_with(
|
||||
spam=False,
|
||||
mod_note="x" * 250,
|
||||
)
|
||||
|
||||
|
||||
def test_ban_user_rejects_non_positive_duration(mocker):
|
||||
client = MagicMock()
|
||||
subreddit = MagicMock()
|
||||
client.subreddit.return_value = subreddit
|
||||
mocker.patch("backend.blocks.reddit_moderation.get_praw", return_value=client)
|
||||
|
||||
with pytest.raises(ValueError, match="positive number of days"):
|
||||
BanSubredditUserBlock.ban_user(
|
||||
TEST_CREDENTIALS,
|
||||
subreddit="testsubreddit",
|
||||
username="spamuser123",
|
||||
duration=0,
|
||||
reason="Spam",
|
||||
mod_note=None,
|
||||
ban_message=None,
|
||||
)
|
||||
|
||||
subreddit.banned.add.assert_not_called()
|
||||
@@ -353,6 +353,17 @@ class TestFindMatchingOAuth2Credential:
|
||||
result = find_matching_credential([cred], field_info)
|
||||
assert result is not None
|
||||
|
||||
def test_matches_credential_with_wildcard_scope(self):
|
||||
"""An OAuth2 credential with wildcard scope should satisfy any scope requirement."""
|
||||
cred = self._make_oauth2_cred("reddit", scopes=["*"])
|
||||
field_info = self._make_field_info(
|
||||
ProviderName.REDDIT,
|
||||
required_scopes=frozenset(["modposts"]),
|
||||
)
|
||||
|
||||
result = find_matching_credential([cred], field_info)
|
||||
assert result is not None
|
||||
|
||||
def test_rejects_credential_with_insufficient_scopes(self):
|
||||
"""An OAuth2 credential missing required scopes should not match."""
|
||||
cred = self._make_oauth2_cred(
|
||||
|
||||
@@ -462,7 +462,10 @@ def _credential_has_required_scopes(
|
||||
# If no scopes are required, any credential matches
|
||||
if not requirements.required_scopes:
|
||||
return True
|
||||
return set(credential.scopes).issuperset(requirements.required_scopes)
|
||||
granted_scopes = set(credential.scopes or [])
|
||||
return "*" in granted_scopes or granted_scopes.issuperset(
|
||||
requirements.required_scopes
|
||||
)
|
||||
|
||||
|
||||
def _credential_is_for_host(
|
||||
|
||||
@@ -446,9 +446,7 @@ Credentials = Annotated[
|
||||
]
|
||||
|
||||
|
||||
CredentialsType = Literal[
|
||||
"api_key", "oauth2", "user_password", "host_scoped", "device_code"
|
||||
]
|
||||
CredentialsType = Literal["api_key", "oauth2", "user_password", "host_scoped"]
|
||||
|
||||
|
||||
class OAuthState(BaseModel):
|
||||
|
||||
@@ -602,64 +602,6 @@ class IntegrationCredentialsStore:
|
||||
|
||||
return None
|
||||
|
||||
async def peek_state_token(
|
||||
self, user_id: str, token: str, provider: str
|
||||
) -> Optional[OAuthState]:
|
||||
"""Validate a state token WITHOUT consuming it.
|
||||
|
||||
Used by the device-auth polling loop: the state must survive many
|
||||
poll attempts and is only consumed once auth reaches a terminal
|
||||
state (approved / denied / expired).
|
||||
"""
|
||||
async with await self.locked_user_integrations(user_id):
|
||||
user_integrations = await self._get_user_integrations(user_id)
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
return next(
|
||||
(
|
||||
state
|
||||
for state in user_integrations.oauth_states
|
||||
if secrets.compare_digest(state.token, token)
|
||||
and provider_matches(state.provider, provider)
|
||||
and state.expires_at > now.timestamp()
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
async def consume_state_token(
|
||||
self, user_id: str, token: str, provider: str
|
||||
) -> Optional[OAuthState]:
|
||||
"""Validate and remove a state token (one-time consumption).
|
||||
|
||||
Used when the device-auth flow reaches a terminal state so the
|
||||
token cannot be reused.
|
||||
"""
|
||||
async with await self.locked_user_integrations(user_id):
|
||||
user_integrations = await self._get_user_integrations(user_id)
|
||||
oauth_states = user_integrations.oauth_states
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
valid_state = next(
|
||||
(
|
||||
state
|
||||
for state in oauth_states
|
||||
if secrets.compare_digest(state.token, token)
|
||||
and provider_matches(state.provider, provider)
|
||||
and state.expires_at > now.timestamp()
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
if valid_state:
|
||||
oauth_states.remove(valid_state)
|
||||
user_integrations.oauth_states = oauth_states
|
||||
await self.db_manager.update_user_integrations(
|
||||
user_id, user_integrations
|
||||
)
|
||||
return valid_state
|
||||
|
||||
return None
|
||||
|
||||
# =================== GET/SET HELPERS =================== #
|
||||
|
||||
@asynccontextmanager
|
||||
|
||||
@@ -12,18 +12,13 @@ from backend.integrations.credentials_store import (
|
||||
IntegrationCredentialsStore,
|
||||
provider_matches,
|
||||
)
|
||||
from backend.integrations.oauth import (
|
||||
CREDENTIALS_BY_PROVIDER,
|
||||
DEVICE_HANDLERS_BY_NAME,
|
||||
HANDLERS_BY_NAME,
|
||||
)
|
||||
from backend.integrations.oauth import CREDENTIALS_BY_PROVIDER, HANDLERS_BY_NAME
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.exceptions import MissingConfigError
|
||||
from backend.util.settings import Settings
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.integrations.oauth.base import BaseOAuthHandler
|
||||
from backend.integrations.oauth.device_base import BaseDeviceAuthHandler
|
||||
from backend.integrations.oauth import BaseOAuthHandler
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
settings = Settings()
|
||||
@@ -203,21 +198,10 @@ class IntegrationCredentialsManager:
|
||||
|
||||
async def _get_oauth_handler(
|
||||
self, credentials: OAuth2Credentials
|
||||
) -> "BaseOAuthHandler | BaseDeviceAuthHandler":
|
||||
) -> "BaseOAuthHandler":
|
||||
"""Resolve the appropriate OAuth handler for the given credentials."""
|
||||
if provider_matches(credentials.provider, ProviderName.MCP.value):
|
||||
return create_mcp_oauth_handler(credentials)
|
||||
|
||||
# Try device handlers first (they don't need client_id/secret lookup)
|
||||
provider_key = (
|
||||
credentials.provider.value
|
||||
if hasattr(credentials.provider, "value")
|
||||
else str(credentials.provider)
|
||||
)
|
||||
if provider_key in DEVICE_HANDLERS_BY_NAME:
|
||||
handler_class = DEVICE_HANDLERS_BY_NAME[provider_key]
|
||||
return handler_class()
|
||||
|
||||
return await _get_provider_oauth_handler(credentials.provider)
|
||||
|
||||
async def _refresh_locked(
|
||||
|
||||
@@ -1,16 +1,14 @@
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.integrations.oauth.todoist import TodoistOAuthHandler
|
||||
|
||||
from .device_base import BaseDeviceAuthHandler
|
||||
from .discord import DiscordOAuthHandler
|
||||
from .github import GitHubOAuthHandler
|
||||
from .google import GoogleOAuthHandler
|
||||
from .notion import NotionOAuthHandler
|
||||
from .reddit import RedditOAuthHandler
|
||||
from .stripe_link import StripeLinkDeviceAuthHandler
|
||||
from .twitter import TwitterOAuthHandler
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -229,64 +227,4 @@ HANDLERS_BY_NAME: dict[str, type["BaseOAuthHandler"]] = SDKAwareHandlersDict()
|
||||
CREDENTIALS_BY_PROVIDER: dict[str, SDKAwareCredentials] = SDKAwareCredentialsDict()
|
||||
# --8<-- [end:HANDLERS_BY_NAMEExample]
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Device Code Grant handlers (RFC 8628)
|
||||
# ------------------------------------------------------------------ #
|
||||
_ORIGINAL_DEVICE_HANDLERS: list[type[BaseDeviceAuthHandler]] = [
|
||||
StripeLinkDeviceAuthHandler,
|
||||
]
|
||||
|
||||
_device_handlers_dict: dict[str, type[BaseDeviceAuthHandler]] = {
|
||||
(
|
||||
handler.PROVIDER_NAME.value
|
||||
if hasattr(handler.PROVIDER_NAME, "value")
|
||||
else str(handler.PROVIDER_NAME)
|
||||
): handler
|
||||
for handler in _ORIGINAL_DEVICE_HANDLERS
|
||||
}
|
||||
|
||||
|
||||
class DeviceHandlersDict(dict):
|
||||
"""Dictionary for device-code auth handlers."""
|
||||
|
||||
def __getitem__(self, key):
|
||||
if key in _device_handlers_dict:
|
||||
return _device_handlers_dict[key]
|
||||
raise KeyError(key)
|
||||
|
||||
def get(self, key, default=None):
|
||||
return _device_handlers_dict.get(key, default)
|
||||
|
||||
def __contains__(self, key):
|
||||
return key in _device_handlers_dict
|
||||
|
||||
def keys(self):
|
||||
return _device_handlers_dict.keys()
|
||||
|
||||
def values(self):
|
||||
return _device_handlers_dict.values()
|
||||
|
||||
def items(self):
|
||||
return _device_handlers_dict.items()
|
||||
|
||||
|
||||
DEVICE_HANDLERS_BY_NAME: dict[str, type[BaseDeviceAuthHandler]] = DeviceHandlersDict()
|
||||
|
||||
# Unified lookup: any handler type for a given provider
|
||||
AnyAuthHandler = Union[type["BaseOAuthHandler"], type[BaseDeviceAuthHandler]]
|
||||
|
||||
|
||||
def get_any_handler(provider_key: str) -> AnyAuthHandler | None:
|
||||
"""Resolve an auth handler (auth-code or device-code) for a provider."""
|
||||
if provider_key in HANDLERS_BY_NAME:
|
||||
return HANDLERS_BY_NAME[provider_key]
|
||||
if provider_key in DEVICE_HANDLERS_BY_NAME:
|
||||
return DEVICE_HANDLERS_BY_NAME[provider_key]
|
||||
return None
|
||||
|
||||
|
||||
__all__ = [
|
||||
"HANDLERS_BY_NAME",
|
||||
"DEVICE_HANDLERS_BY_NAME",
|
||||
"get_any_handler",
|
||||
]
|
||||
__all__ = ["HANDLERS_BY_NAME"]
|
||||
|
||||
@@ -1,115 +0,0 @@
|
||||
"""
|
||||
Base handler for OAuth 2.0 Device Code Grant (RFC 8628).
|
||||
|
||||
Providers that use the device authorization flow (CLI tools, IoT devices,
|
||||
smart TVs, etc.) implement this handler instead of ``BaseOAuthHandler``.
|
||||
|
||||
The resulting credentials are standard ``OAuth2Credentials`` — the device
|
||||
code flow is only an *acquisition method*, not a different credential shape.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import ClassVar, Literal, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.model import OAuth2Credentials
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DeviceAuthInitiation(BaseModel):
|
||||
"""Returned when initiating the device code flow."""
|
||||
|
||||
device_code: str
|
||||
user_code: str
|
||||
verification_url: str
|
||||
verification_url_complete: Optional[str] = None
|
||||
expires_in: int
|
||||
interval: int # recommended seconds between polls
|
||||
|
||||
|
||||
class DeviceAuthPollResult(BaseModel):
|
||||
"""Returned from each poll attempt during the device code flow."""
|
||||
|
||||
status: Literal["pending", "slow_down", "approved", "denied", "expired"]
|
||||
credentials: Optional[OAuth2Credentials] = None
|
||||
next_poll_interval: Optional[int] = None
|
||||
|
||||
|
||||
class BaseDeviceAuthHandler(ABC):
|
||||
"""
|
||||
Abstract handler for OAuth 2.0 Device Code Grant flows.
|
||||
|
||||
Subclasses implement provider-specific HTTP calls; the token lifecycle
|
||||
helpers (refresh, needs_refresh, get_access_token) mirror
|
||||
``BaseOAuthHandler`` so the credential manager can dispatch refresh
|
||||
calls uniformly.
|
||||
"""
|
||||
|
||||
PROVIDER_NAME: ClassVar[ProviderName | str]
|
||||
DEFAULT_SCOPES: ClassVar[list[str]] = []
|
||||
|
||||
@abstractmethod
|
||||
async def initiate_device_auth(self, scopes: list[str]) -> DeviceAuthInitiation:
|
||||
"""Start the device code flow. Returns URLs/codes for the user."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def poll_for_tokens(self, device_code: str) -> DeviceAuthPollResult:
|
||||
"""
|
||||
Poll the auth server for token completion.
|
||||
|
||||
Returns a result with status ``"pending"`` or ``"slow_down"`` while
|
||||
waiting, ``"approved"`` with credentials on success, or
|
||||
``"denied"``/``"expired"`` on terminal failure.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def _refresh_tokens(
|
||||
self, credentials: OAuth2Credentials
|
||||
) -> OAuth2Credentials:
|
||||
"""Implements the token refresh mechanism."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def revoke_tokens(self, credentials: OAuth2Credentials) -> bool:
|
||||
"""Revokes the given token at the provider.
|
||||
Returns False if the provider does not support revocation."""
|
||||
...
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Non-abstract helpers — same interface as BaseOAuthHandler
|
||||
# ------------------------------------------------------------------ #
|
||||
|
||||
async def refresh_tokens(self, credentials: OAuth2Credentials) -> OAuth2Credentials:
|
||||
if credentials.provider != self.PROVIDER_NAME:
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} cannot refresh tokens "
|
||||
f"for provider '{credentials.provider}'"
|
||||
)
|
||||
return await self._refresh_tokens(credentials)
|
||||
|
||||
async def get_access_token(self, credentials: OAuth2Credentials) -> str:
|
||||
"""Returns a valid access token, refreshing it first if needed."""
|
||||
if self.needs_refresh(credentials):
|
||||
credentials = await self.refresh_tokens(credentials)
|
||||
return credentials.access_token.get_secret_value()
|
||||
|
||||
def needs_refresh(self, credentials: OAuth2Credentials) -> bool:
|
||||
"""Indicates whether the given tokens need to be refreshed."""
|
||||
return (
|
||||
credentials.access_token_expires_at is not None
|
||||
and credentials.access_token_expires_at < int(time.time()) + 300
|
||||
)
|
||||
|
||||
def handle_default_scopes(self, scopes: list[str]) -> list[str]:
|
||||
"""Uses default scopes when none are provided."""
|
||||
if not scopes:
|
||||
logger.debug(f"Using default scopes for provider {str(self.PROVIDER_NAME)}")
|
||||
scopes = self.DEFAULT_SCOPES
|
||||
return scopes
|
||||
@@ -65,6 +65,23 @@ class RedditOAuthHandler(BaseOAuthHandler):
|
||||
|
||||
return f"{self.AUTHORIZE_URL}?{urllib.parse.urlencode(params)}"
|
||||
|
||||
@staticmethod
|
||||
def _get_granted_scopes(
|
||||
tokens: dict, fallback_scopes: list[str] | None = None
|
||||
) -> list[str]:
|
||||
raw_scopes = tokens.get("scope", "")
|
||||
if isinstance(raw_scopes, str):
|
||||
granted_scopes = raw_scopes.split()
|
||||
elif isinstance(raw_scopes, list):
|
||||
granted_scopes = [scope for scope in raw_scopes if isinstance(scope, str)]
|
||||
else:
|
||||
granted_scopes = []
|
||||
|
||||
if not granted_scopes or "*" in granted_scopes:
|
||||
return list(fallback_scopes or [])
|
||||
|
||||
return granted_scopes
|
||||
|
||||
async def exchange_code_for_tokens(
|
||||
self, code: str, scopes: list[str], code_verifier: Optional[str]
|
||||
) -> OAuth2Credentials:
|
||||
@@ -101,6 +118,7 @@ class RedditOAuthHandler(BaseOAuthHandler):
|
||||
raise ValueError(f"Reddit OAuth error: {tokens.get('error')}")
|
||||
|
||||
username = await self._get_username(tokens["access_token"])
|
||||
granted_scopes = self._get_granted_scopes(tokens, scopes)
|
||||
|
||||
return OAuth2Credentials(
|
||||
provider=self.PROVIDER_NAME,
|
||||
@@ -110,7 +128,7 @@ class RedditOAuthHandler(BaseOAuthHandler):
|
||||
refresh_token=tokens.get("refresh_token"),
|
||||
access_token_expires_at=int(time.time()) + tokens.get("expires_in", 3600),
|
||||
refresh_token_expires_at=None, # Reddit refresh tokens don't expire
|
||||
scopes=scopes,
|
||||
scopes=granted_scopes,
|
||||
)
|
||||
|
||||
async def _get_username(self, access_token: str) -> str:
|
||||
@@ -164,6 +182,8 @@ class RedditOAuthHandler(BaseOAuthHandler):
|
||||
|
||||
username = await self._get_username(tokens["access_token"])
|
||||
|
||||
granted_scopes = self._get_granted_scopes(tokens, credentials.scopes)
|
||||
|
||||
# Reddit may or may not return a new refresh token
|
||||
new_refresh_token = tokens.get("refresh_token")
|
||||
if new_refresh_token:
|
||||
@@ -183,7 +203,7 @@ class RedditOAuthHandler(BaseOAuthHandler):
|
||||
refresh_token=refresh_token,
|
||||
access_token_expires_at=int(time.time()) + tokens.get("expires_in", 3600),
|
||||
refresh_token_expires_at=None,
|
||||
scopes=credentials.scopes,
|
||||
scopes=granted_scopes,
|
||||
)
|
||||
|
||||
async def revoke_tokens(self, credentials: OAuth2Credentials) -> bool:
|
||||
|
||||
@@ -0,0 +1,177 @@
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
import pytest
|
||||
from pydantic import SecretStr
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from backend.data.model import OAuth2Credentials
|
||||
from backend.integrations.oauth.reddit import RedditOAuthHandler
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
|
||||
def _handler() -> RedditOAuthHandler:
|
||||
return RedditOAuthHandler(
|
||||
client_id="test-client-id",
|
||||
client_secret="test-client-secret",
|
||||
redirect_uri="https://example.com/callback",
|
||||
)
|
||||
|
||||
|
||||
def _creds(scopes: list[str] | None = None) -> OAuth2Credentials:
|
||||
return OAuth2Credentials(
|
||||
provider=ProviderName.REDDIT,
|
||||
title=None,
|
||||
username="reddit-user",
|
||||
access_token=SecretStr("access-token-value"),
|
||||
refresh_token=SecretStr("refresh-token-value"),
|
||||
access_token_expires_at=None,
|
||||
refresh_token_expires_at=None,
|
||||
scopes=scopes or ["identity", "read", "modposts"],
|
||||
)
|
||||
|
||||
|
||||
def test_get_login_url_uses_least_privilege_default_scopes():
|
||||
url = _handler().get_login_url([], "state-token", None)
|
||||
|
||||
query = parse_qs(urlparse(url).query)
|
||||
scopes = set(query["scope"][0].split())
|
||||
|
||||
assert scopes == {
|
||||
"identity",
|
||||
"read",
|
||||
"submit",
|
||||
"edit",
|
||||
"history",
|
||||
"privatemessages",
|
||||
"flair",
|
||||
}
|
||||
assert scopes.isdisjoint({"modposts", "modcontributors", "modmail", "modlog"})
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exchange_code_for_tokens_uses_granted_scopes(
|
||||
mocker: MockerFixture,
|
||||
):
|
||||
mock_response = MagicMock()
|
||||
mock_response.ok = True
|
||||
mock_response.json.return_value = {
|
||||
"access_token": "access-token-value",
|
||||
"refresh_token": "refresh-token-value",
|
||||
"expires_in": 3600,
|
||||
"scope": "identity read",
|
||||
}
|
||||
mock_post = AsyncMock(return_value=mock_response)
|
||||
mocker.patch(
|
||||
"backend.integrations.oauth.reddit.Requests",
|
||||
return_value=MagicMock(post=mock_post),
|
||||
)
|
||||
|
||||
handler = _handler()
|
||||
mocker.patch.object(
|
||||
handler,
|
||||
"_get_username",
|
||||
AsyncMock(return_value="reddit-user"),
|
||||
)
|
||||
|
||||
creds = await handler.exchange_code_for_tokens(
|
||||
code="auth-code",
|
||||
scopes=["identity", "read", "modposts"],
|
||||
code_verifier=None,
|
||||
)
|
||||
|
||||
assert creds.scopes == ["identity", "read"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exchange_code_for_tokens_uses_requested_scopes_for_wildcard(
|
||||
mocker: MockerFixture,
|
||||
):
|
||||
requested_scopes = ["identity", "read", "modposts"]
|
||||
mock_response = MagicMock()
|
||||
mock_response.ok = True
|
||||
mock_response.json.return_value = {
|
||||
"access_token": "access-token-value",
|
||||
"refresh_token": "refresh-token-value",
|
||||
"expires_in": 3600,
|
||||
"scope": "*",
|
||||
}
|
||||
mock_post = AsyncMock(return_value=mock_response)
|
||||
mocker.patch(
|
||||
"backend.integrations.oauth.reddit.Requests",
|
||||
return_value=MagicMock(post=mock_post),
|
||||
)
|
||||
|
||||
handler = _handler()
|
||||
mocker.patch.object(
|
||||
handler,
|
||||
"_get_username",
|
||||
AsyncMock(return_value="reddit-user"),
|
||||
)
|
||||
|
||||
creds = await handler.exchange_code_for_tokens(
|
||||
code="auth-code",
|
||||
scopes=requested_scopes,
|
||||
code_verifier=None,
|
||||
)
|
||||
|
||||
assert creds.scopes == requested_scopes
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_tokens_uses_returned_scope_string(mocker: MockerFixture):
|
||||
mock_response = MagicMock()
|
||||
mock_response.ok = True
|
||||
mock_response.json.return_value = {
|
||||
"access_token": "new-access-token",
|
||||
"expires_in": 3600,
|
||||
"scope": "identity read",
|
||||
}
|
||||
mock_post = AsyncMock(return_value=mock_response)
|
||||
mocker.patch(
|
||||
"backend.integrations.oauth.reddit.Requests",
|
||||
return_value=MagicMock(post=mock_post),
|
||||
)
|
||||
|
||||
handler = _handler()
|
||||
mocker.patch.object(
|
||||
handler,
|
||||
"_get_username",
|
||||
AsyncMock(return_value="reddit-user"),
|
||||
)
|
||||
|
||||
refreshed = await handler._refresh_tokens(_creds())
|
||||
|
||||
assert refreshed.scopes == ["identity", "read"]
|
||||
assert refreshed.refresh_token == SecretStr("refresh-token-value")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_tokens_uses_existing_scopes_for_wildcard(
|
||||
mocker: MockerFixture,
|
||||
):
|
||||
existing_scopes = ["identity", "read", "modposts"]
|
||||
mock_response = MagicMock()
|
||||
mock_response.ok = True
|
||||
mock_response.json.return_value = {
|
||||
"access_token": "new-access-token",
|
||||
"expires_in": 3600,
|
||||
"scope": "*",
|
||||
}
|
||||
mock_post = AsyncMock(return_value=mock_response)
|
||||
mocker.patch(
|
||||
"backend.integrations.oauth.reddit.Requests",
|
||||
return_value=MagicMock(post=mock_post),
|
||||
)
|
||||
|
||||
handler = _handler()
|
||||
mocker.patch.object(
|
||||
handler,
|
||||
"_get_username",
|
||||
AsyncMock(return_value="reddit-user"),
|
||||
)
|
||||
|
||||
refreshed = await handler._refresh_tokens(_creds(existing_scopes))
|
||||
|
||||
assert refreshed.scopes == existing_scopes
|
||||
assert refreshed.refresh_token == SecretStr("refresh-token-value")
|
||||
@@ -1,153 +0,0 @@
|
||||
"""
|
||||
Stripe Link — OAuth 2.0 Device Code Grant handler.
|
||||
|
||||
Implements the device code flow for Stripe Link (login.link.com).
|
||||
Uses a public client ID (no client_secret).
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import ClassVar
|
||||
|
||||
import httpx
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.model import OAuth2Credentials
|
||||
from backend.integrations.oauth.device_base import (
|
||||
BaseDeviceAuthHandler,
|
||||
DeviceAuthInitiation,
|
||||
DeviceAuthPollResult,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
LINK_AUTH_BASE_URL = "https://login.link.com"
|
||||
LINK_CLIENT_ID = "lwlpk_U7Qy7ThG69STZk"
|
||||
|
||||
|
||||
class StripeLinkDeviceAuthHandler(BaseDeviceAuthHandler):
|
||||
"""Device code handler for Stripe Link."""
|
||||
|
||||
PROVIDER_NAME: ClassVar[ProviderName] = ProviderName.STRIPE_LINK
|
||||
DEFAULT_SCOPES: ClassVar[list[str]] = [
|
||||
"userinfo:read",
|
||||
"payment_methods.agentic",
|
||||
]
|
||||
|
||||
async def initiate_device_auth(self, scopes: list[str]) -> DeviceAuthInitiation:
|
||||
import socket
|
||||
|
||||
effective_scopes = self.handle_default_scopes(scopes)
|
||||
hostname = socket.gethostname()
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
f"{LINK_AUTH_BASE_URL}/device/code",
|
||||
data={
|
||||
"client_id": LINK_CLIENT_ID,
|
||||
"scope": " ".join(effective_scopes),
|
||||
"connection_label": f"AutoGPT on {hostname}",
|
||||
"client_hint": "AutoGPT",
|
||||
},
|
||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
return DeviceAuthInitiation(
|
||||
device_code=data["device_code"],
|
||||
user_code=data["user_code"],
|
||||
verification_url=data["verification_uri"],
|
||||
verification_url_complete=data.get("verification_uri_complete"),
|
||||
expires_in=data["expires_in"],
|
||||
interval=data.get("interval", 5),
|
||||
)
|
||||
|
||||
async def poll_for_tokens(self, device_code: str) -> DeviceAuthPollResult:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
f"{LINK_AUTH_BASE_URL}/device/token",
|
||||
data={
|
||||
"grant_type": "urn:ietf:params:oauth:grant-type:device_code",
|
||||
"device_code": device_code,
|
||||
"client_id": LINK_CLIENT_ID,
|
||||
},
|
||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
credentials = OAuth2Credentials(
|
||||
provider=self.PROVIDER_NAME,
|
||||
access_token=SecretStr(data["access_token"]),
|
||||
refresh_token=SecretStr(data["refresh_token"]),
|
||||
access_token_expires_at=int(time.time()) + data["expires_in"],
|
||||
scopes=self.DEFAULT_SCOPES,
|
||||
title="Stripe Link",
|
||||
)
|
||||
return DeviceAuthPollResult(status="approved", credentials=credentials)
|
||||
|
||||
if response.status_code == 400:
|
||||
error = response.json()
|
||||
error_code = error.get("error", "")
|
||||
|
||||
if error_code == "authorization_pending":
|
||||
return DeviceAuthPollResult(status="pending")
|
||||
|
||||
if error_code == "slow_down":
|
||||
return DeviceAuthPollResult(
|
||||
status="slow_down",
|
||||
next_poll_interval=10,
|
||||
)
|
||||
|
||||
if error_code == "expired_token":
|
||||
return DeviceAuthPollResult(status="expired")
|
||||
|
||||
if error_code == "access_denied":
|
||||
return DeviceAuthPollResult(status="denied")
|
||||
|
||||
raise RuntimeError(
|
||||
f"Unexpected response from Link auth: "
|
||||
f"{response.status_code} {response.text}"
|
||||
)
|
||||
|
||||
async def _refresh_tokens(
|
||||
self, credentials: OAuth2Credentials
|
||||
) -> OAuth2Credentials:
|
||||
if not credentials.refresh_token:
|
||||
raise RuntimeError("No refresh token available")
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
f"{LINK_AUTH_BASE_URL}/device/token",
|
||||
data={
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": credentials.refresh_token.get_secret_value(),
|
||||
"client_id": LINK_CLIENT_ID,
|
||||
},
|
||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
credentials.access_token = SecretStr(data["access_token"])
|
||||
credentials.refresh_token = SecretStr(data["refresh_token"])
|
||||
credentials.access_token_expires_at = int(time.time()) + data["expires_in"]
|
||||
return credentials
|
||||
|
||||
async def revoke_tokens(self, credentials: OAuth2Credentials) -> bool:
|
||||
if not credentials.refresh_token:
|
||||
return False
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
f"{LINK_AUTH_BASE_URL}/device/revoke",
|
||||
data={
|
||||
"client_id": LINK_CLIENT_ID,
|
||||
"token": credentials.refresh_token.get_secret_value(),
|
||||
},
|
||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
)
|
||||
|
||||
return response.status_code == 200
|
||||
@@ -48,7 +48,6 @@ class ProviderName(str, Enum):
|
||||
SLANT3D = "slant3d"
|
||||
SMARTLEAD = "smartlead"
|
||||
SMTP = "smtp"
|
||||
STRIPE_LINK = "stripe_link"
|
||||
TELEGRAM = "telegram"
|
||||
TWITTER = "twitter"
|
||||
TODOIST = "todoist"
|
||||
|
||||
@@ -0,0 +1,146 @@
|
||||
import { renderHook, waitFor } from "@testing-library/react";
|
||||
import React from "react";
|
||||
import { beforeEach, describe, expect, it, vi } from "vitest";
|
||||
import type { LibraryAgent } from "@/app/api/__generated__/models/libraryAgent";
|
||||
import type { CredentialsMetaResponse } from "@/lib/autogpt-server-api";
|
||||
import {
|
||||
CredentialsProvidersContext,
|
||||
type CredentialsProvidersContextType,
|
||||
} from "@/providers/agent-credentials/credentials-provider";
|
||||
import { useAgentRunModal } from "./useAgentRunModal";
|
||||
|
||||
const executeGraphMutate = vi.fn();
|
||||
const setupTriggerMutate = vi.fn();
|
||||
const invalidateQueries = vi.fn();
|
||||
const toast = vi.fn();
|
||||
|
||||
vi.mock("@/app/api/__generated__/endpoints/graphs/graphs", () => ({
|
||||
getGetV1ListGraphExecutionsQueryKey: vi.fn(() => ["graph-executions"]),
|
||||
usePostV1ExecuteGraphAgent: vi.fn(() => ({
|
||||
mutate: executeGraphMutate,
|
||||
isPending: false,
|
||||
})),
|
||||
}));
|
||||
|
||||
vi.mock("@/app/api/__generated__/endpoints/presets/presets", () => ({
|
||||
getGetV2ListPresetsQueryKey: vi.fn(() => ["presets"]),
|
||||
usePostV2SetupTrigger: vi.fn(() => ({
|
||||
mutate: setupTriggerMutate,
|
||||
isPending: false,
|
||||
})),
|
||||
}));
|
||||
|
||||
vi.mock("@tanstack/react-query", () => ({
|
||||
useQueryClient: vi.fn(() => ({
|
||||
invalidateQueries,
|
||||
})),
|
||||
}));
|
||||
|
||||
vi.mock("@/components/molecules/Toast/use-toast", () => ({
|
||||
useToast: vi.fn(() => ({
|
||||
toast,
|
||||
})),
|
||||
}));
|
||||
|
||||
vi.mock("@/services/analytics", () => ({
|
||||
analytics: {
|
||||
sendDatafastEvent: vi.fn(),
|
||||
},
|
||||
}));
|
||||
|
||||
vi.mock("./errorHelpers", () => ({
|
||||
showExecutionErrorToast: vi.fn(),
|
||||
}));
|
||||
|
||||
function makeCredential(
|
||||
partial: Partial<CredentialsMetaResponse>,
|
||||
): CredentialsMetaResponse {
|
||||
return {
|
||||
id: "cred-id",
|
||||
provider: "reddit",
|
||||
type: "oauth2",
|
||||
title: "Reddit credential",
|
||||
scopes: [],
|
||||
...partial,
|
||||
} as CredentialsMetaResponse;
|
||||
}
|
||||
|
||||
function makeProviders(
|
||||
savedCredentials: CredentialsMetaResponse[],
|
||||
): CredentialsProvidersContextType {
|
||||
return {
|
||||
reddit: {
|
||||
provider: "reddit",
|
||||
providerName: "Reddit",
|
||||
savedCredentials,
|
||||
isSystemProvider: true,
|
||||
oAuthCallback: vi.fn(),
|
||||
mcpOAuthCallback: vi.fn(),
|
||||
createAPIKeyCredentials: vi.fn(),
|
||||
createUserPasswordCredentials: vi.fn(),
|
||||
createHostScopedCredentials: vi.fn(),
|
||||
deleteCredentials: vi.fn(),
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
function makeAgent(): LibraryAgent {
|
||||
return {
|
||||
id: "agent-id",
|
||||
graph_id: "agent-graph-id",
|
||||
graph_version: 1,
|
||||
name: "Reddit moderation agent",
|
||||
input_schema: {
|
||||
properties: {},
|
||||
required: [],
|
||||
},
|
||||
credentials_input_schema: {
|
||||
properties: {
|
||||
reddit_credentials: {
|
||||
credentials_provider: ["reddit"],
|
||||
credentials_types: ["oauth2"],
|
||||
credentials_scopes: ["modposts"],
|
||||
},
|
||||
},
|
||||
required: ["reddit_credentials"],
|
||||
},
|
||||
trigger_setup_info: null,
|
||||
} as unknown as LibraryAgent;
|
||||
}
|
||||
|
||||
describe("useAgentRunModal", () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
it("initializes wildcard-scoped system credentials as matching", async () => {
|
||||
const providers = makeProviders([
|
||||
makeCredential({
|
||||
id: "wild",
|
||||
title: "Wildcard Reddit credential",
|
||||
scopes: ["*"],
|
||||
is_system: true,
|
||||
}),
|
||||
]);
|
||||
|
||||
function Wrapper({ children }: { children: React.ReactNode }) {
|
||||
return React.createElement(
|
||||
CredentialsProvidersContext.Provider,
|
||||
{ value: providers },
|
||||
children,
|
||||
);
|
||||
}
|
||||
|
||||
const { result } = renderHook(() => useAgentRunModal(makeAgent()), {
|
||||
wrapper: Wrapper,
|
||||
});
|
||||
|
||||
await waitFor(() =>
|
||||
expect(result.current.inputCredentials.reddit_credentials).toMatchObject({
|
||||
id: "wild",
|
||||
provider: "reddit",
|
||||
type: "oauth2",
|
||||
}),
|
||||
);
|
||||
});
|
||||
});
|
||||
@@ -10,6 +10,7 @@ import { GraphExecutionMeta } from "@/app/api/__generated__/models/graphExecutio
|
||||
import { LibraryAgent } from "@/app/api/__generated__/models/libraryAgent";
|
||||
import { LibraryAgentPreset } from "@/app/api/__generated__/models/libraryAgentPreset";
|
||||
import { useToast } from "@/components/molecules/Toast/use-toast";
|
||||
import { hasRequiredCredentialScopes } from "@/lib/credentials/hasRequiredCredentialScopes";
|
||||
import { isEmpty } from "@/lib/utils";
|
||||
import { CredentialsProvidersContext } from "@/providers/agent-credentials/credentials-provider";
|
||||
import { analytics } from "@/services/analytics";
|
||||
@@ -110,9 +111,9 @@ export function useAgentRunModal(
|
||||
requiredScopes &&
|
||||
requiredScopes.length > 0
|
||||
) {
|
||||
const grantedScopes = new Set(cred.scopes || []);
|
||||
const hasAllRequiredScopes = requiredScopes.every(
|
||||
(scope: string) => grantedScopes.has(scope),
|
||||
const hasAllRequiredScopes = hasRequiredCredentialScopes(
|
||||
cred.scopes,
|
||||
requiredScopes,
|
||||
);
|
||||
if (!hasAllRequiredScopes) return false;
|
||||
}
|
||||
|
||||
@@ -25,7 +25,6 @@ interface Props {
|
||||
|
||||
const TAB_PRIORITY: AuthMethod[] = [
|
||||
AuthType.oauth2,
|
||||
AuthType.device_code,
|
||||
AuthType.api_key,
|
||||
AuthType.user_password,
|
||||
AuthType.host_scoped,
|
||||
|
||||
@@ -1,93 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { ArrowSquareOutIcon, SpinnerGapIcon, XIcon } from "@phosphor-icons/react";
|
||||
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
|
||||
import { useDeviceAuthConnect } from "./useDeviceAuthConnect";
|
||||
|
||||
interface Props {
|
||||
provider: string;
|
||||
providerName: string;
|
||||
onSuccess: () => void;
|
||||
}
|
||||
|
||||
export function DeviceAuthConnectButton({
|
||||
provider,
|
||||
providerName,
|
||||
onSuccess,
|
||||
}: Props) {
|
||||
const { connect, cancel, phase, userCode, verificationUrl, isPending } =
|
||||
useDeviceAuthConnect({ provider, onSuccess });
|
||||
|
||||
if (phase === "idle" || phase === "error" || phase === "done") {
|
||||
return (
|
||||
<div className="flex flex-col gap-3">
|
||||
<Text variant="body" className="text-[#505057]">
|
||||
{providerName} uses device authorization. Click below, then follow the
|
||||
link to approve access.
|
||||
</Text>
|
||||
<Button
|
||||
type="button"
|
||||
variant="primary"
|
||||
size="large"
|
||||
onClick={connect}
|
||||
loading={false}
|
||||
>
|
||||
Connect {providerName}
|
||||
</Button>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="flex flex-col gap-4">
|
||||
<Text variant="body" className="text-[#505057]">
|
||||
Open the link below and enter the code to connect your {providerName}{" "}
|
||||
account.
|
||||
</Text>
|
||||
|
||||
<div className="flex flex-col gap-3 rounded-lg border border-[#E0E0E3] bg-[#F9F9FA] p-4">
|
||||
<div className="flex flex-col gap-1">
|
||||
<Text variant="small" className="font-medium text-[#83838C]">
|
||||
Your code
|
||||
</Text>
|
||||
<Text
|
||||
variant="h3"
|
||||
as="p"
|
||||
className="select-all text-center font-mono text-2xl tracking-widest text-[#1F1F20]"
|
||||
>
|
||||
{userCode}
|
||||
</Text>
|
||||
</div>
|
||||
|
||||
<a
|
||||
href={verificationUrl}
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
className="inline-flex items-center justify-center gap-2 rounded-md bg-[#1F1F20] px-4 py-2.5 text-sm font-medium text-white transition-colors hover:bg-[#2F2F30]"
|
||||
>
|
||||
Open {providerName}
|
||||
<ArrowSquareOutIcon size={16} />
|
||||
</a>
|
||||
</div>
|
||||
|
||||
<div className="flex items-center justify-between">
|
||||
<div className="flex items-center gap-2 text-[#83838C]">
|
||||
<SpinnerGapIcon size={16} className="animate-spin" />
|
||||
<Text variant="small">Waiting for approval…</Text>
|
||||
</div>
|
||||
<Button
|
||||
type="button"
|
||||
variant="ghost"
|
||||
size="small"
|
||||
onClick={cancel}
|
||||
rightIcon={<XIcon size={14} />}
|
||||
>
|
||||
Cancel
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -6,7 +6,6 @@ import {
|
||||
type ConnectableProvider,
|
||||
} from "../../helpers";
|
||||
import { ApiKeyConnectForm } from "./ApiKeyConnectForm";
|
||||
import { DeviceAuthConnectButton } from "./DeviceAuthConnectButton";
|
||||
import { OAuthConnectButton } from "./OAuthConnectButton";
|
||||
import { UnsupportedNotice } from "./UnsupportedNotice";
|
||||
|
||||
@@ -15,7 +14,6 @@ const TAB_LABEL: Record<AuthMethod, string> = {
|
||||
[AuthType.api_key]: "API key",
|
||||
[AuthType.user_password]: "User / password",
|
||||
[AuthType.host_scoped]: "Host",
|
||||
[AuthType.device_code]: "Device auth",
|
||||
};
|
||||
|
||||
interface Props {
|
||||
@@ -43,15 +41,6 @@ export function MethodPanel({ method, provider, onSuccess }: Props) {
|
||||
/>
|
||||
);
|
||||
}
|
||||
if (method === AuthType.device_code) {
|
||||
return (
|
||||
<DeviceAuthConnectButton
|
||||
provider={provider.id}
|
||||
providerName={provider.name}
|
||||
onSuccess={onSuccess}
|
||||
/>
|
||||
);
|
||||
}
|
||||
return (
|
||||
<UnsupportedNotice
|
||||
providerName={provider.name}
|
||||
|
||||
@@ -1,184 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { useCallback, useEffect, useRef, useState } from "react";
|
||||
import { useQueryClient } from "@tanstack/react-query";
|
||||
|
||||
import { getGetV1ListCredentialsQueryKey } from "@/app/api/__generated__/endpoints/integrations/integrations";
|
||||
import { customMutator } from "@/app/api/mutators/custom-mutator";
|
||||
import { toast } from "@/components/molecules/Toast/use-toast";
|
||||
|
||||
interface DeviceAuthInitiateResponse {
|
||||
state_token: string;
|
||||
device_code: string;
|
||||
user_code: string;
|
||||
verification_url: string;
|
||||
verification_url_complete: string | null;
|
||||
expires_in: number;
|
||||
interval: number;
|
||||
}
|
||||
|
||||
interface DeviceAuthPollResponse {
|
||||
status: "pending" | "slow_down" | "approved" | "denied" | "expired";
|
||||
credentials: unknown | null;
|
||||
}
|
||||
|
||||
interface Args {
|
||||
provider: string;
|
||||
onSuccess: () => void;
|
||||
}
|
||||
|
||||
type Phase = "idle" | "awaiting_user" | "polling" | "done" | "error";
|
||||
|
||||
export function useDeviceAuthConnect({ provider, onSuccess }: Args) {
|
||||
const queryClient = useQueryClient();
|
||||
const [phase, setPhase] = useState<Phase>("idle");
|
||||
const [userCode, setUserCode] = useState("");
|
||||
const [verificationUrl, setVerificationUrl] = useState("");
|
||||
const [stateToken, setStateToken] = useState("");
|
||||
|
||||
const isUnmountedRef = useRef(false);
|
||||
const pollingRef = useRef<ReturnType<typeof setTimeout> | null>(null);
|
||||
const intervalRef = useRef(5);
|
||||
|
||||
useEffect(() => {
|
||||
return () => {
|
||||
isUnmountedRef.current = true;
|
||||
if (pollingRef.current) clearTimeout(pollingRef.current);
|
||||
};
|
||||
}, []);
|
||||
|
||||
const stopPolling = useCallback(() => {
|
||||
if (pollingRef.current) {
|
||||
clearTimeout(pollingRef.current);
|
||||
pollingRef.current = null;
|
||||
}
|
||||
}, []);
|
||||
|
||||
const poll = useCallback(
|
||||
async (token: string) => {
|
||||
if (isUnmountedRef.current) return;
|
||||
|
||||
try {
|
||||
const response = await customMutator<{
|
||||
data: DeviceAuthPollResponse;
|
||||
status: number;
|
||||
headers: Headers;
|
||||
}>(`/integrations/${provider}/device-auth/poll`, {
|
||||
method: "POST",
|
||||
body: JSON.stringify({ state_token: token }),
|
||||
headers: { "Content-Type": "application/json" },
|
||||
});
|
||||
|
||||
if (isUnmountedRef.current) return;
|
||||
|
||||
const { status } = response.data;
|
||||
|
||||
if (status === "approved") {
|
||||
setPhase("done");
|
||||
stopPolling();
|
||||
toast({ title: "Connected via device auth", variant: "success" });
|
||||
await queryClient.invalidateQueries({
|
||||
queryKey: getGetV1ListCredentialsQueryKey(),
|
||||
});
|
||||
onSuccess();
|
||||
return;
|
||||
}
|
||||
|
||||
if (status === "slow_down") {
|
||||
intervalRef.current = Math.min(intervalRef.current + 5, 30);
|
||||
}
|
||||
|
||||
if (status === "denied" || status === "expired") {
|
||||
setPhase("error");
|
||||
stopPolling();
|
||||
toast({
|
||||
title:
|
||||
status === "denied"
|
||||
? "Authorization denied"
|
||||
: "Authorization expired",
|
||||
description:
|
||||
status === "denied"
|
||||
? "The authorization request was denied."
|
||||
: "The authorization request expired. Please try again.",
|
||||
variant: "destructive",
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
// pending or slow_down — schedule next poll
|
||||
pollingRef.current = setTimeout(
|
||||
() => poll(token),
|
||||
intervalRef.current * 1000,
|
||||
);
|
||||
} catch (error) {
|
||||
if (isUnmountedRef.current) return;
|
||||
setPhase("error");
|
||||
stopPolling();
|
||||
toast({
|
||||
title: "Device auth polling failed",
|
||||
description:
|
||||
error instanceof Error ? error.message : "Unexpected error",
|
||||
variant: "destructive",
|
||||
});
|
||||
}
|
||||
},
|
||||
[provider, onSuccess, queryClient, stopPolling],
|
||||
);
|
||||
|
||||
async function connect() {
|
||||
setPhase("awaiting_user");
|
||||
try {
|
||||
const response = await customMutator<{
|
||||
data: DeviceAuthInitiateResponse;
|
||||
status: number;
|
||||
headers: Headers;
|
||||
}>(`/integrations/${provider}/device-auth/initiate`, {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
});
|
||||
|
||||
if (isUnmountedRef.current) return;
|
||||
|
||||
const data = response.data;
|
||||
setUserCode(data.user_code);
|
||||
setVerificationUrl(
|
||||
data.verification_url_complete || data.verification_url,
|
||||
);
|
||||
setStateToken(data.state_token);
|
||||
intervalRef.current = data.interval;
|
||||
|
||||
// Start polling
|
||||
setPhase("polling");
|
||||
pollingRef.current = setTimeout(
|
||||
() => poll(data.state_token),
|
||||
data.interval * 1000,
|
||||
);
|
||||
} catch (error) {
|
||||
if (isUnmountedRef.current) return;
|
||||
setPhase("error");
|
||||
toast({
|
||||
title: "Device auth initiation failed",
|
||||
description:
|
||||
error instanceof Error ? error.message : "Unexpected error",
|
||||
variant: "destructive",
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
function cancel() {
|
||||
stopPolling();
|
||||
setPhase("idle");
|
||||
setUserCode("");
|
||||
setVerificationUrl("");
|
||||
setStateToken("");
|
||||
}
|
||||
|
||||
return {
|
||||
connect,
|
||||
cancel,
|
||||
phase,
|
||||
userCode,
|
||||
verificationUrl,
|
||||
isPending: phase === "awaiting_user" || phase === "polling",
|
||||
};
|
||||
}
|
||||
@@ -217,4 +217,50 @@ describe("useCredentialsInput – handleScopeUpgrade", () => {
|
||||
undefined,
|
||||
);
|
||||
});
|
||||
|
||||
it("accepts wildcard scopes returned from OAuth callback", async () => {
|
||||
const oAuthLoginMock = vi.fn().mockResolvedValue({
|
||||
login_url: "https://accounts.google.com/o/oauth2/auth",
|
||||
state_token: "state-wild",
|
||||
});
|
||||
mockUseBackendAPI.mockReturnValue({ oAuthLogin: oAuthLoginMock });
|
||||
|
||||
const oAuthCallback = vi.fn().mockResolvedValue(
|
||||
makeCred({
|
||||
id: "wildcard-cred",
|
||||
scopes: ["*"],
|
||||
}),
|
||||
);
|
||||
|
||||
mockUseCredentials.mockReturnValue(
|
||||
makeCredentialsReturn({ oAuthCallback }),
|
||||
);
|
||||
|
||||
mockOpenOAuthPopup.mockReturnValue({
|
||||
promise: Promise.resolve({ code: "wild-code", state: "state-wild" }),
|
||||
cleanup: { abort: vi.fn() },
|
||||
});
|
||||
|
||||
const onSelect = vi.fn();
|
||||
const { result } = renderHook(() =>
|
||||
useCredentialsInput({
|
||||
schema: baseSchema,
|
||||
onSelectCredential: onSelect,
|
||||
}),
|
||||
);
|
||||
|
||||
expect(result.current.isLoading).toBe(false);
|
||||
if (result.current.isLoading) return;
|
||||
|
||||
await act(async () => {
|
||||
await result.current.handleOAuthLogin!();
|
||||
});
|
||||
|
||||
expect(onSelect).toHaveBeenCalledWith({
|
||||
id: "wildcard-cred",
|
||||
type: "oauth2",
|
||||
title: "Test",
|
||||
provider: "google",
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -0,0 +1,57 @@
|
||||
import { describe, expect, it, vi } from "vitest";
|
||||
import type { CredentialsMetaResponse } from "@/lib/autogpt-server-api";
|
||||
import type { CredentialsProvidersContextType } from "@/providers/agent-credentials/credentials-provider";
|
||||
import { findSavedCredentialByProviderAndType } from "./helpers";
|
||||
|
||||
function makeCredential(
|
||||
partial: Partial<CredentialsMetaResponse>,
|
||||
): CredentialsMetaResponse {
|
||||
return {
|
||||
id: "cred-id",
|
||||
provider: "reddit",
|
||||
type: "oauth2",
|
||||
title: "Reddit credential",
|
||||
scopes: [],
|
||||
...partial,
|
||||
} as CredentialsMetaResponse;
|
||||
}
|
||||
|
||||
function makeProviders(
|
||||
savedCredentials: CredentialsMetaResponse[],
|
||||
): CredentialsProvidersContextType {
|
||||
return {
|
||||
reddit: {
|
||||
provider: "reddit",
|
||||
providerName: "Reddit",
|
||||
savedCredentials,
|
||||
isSystemProvider: true,
|
||||
oAuthCallback: vi.fn(),
|
||||
mcpOAuthCallback: vi.fn(),
|
||||
createAPIKeyCredentials: vi.fn(),
|
||||
createUserPasswordCredentials: vi.fn(),
|
||||
createHostScopedCredentials: vi.fn(),
|
||||
deleteCredentials: vi.fn(),
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
describe("findSavedCredentialByProviderAndType", () => {
|
||||
it("accepts wildcard-scoped oauth credentials for required scopes", () => {
|
||||
const providers = makeProviders([
|
||||
makeCredential({
|
||||
id: "wild",
|
||||
scopes: ["*"],
|
||||
is_system: true,
|
||||
}),
|
||||
]);
|
||||
|
||||
const credential = findSavedCredentialByProviderAndType(
|
||||
["reddit"],
|
||||
["oauth2"],
|
||||
["modposts"],
|
||||
providers,
|
||||
);
|
||||
|
||||
expect(credential?.id).toBe("wild");
|
||||
});
|
||||
});
|
||||
@@ -1,4 +1,5 @@
|
||||
import { CredentialsProvidersContextType } from "@/providers/agent-credentials/credentials-provider";
|
||||
import { hasRequiredCredentialScopes } from "@/lib/credentials/hasRequiredCredentialScopes";
|
||||
import { filterSystemCredentials, getSystemCredentials } from "../../helpers";
|
||||
|
||||
export type CredentialField = [string, any];
|
||||
@@ -15,12 +16,7 @@ function hasRequiredScopes(
|
||||
requiredScopes?: string[],
|
||||
) {
|
||||
if (credential.type !== "oauth2") return true;
|
||||
if (!requiredScopes || requiredScopes.length === 0) return true;
|
||||
const grantedScopes = new Set(credential.scopes || []);
|
||||
for (const scope of requiredScopes) {
|
||||
if (!grantedScopes.has(scope)) return false;
|
||||
}
|
||||
return true;
|
||||
return hasRequiredCredentialScopes(credential.scopes, requiredScopes);
|
||||
}
|
||||
|
||||
/** Check if a credential matches the discriminator values (e.g. MCP server URL). */
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import useCredentials from "@/hooks/useCredentials";
|
||||
import { useBackendAPI } from "@/lib/autogpt-server-api/context";
|
||||
import { hasRequiredCredentialScopes } from "@/lib/credentials/hasRequiredCredentialScopes";
|
||||
import {
|
||||
BlockIOCredentialsSubSchema,
|
||||
CredentialsMetaInput,
|
||||
@@ -219,9 +220,9 @@ export function useCredentialsInput({
|
||||
if (!isMCP) {
|
||||
const requiredScopes = schema.credentials_scopes;
|
||||
if (requiredScopes && requiredScopes.length > 0) {
|
||||
const grantedScopes = new Set(credentialResult.scopes || []);
|
||||
const hasAllRequiredScopes = new Set(requiredScopes).isSubsetOf(
|
||||
grantedScopes,
|
||||
const hasAllRequiredScopes = hasRequiredCredentialScopes(
|
||||
credentialResult.scopes,
|
||||
requiredScopes,
|
||||
);
|
||||
|
||||
if (!hasAllRequiredScopes) {
|
||||
|
||||
@@ -0,0 +1,54 @@
|
||||
import { describe, expect, it } from "vitest";
|
||||
import type { CredentialsMetaResponse } from "@/app/api/__generated__/models/credentialsMetaResponse";
|
||||
import type { BlockIOCredentialsSubSchema } from "@/lib/autogpt-server-api";
|
||||
import { filterCredentialsByProvider } from "./helpers";
|
||||
|
||||
function makeCredential(
|
||||
partial: Partial<CredentialsMetaResponse>,
|
||||
): CredentialsMetaResponse {
|
||||
return {
|
||||
id: "cred-id",
|
||||
provider: "reddit",
|
||||
type: "oauth2",
|
||||
title: "Reddit credential",
|
||||
scopes: [],
|
||||
...partial,
|
||||
} as CredentialsMetaResponse;
|
||||
}
|
||||
|
||||
function makeSchema(
|
||||
partial: Partial<BlockIOCredentialsSubSchema> = {},
|
||||
): BlockIOCredentialsSubSchema {
|
||||
return {
|
||||
credentials_provider: ["reddit"],
|
||||
credentials_types: ["oauth2"],
|
||||
credentials_scopes: ["modposts"],
|
||||
...partial,
|
||||
} as BlockIOCredentialsSubSchema;
|
||||
}
|
||||
|
||||
describe("filterCredentialsByProvider", () => {
|
||||
it("keeps wildcard-scoped oauth credentials when scopes are required", () => {
|
||||
const result = filterCredentialsByProvider(
|
||||
[makeCredential({ id: "wild", scopes: ["*"] })],
|
||||
"reddit",
|
||||
makeSchema(),
|
||||
);
|
||||
|
||||
expect(result.exists).toBe(true);
|
||||
expect(result.credentials.map((credential) => credential.id)).toEqual([
|
||||
"wild",
|
||||
]);
|
||||
});
|
||||
|
||||
it("filters oauth credentials that still lack a required scope", () => {
|
||||
const result = filterCredentialsByProvider(
|
||||
[makeCredential({ id: "narrow", scopes: ["read"] })],
|
||||
"reddit",
|
||||
makeSchema(),
|
||||
);
|
||||
|
||||
expect(result.exists).toBe(false);
|
||||
expect(result.credentials).toEqual([]);
|
||||
});
|
||||
});
|
||||
@@ -1,5 +1,6 @@
|
||||
import { CredentialsMetaResponse } from "@/app/api/__generated__/models/credentialsMetaResponse";
|
||||
import { BlockIOCredentialsSubSchema } from "@/lib/autogpt-server-api";
|
||||
import { hasRequiredCredentialScopes } from "@/lib/credentials/hasRequiredCredentialScopes";
|
||||
import { getHostFromUrl } from "@/lib/utils/url";
|
||||
import {
|
||||
GoogleLogoIcon,
|
||||
@@ -32,10 +33,9 @@ export const filterCredentialsByProvider = (
|
||||
|
||||
// Filter OAuth credentials that have sufficient scopes for this block
|
||||
if (credential.type === "oauth2" && schema?.credentials_scopes) {
|
||||
const credentialScopes = new Set(credential.scopes || []);
|
||||
const requiredScopes = new Set(schema.credentials_scopes);
|
||||
const hasAllScopes = [...requiredScopes].every((scope) =>
|
||||
credentialScopes.has(scope),
|
||||
const hasAllScopes = hasRequiredCredentialScopes(
|
||||
credential.scopes,
|
||||
schema.credentials_scopes,
|
||||
);
|
||||
if (!hasAllScopes) {
|
||||
return false;
|
||||
|
||||
@@ -71,6 +71,20 @@ describe("classifyCredentials", () => {
|
||||
expect(upgradeableCredentials).toEqual([]);
|
||||
});
|
||||
|
||||
it("classifies wildcard-scoped OAuth2 creds as saved", () => {
|
||||
const schema = makeSchema({
|
||||
credentials_scopes: ["modposts", "modcontributors"],
|
||||
});
|
||||
const { savedCredentials, upgradeableCredentials } = classifyCredentials(
|
||||
[makeCred({ id: "wildcard", scopes: ["*"] })],
|
||||
schema,
|
||||
undefined,
|
||||
);
|
||||
|
||||
expect(savedCredentials.map((c) => c.id)).toEqual(["wildcard"]);
|
||||
expect(upgradeableCredentials).toEqual([]);
|
||||
});
|
||||
|
||||
it("classifies OAuth2 creds missing a scope as upgradeable (not discarded)", () => {
|
||||
// Regression coverage for the incremental-OAuth flow: a credential
|
||||
// that's missing only one scope must land in upgradeableCredentials so
|
||||
|
||||
@@ -5,6 +5,7 @@ import {
|
||||
CredentialsProviderData,
|
||||
CredentialsProvidersContext,
|
||||
} from "@/providers/agent-credentials/credentials-provider";
|
||||
import { hasRequiredCredentialScopes } from "@/lib/credentials/hasRequiredCredentialScopes";
|
||||
import {
|
||||
BlockIOCredentialsSubSchema,
|
||||
CredentialsMetaResponse,
|
||||
@@ -37,12 +38,10 @@ export function classifyCredentials(
|
||||
|
||||
if (c.type === "oauth2") {
|
||||
const requiredScopes = credsInputSchema.credentials_scopes;
|
||||
// Set.prototype.isSupersetOf is ES2025 and this project targets
|
||||
// ES2022 — fall back to an array every() check so the picker's
|
||||
// scope filter runs cleanly on current Node/browser baselines.
|
||||
const credScopes = new Set(c.scopes);
|
||||
const hasAllScopes =
|
||||
!requiredScopes || requiredScopes.every((s) => credScopes.has(s));
|
||||
const hasAllScopes = hasRequiredCredentialScopes(
|
||||
c.scopes,
|
||||
requiredScopes,
|
||||
);
|
||||
if (hasAllScopes) {
|
||||
savedCredentials.push(c);
|
||||
} else {
|
||||
|
||||
@@ -0,0 +1,21 @@
|
||||
import { describe, expect, it } from "vitest";
|
||||
import { hasRequiredCredentialScopes } from "./hasRequiredCredentialScopes";
|
||||
|
||||
describe("hasRequiredCredentialScopes", () => {
|
||||
it("returns true when no scopes are required", () => {
|
||||
expect(hasRequiredCredentialScopes(["read"], undefined)).toBe(true);
|
||||
expect(hasRequiredCredentialScopes(["read"], [])).toBe(true);
|
||||
});
|
||||
|
||||
it("treats wildcard scopes as satisfying all requirements", () => {
|
||||
expect(
|
||||
hasRequiredCredentialScopes(["*"], ["modposts", "modcontributors"]),
|
||||
).toBe(true);
|
||||
});
|
||||
|
||||
it("returns false when a required scope is missing", () => {
|
||||
expect(hasRequiredCredentialScopes(["read"], ["read", "modposts"])).toBe(
|
||||
false,
|
||||
);
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,15 @@
|
||||
export function hasRequiredCredentialScopes(
|
||||
grantedScopes: readonly string[] | null | undefined,
|
||||
requiredScopes: readonly string[] | null | undefined,
|
||||
) {
|
||||
if (!requiredScopes || requiredScopes.length === 0) {
|
||||
return true;
|
||||
}
|
||||
|
||||
const normalizedGrantedScopes = new Set(grantedScopes ?? []);
|
||||
if (normalizedGrantedScopes.has("*")) {
|
||||
return true;
|
||||
}
|
||||
|
||||
return requiredScopes.every((scope) => normalizedGrantedScopes.has(scope));
|
||||
}
|
||||
@@ -302,6 +302,8 @@ Below is a comprehensive list of all available blocks, categorized by their prim
|
||||
|
||||
| Block Name | Description |
|
||||
|------------|-------------|
|
||||
| [Approve Reddit Post](block-integrations/misc.md#approve-reddit-post) | Approves a Reddit post or comment from the mod queue |
|
||||
| [Ban Subreddit User](block-integrations/misc.md#ban-subreddit-user) | Bans a user from a subreddit |
|
||||
| [Create Discord Thread](block-integrations/discord/bot_blocks.md#create-discord-thread) | Creates a new thread in a Discord channel |
|
||||
| [Create Reddit Post](block-integrations/misc.md#create-reddit-post) | Create a new post on a subreddit |
|
||||
| [Delete Reddit Comment](block-integrations/misc.md#delete-reddit-comment) | Delete a Reddit comment that you own |
|
||||
@@ -328,6 +330,8 @@ Below is a comprehensive list of all available blocks, categorized by their prim
|
||||
| [Get User Posts](block-integrations/misc.md#get-user-posts) | Fetch posts by a specific Reddit user |
|
||||
| [Linkedin Person Lookup](block-integrations/enrichlayer/linkedin.md#linkedin-person-lookup) | Look up LinkedIn profiles by person information using Enrichlayer |
|
||||
| [Linkedin Role Lookup](block-integrations/enrichlayer/linkedin.md#linkedin-role-lookup) | Look up LinkedIn profiles by role in a company using Enrichlayer |
|
||||
| [Lock Reddit Post](block-integrations/misc.md#lock-reddit-post) | Locks or unlocks a Reddit post or comment to prevent or allow replies |
|
||||
| [Mod Queue](block-integrations/misc.md#mod-queue) | Fetches the mod queue for a subreddit |
|
||||
| [Post Reddit Comment](block-integrations/misc.md#post-reddit-comment) | This block posts a Reddit comment on a specified Reddit post |
|
||||
| [Post To Bluesky](block-integrations/ayrshare/post_to_bluesky.md#post-to-bluesky) | Post to Bluesky using Ayrshare |
|
||||
| [Post To Facebook](block-integrations/ayrshare/post_to_facebook.md#post-to-facebook) | Post to Facebook using Ayrshare |
|
||||
@@ -345,6 +349,7 @@ Below is a comprehensive list of all available blocks, categorized by their prim
|
||||
| [Publish To Medium](block-integrations/misc.md#publish-to-medium) | Publishes a post to Medium |
|
||||
| [Read Discord Messages](block-integrations/discord/bot_blocks.md#read-discord-messages) | Reads new messages from a Discord channel using a bot token and triggers when a new message is posted |
|
||||
| [Reddit Get My Posts](block-integrations/misc.md#reddit-get-my-posts) | Fetch posts created by the authenticated Reddit user (you) |
|
||||
| [Remove Reddit Post](block-integrations/misc.md#remove-reddit-post) | Removes a Reddit post or comment as a moderator |
|
||||
| [Reply To Discord Message](block-integrations/discord/bot_blocks.md#reply-to-discord-message) | Replies to a specific Discord message |
|
||||
| [Reply To Reddit Comment](block-integrations/misc.md#reply-to-reddit-comment) | Reply to a specific Reddit comment |
|
||||
| [Reply To Telegram Message](block-integrations/telegram/blocks.md#reply-to-telegram-message) | Reply to a specific message in a Telegram chat |
|
||||
@@ -353,6 +358,7 @@ Below is a comprehensive list of all available blocks, categorized by their prim
|
||||
| [Send Discord Embed](block-integrations/discord/bot_blocks.md#send-discord-embed) | Sends a rich embed message to a Discord channel |
|
||||
| [Send Discord File](block-integrations/discord/bot_blocks.md#send-discord-file) | Sends a file attachment to a Discord channel |
|
||||
| [Send Discord Message](block-integrations/discord/bot_blocks.md#send-discord-message) | Sends a message to a Discord channel using a bot token |
|
||||
| [Send Mod Mail](block-integrations/misc.md#send-mod-mail) | Sends a modmail message from a subreddit to a user |
|
||||
| [Send Reddit Message](block-integrations/misc.md#send-reddit-message) | Send a private message (DM) to a Reddit user |
|
||||
| [Send Telegram Audio](block-integrations/telegram/blocks.md#send-telegram-audio) | Send an audio file to a Telegram chat |
|
||||
| [Send Telegram Document](block-integrations/telegram/blocks.md#send-telegram-document) | Send a document (any file type) to a Telegram chat |
|
||||
@@ -414,6 +420,7 @@ Below is a comprehensive list of all available blocks, categorized by their prim
|
||||
| [Twitter Unmute User](block-integrations/twitter/mutes.md#twitter-unmute-user) | This block unmutes a specified Twitter user |
|
||||
| [Twitter Unpin List](block-integrations/twitter/pinned_lists.md#twitter-unpin-list) | This block allows the authenticated user to unpin a specified List |
|
||||
| [Twitter Update List](block-integrations/twitter/manage_lists.md#twitter-update-list) | This block updates a specified Twitter List owned by the authenticated user |
|
||||
| [Unban Subreddit User](block-integrations/misc.md#unban-subreddit-user) | Unbans a user from a subreddit |
|
||||
|
||||
## Communication
|
||||
|
||||
|
||||
@@ -38,6 +38,43 @@ Input and output schemas define the expected data structure for communication be
|
||||
|
||||
---
|
||||
|
||||
## Approve Reddit Post
|
||||
|
||||
### What it is
|
||||
Approves a Reddit post or comment from the mod queue. Requires 'modposts' scope.
|
||||
|
||||
### How it works
|
||||
<!-- MANUAL: how_it_works -->
|
||||
This block normalizes the incoming `post_id`, accepting either a bare submission ID or a Reddit fullname such as `t3_...` or `t1_...`. It then loads the correct PRAW object for a submission or comment and calls the moderator `approve()` action with credentials that include the `modposts` scope.
|
||||
|
||||
On success, the block returns the original `post_id` and `success=True`, which makes it easy to chain directly from `Mod Queue` output without losing whether the item started as a comment or a submission.
|
||||
<!-- END MANUAL -->
|
||||
|
||||
### Inputs
|
||||
|
||||
| Input | Description | Type | Required |
|
||||
|-------|-------------|------|----------|
|
||||
| post_id | ID or fullname of the post/comment to approve, such as 't3_abc123', 't1_xyz789', or bare submission ID 'abc123' | str | Yes |
|
||||
|
||||
### Outputs
|
||||
|
||||
| Output | Description | Type |
|
||||
|--------|-------------|------|
|
||||
| error | Error message if the operation failed | str |
|
||||
| post_id | ID of the approved post (pass-through) | str |
|
||||
| success | Whether the approval succeeded | bool |
|
||||
|
||||
### Possible use case
|
||||
<!-- MANUAL: use_case -->
|
||||
**False Positive Cleanup**: Approve items from the mod queue after an external classifier determines they are safe.
|
||||
|
||||
**Appeal Handling**: Restore a previously filtered post or comment after a moderator reviews a user appeal.
|
||||
|
||||
**Hybrid Moderation**: Combine automated queue scoring with a final approval step for borderline content.
|
||||
<!-- END MANUAL -->
|
||||
|
||||
---
|
||||
|
||||
## AutoPilot
|
||||
|
||||
### What it is
|
||||
@@ -86,6 +123,50 @@ Tool and block identifiers provided in `tools` and `blocks` are validated at run
|
||||
|
||||
---
|
||||
|
||||
## Ban Subreddit User
|
||||
|
||||
### What it is
|
||||
Bans a user from a subreddit. Requires 'modcontributors' scope.
|
||||
|
||||
### How it works
|
||||
<!-- MANUAL: how_it_works -->
|
||||
This block validates the optional `duration` before sending the request, rejects non-positive values, and truncates the internal `reason` and `mod_note` fields to Reddit's moderation limits. It then calls `sub.banned.add(...)`, optionally including a temporary duration and a user-facing `ban_message`.
|
||||
|
||||
The outputs echo the target user and subreddit, plus `success` and a derived `permanent` flag so downstream steps can branch on temporary versus permanent bans. Use `reason` and `mod_note` for internal moderation context, and reserve `ban_message` for the explanation shown to the banned user.
|
||||
<!-- END MANUAL -->
|
||||
|
||||
### Inputs
|
||||
|
||||
| Input | Description | Type | Required |
|
||||
|-------|-------------|------|----------|
|
||||
| subreddit | Subreddit to ban the user from, excluding the /r/ prefix | str | Yes |
|
||||
| username | Reddit username to ban (without the u/ prefix) | str | Yes |
|
||||
| duration | Ban duration in days. Leave blank for a permanent ban. | int | No |
|
||||
| reason | Internal moderator-only ban reason (max 100 chars). Use ban_message to explain the ban to the user. | str | No |
|
||||
| mod_note | Internal moderator note (not shown to the user) | str | No |
|
||||
| ban_message | Optional custom message sent to the user explaining the ban | str | No |
|
||||
|
||||
### Outputs
|
||||
|
||||
| Output | Description | Type |
|
||||
|--------|-------------|------|
|
||||
| error | Error message if the operation failed | str |
|
||||
| username | Banned username (pass-through) | str |
|
||||
| subreddit | Subreddit (pass-through) | str |
|
||||
| success | Whether the ban was applied | bool |
|
||||
| permanent | True if the ban is permanent | bool |
|
||||
|
||||
### Possible use case
|
||||
<!-- MANUAL: use_case -->
|
||||
**Escalation Workflow**: Ban repeat offenders automatically after multiple confirmed moderation violations.
|
||||
|
||||
**Temporary Cooldown**: Apply short bans during heated incidents while moderators review the situation.
|
||||
|
||||
**Policy Enforcement**: Pair user reporting or detection blocks with standardized ban actions and messaging.
|
||||
<!-- END MANUAL -->
|
||||
|
||||
---
|
||||
|
||||
## Create Reddit Post
|
||||
|
||||
### What it is
|
||||
@@ -737,6 +818,88 @@ The sandbox persists until its timeout expires or it's explicitly disposed. Use
|
||||
|
||||
---
|
||||
|
||||
## Lock Reddit Post
|
||||
|
||||
### What it is
|
||||
Locks or unlocks a Reddit post or comment to prevent or allow replies. Requires 'modposts' scope.
|
||||
|
||||
### How it works
|
||||
<!-- MANUAL: how_it_works -->
|
||||
This block resolves the target `post_id` to the correct Reddit object, so both submissions and comments can be locked or unlocked from the same input field. It then calls `lock()` when `lock=True` or `unlock()` when `lock=False`, using moderator credentials with the `modposts` scope.
|
||||
|
||||
The block returns the original `post_id` plus the resulting `locked` state, which makes it useful in review pipelines where the moderation decision and final state need to be recorded explicitly.
|
||||
<!-- END MANUAL -->
|
||||
|
||||
### Inputs
|
||||
|
||||
| Input | Description | Type | Required |
|
||||
|-------|-------------|------|----------|
|
||||
| post_id | ID or fullname of the post/comment to lock or unlock | str | Yes |
|
||||
| lock | True to lock (disable comments/replies), False to unlock | bool | No |
|
||||
|
||||
### Outputs
|
||||
|
||||
| Output | Description | Type |
|
||||
|--------|-------------|------|
|
||||
| error | Error message if the operation failed | str |
|
||||
| post_id | ID of the post (pass-through) | str |
|
||||
| locked | Current lock state after the action | bool |
|
||||
|
||||
### Possible use case
|
||||
<!-- MANUAL: use_case -->
|
||||
**Thread Freeze**: Lock a post when discussion becomes abusive or starts attracting brigading.
|
||||
|
||||
**Post-Resolution Cleanup**: Unlock a thread again after moderators have resolved the incident and want to reopen discussion.
|
||||
|
||||
**Incident Playbooks**: Trigger lock or unlock actions automatically from moderation decision trees.
|
||||
<!-- END MANUAL -->
|
||||
|
||||
---
|
||||
|
||||
## Mod Queue
|
||||
|
||||
### What it is
|
||||
Fetches the mod queue for a subreddit. Requires moderator access.
|
||||
|
||||
### How it works
|
||||
<!-- MANUAL: how_it_works -->
|
||||
This block calls `sub.mod.modqueue(...)` for the target subreddit, forwarding the optional `only` filter and `limit` value to Reddit. Each returned PRAW item is normalized into a predictable dictionary with a fullname ID, detected item type, title fallback for comments, author, permalink, and moderator reason.
|
||||
|
||||
The block emits every queue entry individually for fan-out workflows and also emits the full `items` list for batch processing. Because the `post_id` output keeps the Reddit fullname (`t1_...` or `t3_...`), downstream moderation blocks can safely act on comments and submissions without extra type checks.
|
||||
<!-- END MANUAL -->
|
||||
|
||||
### Inputs
|
||||
|
||||
| Input | Description | Type | Required |
|
||||
|-------|-------------|------|----------|
|
||||
| subreddit | Subreddit name, excluding the /r/ prefix | str | Yes |
|
||||
| limit | Maximum number of items to fetch from the mod queue | int | No |
|
||||
| only | Filter to only submissions or only comments. Leave blank for both. | "submissions" \| "comments" | No |
|
||||
|
||||
### Outputs
|
||||
|
||||
| Output | Description | Type |
|
||||
|--------|-------------|------|
|
||||
| error | Error message if the operation failed | str |
|
||||
| post_id | Full Reddit thing ID of a queued item, such as 't3_abc123' or 't1_xyz789' | str |
|
||||
| item_type | Whether the queued item is a comment or submission | "comment" \| "submission" |
|
||||
| post_title | Title of the queued item | str |
|
||||
| author | Username of the author | str |
|
||||
| permalink | Full Reddit permalink | str |
|
||||
| reason | Mod queue reason (if any) | str |
|
||||
| items | All queued items as a list | List[Dict[str, Any]] |
|
||||
|
||||
### Possible use case
|
||||
<!-- MANUAL: use_case -->
|
||||
**Queue Triage**: Pull the latest subreddit queue and route each item into approve, remove, or lock actions.
|
||||
|
||||
**Moderator Dashboards**: Feed queued items into summaries, alerts, or external review systems for human moderators.
|
||||
|
||||
**Policy Automation**: Filter only comments or only submissions when building specialized moderation pipelines.
|
||||
<!-- END MANUAL -->
|
||||
|
||||
---
|
||||
|
||||
## Post Reddit Comment
|
||||
|
||||
### What it is
|
||||
@@ -889,6 +1052,45 @@ This block uses the Reddit API via PRAW to fetch posts you've submitted to Reddi
|
||||
|
||||
---
|
||||
|
||||
## Remove Reddit Post
|
||||
|
||||
### What it is
|
||||
Removes a Reddit post or comment as a moderator. Requires 'modposts' scope.
|
||||
|
||||
### How it works
|
||||
<!-- MANUAL: how_it_works -->
|
||||
This block accepts either bare submission IDs or full Reddit thing IDs and resolves them to the correct submission or comment object before moderation. It calls `thing.mod.remove(...)`, forwarding the `spam` flag and truncating the optional `mod_note` to Reddit's 250-character moderator-note limit.
|
||||
|
||||
The block returns the original `post_id` and a success flag so workflows can record or branch on the removal decision. Passing through fullnames from `Mod Queue` lets the same flow moderate queued comments and submissions without extra conversion.
|
||||
<!-- END MANUAL -->
|
||||
|
||||
### Inputs
|
||||
|
||||
| Input | Description | Type | Required |
|
||||
|-------|-------------|------|----------|
|
||||
| post_id | ID or fullname of the post/comment to remove, such as 't3_abc123', 't1_xyz789', or bare submission ID 'abc123' | str | Yes |
|
||||
| spam | Mark as spam (True) or just remove (False). Spam trains the filter. | bool | No |
|
||||
| mod_note | Optional internal moderator note visible only to mods | str | No |
|
||||
|
||||
### Outputs
|
||||
|
||||
| Output | Description | Type |
|
||||
|--------|-------------|------|
|
||||
| error | Error message if the operation failed | str |
|
||||
| post_id | ID of the removed post (pass-through) | str |
|
||||
| success | Whether the removal succeeded | bool |
|
||||
|
||||
### Possible use case
|
||||
<!-- MANUAL: use_case -->
|
||||
**Spam Cleanup**: Remove obvious spam and optionally train Reddit's spam filter by marking it as spam.
|
||||
|
||||
**Comment Moderation**: Use fullnames from `Mod Queue` to remove problematic comments without additional lookup steps.
|
||||
|
||||
**Human-in-the-Loop Review**: Attach an internal moderator note when an automated rule removes borderline content.
|
||||
<!-- END MANUAL -->
|
||||
|
||||
---
|
||||
|
||||
## Reply To Reddit Comment
|
||||
|
||||
### What it is
|
||||
@@ -1043,6 +1245,46 @@ The block handles connection, authentication, and message delivery, returning a
|
||||
|
||||
---
|
||||
|
||||
## Send Mod Mail
|
||||
|
||||
### What it is
|
||||
Sends a modmail message from a subreddit to a user. Requires 'modmail' scope.
|
||||
|
||||
### How it works
|
||||
<!-- MANUAL: how_it_works -->
|
||||
This block opens the target subreddit and creates a modmail conversation via `sub.modmail.create(...)` using the provided recipient, subject, and body. Because it uses the subreddit modmail endpoint, the credential must include the `modmail` scope and moderator access to that community.
|
||||
|
||||
On success, the block returns the new conversation ID and `success=True`, which gives later steps a stable reference for logging or follow-up actions. Any Reddit API failure is surfaced through the standard `error` output.
|
||||
<!-- END MANUAL -->
|
||||
|
||||
### Inputs
|
||||
|
||||
| Input | Description | Type | Required |
|
||||
|-------|-------------|------|----------|
|
||||
| subreddit | Subreddit to send modmail from, excluding the /r/ prefix | str | Yes |
|
||||
| to_username | Username to send the modmail to (without u/ prefix) | str | Yes |
|
||||
| subject | Subject line of the modmail message | str | Yes |
|
||||
| body | Body of the modmail message | str | Yes |
|
||||
|
||||
### Outputs
|
||||
|
||||
| Output | Description | Type |
|
||||
|--------|-------------|------|
|
||||
| error | Error message if the operation failed | str |
|
||||
| conversation_id | ID of the created modmail conversation | str |
|
||||
| success | Whether the modmail was sent | bool |
|
||||
|
||||
### Possible use case
|
||||
<!-- MANUAL: use_case -->
|
||||
**Appeal Responses**: Send official moderator replies when a user asks why content was removed or locked.
|
||||
|
||||
**Proactive Outreach**: Notify a user about rule issues before escalating to stronger moderation actions.
|
||||
|
||||
**Case Management**: Create modmail threads that can be referenced by later audit or follow-up steps.
|
||||
<!-- END MANUAL -->
|
||||
|
||||
---
|
||||
|
||||
## Send Reddit Message
|
||||
|
||||
### What it is
|
||||
@@ -1158,3 +1400,42 @@ The transcript text is returned as a single string, suitable for summarization,
|
||||
<!-- END MANUAL -->
|
||||
|
||||
---
|
||||
|
||||
## Unban Subreddit User
|
||||
|
||||
### What it is
|
||||
Unbans a user from a subreddit. Requires 'modcontributors' scope.
|
||||
|
||||
### How it works
|
||||
<!-- MANUAL: how_it_works -->
|
||||
This block opens the target subreddit with moderator credentials and calls `sub.banned.remove(username)` to remove the user from the community ban list. It requires the `modcontributors` scope, and on success it returns the `username`, `subreddit`, and `success=True` so the unban can be audited or chained into follow-up actions such as notifications.
|
||||
|
||||
Reddit is responsible for validating the username, subreddit, and moderator permissions. If the username is malformed, the subreddit is missing, or the credential does not have sufficient moderator access, the Reddit API error is surfaced through the block's standard `error` output. In practice, repeated unban attempts are typically safe to treat as idempotent workflow steps: if another moderator already removed the ban, the operation should be logged as a no-op for auditability, while transient API failures or rate limits should be retried with normal backoff before escalating to a human moderator.
|
||||
<!-- END MANUAL -->
|
||||
|
||||
### Inputs
|
||||
|
||||
| Input | Description | Type | Required |
|
||||
|-------|-------------|------|----------|
|
||||
| subreddit | Subreddit to unban the user from, excluding the /r/ prefix | str | Yes |
|
||||
| username | Reddit username to unban (without the u/ prefix) | str | Yes |
|
||||
|
||||
### Outputs
|
||||
|
||||
| Output | Description | Type |
|
||||
|--------|-------------|------|
|
||||
| error | Error message if the operation failed | str |
|
||||
| username | Unbanned username (pass-through) | str |
|
||||
| subreddit | Subreddit (pass-through) | str |
|
||||
| success | Whether the unban succeeded | bool |
|
||||
|
||||
### Possible use case
|
||||
<!-- MANUAL: use_case -->
|
||||
**Appeal Resolution**: Restore access after a moderator approves a user's ban appeal.
|
||||
|
||||
**Temporary Ban Expiry**: Pair scheduled workflows with unban actions when a manual review confirms the restriction should end.
|
||||
|
||||
**Moderator Remediation**: Correct mistaken bans and immediately hand the result to a notification step.
|
||||
<!-- END MANUAL -->
|
||||
|
||||
---
|
||||
|
||||
Reference in New Issue
Block a user