diff --git a/autogpt_platform/backend/backend/blocks/reddit.py b/autogpt_platform/backend/backend/blocks/reddit.py index 1af4a0ac05..93147adb2d 100644 --- a/autogpt_platform/backend/backend/blocks/reddit.py +++ b/autogpt_platform/backend/backend/blocks/reddit.py @@ -3,7 +3,7 @@ from datetime import datetime, timezone from typing import Iterator, Literal import praw -from praw.models import MoreComments +from praw.models import Comment, MoreComments, Submission from pydantic import BaseModel, SecretStr from backend.data.block import ( @@ -23,6 +23,13 @@ from backend.integrations.providers import ProviderName from backend.util.mock import MockObject from backend.util.settings import Settings +# Type aliases for Reddit API options +UserPostSort = Literal["new", "hot", "top", "controversial"] +SearchSort = Literal["relevance", "hot", "top", "new", "comments"] +TimeFilter = Literal["all", "day", "hour", "month", "week", "year"] +CommentSort = Literal["best", "top", "new", "controversial", "old", "q&a"] +InboxType = Literal["all", "unread", "messages", "mentions", "comment_replies"] + RedditCredentials = OAuth2Credentials RedditCredentialsInput = CredentialsMetaInput[ Literal[ProviderName.REDDIT], @@ -43,7 +50,15 @@ TEST_CREDENTIALS = OAuth2Credentials( access_token=SecretStr("mock-reddit-access-token"), refresh_token=SecretStr("mock-reddit-refresh-token"), access_token_expires_at=9999999999, - scopes=["identity", "read", "submit", "history"], + scopes=[ + "identity", + "read", + "submit", + "edit", + "history", + "privatemessages", + "flair", + ], title="Mock Reddit credentials", username="mock-reddit-username", ) @@ -57,17 +72,12 @@ TEST_CREDENTIALS_INPUT = { class RedditPost(BaseModel): - id: str + post_id: str subreddit: str title: str body: str -class RedditComment(BaseModel): - post_id: str - comment: str - - settings = Settings() logger = logging.getLogger(__name__) @@ -99,6 +109,7 @@ class GetRedditPostsBlock(Block): subreddit: str = SchemaField( description="Subreddit name, excluding the /r/ prefix", default="writingprompts", + advanced=False, ) credentials: RedditCredentialsInput = RedditCredentialsField() last_minutes: int | None = SchemaField( @@ -139,26 +150,32 @@ class GetRedditPostsBlock(Block): ( "post", RedditPost( - id="id1", subreddit="subreddit", title="title1", body="body1" + post_id="id1", + subreddit="subreddit", + title="title1", + body="body1", ), ), ( "post", RedditPost( - id="id2", subreddit="subreddit", title="title2", body="body2" + post_id="id2", + subreddit="subreddit", + title="title2", + body="body2", ), ), ( "posts", [ RedditPost( - id="id1", + post_id="id1", subreddit="subreddit", title="title1", body="body1", ), RedditPost( - id="id2", + post_id="id2", subreddit="subreddit", title="title2", body="body2", @@ -195,13 +212,14 @@ class GetRedditPostsBlock(Block): ) time_difference = current_time - post_datetime if time_difference.total_seconds() / 60 > input_data.last_minutes: - continue + # Posts are ordered newest-first, so all subsequent posts will also be older + break if input_data.last_post and post.id == input_data.last_post: break reddit_post = RedditPost( - id=post.id, + post_id=post.id, subreddit=input_data.subreddit, title=post.title, body=post.selftext, @@ -224,6 +242,9 @@ class PostRedditCommentBlock(Block): class Output(BlockSchemaOutput): comment_id: str = SchemaField(description="Posted comment ID") + post_id: str = SchemaField( + description="The post ID (pass-through for chaining)" + ) def __init__(self): super().__init__( @@ -239,10 +260,16 @@ class PostRedditCommentBlock(Block): test_credentials=TEST_CREDENTIALS, test_input={ "credentials": TEST_CREDENTIALS_INPUT, - "data": {"post_id": "id", "comment": "comment"}, + "post_id": "test_post_id", + "comment": "comment", + }, + test_output=[ + ("comment_id", "dummy_comment_id"), + ("post_id", "test_post_id"), + ], + test_mock={ + "reply_post": lambda creds, post_id, comment: "dummy_comment_id" }, - test_output=[("comment_id", "dummy_comment_id")], - test_mock={"reply_post": lambda creds, comment: "dummy_comment_id"}, ) @staticmethod @@ -262,6 +289,7 @@ class PostRedditCommentBlock(Block): post_id=input_data.post_id, comment=input_data.comment, ) + yield "post_id", input_data.post_id class CreateRedditPostBlock(Block): @@ -282,13 +310,20 @@ class CreateRedditPostBlock(Block): default=None, ) flair_id: str | None = SchemaField( - description="Flair ID to apply to the post", + description="Flair template ID to apply to the post (from GetSubredditFlairsBlock)", + default=None, + ) + flair_text: str | None = SchemaField( + description="Custom flair text (only used if the flair template allows editing)", default=None, ) class Output(BlockSchemaOutput): post_id: str = SchemaField(description="ID of the created post") post_url: str = SchemaField(description="URL of the created post") + subreddit: str = SchemaField( + description="The subreddit name (pass-through for chaining)" + ) def __init__(self): super().__init__( @@ -311,9 +346,10 @@ class CreateRedditPostBlock(Block): test_output=[ ("post_id", "abc123"), ("post_url", "https://reddit.com/r/test/comments/abc123/test_post/"), + ("subreddit", "test"), ], test_mock={ - "create_post": lambda creds, subreddit, title, content, url, flair_id: ( + "create_post": lambda creds, subreddit, title, content, url, flair_id, flair_text: ( "abc123", "https://reddit.com/r/test/comments/abc123/test_post/", ) @@ -328,6 +364,7 @@ class CreateRedditPostBlock(Block): content: str = "", url: str | None = None, flair_id: str | None = None, + flair_text: str | None = None, ) -> tuple[str, str]: """ Create a new post on a subreddit. @@ -338,7 +375,8 @@ class CreateRedditPostBlock(Block): title: Post title content: Post body text (for text posts) url: URL to submit (for link posts, overrides content) - flair_id: Optional flair ID to apply + flair_id: Optional flair template ID to apply + flair_text: Optional custom flair text (for editable flairs) Returns: Tuple of (post_id, post_url) @@ -347,9 +385,13 @@ class CreateRedditPostBlock(Block): sub = client.subreddit(subreddit) if url: - submission = sub.submit(title=title, url=url, flair_id=flair_id) + submission = sub.submit( + title=title, url=url, flair_id=flair_id, flair_text=flair_text + ) else: - submission = sub.submit(title=title, selftext=content, flair_id=flair_id) + submission = sub.submit( + title=title, selftext=content, flair_id=flair_id, flair_text=flair_text + ) return submission.id, f"https://reddit.com{submission.permalink}" @@ -363,9 +405,11 @@ class CreateRedditPostBlock(Block): input_data.content, input_data.url, input_data.flair_id, + input_data.flair_text, ) yield "post_id", post_id yield "post_url", post_url + yield "subreddit", input_data.subreddit class RedditPostDetails(BaseModel): @@ -502,8 +546,8 @@ class GetUserPostsBlock(Block): description="Maximum number of posts to fetch", default=10, ) - sort: str = SchemaField( - description="Sort order: 'new', 'hot', 'top', or 'controversial'", + sort: UserPostSort = SchemaField( + description="Sort order for user posts", default="new", ) @@ -535,23 +579,29 @@ class GetUserPostsBlock(Block): ( "post", RedditPost( - id="id1", subreddit="sub1", title="title1", body="body1" + post_id="id1", subreddit="sub1", title="title1", body="body1" ), ), ( "post", RedditPost( - id="id2", subreddit="sub2", title="title2", body="body2" + post_id="id2", subreddit="sub2", title="title2", body="body2" ), ), ( "posts", [ RedditPost( - id="id1", subreddit="sub1", title="title1", body="body1" + post_id="id1", + subreddit="sub1", + title="title1", + body="body1", ), RedditPost( - id="id2", subreddit="sub2", title="title2", body="body2" + post_id="id2", + subreddit="sub2", + title="title2", + body="body2", ), ], ), @@ -576,8 +626,8 @@ class GetUserPostsBlock(Block): @staticmethod def get_user_posts( - creds: RedditCredentials, username: str, limit: int, sort: str - ) -> list: + creds: RedditCredentials, username: str, limit: int, sort: UserPostSort + ) -> list[Submission]: client = get_praw(creds) redditor = client.redditor(username) @@ -607,7 +657,139 @@ class GetUserPostsBlock(Block): all_posts = [] for submission in submissions: post = RedditPost( - id=submission.id, + post_id=submission.id, + subreddit=submission.subreddit.display_name, + title=submission.title, + body=submission.selftext, + ) + all_posts.append(post) + yield "post", post + yield "posts", all_posts + except Exception as e: + yield "error", str(e) + + +class GetMyPostsBlock(Block): + """Get posts by the authenticated Reddit user.""" + + class Input(BlockSchemaInput): + credentials: RedditCredentialsInput = RedditCredentialsField() + post_limit: int = SchemaField( + description="Maximum number of posts to fetch", + default=10, + ) + sort: UserPostSort = SchemaField( + description="Sort order for posts", + default="new", + ) + + class Output(BlockSchemaOutput): + post: RedditPost = SchemaField(description="A post by you") + posts: list[RedditPost] = SchemaField(description="All your posts") + error: str = SchemaField( + description="Error message if posts couldn't be fetched" + ) + + def __init__(self): + super().__init__( + id="4ab3381b-0c07-4201-89b3-fa2ec264f154", + description="Fetch posts created by the authenticated Reddit user (you).", + categories={BlockCategory.SOCIAL}, + input_schema=GetMyPostsBlock.Input, + output_schema=GetMyPostsBlock.Output, + disabled=( + not settings.secrets.reddit_client_id + or not settings.secrets.reddit_client_secret + ), + test_credentials=TEST_CREDENTIALS, + test_input={ + "credentials": TEST_CREDENTIALS_INPUT, + "post_limit": 2, + }, + test_output=[ + ( + "post", + RedditPost( + post_id="id1", subreddit="sub1", title="title1", body="body1" + ), + ), + ( + "post", + RedditPost( + post_id="id2", subreddit="sub2", title="title2", body="body2" + ), + ), + ( + "posts", + [ + RedditPost( + post_id="id1", + subreddit="sub1", + title="title1", + body="body1", + ), + RedditPost( + post_id="id2", + subreddit="sub2", + title="title2", + body="body2", + ), + ], + ), + ], + test_mock={ + "get_my_posts": lambda creds, limit, sort: [ + MockObject( + id="id1", + subreddit=MockObject(display_name="sub1"), + title="title1", + selftext="body1", + ), + MockObject( + id="id2", + subreddit=MockObject(display_name="sub2"), + title="title2", + selftext="body2", + ), + ] + }, + ) + + @staticmethod + def get_my_posts( + creds: RedditCredentials, limit: int, sort: UserPostSort + ) -> list[Submission]: + client = get_praw(creds) + me = client.user.me() + if not me: + raise ValueError("Could not get authenticated user.") + + if sort == "new": + submissions = me.submissions.new(limit=limit) + elif sort == "hot": + submissions = me.submissions.hot(limit=limit) + elif sort == "top": + submissions = me.submissions.top(limit=limit) + elif sort == "controversial": + submissions = me.submissions.controversial(limit=limit) + else: + submissions = me.submissions.new(limit=limit) + + return list(submissions) + + async def run( + self, input_data: Input, *, credentials: RedditCredentials, **kwargs + ) -> BlockOutput: + try: + submissions = self.get_my_posts( + credentials, + input_data.post_limit, + input_data.sort, + ) + all_posts = [] + for submission in submissions: + post = RedditPost( + post_id=submission.id, subreddit=submission.subreddit.display_name, title=submission.title, body=submission.selftext, @@ -645,12 +827,12 @@ class SearchRedditBlock(Block): description="Limit search to a specific subreddit (without /r/ prefix)", default=None, ) - sort: str = SchemaField( - description="Sort order: 'relevance', 'hot', 'top', 'new', or 'comments'", + sort: SearchSort = SchemaField( + description="Sort order for search results", default="relevance", ) - time_filter: str = SchemaField( - description="Time filter: 'all', 'day', 'hour', 'month', 'week', or 'year'", + time_filter: TimeFilter = SchemaField( + description="Time filter for search results", default="all", ) limit: int = SchemaField( @@ -772,10 +954,10 @@ class SearchRedditBlock(Block): creds: RedditCredentials, query: str, subreddit: str | None, - sort: str, - time_filter: str, + sort: SearchSort, + time_filter: TimeFilter, limit: int, - ) -> list: + ) -> list[Submission]: client = get_praw(creds) if subreddit: @@ -834,6 +1016,9 @@ class EditRedditPostBlock(Block): class Output(BlockSchemaOutput): success: bool = SchemaField(description="Whether the edit was successful") + post_id: str = SchemaField( + description="The post ID (pass-through for chaining)" + ) post_url: str = SchemaField(description="URL of the edited post") error: str = SchemaField(description="Error message if the edit failed") @@ -856,6 +1041,7 @@ class EditRedditPostBlock(Block): }, test_output=[ ("success", True), + ("post_id", "abc123"), ("post_url", "https://reddit.com/r/test/comments/abc123/test_post/"), ], test_mock={ @@ -886,6 +1072,7 @@ class EditRedditPostBlock(Block): credentials, input_data.post_id, input_data.new_content ) yield "success", success + yield "post_id", input_data.post_id yield "post_url", post_url except Exception as e: error_msg = str(e) @@ -922,6 +1109,9 @@ class GetSubredditInfoBlock(Block): class Output(BlockSchemaOutput): info: SubredditInfo = SchemaField(description="Subreddit information") + subreddit: str = SchemaField( + description="The subreddit name (pass-through for chaining)" + ) error: str = SchemaField( description="Error message if the subreddit couldn't be fetched" ) @@ -957,6 +1147,7 @@ class GetSubredditInfoBlock(Block): url="/r/python/", ), ), + ("subreddit", "python"), ], test_mock={ "get_subreddit_info": lambda creds, subreddit: SubredditInfo( @@ -996,16 +1187,17 @@ class GetSubredditInfoBlock(Block): try: info = self.get_subreddit_info(credentials, input_data.subreddit) yield "info", info + yield "subreddit", input_data.subreddit except Exception as e: yield "error", str(e) -class RedditCommentData(BaseModel): - """Data about a Reddit comment.""" +class RedditComment(BaseModel): + """A Reddit comment.""" - id: str + comment_id: str post_id: str - parent_id: str + parent_comment_id: str | None author: str body: str score: int @@ -1016,7 +1208,7 @@ class RedditCommentData(BaseModel): depth: int -class GetPostCommentsBlock(Block): +class GetRedditPostCommentsBlock(Block): """Get comments on a Reddit post.""" class Input(BlockSchemaInput): @@ -1028,25 +1220,28 @@ class GetPostCommentsBlock(Block): description="Maximum number of top-level comments to fetch (max 100)", default=25, ) - sort: str = SchemaField( - description="Sort order: 'best', 'top', 'new', 'controversial', 'old', 'q&a'", + sort: CommentSort = SchemaField( + description="Sort order for comments", default="best", ) class Output(BlockSchemaOutput): - comment: RedditCommentData = SchemaField(description="A comment on the post") - comments: list[RedditCommentData] = SchemaField( - description="All fetched comments" + comment: RedditComment = SchemaField(description="A comment on the post") + comments: list[RedditComment] = SchemaField(description="All fetched comments") + post_id: str = SchemaField( + description="The post ID (pass-through for chaining)" + ) + error: str = SchemaField( + description="Error message if comments couldn't be fetched" ) - error: str = SchemaField(description="Error message if comments couldn't be fetched") def __init__(self): super().__init__( id="98422b2c-c3b0-4d70-871f-56bd966f46da", description="Get top-level comments on a Reddit post.", categories={BlockCategory.SOCIAL}, - input_schema=GetPostCommentsBlock.Input, - output_schema=GetPostCommentsBlock.Output, + input_schema=GetRedditPostCommentsBlock.Input, + output_schema=GetRedditPostCommentsBlock.Output, disabled=( not settings.secrets.reddit_client_id or not settings.secrets.reddit_client_secret @@ -1060,10 +1255,10 @@ class GetPostCommentsBlock(Block): test_output=[ ( "comment", - RedditCommentData( - id="comment1", + RedditComment( + comment_id="comment1", post_id="abc123", - parent_id="t3_abc123", + parent_comment_id=None, author="user1", body="Comment body 1", score=10, @@ -1076,10 +1271,10 @@ class GetPostCommentsBlock(Block): ), ( "comment", - RedditCommentData( - id="comment2", + RedditComment( + comment_id="comment2", post_id="abc123", - parent_id="t3_abc123", + parent_comment_id=None, author="user2", body="Comment body 2", score=5, @@ -1093,10 +1288,10 @@ class GetPostCommentsBlock(Block): ( "comments", [ - RedditCommentData( - id="comment1", + RedditComment( + comment_id="comment1", post_id="abc123", - parent_id="t3_abc123", + parent_comment_id=None, author="user1", body="Comment body 1", score=10, @@ -1106,10 +1301,10 @@ class GetPostCommentsBlock(Block): permalink="/r/test/comments/abc123/test/comment1/", depth=0, ), - RedditCommentData( - id="comment2", + RedditComment( + comment_id="comment2", post_id="abc123", - parent_id="t3_abc123", + parent_comment_id=None, author="user2", body="Comment body 2", score=5, @@ -1121,6 +1316,7 @@ class GetPostCommentsBlock(Block): ), ], ), + ("post_id", "abc123"), ], test_mock={ "get_comments": lambda creds, post_id, limit, sort: [ @@ -1156,8 +1352,8 @@ class GetPostCommentsBlock(Block): @staticmethod def get_comments( - creds: RedditCredentials, post_id: str, limit: int, sort: str - ) -> list: + creds: RedditCredentials, post_id: str, limit: int, sort: CommentSort + ) -> list[Comment]: client = get_praw(creds) if post_id.startswith("t3_"): post_id = post_id[3:] @@ -1168,7 +1364,10 @@ class GetPostCommentsBlock(Block): # Return only top-level comments (depth=0), limited # CommentForest supports indexing, so use slicing directly max_comments = min(limit, 100) - return [submission.comments[i] for i in range(min(len(submission.comments), max_comments))] + return [ + submission.comments[i] + for i in range(min(len(submission.comments), max_comments)) + ] async def run( self, input_data: Input, *, credentials: RedditCredentials, **kwargs @@ -1183,14 +1382,20 @@ class GetPostCommentsBlock(Block): all_comments = [] for comment in comments: # Extract post_id from link_id (format: t3_xxxxx) - post_id = comment.link_id - if post_id.startswith("t3_"): - post_id = post_id[3:] + comment_post_id = comment.link_id + if comment_post_id.startswith("t3_"): + comment_post_id = comment_post_id[3:] - comment_data = RedditCommentData( - id=comment.id, - post_id=post_id, - parent_id=comment.parent_id, + # parent_comment_id is None for top-level comments (parent is a post: t3_) + # For replies, extract the comment ID from t1_xxxxx + parent_comment_id = None + if comment.parent_id.startswith("t1_"): + parent_comment_id = comment.parent_id[3:] + + comment_data = RedditComment( + comment_id=comment.id, + post_id=comment_post_id, + parent_comment_id=parent_comment_id, author=str(comment.author) if comment.author else "[deleted]", body=comment.body, score=comment.score, @@ -1203,11 +1408,12 @@ class GetPostCommentsBlock(Block): all_comments.append(comment_data) yield "comment", comment_data yield "comments", all_comments + yield "post_id", input_data.post_id except Exception as e: yield "error", str(e) -class GetCommentRepliesBlock(Block): +class GetRedditCommentRepliesBlock(Block): """Get replies to a specific Reddit comment.""" class Input(BlockSchemaInput): @@ -1224,17 +1430,25 @@ class GetCommentRepliesBlock(Block): ) class Output(BlockSchemaOutput): - reply: RedditCommentData = SchemaField(description="A reply to the comment") - replies: list[RedditCommentData] = SchemaField(description="All replies") - error: str = SchemaField(description="Error message if replies couldn't be fetched") + reply: RedditComment = SchemaField(description="A reply to the comment") + replies: list[RedditComment] = SchemaField(description="All replies") + comment_id: str = SchemaField( + description="The parent comment ID (pass-through for chaining)" + ) + post_id: str = SchemaField( + description="The post ID (pass-through for chaining)" + ) + error: str = SchemaField( + description="Error message if replies couldn't be fetched" + ) def __init__(self): super().__init__( id="7fa83965-7289-432f-98a9-1575f5bcc8f1", description="Get replies to a specific Reddit comment.", categories={BlockCategory.SOCIAL}, - input_schema=GetCommentRepliesBlock.Input, - output_schema=GetCommentRepliesBlock.Output, + input_schema=GetRedditCommentRepliesBlock.Input, + output_schema=GetRedditCommentRepliesBlock.Output, disabled=( not settings.secrets.reddit_client_id or not settings.secrets.reddit_client_secret @@ -1249,10 +1463,10 @@ class GetCommentRepliesBlock(Block): test_output=[ ( "reply", - RedditCommentData( - id="reply1", + RedditComment( + comment_id="reply1", post_id="abc123", - parent_id="t1_comment1", + parent_comment_id="comment1", author="replier1", body="Reply body 1", score=3, @@ -1266,10 +1480,10 @@ class GetCommentRepliesBlock(Block): ( "replies", [ - RedditCommentData( - id="reply1", + RedditComment( + comment_id="reply1", post_id="abc123", - parent_id="t1_comment1", + parent_comment_id="comment1", author="replier1", body="Reply body 1", score=3, @@ -1281,6 +1495,8 @@ class GetCommentRepliesBlock(Block): ), ], ), + ("comment_id", "comment1"), + ("post_id", "abc123"), ], test_mock={ "get_replies": lambda creds, comment_id, post_id, limit: [ @@ -1304,7 +1520,7 @@ class GetCommentRepliesBlock(Block): @staticmethod def get_replies( creds: RedditCredentials, comment_id: str, post_id: str, limit: int - ) -> list: + ) -> list[Comment]: client = get_praw(creds) if post_id.startswith("t3_"): post_id = post_id[3:] @@ -1352,14 +1568,19 @@ class GetCommentRepliesBlock(Block): ) all_replies = [] for reply in replies: - post_id = reply.link_id - if post_id.startswith("t3_"): - post_id = post_id[3:] + reply_post_id = reply.link_id + if reply_post_id.startswith("t3_"): + reply_post_id = reply_post_id[3:] - reply_data = RedditCommentData( - id=reply.id, - post_id=post_id, - parent_id=reply.parent_id, + # parent_comment_id is the parent comment (always present for replies) + parent_comment_id = None + if reply.parent_id.startswith("t1_"): + parent_comment_id = reply.parent_id[3:] + + reply_data = RedditComment( + comment_id=reply.id, + post_id=reply_post_id, + parent_comment_id=parent_comment_id, author=str(reply.author) if reply.author else "[deleted]", body=reply.body, score=reply.score, @@ -1372,11 +1593,13 @@ class GetCommentRepliesBlock(Block): all_replies.append(reply_data) yield "reply", reply_data yield "replies", all_replies + yield "comment_id", input_data.comment_id + yield "post_id", input_data.post_id except Exception as e: yield "error", str(e) -class GetCommentBlock(Block): +class GetRedditCommentBlock(Block): """Get details about a specific Reddit comment.""" class Input(BlockSchemaInput): @@ -1386,16 +1609,18 @@ class GetCommentBlock(Block): ) class Output(BlockSchemaOutput): - comment: RedditCommentData = SchemaField(description="The comment details") - error: str = SchemaField(description="Error message if comment couldn't be fetched") + comment: RedditComment = SchemaField(description="The comment details") + error: str = SchemaField( + description="Error message if comment couldn't be fetched" + ) def __init__(self): super().__init__( id="72cb311a-5998-4e0a-9bc4-f1b67a97284e", description="Get details about a specific Reddit comment by its ID.", categories={BlockCategory.SOCIAL}, - input_schema=GetCommentBlock.Input, - output_schema=GetCommentBlock.Output, + input_schema=GetRedditCommentBlock.Input, + output_schema=GetRedditCommentBlock.Output, disabled=( not settings.secrets.reddit_client_id or not settings.secrets.reddit_client_secret @@ -1408,10 +1633,10 @@ class GetCommentBlock(Block): test_output=[ ( "comment", - RedditCommentData( - id="comment1", + RedditComment( + comment_id="comment1", post_id="abc123", - parent_id="t3_abc123", + parent_comment_id=None, author="user1", body="Comment body", score=10, @@ -1457,10 +1682,15 @@ class GetCommentBlock(Block): if post_id.startswith("t3_"): post_id = post_id[3:] - comment_data = RedditCommentData( - id=comment.id, + # parent_comment_id is None for top-level comments (parent is a post: t3_) + parent_comment_id = None + if comment.parent_id.startswith("t1_"): + parent_comment_id = comment.parent_id[3:] + + comment_data = RedditComment( + comment_id=comment.id, post_id=post_id, - parent_id=comment.parent_id, + parent_comment_id=parent_comment_id, author=str(comment.author) if comment.author else "[deleted]", body=comment.body, score=comment.score, @@ -1468,8 +1698,833 @@ class GetCommentBlock(Block): edited=bool(comment.edited), is_submitter=comment.is_submitter, permalink=comment.permalink, - depth=comment.depth, + # depth is only available when comments are fetched as part of a tree, + # not when fetched directly by ID + depth=getattr(comment, "depth", 0), ) yield "comment", comment_data except Exception as e: yield "error", str(e) + + +class ReplyToRedditCommentBlock(Block): + """Reply to a specific Reddit comment.""" + + class Input(BlockSchemaInput): + credentials: RedditCredentialsInput = RedditCredentialsField() + comment_id: str = SchemaField( + description="The ID of the comment to reply to", + ) + reply_text: str = SchemaField( + description="The text content of the reply", + ) + + class Output(BlockSchemaOutput): + comment_id: str = SchemaField(description="ID of the newly created reply") + parent_comment_id: str = SchemaField( + description="The parent comment ID (pass-through for chaining)" + ) + error: str = SchemaField(description="Error message if reply failed") + + def __init__(self): + super().__init__( + id="7635b059-3a9f-4f7d-b499-1b56c4f76f4f", + description="Reply to a specific Reddit comment. Useful for threaded conversations.", + categories={BlockCategory.SOCIAL}, + input_schema=ReplyToRedditCommentBlock.Input, + output_schema=ReplyToRedditCommentBlock.Output, + disabled=( + not settings.secrets.reddit_client_id + or not settings.secrets.reddit_client_secret + ), + test_credentials=TEST_CREDENTIALS, + test_input={ + "credentials": TEST_CREDENTIALS_INPUT, + "comment_id": "parent_comment", + "reply_text": "This is a reply", + }, + test_output=[ + ("comment_id", "new_reply_id"), + ("parent_comment_id", "parent_comment"), + ], + test_mock={ + "reply_to_comment": lambda creds, comment_id, reply_text: "new_reply_id" + }, + ) + + @staticmethod + def reply_to_comment( + creds: RedditCredentials, comment_id: str, reply_text: str + ) -> str: + client = get_praw(creds) + if comment_id.startswith("t1_"): + comment_id = comment_id[3:] + comment = client.comment(id=comment_id) + reply = comment.reply(reply_text) + if not reply: + raise ValueError("Failed to post reply.") + return reply.id + + async def run( + self, input_data: Input, *, credentials: RedditCredentials, **kwargs + ) -> BlockOutput: + try: + new_comment_id = self.reply_to_comment( + credentials, input_data.comment_id, input_data.reply_text + ) + yield "comment_id", new_comment_id + yield "parent_comment_id", input_data.comment_id + except Exception as e: + yield "error", str(e) + + +class RedditUserProfileSubreddit(BaseModel): + """Information about a user's profile subreddit.""" + + name: str + title: str + public_description: str + subscribers: int + over_18: bool + + +class RedditUserInfo(BaseModel): + """Information about a Reddit user.""" + + username: str + user_id: str + comment_karma: int + link_karma: int + total_karma: int + created_utc: float + is_gold: bool + is_mod: bool + has_verified_email: bool + moderated_subreddits: list[str] + profile_subreddit: RedditUserProfileSubreddit | None + + +class GetRedditUserInfoBlock(Block): + """Get information about a Reddit user.""" + + class Input(BlockSchemaInput): + credentials: RedditCredentialsInput = RedditCredentialsField() + username: str = SchemaField( + description="The Reddit username to look up (without /u/ prefix)", + ) + + class Output(BlockSchemaOutput): + user: RedditUserInfo = SchemaField(description="User information") + username: str = SchemaField( + description="The username (pass-through for chaining)" + ) + error: str = SchemaField(description="Error message if user lookup failed") + + def __init__(self): + super().__init__( + id="1b4c6bd1-4f28-4bad-9ae9-e7034a0f61ff", + description="Get information about a Reddit user including karma, account age, and verification status.", + categories={BlockCategory.SOCIAL}, + input_schema=GetRedditUserInfoBlock.Input, + output_schema=GetRedditUserInfoBlock.Output, + disabled=( + not settings.secrets.reddit_client_id + or not settings.secrets.reddit_client_secret + ), + test_credentials=TEST_CREDENTIALS, + test_input={ + "credentials": TEST_CREDENTIALS_INPUT, + "username": "testuser", + }, + test_output=[ + ( + "user", + RedditUserInfo( + username="testuser", + user_id="abc123", + comment_karma=1000, + link_karma=500, + total_karma=1500, + created_utc=1234567890.0, + is_gold=False, + is_mod=True, + has_verified_email=True, + moderated_subreddits=["python", "learnpython"], + profile_subreddit=RedditUserProfileSubreddit( + name="u_testuser", + title="testuser's profile", + public_description="A test user", + subscribers=100, + over_18=False, + ), + ), + ), + ("username", "testuser"), + ], + test_mock={ + "get_user_info": lambda creds, username: MockObject( + name="testuser", + id="abc123", + comment_karma=1000, + link_karma=500, + total_karma=1500, + created_utc=1234567890.0, + is_gold=False, + is_mod=True, + has_verified_email=True, + subreddit=MockObject( + display_name="u_testuser", + title="testuser's profile", + public_description="A test user", + subscribers=100, + over_18=False, + ), + ), + "get_moderated_subreddits": lambda creds, username: [ + MockObject(display_name="python"), + MockObject(display_name="learnpython"), + ], + }, + ) + + @staticmethod + def get_user_info(creds: RedditCredentials, username: str): + client = get_praw(creds) + if username.startswith("u/"): + username = username[2:] + return client.redditor(username) + + @staticmethod + def get_moderated_subreddits(creds: RedditCredentials, username: str) -> list: + client = get_praw(creds) + if username.startswith("u/"): + username = username[2:] + redditor = client.redditor(username) + return list(redditor.moderated()) + + async def run( + self, input_data: Input, *, credentials: RedditCredentials, **kwargs + ) -> BlockOutput: + try: + redditor = self.get_user_info(credentials, input_data.username) + moderated = self.get_moderated_subreddits(credentials, input_data.username) + + # Extract moderated subreddit names + moderated_subreddits = [sub.display_name for sub in moderated] + + # Get profile subreddit info if available + profile_subreddit = None + if hasattr(redditor, "subreddit") and redditor.subreddit: + try: + profile_subreddit = RedditUserProfileSubreddit( + name=redditor.subreddit.display_name, + title=redditor.subreddit.title or "", + public_description=redditor.subreddit.public_description or "", + subscribers=redditor.subreddit.subscribers or 0, + over_18=( + redditor.subreddit.over_18 + if hasattr(redditor.subreddit, "over_18") + else False + ), + ) + except Exception: + # Profile subreddit may not be accessible + pass + + user_info = RedditUserInfo( + username=redditor.name, + user_id=redditor.id, + comment_karma=redditor.comment_karma, + link_karma=redditor.link_karma, + total_karma=redditor.total_karma, + created_utc=redditor.created_utc, + is_gold=redditor.is_gold, + is_mod=redditor.is_mod, + has_verified_email=redditor.has_verified_email, + moderated_subreddits=moderated_subreddits, + profile_subreddit=profile_subreddit, + ) + yield "user", user_info + yield "username", input_data.username + except Exception as e: + yield "error", str(e) + + +class SendRedditMessageBlock(Block): + """Send a private message to a Reddit user.""" + + class Input(BlockSchemaInput): + credentials: RedditCredentialsInput = RedditCredentialsField() + username: str = SchemaField( + description="The Reddit username to send a message to (without /u/ prefix)", + ) + subject: str = SchemaField( + description="The subject line of the message", + ) + message: str = SchemaField( + description="The body content of the message", + ) + + class Output(BlockSchemaOutput): + success: bool = SchemaField(description="Whether the message was sent") + username: str = SchemaField( + description="The username (pass-through for chaining)" + ) + error: str = SchemaField(description="Error message if sending failed") + + def __init__(self): + super().__init__( + id="7921101a-0537-4259-82ea-bc186ca6b1b6", + description="Send a private message (DM) to a Reddit user.", + categories={BlockCategory.SOCIAL}, + input_schema=SendRedditMessageBlock.Input, + output_schema=SendRedditMessageBlock.Output, + disabled=( + not settings.secrets.reddit_client_id + or not settings.secrets.reddit_client_secret + ), + test_credentials=TEST_CREDENTIALS, + test_input={ + "credentials": TEST_CREDENTIALS_INPUT, + "username": "testuser", + "subject": "Hello", + "message": "This is a test message", + }, + test_output=[ + ("success", True), + ("username", "testuser"), + ], + test_mock={"send_message": lambda creds, username, subject, message: True}, + ) + + @staticmethod + def send_message( + creds: RedditCredentials, username: str, subject: str, message: str + ) -> bool: + client = get_praw(creds) + if username.startswith("u/"): + username = username[2:] + redditor = client.redditor(username) + redditor.message(subject=subject, message=message) + return True + + async def run( + self, input_data: Input, *, credentials: RedditCredentials, **kwargs + ) -> BlockOutput: + try: + success = self.send_message( + credentials, + input_data.username, + input_data.subject, + input_data.message, + ) + yield "success", success + yield "username", input_data.username + except Exception as e: + yield "error", str(e) + + +class RedditInboxItem(BaseModel): + """A Reddit inbox item (message, comment reply, or mention).""" + + item_id: str + item_type: str # "message", "comment_reply", "mention" + subject: str + body: str + author: str + created_utc: float + is_read: bool + context: str | None # permalink for comments, None for messages + + +class GetRedditInboxBlock(Block): + """Get messages and notifications from Reddit inbox.""" + + class Input(BlockSchemaInput): + credentials: RedditCredentialsInput = RedditCredentialsField() + inbox_type: InboxType = SchemaField( + description="Type of inbox items to fetch", + default="unread", + ) + limit: int = SchemaField( + description="Maximum number of items to fetch", + default=25, + ) + mark_read: bool = SchemaField( + description="Whether to mark fetched items as read", + default=False, + ) + + class Output(BlockSchemaOutput): + item: RedditInboxItem = SchemaField(description="An inbox item") + items: list[RedditInboxItem] = SchemaField(description="All fetched items") + error: str = SchemaField(description="Error message if fetch failed") + + def __init__(self): + super().__init__( + id="5a91bb34-7ffe-4b9e-957b-9d4f8fe8dbc9", + description="Get messages, mentions, and comment replies from your Reddit inbox.", + categories={BlockCategory.SOCIAL}, + input_schema=GetRedditInboxBlock.Input, + output_schema=GetRedditInboxBlock.Output, + disabled=( + not settings.secrets.reddit_client_id + or not settings.secrets.reddit_client_secret + ), + test_credentials=TEST_CREDENTIALS, + test_input={ + "credentials": TEST_CREDENTIALS_INPUT, + "inbox_type": "unread", + "limit": 10, + }, + test_output=[ + ( + "item", + RedditInboxItem( + item_id="msg123", + item_type="message", + subject="Hello", + body="Test message body", + author="sender_user", + created_utc=1234567890.0, + is_read=False, + context=None, + ), + ), + ( + "items", + [ + RedditInboxItem( + item_id="msg123", + item_type="message", + subject="Hello", + body="Test message body", + author="sender_user", + created_utc=1234567890.0, + is_read=False, + context=None, + ), + ], + ), + ], + test_mock={ + "get_inbox": lambda creds, inbox_type, limit: [ + MockObject( + id="msg123", + subject="Hello", + body="Test message body", + author="sender_user", + created_utc=1234567890.0, + new=True, + context=None, + was_comment=False, + ), + ] + }, + ) + + @staticmethod + def get_inbox(creds: RedditCredentials, inbox_type: InboxType, limit: int) -> list: + client = get_praw(creds) + inbox = client.inbox + + if inbox_type == "all": + items = inbox.all(limit=limit) + elif inbox_type == "unread": + items = inbox.unread(limit=limit) + elif inbox_type == "messages": + items = inbox.messages(limit=limit) + elif inbox_type == "mentions": + items = inbox.mentions(limit=limit) + elif inbox_type == "comment_replies": + items = inbox.comment_replies(limit=limit) + else: + items = inbox.unread(limit=limit) + + return list(items) + + async def run( + self, input_data: Input, *, credentials: RedditCredentials, **kwargs + ) -> BlockOutput: + try: + raw_items = self.get_inbox( + credentials, input_data.inbox_type, input_data.limit + ) + all_items = [] + + for item in raw_items: + # Determine item type + if hasattr(item, "was_comment") and item.was_comment: + if hasattr(item, "subject") and "mention" in item.subject.lower(): + item_type = "mention" + else: + item_type = "comment_reply" + else: + item_type = "message" + + inbox_item = RedditInboxItem( + item_id=item.id, + item_type=item_type, + subject=item.subject if hasattr(item, "subject") else "", + body=item.body, + author=str(item.author) if item.author else "[deleted]", + created_utc=item.created_utc, + is_read=not item.new, + context=item.context if hasattr(item, "context") else None, + ) + all_items.append(inbox_item) + yield "item", inbox_item + + # Mark as read if requested + if input_data.mark_read and raw_items: + client = get_praw(credentials) + client.inbox.mark_read(raw_items) + + yield "items", all_items + except Exception as e: + yield "error", str(e) + + +class DeleteRedditPostBlock(Block): + """Delete a Reddit post that you own.""" + + class Input(BlockSchemaInput): + credentials: RedditCredentialsInput = RedditCredentialsField() + post_id: str = SchemaField( + description="The ID of the post to delete (must be your own post)", + ) + + class Output(BlockSchemaOutput): + success: bool = SchemaField(description="Whether the deletion was successful") + post_id: str = SchemaField( + description="The post ID (pass-through for chaining)" + ) + error: str = SchemaField(description="Error message if deletion failed") + + def __init__(self): + super().__init__( + id="72e4730a-d66d-4785-8e54-5ab3af450c81", + description="Delete a Reddit post that you own.", + categories={BlockCategory.SOCIAL}, + input_schema=DeleteRedditPostBlock.Input, + output_schema=DeleteRedditPostBlock.Output, + disabled=( + not settings.secrets.reddit_client_id + or not settings.secrets.reddit_client_secret + ), + test_credentials=TEST_CREDENTIALS, + test_input={ + "credentials": TEST_CREDENTIALS_INPUT, + "post_id": "abc123", + }, + test_output=[ + ("success", True), + ("post_id", "abc123"), + ], + test_mock={"delete_post": lambda creds, post_id: True}, + ) + + @staticmethod + def delete_post(creds: RedditCredentials, post_id: str) -> bool: + client = get_praw(creds) + if post_id.startswith("t3_"): + post_id = post_id[3:] + submission = client.submission(id=post_id) + submission.delete() + return True + + async def run( + self, input_data: Input, *, credentials: RedditCredentials, **kwargs + ) -> BlockOutput: + try: + success = self.delete_post(credentials, input_data.post_id) + yield "success", success + yield "post_id", input_data.post_id + except Exception as e: + yield "error", str(e) + + +class DeleteRedditCommentBlock(Block): + """Delete a Reddit comment that you own.""" + + class Input(BlockSchemaInput): + credentials: RedditCredentialsInput = RedditCredentialsField() + comment_id: str = SchemaField( + description="The ID of the comment to delete (must be your own comment)", + ) + + class Output(BlockSchemaOutput): + success: bool = SchemaField(description="Whether the deletion was successful") + comment_id: str = SchemaField( + description="The comment ID (pass-through for chaining)" + ) + error: str = SchemaField(description="Error message if deletion failed") + + def __init__(self): + super().__init__( + id="2650584d-434f-46db-81ef-26c8d8d41f81", + description="Delete a Reddit comment that you own.", + categories={BlockCategory.SOCIAL}, + input_schema=DeleteRedditCommentBlock.Input, + output_schema=DeleteRedditCommentBlock.Output, + disabled=( + not settings.secrets.reddit_client_id + or not settings.secrets.reddit_client_secret + ), + test_credentials=TEST_CREDENTIALS, + test_input={ + "credentials": TEST_CREDENTIALS_INPUT, + "comment_id": "xyz789", + }, + test_output=[ + ("success", True), + ("comment_id", "xyz789"), + ], + test_mock={"delete_comment": lambda creds, comment_id: True}, + ) + + @staticmethod + def delete_comment(creds: RedditCredentials, comment_id: str) -> bool: + client = get_praw(creds) + if comment_id.startswith("t1_"): + comment_id = comment_id[3:] + comment = client.comment(id=comment_id) + comment.delete() + return True + + async def run( + self, input_data: Input, *, credentials: RedditCredentials, **kwargs + ) -> BlockOutput: + try: + success = self.delete_comment(credentials, input_data.comment_id) + yield "success", success + yield "comment_id", input_data.comment_id + except Exception as e: + yield "error", str(e) + + +class SubredditFlair(BaseModel): + """A subreddit link flair template.""" + + flair_id: str + text: str + text_editable: bool + css_class: str = "" # The CSS class for styling (from flair_css_class) + + +class GetSubredditFlairsBlock(Block): + """Get available link flairs for a subreddit.""" + + class Input(BlockSchemaInput): + credentials: RedditCredentialsInput = RedditCredentialsField() + subreddit: str = SchemaField( + description="Subreddit name (without /r/ prefix)", + ) + + class Output(BlockSchemaOutput): + flair: SubredditFlair = SchemaField(description="A flair option") + flairs: list[SubredditFlair] = SchemaField(description="All available flairs") + subreddit: str = SchemaField( + description="The subreddit name (pass-through for chaining)" + ) + error: str = SchemaField(description="Error message if fetch failed") + + def __init__(self): + super().__init__( + id="ada08f34-a7a9-44aa-869f-0638fa4e0a84", + description="Get available link flair options for a subreddit.", + categories={BlockCategory.SOCIAL}, + input_schema=GetSubredditFlairsBlock.Input, + output_schema=GetSubredditFlairsBlock.Output, + disabled=( + not settings.secrets.reddit_client_id + or not settings.secrets.reddit_client_secret + ), + test_credentials=TEST_CREDENTIALS, + test_input={ + "credentials": TEST_CREDENTIALS_INPUT, + "subreddit": "test", + }, + test_output=[ + ( + "flair", + SubredditFlair( + flair_id="abc123", + text="Discussion", + text_editable=False, + css_class="discussion", + ), + ), + ( + "flairs", + [ + SubredditFlair( + flair_id="abc123", + text="Discussion", + text_editable=False, + css_class="discussion", + ), + ], + ), + ("subreddit", "test"), + ], + test_mock={ + "get_flairs": lambda creds, subreddit: [ + { + "flair_template_id": "abc123", + "flair_text": "Discussion", + "flair_text_editable": False, + "flair_css_class": "discussion", + }, + ] + }, + ) + + @staticmethod + def get_flairs(creds: RedditCredentials, subreddit: str) -> list: + client = get_praw(creds) + # Use /r/{subreddit}/api/flairselector endpoint directly with is_newlink=True + # This returns link flairs available for new submissions without requiring mod access + # The link_templates API is moderator-only, so we use flairselector instead + # Path must include the subreddit prefix per Reddit API docs + response = client.post( + f"r/{subreddit}/api/flairselector", + data={"is_newlink": "true"}, + ) + # Response contains 'choices' list with available flairs + choices = response.get("choices", []) + return choices + + async def run( + self, input_data: Input, *, credentials: RedditCredentials, **kwargs + ) -> BlockOutput: + try: + raw_flairs = self.get_flairs(credentials, input_data.subreddit) + all_flairs = [] + + for flair in raw_flairs: + # /api/flairselector returns flairs with flair_template_id, flair_text, etc. + flair_data = SubredditFlair( + flair_id=flair.get("flair_template_id", ""), + text=flair.get("flair_text", ""), + text_editable=flair.get("flair_text_editable", False), + css_class=flair.get("flair_css_class", ""), + ) + all_flairs.append(flair_data) + yield "flair", flair_data + + yield "flairs", all_flairs + yield "subreddit", input_data.subreddit + except Exception as e: + yield "error", str(e) + + +class SubredditRule(BaseModel): + """A subreddit rule.""" + + short_name: str + description: str + kind: str # "all", "link", "comment" + violation_reason: str + priority: int + + +class GetSubredditRulesBlock(Block): + """Get the rules for a subreddit.""" + + class Input(BlockSchemaInput): + credentials: RedditCredentialsInput = RedditCredentialsField() + subreddit: str = SchemaField( + description="Subreddit name (without /r/ prefix)", + ) + + class Output(BlockSchemaOutput): + rule: SubredditRule = SchemaField(description="A subreddit rule") + rules: list[SubredditRule] = SchemaField(description="All subreddit rules") + subreddit: str = SchemaField( + description="The subreddit name (pass-through for chaining)" + ) + error: str = SchemaField(description="Error message if fetch failed") + + def __init__(self): + super().__init__( + id="222aa36c-fa70-4879-8e8a-37d100175f5c", + description="Get the rules for a subreddit to ensure compliance before posting.", + categories={BlockCategory.SOCIAL}, + input_schema=GetSubredditRulesBlock.Input, + output_schema=GetSubredditRulesBlock.Output, + disabled=( + not settings.secrets.reddit_client_id + or not settings.secrets.reddit_client_secret + ), + test_credentials=TEST_CREDENTIALS, + test_input={ + "credentials": TEST_CREDENTIALS_INPUT, + "subreddit": "test", + }, + test_output=[ + ( + "rule", + SubredditRule( + short_name="No spam", + description="Do not post spam or self-promotional content.", + kind="all", + violation_reason="Spam", + priority=0, + ), + ), + ( + "rules", + [ + SubredditRule( + short_name="No spam", + description="Do not post spam or self-promotional content.", + kind="all", + violation_reason="Spam", + priority=0, + ), + ], + ), + ("subreddit", "test"), + ], + test_mock={ + "get_rules": lambda creds, subreddit: [ + MockObject( + short_name="No spam", + description="Do not post spam or self-promotional content.", + kind="all", + violation_reason="Spam", + priority=0, + ), + ] + }, + ) + + @staticmethod + def get_rules(creds: RedditCredentials, subreddit: str) -> list: + client = get_praw(creds) + sub = client.subreddit(subreddit) + return list(sub.rules) + + async def run( + self, input_data: Input, *, credentials: RedditCredentials, **kwargs + ) -> BlockOutput: + try: + raw_rules = self.get_rules(credentials, input_data.subreddit) + all_rules = [] + + for idx, rule in enumerate(raw_rules): + rule_data = SubredditRule( + short_name=rule.short_name, + description=rule.description or "", + kind=rule.kind, + violation_reason=rule.violation_reason or rule.short_name, + priority=idx, + ) + all_rules.append(rule_data) + yield "rule", rule_data + + yield "rules", all_rules + yield "subreddit", input_data.subreddit + except Exception as e: + yield "error", str(e) diff --git a/autogpt_platform/backend/backend/integrations/oauth/reddit.py b/autogpt_platform/backend/backend/integrations/oauth/reddit.py index 8afe915189..f714854a70 100644 --- a/autogpt_platform/backend/backend/integrations/oauth/reddit.py +++ b/autogpt_platform/backend/backend/integrations/oauth/reddit.py @@ -31,6 +31,8 @@ class RedditOAuthHandler(BaseOAuthHandler): "submit", # Submit new posts and comments "edit", # Edit own posts and comments "history", # Access user's post history + "privatemessages", # Access inbox and send private messages + "flair", # Access and set flair on posts/subreddits ] AUTHORIZE_URL = "https://www.reddit.com/api/v1/authorize"