diff --git a/.branchlet.json b/.branchlet.json new file mode 100644 index 0000000000..cc13ff9f74 --- /dev/null +++ b/.branchlet.json @@ -0,0 +1,37 @@ +{ + "worktreeCopyPatterns": [ + ".env*", + ".vscode/**", + ".auth/**", + ".claude/**", + "autogpt_platform/.env*", + "autogpt_platform/backend/.env*", + "autogpt_platform/frontend/.env*", + "autogpt_platform/frontend/.auth/**", + "autogpt_platform/db/docker/.env*" + ], + "worktreeCopyIgnores": [ + "**/node_modules/**", + "**/dist/**", + "**/.git/**", + "**/Thumbs.db", + "**/.DS_Store", + "**/.next/**", + "**/__pycache__/**", + "**/.ruff_cache/**", + "**/.pytest_cache/**", + "**/*.pyc", + "**/playwright-report/**", + "**/logs/**", + "**/site/**" + ], + "worktreePathTemplate": "$BASE_PATH.worktree", + "postCreateCmd": [ + "cd autogpt_platform/autogpt_libs && poetry install", + "cd autogpt_platform/backend && poetry install && poetry run prisma generate", + "cd autogpt_platform/frontend && pnpm install", + "cd docs && pip install -r requirements.txt" + ], + "terminalCommand": "code .", + "deleteBranchWithWorktree": false +} diff --git a/.dockerignore b/.dockerignore index 94bf1742f1..c9524ce700 100644 --- a/.dockerignore +++ b/.dockerignore @@ -16,6 +16,7 @@ !autogpt_platform/backend/poetry.lock !autogpt_platform/backend/README.md !autogpt_platform/backend/.env +!autogpt_platform/backend/gen_prisma_types_stub.py # Platform - Market !autogpt_platform/market/market/ diff --git a/.github/workflows/claude-dependabot.yml b/.github/workflows/claude-dependabot.yml index 20b6f1d28e..1fd0da3d8e 100644 --- a/.github/workflows/claude-dependabot.yml +++ b/.github/workflows/claude-dependabot.yml @@ -74,7 +74,7 @@ jobs: - name: Generate Prisma Client working-directory: autogpt_platform/backend - run: poetry run prisma generate + run: poetry run prisma generate && poetry run gen-prisma-stub # Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml) - name: Set up Node.js diff --git a/.github/workflows/claude.yml b/.github/workflows/claude.yml index 3f5e8c22ec..71c6ef49c2 100644 --- a/.github/workflows/claude.yml +++ b/.github/workflows/claude.yml @@ -90,7 +90,7 @@ jobs: - name: Generate Prisma Client working-directory: autogpt_platform/backend - run: poetry run prisma generate + run: poetry run prisma generate && poetry run gen-prisma-stub # Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml) - name: Set up Node.js diff --git a/.github/workflows/copilot-setup-steps.yml b/.github/workflows/copilot-setup-steps.yml index 13ef01cc44..aac8befee0 100644 --- a/.github/workflows/copilot-setup-steps.yml +++ b/.github/workflows/copilot-setup-steps.yml @@ -72,7 +72,7 @@ jobs: - name: Generate Prisma Client working-directory: autogpt_platform/backend - run: poetry run prisma generate + run: poetry run prisma generate && poetry run gen-prisma-stub # Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml) - name: Set up Node.js @@ -108,6 +108,16 @@ jobs: # run: pnpm playwright install --with-deps chromium # Docker setup for development environment + - name: Free up disk space + run: | + # Remove large unused tools to free disk space for Docker builds + sudo rm -rf /usr/share/dotnet + sudo rm -rf /usr/local/lib/android + sudo rm -rf /opt/ghc + sudo rm -rf /opt/hostedtoolcache/CodeQL + sudo docker system prune -af + df -h + - name: Set up Docker Buildx uses: docker/setup-buildx-action@v3 diff --git a/.github/workflows/platform-backend-ci.yml b/.github/workflows/platform-backend-ci.yml index f962382fa5..da5ab83c1c 100644 --- a/.github/workflows/platform-backend-ci.yml +++ b/.github/workflows/platform-backend-ci.yml @@ -134,7 +134,7 @@ jobs: run: poetry install - name: Generate Prisma Client - run: poetry run prisma generate + run: poetry run prisma generate && poetry run gen-prisma-stub - id: supabase name: Start Supabase diff --git a/autogpt_platform/Makefile b/autogpt_platform/Makefile index d99fee49d7..2ff454e392 100644 --- a/autogpt_platform/Makefile +++ b/autogpt_platform/Makefile @@ -12,6 +12,7 @@ reset-db: rm -rf db/docker/volumes/db/data cd backend && poetry run prisma migrate deploy cd backend && poetry run prisma generate + cd backend && poetry run gen-prisma-stub # View logs for core services logs-core: @@ -33,6 +34,7 @@ init-env: migrate: cd backend && poetry run prisma migrate deploy cd backend && poetry run prisma generate + cd backend && poetry run gen-prisma-stub run-backend: cd backend && poetry run app diff --git a/autogpt_platform/backend/Dockerfile b/autogpt_platform/backend/Dockerfile index 7f51bad3a1..b3389d1787 100644 --- a/autogpt_platform/backend/Dockerfile +++ b/autogpt_platform/backend/Dockerfile @@ -48,7 +48,8 @@ RUN poetry install --no-ansi --no-root # Generate Prisma client COPY autogpt_platform/backend/schema.prisma ./ COPY autogpt_platform/backend/backend/data/partial_types.py ./backend/data/partial_types.py -RUN poetry run prisma generate +COPY autogpt_platform/backend/gen_prisma_types_stub.py ./ +RUN poetry run prisma generate && poetry run gen-prisma-stub FROM debian:13-slim AS server_dependencies diff --git a/autogpt_platform/backend/backend/api/features/library/db.py b/autogpt_platform/backend/backend/api/features/library/db.py index 69ed0d2730..1c17e7b36c 100644 --- a/autogpt_platform/backend/backend/api/features/library/db.py +++ b/autogpt_platform/backend/backend/api/features/library/db.py @@ -489,7 +489,7 @@ async def update_agent_version_in_library( agent_graph_version: int, ) -> library_model.LibraryAgent: """ - Updates the agent version in the library if useGraphIsActiveVersion is True. + Updates the agent version in the library for any agent owned by the user. Args: user_id: Owner of the LibraryAgent. @@ -498,20 +498,31 @@ async def update_agent_version_in_library( Raises: DatabaseError: If there's an error with the update. + NotFoundError: If no library agent is found for this user and agent. """ logger.debug( f"Updating agent version in library for user #{user_id}, " f"agent #{agent_graph_id} v{agent_graph_version}" ) - try: - library_agent = await prisma.models.LibraryAgent.prisma().find_first_or_raise( + async with transaction() as tx: + library_agent = await prisma.models.LibraryAgent.prisma(tx).find_first_or_raise( where={ "userId": user_id, "agentGraphId": agent_graph_id, - "useGraphIsActiveVersion": True, }, ) - lib = await prisma.models.LibraryAgent.prisma().update( + + # Delete any conflicting LibraryAgent for the target version + await prisma.models.LibraryAgent.prisma(tx).delete_many( + where={ + "userId": user_id, + "agentGraphId": agent_graph_id, + "agentGraphVersion": agent_graph_version, + "id": {"not": library_agent.id}, + } + ) + + lib = await prisma.models.LibraryAgent.prisma(tx).update( where={"id": library_agent.id}, data={ "AgentGraph": { @@ -525,13 +536,13 @@ async def update_agent_version_in_library( }, include={"AgentGraph": True}, ) - if lib is None: - raise NotFoundError(f"Library agent {library_agent.id} not found") - return library_model.LibraryAgent.from_db(lib) - except prisma.errors.PrismaError as e: - logger.error(f"Database error updating agent version in library: {e}") - raise DatabaseError("Failed to update agent version in library") from e + if lib is None: + raise NotFoundError( + f"Failed to update library agent for {agent_graph_id} v{agent_graph_version}" + ) + + return library_model.LibraryAgent.from_db(lib) async def update_library_agent( @@ -825,6 +836,7 @@ async def add_store_agent_to_library( } }, "isCreatedByUser": False, + "useGraphIsActiveVersion": False, "settings": SafeJson( _initialize_graph_settings(graph_model).model_dump() ), diff --git a/autogpt_platform/backend/backend/api/features/library/model.py b/autogpt_platform/backend/backend/api/features/library/model.py index c20f82afae..56fad7bfd3 100644 --- a/autogpt_platform/backend/backend/api/features/library/model.py +++ b/autogpt_platform/backend/backend/api/features/library/model.py @@ -48,6 +48,7 @@ class LibraryAgent(pydantic.BaseModel): id: str graph_id: str graph_version: int + owner_user_id: str # ID of user who owns/created this agent graph image_url: str | None @@ -163,6 +164,7 @@ class LibraryAgent(pydantic.BaseModel): id=agent.id, graph_id=agent.agentGraphId, graph_version=agent.agentGraphVersion, + owner_user_id=agent.userId, image_url=agent.imageUrl, creator_name=creator_name, creator_image_url=creator_image_url, diff --git a/autogpt_platform/backend/backend/api/features/library/routes_test.py b/autogpt_platform/backend/backend/api/features/library/routes_test.py index ad28b5b6bd..0f05240a7f 100644 --- a/autogpt_platform/backend/backend/api/features/library/routes_test.py +++ b/autogpt_platform/backend/backend/api/features/library/routes_test.py @@ -42,6 +42,7 @@ async def test_get_library_agents_success( id="test-agent-1", graph_id="test-agent-1", graph_version=1, + owner_user_id=test_user_id, name="Test Agent 1", description="Test Description 1", image_url=None, @@ -64,6 +65,7 @@ async def test_get_library_agents_success( id="test-agent-2", graph_id="test-agent-2", graph_version=1, + owner_user_id=test_user_id, name="Test Agent 2", description="Test Description 2", image_url=None, @@ -138,6 +140,7 @@ async def test_get_favorite_library_agents_success( id="test-agent-1", graph_id="test-agent-1", graph_version=1, + owner_user_id=test_user_id, name="Favorite Agent 1", description="Test Favorite Description 1", image_url=None, @@ -205,6 +208,7 @@ def test_add_agent_to_library_success( id="test-library-agent-id", graph_id="test-agent-1", graph_version=1, + owner_user_id=test_user_id, name="Test Agent 1", description="Test Description 1", image_url=None, diff --git a/autogpt_platform/backend/backend/api/features/store/db.py b/autogpt_platform/backend/backend/api/features/store/db.py index 77f02cdd99..c3097ab5aa 100644 --- a/autogpt_platform/backend/backend/api/features/store/db.py +++ b/autogpt_platform/backend/backend/api/features/store/db.py @@ -615,6 +615,7 @@ async def get_store_submissions( submission_models = [] for sub in submissions: submission_model = store_model.StoreSubmission( + listing_id=sub.listing_id, agent_id=sub.agent_id, agent_version=sub.agent_version, name=sub.name, @@ -668,35 +669,48 @@ async def delete_store_submission( submission_id: str, ) -> bool: """ - Delete a store listing submission as the submitting user. + Delete a store submission version as the submitting user. Args: user_id: ID of the authenticated user - submission_id: ID of the submission to be deleted + submission_id: StoreListingVersion ID to delete Returns: - bool: True if the submission was successfully deleted, False otherwise + bool: True if successfully deleted """ - logger.debug(f"Deleting store submission {submission_id} for user {user_id}") - try: - # Verify the submission belongs to this user - submission = await prisma.models.StoreListing.prisma().find_first( - where={"agentGraphId": submission_id, "owningUserId": user_id} + # Find the submission version with ownership check + version = await prisma.models.StoreListingVersion.prisma().find_first( + where={"id": submission_id}, include={"StoreListing": True} ) - if not submission: - logger.warning(f"Submission not found for user {user_id}: {submission_id}") - raise store_exceptions.SubmissionNotFoundError( - f"Submission not found for this user. User ID: {user_id}, Submission ID: {submission_id}" + if ( + not version + or not version.StoreListing + or version.StoreListing.owningUserId != user_id + ): + raise store_exceptions.SubmissionNotFoundError("Submission not found") + + # Prevent deletion of approved submissions + if version.submissionStatus == prisma.enums.SubmissionStatus.APPROVED: + raise store_exceptions.InvalidOperationError( + "Cannot delete approved submissions" ) - # Delete the submission - await prisma.models.StoreListing.prisma().delete(where={"id": submission.id}) - - logger.debug( - f"Successfully deleted submission {submission_id} for user {user_id}" + # Delete the version + await prisma.models.StoreListingVersion.prisma().delete( + where={"id": version.id} ) + + # Clean up empty listing if this was the last version + remaining = await prisma.models.StoreListingVersion.prisma().count( + where={"storeListingId": version.storeListingId} + ) + if remaining == 0: + await prisma.models.StoreListing.prisma().delete( + where={"id": version.storeListingId} + ) + return True except Exception as e: @@ -760,9 +774,15 @@ async def create_store_submission( logger.warning( f"Agent not found for user {user_id}: {agent_id} v{agent_version}" ) - raise store_exceptions.AgentNotFoundError( - f"Agent not found for this user. User ID: {user_id}, Agent ID: {agent_id}, Version: {agent_version}" - ) + # Provide more user-friendly error message when agent_id is empty + if not agent_id or agent_id.strip() == "": + raise store_exceptions.AgentNotFoundError( + "No agent selected. Please select an agent before submitting to the store." + ) + else: + raise store_exceptions.AgentNotFoundError( + f"Agent not found for this user. User ID: {user_id}, Agent ID: {agent_id}, Version: {agent_version}" + ) # Check if listing already exists for this agent existing_listing = await prisma.models.StoreListing.prisma().find_first( @@ -834,6 +854,7 @@ async def create_store_submission( logger.debug(f"Created store listing for agent {agent_id}") # Return submission details return store_model.StoreSubmission( + listing_id=listing.id, agent_id=agent_id, agent_version=agent_version, name=name, @@ -945,81 +966,56 @@ async def edit_store_submission( # Currently we are not allowing user to update the agent associated with a submission # If we allow it in future, then we need a check here to verify the agent belongs to this user. - # Check if we can edit this submission - if current_version.submissionStatus == prisma.enums.SubmissionStatus.REJECTED: + # Only allow editing of PENDING submissions + if current_version.submissionStatus != prisma.enums.SubmissionStatus.PENDING: raise store_exceptions.InvalidOperationError( - "Cannot edit a rejected submission" - ) - - # For APPROVED submissions, we need to create a new version - if current_version.submissionStatus == prisma.enums.SubmissionStatus.APPROVED: - # Create a new version for the existing listing - return await create_store_version( - user_id=user_id, - agent_id=current_version.agentGraphId, - agent_version=current_version.agentGraphVersion, - store_listing_id=current_version.storeListingId, - name=name, - video_url=video_url, - agent_output_demo_url=agent_output_demo_url, - image_urls=image_urls, - description=description, - sub_heading=sub_heading, - categories=categories, - changes_summary=changes_summary, - recommended_schedule_cron=recommended_schedule_cron, - instructions=instructions, + f"Cannot edit a {current_version.submissionStatus.value.lower()} submission. Only pending submissions can be edited." ) # For PENDING submissions, we can update the existing version - elif current_version.submissionStatus == prisma.enums.SubmissionStatus.PENDING: - # Update the existing version - updated_version = await prisma.models.StoreListingVersion.prisma().update( - where={"id": store_listing_version_id}, - data=prisma.types.StoreListingVersionUpdateInput( - name=name, - videoUrl=video_url, - agentOutputDemoUrl=agent_output_demo_url, - imageUrls=image_urls, - description=description, - categories=categories, - subHeading=sub_heading, - changesSummary=changes_summary, - recommendedScheduleCron=recommended_schedule_cron, - instructions=instructions, - ), - ) - - logger.debug( - f"Updated existing version {store_listing_version_id} for agent {current_version.agentGraphId}" - ) - - if not updated_version: - raise DatabaseError("Failed to update store listing version") - return store_model.StoreSubmission( - agent_id=current_version.agentGraphId, - agent_version=current_version.agentGraphVersion, + # Update the existing version + updated_version = await prisma.models.StoreListingVersion.prisma().update( + where={"id": store_listing_version_id}, + data=prisma.types.StoreListingVersionUpdateInput( name=name, - sub_heading=sub_heading, - slug=current_version.StoreListing.slug, + videoUrl=video_url, + agentOutputDemoUrl=agent_output_demo_url, + imageUrls=image_urls, description=description, - instructions=instructions, - image_urls=image_urls, - date_submitted=updated_version.submittedAt or updated_version.createdAt, - status=updated_version.submissionStatus, - runs=0, - rating=0.0, - store_listing_version_id=updated_version.id, - changes_summary=changes_summary, - video_url=video_url, categories=categories, - version=updated_version.version, - ) + subHeading=sub_heading, + changesSummary=changes_summary, + recommendedScheduleCron=recommended_schedule_cron, + instructions=instructions, + ), + ) - else: - raise store_exceptions.InvalidOperationError( - f"Cannot edit submission with status: {current_version.submissionStatus}" - ) + logger.debug( + f"Updated existing version {store_listing_version_id} for agent {current_version.agentGraphId}" + ) + + if not updated_version: + raise DatabaseError("Failed to update store listing version") + return store_model.StoreSubmission( + listing_id=current_version.StoreListing.id, + agent_id=current_version.agentGraphId, + agent_version=current_version.agentGraphVersion, + name=name, + sub_heading=sub_heading, + slug=current_version.StoreListing.slug, + description=description, + instructions=instructions, + image_urls=image_urls, + date_submitted=updated_version.submittedAt or updated_version.createdAt, + status=updated_version.submissionStatus, + runs=0, + rating=0.0, + store_listing_version_id=updated_version.id, + changes_summary=changes_summary, + video_url=video_url, + categories=categories, + version=updated_version.version, + ) except ( store_exceptions.SubmissionNotFoundError, @@ -1098,38 +1094,78 @@ async def create_store_version( f"Agent not found for this user. User ID: {user_id}, Agent ID: {agent_id}, Version: {agent_version}" ) - # Get the latest version number - latest_version = listing.Versions[0] if listing.Versions else None - - next_version = (latest_version.version + 1) if latest_version else 1 - - # Create a new version for the existing listing - new_version = await prisma.models.StoreListingVersion.prisma().create( - data=prisma.types.StoreListingVersionCreateInput( - version=next_version, - agentGraphId=agent_id, - agentGraphVersion=agent_version, - name=name, - videoUrl=video_url, - agentOutputDemoUrl=agent_output_demo_url, - imageUrls=image_urls, - description=description, - instructions=instructions, - categories=categories, - subHeading=sub_heading, - submissionStatus=prisma.enums.SubmissionStatus.PENDING, - submittedAt=datetime.now(), - changesSummary=changes_summary, - recommendedScheduleCron=recommended_schedule_cron, - storeListingId=store_listing_id, + # Check if there's already a PENDING submission for this agent (any version) + existing_pending_submission = ( + await prisma.models.StoreListingVersion.prisma().find_first( + where=prisma.types.StoreListingVersionWhereInput( + storeListingId=store_listing_id, + agentGraphId=agent_id, + submissionStatus=prisma.enums.SubmissionStatus.PENDING, + isDeleted=False, + ) ) ) + # Handle existing pending submission and create new one atomically + async with transaction() as tx: + # Get the latest version number first + latest_listing = await prisma.models.StoreListing.prisma(tx).find_first( + where=prisma.types.StoreListingWhereInput( + id=store_listing_id, owningUserId=user_id + ), + include={"Versions": {"order_by": {"version": "desc"}, "take": 1}}, + ) + + if not latest_listing: + raise store_exceptions.ListingNotFoundError( + f"Store listing not found. User ID: {user_id}, Listing ID: {store_listing_id}" + ) + + latest_version = ( + latest_listing.Versions[0] if latest_listing.Versions else None + ) + next_version = (latest_version.version + 1) if latest_version else 1 + + # If there's an existing pending submission, delete it atomically before creating new one + if existing_pending_submission: + logger.info( + f"Found existing PENDING submission for agent {agent_id} (was v{existing_pending_submission.agentGraphVersion}, now v{agent_version}), replacing existing submission instead of creating duplicate" + ) + await prisma.models.StoreListingVersion.prisma(tx).delete( + where={"id": existing_pending_submission.id} + ) + logger.debug( + f"Deleted existing pending submission {existing_pending_submission.id}" + ) + + # Create a new version for the existing listing + new_version = await prisma.models.StoreListingVersion.prisma(tx).create( + data=prisma.types.StoreListingVersionCreateInput( + version=next_version, + agentGraphId=agent_id, + agentGraphVersion=agent_version, + name=name, + videoUrl=video_url, + agentOutputDemoUrl=agent_output_demo_url, + imageUrls=image_urls, + description=description, + instructions=instructions, + categories=categories, + subHeading=sub_heading, + submissionStatus=prisma.enums.SubmissionStatus.PENDING, + submittedAt=datetime.now(), + changesSummary=changes_summary, + recommendedScheduleCron=recommended_schedule_cron, + storeListingId=store_listing_id, + ) + ) + logger.debug( f"Created new version for listing {store_listing_id} of agent {agent_id}" ) # Return submission details return store_model.StoreSubmission( + listing_id=listing.id, agent_id=agent_id, agent_version=agent_version, name=name, @@ -1732,15 +1768,12 @@ async def review_store_submission( # Convert to Pydantic model for consistency return store_model.StoreSubmission( + listing_id=(submission.StoreListing.id if submission.StoreListing else ""), agent_id=submission.agentGraphId, agent_version=submission.agentGraphVersion, name=submission.name, sub_heading=submission.subHeading, - slug=( - submission.StoreListing.slug - if hasattr(submission, "storeListing") and submission.StoreListing - else "" - ), + slug=(submission.StoreListing.slug if submission.StoreListing else ""), description=submission.description, instructions=submission.instructions, image_urls=submission.imageUrls or [], @@ -1842,9 +1875,7 @@ async def get_admin_listings_with_versions( where = prisma.types.StoreListingWhereInput(**where_dict) include = prisma.types.StoreListingInclude( Versions=prisma.types.FindManyStoreListingVersionArgsFromStoreListing( - order_by=prisma.types._StoreListingVersion_version_OrderByInput( - version="desc" - ) + order_by={"version": "desc"} ), OwningUser=True, ) @@ -1869,6 +1900,7 @@ async def get_admin_listings_with_versions( # If we have versions, turn them into StoreSubmission models for version in listing.Versions or []: version_model = store_model.StoreSubmission( + listing_id=listing.id, agent_id=version.agentGraphId, agent_version=version.agentGraphVersion, name=version.name, diff --git a/autogpt_platform/backend/backend/api/features/store/model.py b/autogpt_platform/backend/backend/api/features/store/model.py index b68f1144e5..41eaeb6679 100644 --- a/autogpt_platform/backend/backend/api/features/store/model.py +++ b/autogpt_platform/backend/backend/api/features/store/model.py @@ -110,6 +110,7 @@ class Profile(pydantic.BaseModel): class StoreSubmission(pydantic.BaseModel): + listing_id: str agent_id: str agent_version: int name: str @@ -164,8 +165,12 @@ class StoreListingsWithVersionsResponse(pydantic.BaseModel): class StoreSubmissionRequest(pydantic.BaseModel): - agent_id: str - agent_version: int + agent_id: str = pydantic.Field( + ..., min_length=1, description="Agent ID cannot be empty" + ) + agent_version: int = pydantic.Field( + ..., gt=0, description="Agent version must be greater than 0" + ) slug: str name: str sub_heading: str diff --git a/autogpt_platform/backend/backend/api/features/store/model_test.py b/autogpt_platform/backend/backend/api/features/store/model_test.py index a37966601b..fd09a0cf77 100644 --- a/autogpt_platform/backend/backend/api/features/store/model_test.py +++ b/autogpt_platform/backend/backend/api/features/store/model_test.py @@ -138,6 +138,7 @@ def test_creator_details(): def test_store_submission(): submission = store_model.StoreSubmission( + listing_id="listing123", agent_id="agent123", agent_version=1, sub_heading="Test subheading", @@ -159,6 +160,7 @@ def test_store_submissions_response(): response = store_model.StoreSubmissionsResponse( submissions=[ store_model.StoreSubmission( + listing_id="listing123", agent_id="agent123", agent_version=1, sub_heading="Test subheading", diff --git a/autogpt_platform/backend/backend/api/features/store/routes_test.py b/autogpt_platform/backend/backend/api/features/store/routes_test.py index 7fdc0b9ebb..36431c20ec 100644 --- a/autogpt_platform/backend/backend/api/features/store/routes_test.py +++ b/autogpt_platform/backend/backend/api/features/store/routes_test.py @@ -521,6 +521,7 @@ def test_get_submissions_success( mocked_value = store_model.StoreSubmissionsResponse( submissions=[ store_model.StoreSubmission( + listing_id="test-listing-id", name="Test Agent", description="Test agent description", image_urls=["test.jpg"], diff --git a/autogpt_platform/backend/backend/blocks/airtable/_webhook.py b/autogpt_platform/backend/backend/blocks/airtable/_webhook.py index 58e6f95d0c..452630953e 100644 --- a/autogpt_platform/backend/backend/blocks/airtable/_webhook.py +++ b/autogpt_platform/backend/backend/blocks/airtable/_webhook.py @@ -6,6 +6,9 @@ import hashlib import hmac import logging from enum import Enum +from typing import cast + +from prisma.types import Serializable from backend.sdk import ( BaseWebhooksManager, @@ -84,7 +87,9 @@ class AirtableWebhookManager(BaseWebhooksManager): # update webhook config await update_webhook( webhook.id, - config={"base_id": base_id, "cursor": response.cursor}, + config=cast( + dict[str, Serializable], {"base_id": base_id, "cursor": response.cursor} + ), ) event_type = "notification" diff --git a/autogpt_platform/backend/backend/blocks/helpers/review.py b/autogpt_platform/backend/backend/blocks/helpers/review.py new file mode 100644 index 0000000000..f35397e6aa --- /dev/null +++ b/autogpt_platform/backend/backend/blocks/helpers/review.py @@ -0,0 +1,184 @@ +""" +Shared helpers for Human-In-The-Loop (HITL) review functionality. +Used by both the dedicated HumanInTheLoopBlock and blocks that require human review. +""" + +import logging +from typing import Any, Optional + +from prisma.enums import ReviewStatus +from pydantic import BaseModel + +from backend.data.execution import ExecutionContext, ExecutionStatus +from backend.data.human_review import ReviewResult +from backend.executor.manager import async_update_node_execution_status +from backend.util.clients import get_database_manager_async_client + +logger = logging.getLogger(__name__) + + +class ReviewDecision(BaseModel): + """Result of a review decision.""" + + should_proceed: bool + message: str + review_result: ReviewResult + + +class HITLReviewHelper: + """Helper class for Human-In-The-Loop review operations.""" + + @staticmethod + async def get_or_create_human_review(**kwargs) -> Optional[ReviewResult]: + """Create or retrieve a human review from the database.""" + return await get_database_manager_async_client().get_or_create_human_review( + **kwargs + ) + + @staticmethod + async def update_node_execution_status(**kwargs) -> None: + """Update the execution status of a node.""" + await async_update_node_execution_status( + db_client=get_database_manager_async_client(), **kwargs + ) + + @staticmethod + async def update_review_processed_status( + node_exec_id: str, processed: bool + ) -> None: + """Update the processed status of a review.""" + return await get_database_manager_async_client().update_review_processed_status( + node_exec_id, processed + ) + + @staticmethod + async def _handle_review_request( + input_data: Any, + user_id: str, + node_exec_id: str, + graph_exec_id: str, + graph_id: str, + graph_version: int, + execution_context: ExecutionContext, + block_name: str = "Block", + editable: bool = False, + ) -> Optional[ReviewResult]: + """ + Handle a review request for a block that requires human review. + + Args: + input_data: The input data to be reviewed + user_id: ID of the user requesting the review + node_exec_id: ID of the node execution + graph_exec_id: ID of the graph execution + graph_id: ID of the graph + graph_version: Version of the graph + execution_context: Current execution context + block_name: Name of the block requesting review + editable: Whether the reviewer can edit the data + + Returns: + ReviewResult if review is complete, None if waiting for human input + + Raises: + Exception: If review creation or status update fails + """ + # Skip review if safe mode is disabled - return auto-approved result + if not execution_context.safe_mode: + logger.info( + f"Block {block_name} skipping review for node {node_exec_id} - safe mode disabled" + ) + return ReviewResult( + data=input_data, + status=ReviewStatus.APPROVED, + message="Auto-approved (safe mode disabled)", + processed=True, + node_exec_id=node_exec_id, + ) + + result = await HITLReviewHelper.get_or_create_human_review( + user_id=user_id, + node_exec_id=node_exec_id, + graph_exec_id=graph_exec_id, + graph_id=graph_id, + graph_version=graph_version, + input_data=input_data, + message=f"Review required for {block_name} execution", + editable=editable, + ) + + if result is None: + logger.info( + f"Block {block_name} pausing execution for node {node_exec_id} - awaiting human review" + ) + await HITLReviewHelper.update_node_execution_status( + exec_id=node_exec_id, + status=ExecutionStatus.REVIEW, + ) + return None # Signal that execution should pause + + # Mark review as processed if not already done + if not result.processed: + await HITLReviewHelper.update_review_processed_status( + node_exec_id=node_exec_id, processed=True + ) + + return result + + @staticmethod + async def handle_review_decision( + input_data: Any, + user_id: str, + node_exec_id: str, + graph_exec_id: str, + graph_id: str, + graph_version: int, + execution_context: ExecutionContext, + block_name: str = "Block", + editable: bool = False, + ) -> Optional[ReviewDecision]: + """ + Handle a review request and return the decision in a single call. + + Args: + input_data: The input data to be reviewed + user_id: ID of the user requesting the review + node_exec_id: ID of the node execution + graph_exec_id: ID of the graph execution + graph_id: ID of the graph + graph_version: Version of the graph + execution_context: Current execution context + block_name: Name of the block requesting review + editable: Whether the reviewer can edit the data + + Returns: + ReviewDecision if review is complete (approved/rejected), + None if execution should pause (awaiting review) + """ + review_result = await HITLReviewHelper._handle_review_request( + input_data=input_data, + user_id=user_id, + node_exec_id=node_exec_id, + graph_exec_id=graph_exec_id, + graph_id=graph_id, + graph_version=graph_version, + execution_context=execution_context, + block_name=block_name, + editable=editable, + ) + + if review_result is None: + # Still awaiting review - return None to pause execution + return None + + # Review is complete, determine outcome + should_proceed = review_result.status == ReviewStatus.APPROVED + message = review_result.message or ( + "Execution approved by reviewer" + if should_proceed + else "Execution rejected by reviewer" + ) + + return ReviewDecision( + should_proceed=should_proceed, message=message, review_result=review_result + ) diff --git a/autogpt_platform/backend/backend/blocks/human_in_the_loop.py b/autogpt_platform/backend/backend/blocks/human_in_the_loop.py index 13c9fb31db..1e338816c8 100644 --- a/autogpt_platform/backend/backend/blocks/human_in_the_loop.py +++ b/autogpt_platform/backend/backend/blocks/human_in_the_loop.py @@ -3,6 +3,7 @@ from typing import Any from prisma.enums import ReviewStatus +from backend.blocks.helpers.review import HITLReviewHelper from backend.data.block import ( Block, BlockCategory, @@ -11,11 +12,9 @@ from backend.data.block import ( BlockSchemaOutput, BlockType, ) -from backend.data.execution import ExecutionContext, ExecutionStatus +from backend.data.execution import ExecutionContext from backend.data.human_review import ReviewResult from backend.data.model import SchemaField -from backend.executor.manager import async_update_node_execution_status -from backend.util.clients import get_database_manager_async_client logger = logging.getLogger(__name__) @@ -72,32 +71,26 @@ class HumanInTheLoopBlock(Block): ("approved_data", {"name": "John Doe", "age": 30}), ], test_mock={ - "get_or_create_human_review": lambda *_args, **_kwargs: ReviewResult( - data={"name": "John Doe", "age": 30}, - status=ReviewStatus.APPROVED, - message="", - processed=False, - node_exec_id="test-node-exec-id", - ), - "update_node_execution_status": lambda *_args, **_kwargs: None, - "update_review_processed_status": lambda *_args, **_kwargs: None, + "handle_review_decision": lambda **kwargs: type( + "ReviewDecision", + (), + { + "should_proceed": True, + "message": "Test approval message", + "review_result": ReviewResult( + data={"name": "John Doe", "age": 30}, + status=ReviewStatus.APPROVED, + message="", + processed=False, + node_exec_id="test-node-exec-id", + ), + }, + )(), }, ) - async def get_or_create_human_review(self, **kwargs): - return await get_database_manager_async_client().get_or_create_human_review( - **kwargs - ) - - async def update_node_execution_status(self, **kwargs): - return await async_update_node_execution_status( - db_client=get_database_manager_async_client(), **kwargs - ) - - async def update_review_processed_status(self, node_exec_id: str, processed: bool): - return await get_database_manager_async_client().update_review_processed_status( - node_exec_id, processed - ) + async def handle_review_decision(self, **kwargs): + return await HITLReviewHelper.handle_review_decision(**kwargs) async def run( self, @@ -109,7 +102,7 @@ class HumanInTheLoopBlock(Block): graph_id: str, graph_version: int, execution_context: ExecutionContext, - **kwargs, + **_kwargs, ) -> BlockOutput: if not execution_context.safe_mode: logger.info( @@ -119,48 +112,28 @@ class HumanInTheLoopBlock(Block): yield "review_message", "Auto-approved (safe mode disabled)" return - try: - result = await self.get_or_create_human_review( - user_id=user_id, - node_exec_id=node_exec_id, - graph_exec_id=graph_exec_id, - graph_id=graph_id, - graph_version=graph_version, - input_data=input_data.data, - message=input_data.name, - editable=input_data.editable, - ) - except Exception as e: - logger.error(f"Error in HITL block for node {node_exec_id}: {str(e)}") - raise + decision = await self.handle_review_decision( + input_data=input_data.data, + user_id=user_id, + node_exec_id=node_exec_id, + graph_exec_id=graph_exec_id, + graph_id=graph_id, + graph_version=graph_version, + execution_context=execution_context, + block_name=self.name, + editable=input_data.editable, + ) - if result is None: - logger.info( - f"HITL block pausing execution for node {node_exec_id} - awaiting human review" - ) - try: - await self.update_node_execution_status( - exec_id=node_exec_id, - status=ExecutionStatus.REVIEW, - ) - return - except Exception as e: - logger.error( - f"Failed to update node status for HITL block {node_exec_id}: {str(e)}" - ) - raise + if decision is None: + return - if not result.processed: - await self.update_review_processed_status( - node_exec_id=node_exec_id, processed=True - ) + status = decision.review_result.status + if status == ReviewStatus.APPROVED: + yield "approved_data", decision.review_result.data + elif status == ReviewStatus.REJECTED: + yield "rejected_data", decision.review_result.data + else: + raise RuntimeError(f"Unexpected review status: {status}") - if result.status == ReviewStatus.APPROVED: - yield "approved_data", result.data - if result.message: - yield "review_message", result.message - - elif result.status == ReviewStatus.REJECTED: - yield "rejected_data", result.data - if result.message: - yield "review_message", result.message + if decision.message: + yield "review_message", decision.message diff --git a/autogpt_platform/backend/backend/blocks/reddit.py b/autogpt_platform/backend/backend/blocks/reddit.py index 231e7affef..1109d568db 100644 --- a/autogpt_platform/backend/backend/blocks/reddit.py +++ b/autogpt_platform/backend/backend/blocks/reddit.py @@ -3,6 +3,7 @@ from datetime import datetime, timezone from typing import Iterator, Literal import praw +from praw.models import Comment, MoreComments, Submission from pydantic import BaseModel, SecretStr from backend.data.block import ( @@ -15,33 +16,51 @@ from backend.data.block import ( from backend.data.model import ( CredentialsField, CredentialsMetaInput, + OAuth2Credentials, SchemaField, - UserPasswordCredentials, ) from backend.integrations.providers import ProviderName from backend.util.mock import MockObject from backend.util.settings import Settings -RedditCredentials = UserPasswordCredentials +# Type aliases for Reddit API options +UserPostSort = Literal["new", "hot", "top", "controversial"] +SearchSort = Literal["relevance", "hot", "top", "new", "comments"] +TimeFilter = Literal["all", "day", "hour", "month", "week", "year"] +CommentSort = Literal["best", "top", "new", "controversial", "old", "qa"] +InboxType = Literal["all", "unread", "messages", "mentions", "comment_replies"] + +RedditCredentials = OAuth2Credentials RedditCredentialsInput = CredentialsMetaInput[ Literal[ProviderName.REDDIT], - Literal["user_password"], + Literal["oauth2"], ] def RedditCredentialsField() -> RedditCredentialsInput: """Creates a Reddit credentials input on a block.""" return CredentialsField( - description="The Reddit integration requires a username and password.", + description="Connect your Reddit account to access Reddit features.", ) -TEST_CREDENTIALS = UserPasswordCredentials( +TEST_CREDENTIALS = OAuth2Credentials( id="01234567-89ab-cdef-0123-456789abcdef", provider="reddit", - username=SecretStr("mock-reddit-username"), - password=SecretStr("mock-reddit-password"), + access_token=SecretStr("mock-reddit-access-token"), + refresh_token=SecretStr("mock-reddit-refresh-token"), + access_token_expires_at=9999999999, + scopes=[ + "identity", + "read", + "submit", + "edit", + "history", + "privatemessages", + "flair", + ], title="Mock Reddit credentials", + username="mock-reddit-username", ) TEST_CREDENTIALS_INPUT = { @@ -53,27 +72,29 @@ TEST_CREDENTIALS_INPUT = { class RedditPost(BaseModel): - id: str + post_id: str subreddit: str title: str body: str -class RedditComment(BaseModel): - post_id: str - comment: str - - settings = Settings() logger = logging.getLogger(__name__) def get_praw(creds: RedditCredentials) -> praw.Reddit: + """ + Create a PRAW Reddit client using OAuth2 credentials. + + Uses the refresh_token for authentication, which allows the client + to automatically refresh the access token when needed. + """ client = praw.Reddit( client_id=settings.secrets.reddit_client_id, client_secret=settings.secrets.reddit_client_secret, - username=creds.username.get_secret_value(), - password=creds.password.get_secret_value(), + refresh_token=( + creds.refresh_token.get_secret_value() if creds.refresh_token else None + ), user_agent=settings.config.reddit_user_agent, ) me = client.user.me() @@ -83,11 +104,36 @@ def get_praw(creds: RedditCredentials) -> praw.Reddit: return client +def strip_reddit_prefix(id_str: str) -> str: + """ + Strip Reddit type prefix (t1_, t3_, etc.) from an ID if present. + + Reddit uses type prefixes like t1_ for comments, t3_ for posts, etc. + This helper normalizes IDs by removing these prefixes when present, + allowing blocks to accept both 'abc123' and 't3_abc123' formats. + + Args: + id_str: The ID string that may have a Reddit type prefix. + + Returns: + The ID without the type prefix. + """ + if ( + len(id_str) > 3 + and id_str[0] == "t" + and id_str[1].isdigit() + and id_str[2] == "_" + ): + return id_str[3:] + return id_str + + class GetRedditPostsBlock(Block): class Input(BlockSchemaInput): subreddit: str = SchemaField( description="Subreddit name, excluding the /r/ prefix", default="writingprompts", + advanced=False, ) credentials: RedditCredentialsInput = RedditCredentialsField() last_minutes: int | None = SchemaField( @@ -128,26 +174,32 @@ class GetRedditPostsBlock(Block): ( "post", RedditPost( - id="id1", subreddit="subreddit", title="title1", body="body1" + post_id="id1", + subreddit="subreddit", + title="title1", + body="body1", ), ), ( "post", RedditPost( - id="id2", subreddit="subreddit", title="title2", body="body2" + post_id="id2", + subreddit="subreddit", + title="title2", + body="body2", ), ), ( "posts", [ RedditPost( - id="id1", + post_id="id1", subreddit="subreddit", title="title1", body="body1", ), RedditPost( - id="id2", + post_id="id2", subreddit="subreddit", title="title2", body="body2", @@ -184,13 +236,14 @@ class GetRedditPostsBlock(Block): ) time_difference = current_time - post_datetime if time_difference.total_seconds() / 60 > input_data.last_minutes: - continue + # Posts are ordered newest-first, so all subsequent posts will also be older + break if input_data.last_post and post.id == input_data.last_post: break reddit_post = RedditPost( - id=post.id, + post_id=post.id, subreddit=input_data.subreddit, title=post.title, body=post.selftext, @@ -204,10 +257,18 @@ class GetRedditPostsBlock(Block): class PostRedditCommentBlock(Block): class Input(BlockSchemaInput): credentials: RedditCredentialsInput = RedditCredentialsField() - data: RedditComment = SchemaField(description="Reddit comment") + post_id: str = SchemaField( + description="The ID of the post to comment on", + ) + comment: str = SchemaField( + description="The content of the comment to post", + ) class Output(BlockSchemaOutput): comment_id: str = SchemaField(description="Posted comment ID") + post_id: str = SchemaField( + description="The post ID (pass-through for chaining)" + ) def __init__(self): super().__init__( @@ -223,17 +284,24 @@ class PostRedditCommentBlock(Block): test_credentials=TEST_CREDENTIALS, test_input={ "credentials": TEST_CREDENTIALS_INPUT, - "data": {"post_id": "id", "comment": "comment"}, + "post_id": "test_post_id", + "comment": "comment", + }, + test_output=[ + ("comment_id", "dummy_comment_id"), + ("post_id", "test_post_id"), + ], + test_mock={ + "reply_post": lambda creds, post_id, comment: "dummy_comment_id" }, - test_output=[("comment_id", "dummy_comment_id")], - test_mock={"reply_post": lambda creds, comment: "dummy_comment_id"}, ) @staticmethod - def reply_post(creds: RedditCredentials, comment: RedditComment) -> str: + def reply_post(creds: RedditCredentials, post_id: str, comment: str) -> str: client = get_praw(creds) - submission = client.submission(id=comment.post_id) - new_comment = submission.reply(comment.comment) + post_id = strip_reddit_prefix(post_id) + submission = client.submission(id=post_id) + new_comment = submission.reply(comment) if not new_comment: raise ValueError("Failed to post comment.") return new_comment.id @@ -241,4 +309,2230 @@ class PostRedditCommentBlock(Block): async def run( self, input_data: Input, *, credentials: RedditCredentials, **kwargs ) -> BlockOutput: - yield "comment_id", self.reply_post(credentials, input_data.data) + yield "comment_id", self.reply_post( + credentials, + post_id=input_data.post_id, + comment=input_data.comment, + ) + yield "post_id", input_data.post_id + + +class CreateRedditPostBlock(Block): + class Input(BlockSchemaInput): + credentials: RedditCredentialsInput = RedditCredentialsField() + subreddit: str = SchemaField( + description="Subreddit to post to, excluding the /r/ prefix", + ) + title: str = SchemaField( + description="Title of the post", + ) + content: str = SchemaField( + description="Body text of the post (for text posts)", + default="", + ) + url: str | None = SchemaField( + description="URL to submit (for link posts). If provided, content is ignored.", + default=None, + ) + flair_id: str | None = SchemaField( + description="Flair template ID to apply to the post (from GetSubredditFlairsBlock)", + default=None, + ) + flair_text: str | None = SchemaField( + description="Custom flair text (only used if the flair template allows editing)", + default=None, + ) + + class Output(BlockSchemaOutput): + post_id: str = SchemaField(description="ID of the created post") + post_url: str = SchemaField(description="URL of the created post") + subreddit: str = SchemaField( + description="The subreddit name (pass-through for chaining)" + ) + + def __init__(self): + super().__init__( + id="f3a2b1c0-8d7e-4f6a-9b5c-1234567890ab", + description="Create a new post on a subreddit. Can create text posts or link posts.", + categories={BlockCategory.SOCIAL}, + input_schema=CreateRedditPostBlock.Input, + output_schema=CreateRedditPostBlock.Output, + disabled=( + not settings.secrets.reddit_client_id + or not settings.secrets.reddit_client_secret + ), + test_credentials=TEST_CREDENTIALS, + test_input={ + "credentials": TEST_CREDENTIALS_INPUT, + "subreddit": "test", + "title": "Test Post", + "content": "This is a test post body.", + }, + test_output=[ + ("post_id", "abc123"), + ("post_url", "https://reddit.com/r/test/comments/abc123/test_post/"), + ("subreddit", "test"), + ], + test_mock={ + "create_post": lambda creds, subreddit, title, content, url, flair_id, flair_text: ( + "abc123", + "https://reddit.com/r/test/comments/abc123/test_post/", + ) + }, + ) + + @staticmethod + def create_post( + creds: RedditCredentials, + subreddit: str, + title: str, + content: str = "", + url: str | None = None, + flair_id: str | None = None, + flair_text: str | None = None, + ) -> tuple[str, str]: + """ + Create a new post on a subreddit. + + Args: + creds: Reddit OAuth2 credentials + subreddit: Subreddit name (without /r/ prefix) + title: Post title + content: Post body text (for text posts) + url: URL to submit (for link posts, overrides content) + flair_id: Optional flair template ID to apply + flair_text: Optional custom flair text (for editable flairs) + + Returns: + Tuple of (post_id, post_url) + """ + client = get_praw(creds) + sub = client.subreddit(subreddit) + + if url: + submission = sub.submit( + title=title, url=url, flair_id=flair_id, flair_text=flair_text + ) + else: + submission = sub.submit( + title=title, selftext=content, flair_id=flair_id, flair_text=flair_text + ) + + return submission.id, f"https://reddit.com{submission.permalink}" + + async def run( + self, input_data: Input, *, credentials: RedditCredentials, **kwargs + ) -> BlockOutput: + post_id, post_url = self.create_post( + credentials, + input_data.subreddit, + input_data.title, + input_data.content, + input_data.url, + input_data.flair_id, + input_data.flair_text, + ) + yield "post_id", post_id + yield "post_url", post_url + yield "subreddit", input_data.subreddit + + +class RedditPostDetails(BaseModel): + """Detailed information about a Reddit post.""" + + id: str + subreddit: str + title: str + body: str + author: str + score: int + upvote_ratio: float + num_comments: int + created_utc: float + url: str + permalink: str + is_self: bool + over_18: bool + + +class GetRedditPostBlock(Block): + """Get detailed information about a specific Reddit post.""" + + class Input(BlockSchemaInput): + credentials: RedditCredentialsInput = RedditCredentialsField() + post_id: str = SchemaField( + description="The ID of the post to fetch (e.g., 'abc123' or full ID 't3_abc123')", + ) + + class Output(BlockSchemaOutput): + post: RedditPostDetails = SchemaField(description="Detailed post information") + error: str = SchemaField( + description="Error message if the post couldn't be fetched" + ) + + def __init__(self): + super().__init__( + id="36e6a259-168c-4032-83ec-b2935d0e4584", + description="Get detailed information about a specific Reddit post by its ID.", + categories={BlockCategory.SOCIAL}, + input_schema=GetRedditPostBlock.Input, + output_schema=GetRedditPostBlock.Output, + disabled=( + not settings.secrets.reddit_client_id + or not settings.secrets.reddit_client_secret + ), + test_credentials=TEST_CREDENTIALS, + test_input={ + "credentials": TEST_CREDENTIALS_INPUT, + "post_id": "abc123", + }, + test_output=[ + ( + "post", + RedditPostDetails( + id="abc123", + subreddit="test", + title="Test Post", + body="Test body", + author="testuser", + score=100, + upvote_ratio=0.95, + num_comments=10, + created_utc=1234567890.0, + url="https://reddit.com/r/test/comments/abc123/test_post/", + permalink="/r/test/comments/abc123/test_post/", + is_self=True, + over_18=False, + ), + ), + ], + test_mock={ + "get_post": lambda creds, post_id: RedditPostDetails( + id="abc123", + subreddit="test", + title="Test Post", + body="Test body", + author="testuser", + score=100, + upvote_ratio=0.95, + num_comments=10, + created_utc=1234567890.0, + url="https://reddit.com/r/test/comments/abc123/test_post/", + permalink="/r/test/comments/abc123/test_post/", + is_self=True, + over_18=False, + ) + }, + ) + + @staticmethod + def get_post(creds: RedditCredentials, post_id: str) -> RedditPostDetails: + client = get_praw(creds) + post_id = strip_reddit_prefix(post_id) + submission = client.submission(id=post_id) + + return RedditPostDetails( + id=submission.id, + subreddit=submission.subreddit.display_name, + title=submission.title, + body=submission.selftext, + author=str(submission.author) if submission.author else "[deleted]", + score=submission.score, + upvote_ratio=submission.upvote_ratio, + num_comments=submission.num_comments, + created_utc=submission.created_utc, + url=submission.url, + permalink=submission.permalink, + is_self=submission.is_self, + over_18=submission.over_18, + ) + + async def run( + self, input_data: Input, *, credentials: RedditCredentials, **kwargs + ) -> BlockOutput: + try: + post = self.get_post(credentials, input_data.post_id) + yield "post", post + except Exception as e: + yield "error", str(e) + + +class GetUserPostsBlock(Block): + """Get posts by a specific Reddit user.""" + + class Input(BlockSchemaInput): + credentials: RedditCredentialsInput = RedditCredentialsField() + username: str = SchemaField( + description="Reddit username to fetch posts from (without /u/ prefix)", + ) + post_limit: int = SchemaField( + description="Maximum number of posts to fetch", + default=10, + ) + sort: UserPostSort = SchemaField( + description="Sort order for user posts", + default="new", + ) + + class Output(BlockSchemaOutput): + post: RedditPost = SchemaField(description="A post by the user") + posts: list[RedditPost] = SchemaField(description="All posts by the user") + error: str = SchemaField( + description="Error message if posts couldn't be fetched" + ) + + def __init__(self): + super().__init__( + id="6fbe6329-d13e-4d2e-bd4d-b4d921b56161", + description="Fetch posts by a specific Reddit user.", + categories={BlockCategory.SOCIAL}, + input_schema=GetUserPostsBlock.Input, + output_schema=GetUserPostsBlock.Output, + disabled=( + not settings.secrets.reddit_client_id + or not settings.secrets.reddit_client_secret + ), + test_credentials=TEST_CREDENTIALS, + test_input={ + "credentials": TEST_CREDENTIALS_INPUT, + "username": "testuser", + "post_limit": 2, + }, + test_output=[ + ( + "post", + RedditPost( + post_id="id1", subreddit="sub1", title="title1", body="body1" + ), + ), + ( + "post", + RedditPost( + post_id="id2", subreddit="sub2", title="title2", body="body2" + ), + ), + ( + "posts", + [ + RedditPost( + post_id="id1", + subreddit="sub1", + title="title1", + body="body1", + ), + RedditPost( + post_id="id2", + subreddit="sub2", + title="title2", + body="body2", + ), + ], + ), + ], + test_mock={ + "get_user_posts": lambda creds, username, limit, sort: [ + MockObject( + id="id1", + subreddit=MockObject(display_name="sub1"), + title="title1", + selftext="body1", + ), + MockObject( + id="id2", + subreddit=MockObject(display_name="sub2"), + title="title2", + selftext="body2", + ), + ] + }, + ) + + @staticmethod + def get_user_posts( + creds: RedditCredentials, username: str, limit: int, sort: UserPostSort + ) -> list[Submission]: + client = get_praw(creds) + redditor = client.redditor(username) + + if sort == "new": + submissions = redditor.submissions.new(limit=limit) + elif sort == "hot": + submissions = redditor.submissions.hot(limit=limit) + elif sort == "top": + submissions = redditor.submissions.top(limit=limit) + elif sort == "controversial": + submissions = redditor.submissions.controversial(limit=limit) + else: + submissions = redditor.submissions.new(limit=limit) + + return list(submissions) + + async def run( + self, input_data: Input, *, credentials: RedditCredentials, **kwargs + ) -> BlockOutput: + try: + submissions = self.get_user_posts( + credentials, + input_data.username, + input_data.post_limit, + input_data.sort, + ) + all_posts = [] + for submission in submissions: + post = RedditPost( + post_id=submission.id, + subreddit=submission.subreddit.display_name, + title=submission.title, + body=submission.selftext, + ) + all_posts.append(post) + yield "post", post + yield "posts", all_posts + except Exception as e: + yield "error", str(e) + + +class RedditGetMyPostsBlock(Block): + """Get posts by the authenticated Reddit user.""" + + class Input(BlockSchemaInput): + credentials: RedditCredentialsInput = RedditCredentialsField() + post_limit: int = SchemaField( + description="Maximum number of posts to fetch", + default=10, + ) + sort: UserPostSort = SchemaField( + description="Sort order for posts", + default="new", + ) + + class Output(BlockSchemaOutput): + post: RedditPost = SchemaField(description="A post by you") + posts: list[RedditPost] = SchemaField(description="All your posts") + error: str = SchemaField( + description="Error message if posts couldn't be fetched" + ) + + def __init__(self): + super().__init__( + id="4ab3381b-0c07-4201-89b3-fa2ec264f154", + description="Fetch posts created by the authenticated Reddit user (you).", + categories={BlockCategory.SOCIAL}, + input_schema=RedditGetMyPostsBlock.Input, + output_schema=RedditGetMyPostsBlock.Output, + disabled=( + not settings.secrets.reddit_client_id + or not settings.secrets.reddit_client_secret + ), + test_credentials=TEST_CREDENTIALS, + test_input={ + "credentials": TEST_CREDENTIALS_INPUT, + "post_limit": 2, + }, + test_output=[ + ( + "post", + RedditPost( + post_id="id1", subreddit="sub1", title="title1", body="body1" + ), + ), + ( + "post", + RedditPost( + post_id="id2", subreddit="sub2", title="title2", body="body2" + ), + ), + ( + "posts", + [ + RedditPost( + post_id="id1", + subreddit="sub1", + title="title1", + body="body1", + ), + RedditPost( + post_id="id2", + subreddit="sub2", + title="title2", + body="body2", + ), + ], + ), + ], + test_mock={ + "get_my_posts": lambda creds, limit, sort: [ + MockObject( + id="id1", + subreddit=MockObject(display_name="sub1"), + title="title1", + selftext="body1", + ), + MockObject( + id="id2", + subreddit=MockObject(display_name="sub2"), + title="title2", + selftext="body2", + ), + ] + }, + ) + + @staticmethod + def get_my_posts( + creds: RedditCredentials, limit: int, sort: UserPostSort + ) -> list[Submission]: + client = get_praw(creds) + me = client.user.me() + if not me: + raise ValueError("Could not get authenticated user.") + + if sort == "new": + submissions = me.submissions.new(limit=limit) + elif sort == "hot": + submissions = me.submissions.hot(limit=limit) + elif sort == "top": + submissions = me.submissions.top(limit=limit) + elif sort == "controversial": + submissions = me.submissions.controversial(limit=limit) + else: + submissions = me.submissions.new(limit=limit) + + return list(submissions) + + async def run( + self, input_data: Input, *, credentials: RedditCredentials, **kwargs + ) -> BlockOutput: + try: + submissions = self.get_my_posts( + credentials, + input_data.post_limit, + input_data.sort, + ) + all_posts = [] + for submission in submissions: + post = RedditPost( + post_id=submission.id, + subreddit=submission.subreddit.display_name, + title=submission.title, + body=submission.selftext, + ) + all_posts.append(post) + yield "post", post + yield "posts", all_posts + except Exception as e: + yield "error", str(e) + + +class RedditSearchResult(BaseModel): + """A search result from Reddit.""" + + id: str + subreddit: str + title: str + body: str + author: str + score: int + num_comments: int + created_utc: float + permalink: str + + +class SearchRedditBlock(Block): + """Search Reddit for posts matching a query.""" + + class Input(BlockSchemaInput): + credentials: RedditCredentialsInput = RedditCredentialsField() + query: str = SchemaField( + description="Search query string", + ) + subreddit: str | None = SchemaField( + description="Limit search to a specific subreddit (without /r/ prefix)", + default=None, + ) + sort: SearchSort = SchemaField( + description="Sort order for search results", + default="relevance", + ) + time_filter: TimeFilter = SchemaField( + description="Time filter for search results", + default="all", + ) + limit: int = SchemaField( + description="Maximum number of results to return", + default=10, + ) + + class Output(BlockSchemaOutput): + result: RedditSearchResult = SchemaField(description="A search result") + results: list[RedditSearchResult] = SchemaField( + description="All search results" + ) + error: str = SchemaField(description="Error message if search failed") + + def __init__(self): + super().__init__( + id="4a0c975e-807b-4d5e-83c9-1619864a4b1a", + description="Search Reddit for posts matching a query. Can search all of Reddit or a specific subreddit.", + categories={BlockCategory.SOCIAL}, + input_schema=SearchRedditBlock.Input, + output_schema=SearchRedditBlock.Output, + disabled=( + not settings.secrets.reddit_client_id + or not settings.secrets.reddit_client_secret + ), + test_credentials=TEST_CREDENTIALS, + test_input={ + "credentials": TEST_CREDENTIALS_INPUT, + "query": "test query", + "limit": 2, + }, + test_output=[ + ( + "result", + RedditSearchResult( + id="id1", + subreddit="sub1", + title="title1", + body="body1", + author="author1", + score=100, + num_comments=10, + created_utc=1234567890.0, + permalink="/r/sub1/comments/id1/title1/", + ), + ), + ( + "result", + RedditSearchResult( + id="id2", + subreddit="sub2", + title="title2", + body="body2", + author="author2", + score=50, + num_comments=5, + created_utc=1234567891.0, + permalink="/r/sub2/comments/id2/title2/", + ), + ), + ( + "results", + [ + RedditSearchResult( + id="id1", + subreddit="sub1", + title="title1", + body="body1", + author="author1", + score=100, + num_comments=10, + created_utc=1234567890.0, + permalink="/r/sub1/comments/id1/title1/", + ), + RedditSearchResult( + id="id2", + subreddit="sub2", + title="title2", + body="body2", + author="author2", + score=50, + num_comments=5, + created_utc=1234567891.0, + permalink="/r/sub2/comments/id2/title2/", + ), + ], + ), + ], + test_mock={ + "search_reddit": lambda creds, query, subreddit, sort, time_filter, limit: [ + MockObject( + id="id1", + subreddit=MockObject(display_name="sub1"), + title="title1", + selftext="body1", + author="author1", + score=100, + num_comments=10, + created_utc=1234567890.0, + permalink="/r/sub1/comments/id1/title1/", + ), + MockObject( + id="id2", + subreddit=MockObject(display_name="sub2"), + title="title2", + selftext="body2", + author="author2", + score=50, + num_comments=5, + created_utc=1234567891.0, + permalink="/r/sub2/comments/id2/title2/", + ), + ] + }, + ) + + @staticmethod + def search_reddit( + creds: RedditCredentials, + query: str, + subreddit: str | None, + sort: SearchSort, + time_filter: TimeFilter, + limit: int, + ) -> list[Submission]: + client = get_praw(creds) + + if subreddit: + sub = client.subreddit(subreddit) + results = sub.search(query, sort=sort, time_filter=time_filter, limit=limit) + else: + results = client.subreddit("all").search( + query, sort=sort, time_filter=time_filter, limit=limit + ) + + return list(results) + + async def run( + self, input_data: Input, *, credentials: RedditCredentials, **kwargs + ) -> BlockOutput: + try: + submissions = self.search_reddit( + credentials, + input_data.query, + input_data.subreddit, + input_data.sort, + input_data.time_filter, + input_data.limit, + ) + all_results = [] + for submission in submissions: + result = RedditSearchResult( + id=submission.id, + subreddit=submission.subreddit.display_name, + title=submission.title, + body=submission.selftext, + author=str(submission.author) if submission.author else "[deleted]", + score=submission.score, + num_comments=submission.num_comments, + created_utc=submission.created_utc, + permalink=submission.permalink, + ) + all_results.append(result) + yield "result", result + yield "results", all_results + except Exception as e: + yield "error", str(e) + + +class EditRedditPostBlock(Block): + """Edit an existing Reddit post that you own.""" + + class Input(BlockSchemaInput): + credentials: RedditCredentialsInput = RedditCredentialsField() + post_id: str = SchemaField( + description="The ID of the post to edit (must be your own post)", + ) + new_content: str = SchemaField( + description="The new body text for the post", + ) + + class Output(BlockSchemaOutput): + success: bool = SchemaField(description="Whether the edit was successful") + post_id: str = SchemaField( + description="The post ID (pass-through for chaining)" + ) + post_url: str = SchemaField(description="URL of the edited post") + error: str = SchemaField(description="Error message if the edit failed") + + def __init__(self): + super().__init__( + id="cdb9df0f-8b1d-433e-873a-ededc1b6479d", + description="Edit the body text of an existing Reddit post that you own. Only works for self/text posts.", + categories={BlockCategory.SOCIAL}, + input_schema=EditRedditPostBlock.Input, + output_schema=EditRedditPostBlock.Output, + disabled=( + not settings.secrets.reddit_client_id + or not settings.secrets.reddit_client_secret + ), + test_credentials=TEST_CREDENTIALS, + test_input={ + "credentials": TEST_CREDENTIALS_INPUT, + "post_id": "abc123", + "new_content": "Updated post content", + }, + test_output=[ + ("success", True), + ("post_id", "abc123"), + ("post_url", "https://reddit.com/r/test/comments/abc123/test_post/"), + ], + test_mock={ + "edit_post": lambda creds, post_id, new_content: ( + True, + "https://reddit.com/r/test/comments/abc123/test_post/", + ) + }, + ) + + @staticmethod + def edit_post( + creds: RedditCredentials, post_id: str, new_content: str + ) -> tuple[bool, str]: + client = get_praw(creds) + post_id = strip_reddit_prefix(post_id) + submission = client.submission(id=post_id) + submission.edit(new_content) + return True, f"https://reddit.com{submission.permalink}" + + async def run( + self, input_data: Input, *, credentials: RedditCredentials, **kwargs + ) -> BlockOutput: + try: + success, post_url = self.edit_post( + credentials, input_data.post_id, input_data.new_content + ) + yield "success", success + yield "post_id", input_data.post_id + yield "post_url", post_url + except Exception as e: + error_msg = str(e) + if "403" in error_msg: + error_msg = ( + "Permission denied (403): You can only edit your own posts. " + "Make sure the post belongs to the authenticated Reddit account." + ) + yield "error", error_msg + + +class SubredditInfo(BaseModel): + """Information about a subreddit.""" + + name: str + display_name: str + title: str + description: str + public_description: str + subscribers: int + created_utc: float + over_18: bool + url: str + + +class GetSubredditInfoBlock(Block): + """Get information about a subreddit.""" + + class Input(BlockSchemaInput): + credentials: RedditCredentialsInput = RedditCredentialsField() + subreddit: str = SchemaField( + description="Subreddit name (without /r/ prefix)", + ) + + class Output(BlockSchemaOutput): + info: SubredditInfo = SchemaField(description="Subreddit information") + subreddit: str = SchemaField( + description="The subreddit name (pass-through for chaining)" + ) + error: str = SchemaField( + description="Error message if the subreddit couldn't be fetched" + ) + + def __init__(self): + super().__init__( + id="5a2d1f0c-01fb-43ea-bad7-2260d269c930", + description="Get information about a subreddit including subscriber count, description, and rules.", + categories={BlockCategory.SOCIAL}, + input_schema=GetSubredditInfoBlock.Input, + output_schema=GetSubredditInfoBlock.Output, + disabled=( + not settings.secrets.reddit_client_id + or not settings.secrets.reddit_client_secret + ), + test_credentials=TEST_CREDENTIALS, + test_input={ + "credentials": TEST_CREDENTIALS_INPUT, + "subreddit": "python", + }, + test_output=[ + ( + "info", + SubredditInfo( + name="t5_2qh0y", + display_name="python", + title="Python", + description="News about the Python programming language", + public_description="News about Python", + subscribers=1000000, + created_utc=1234567890.0, + over_18=False, + url="/r/python/", + ), + ), + ("subreddit", "python"), + ], + test_mock={ + "get_subreddit_info": lambda creds, subreddit: SubredditInfo( + name="t5_2qh0y", + display_name="python", + title="Python", + description="News about the Python programming language", + public_description="News about Python", + subscribers=1000000, + created_utc=1234567890.0, + over_18=False, + url="/r/python/", + ) + }, + ) + + @staticmethod + def get_subreddit_info(creds: RedditCredentials, subreddit: str) -> SubredditInfo: + client = get_praw(creds) + sub = client.subreddit(subreddit) + + return SubredditInfo( + name=sub.name, + display_name=sub.display_name, + title=sub.title, + description=sub.description, + public_description=sub.public_description, + subscribers=sub.subscribers, + created_utc=sub.created_utc, + over_18=sub.over18, + url=sub.url, + ) + + async def run( + self, input_data: Input, *, credentials: RedditCredentials, **kwargs + ) -> BlockOutput: + try: + info = self.get_subreddit_info(credentials, input_data.subreddit) + yield "info", info + yield "subreddit", input_data.subreddit + except Exception as e: + yield "error", str(e) + + +class RedditComment(BaseModel): + """A Reddit comment.""" + + comment_id: str + post_id: str + parent_comment_id: str | None + author: str + body: str + score: int + created_utc: float + edited: bool + is_submitter: bool + permalink: str + depth: int + + +class GetRedditPostCommentsBlock(Block): + """Get comments on a Reddit post.""" + + class Input(BlockSchemaInput): + credentials: RedditCredentialsInput = RedditCredentialsField() + post_id: str = SchemaField( + description="The ID of the post to get comments from", + ) + limit: int = SchemaField( + description="Maximum number of top-level comments to fetch (max 100)", + default=25, + ) + sort: CommentSort = SchemaField( + description="Sort order for comments", + default="best", + ) + + class Output(BlockSchemaOutput): + comment: RedditComment = SchemaField(description="A comment on the post") + comments: list[RedditComment] = SchemaField(description="All fetched comments") + post_id: str = SchemaField( + description="The post ID (pass-through for chaining)" + ) + error: str = SchemaField( + description="Error message if comments couldn't be fetched" + ) + + def __init__(self): + super().__init__( + id="98422b2c-c3b0-4d70-871f-56bd966f46da", + description="Get top-level comments on a Reddit post.", + categories={BlockCategory.SOCIAL}, + input_schema=GetRedditPostCommentsBlock.Input, + output_schema=GetRedditPostCommentsBlock.Output, + disabled=( + not settings.secrets.reddit_client_id + or not settings.secrets.reddit_client_secret + ), + test_credentials=TEST_CREDENTIALS, + test_input={ + "credentials": TEST_CREDENTIALS_INPUT, + "post_id": "abc123", + "limit": 2, + }, + test_output=[ + ( + "comment", + RedditComment( + comment_id="comment1", + post_id="abc123", + parent_comment_id=None, + author="user1", + body="Comment body 1", + score=10, + created_utc=1234567890.0, + edited=False, + is_submitter=False, + permalink="/r/test/comments/abc123/test/comment1/", + depth=0, + ), + ), + ( + "comment", + RedditComment( + comment_id="comment2", + post_id="abc123", + parent_comment_id=None, + author="user2", + body="Comment body 2", + score=5, + created_utc=1234567891.0, + edited=False, + is_submitter=True, + permalink="/r/test/comments/abc123/test/comment2/", + depth=0, + ), + ), + ( + "comments", + [ + RedditComment( + comment_id="comment1", + post_id="abc123", + parent_comment_id=None, + author="user1", + body="Comment body 1", + score=10, + created_utc=1234567890.0, + edited=False, + is_submitter=False, + permalink="/r/test/comments/abc123/test/comment1/", + depth=0, + ), + RedditComment( + comment_id="comment2", + post_id="abc123", + parent_comment_id=None, + author="user2", + body="Comment body 2", + score=5, + created_utc=1234567891.0, + edited=False, + is_submitter=True, + permalink="/r/test/comments/abc123/test/comment2/", + depth=0, + ), + ], + ), + ("post_id", "abc123"), + ], + test_mock={ + "get_comments": lambda creds, post_id, limit, sort: [ + MockObject( + id="comment1", + link_id="t3_abc123", + parent_id="t3_abc123", + author="user1", + body="Comment body 1", + score=10, + created_utc=1234567890.0, + edited=False, + is_submitter=False, + permalink="/r/test/comments/abc123/test/comment1/", + depth=0, + ), + MockObject( + id="comment2", + link_id="t3_abc123", + parent_id="t3_abc123", + author="user2", + body="Comment body 2", + score=5, + created_utc=1234567891.0, + edited=False, + is_submitter=True, + permalink="/r/test/comments/abc123/test/comment2/", + depth=0, + ), + ] + }, + ) + + @staticmethod + def get_comments( + creds: RedditCredentials, post_id: str, limit: int, sort: CommentSort + ) -> list[Comment]: + client = get_praw(creds) + post_id = strip_reddit_prefix(post_id) + submission = client.submission(id=post_id) + submission.comment_sort = sort + # Replace MoreComments with actual comments up to limit + submission.comments.replace_more(limit=0) + # Return only top-level comments (depth=0), limited + # CommentForest supports indexing, so use slicing directly + max_comments = min(limit, 100) + return [ + submission.comments[i] + for i in range(min(len(submission.comments), max_comments)) + ] + + async def run( + self, input_data: Input, *, credentials: RedditCredentials, **kwargs + ) -> BlockOutput: + try: + comments = self.get_comments( + credentials, + input_data.post_id, + input_data.limit, + input_data.sort, + ) + all_comments = [] + for comment in comments: + # Extract post_id from link_id (format: t3_xxxxx) + comment_post_id = strip_reddit_prefix(comment.link_id) + + # parent_comment_id is None for top-level comments (parent is a post: t3_) + # For replies, extract the comment ID from t1_xxxxx + parent_comment_id = None + if comment.parent_id.startswith("t1_"): + parent_comment_id = strip_reddit_prefix(comment.parent_id) + + comment_data = RedditComment( + comment_id=comment.id, + post_id=comment_post_id, + parent_comment_id=parent_comment_id, + author=str(comment.author) if comment.author else "[deleted]", + body=comment.body, + score=comment.score, + created_utc=comment.created_utc, + edited=bool(comment.edited), + is_submitter=comment.is_submitter, + permalink=comment.permalink, + depth=comment.depth, + ) + all_comments.append(comment_data) + yield "comment", comment_data + yield "comments", all_comments + yield "post_id", input_data.post_id + except Exception as e: + yield "error", str(e) + + +class GetRedditCommentRepliesBlock(Block): + """Get replies to a specific Reddit comment.""" + + class Input(BlockSchemaInput): + credentials: RedditCredentialsInput = RedditCredentialsField() + comment_id: str = SchemaField( + description="The ID of the comment to get replies from", + ) + post_id: str = SchemaField( + description="The ID of the post containing the comment", + ) + limit: int = SchemaField( + description="Maximum number of replies to fetch (max 50)", + default=10, + ) + + class Output(BlockSchemaOutput): + reply: RedditComment = SchemaField(description="A reply to the comment") + replies: list[RedditComment] = SchemaField(description="All replies") + comment_id: str = SchemaField( + description="The parent comment ID (pass-through for chaining)" + ) + post_id: str = SchemaField( + description="The post ID (pass-through for chaining)" + ) + error: str = SchemaField( + description="Error message if replies couldn't be fetched" + ) + + def __init__(self): + super().__init__( + id="7fa83965-7289-432f-98a9-1575f5bcc8f1", + description="Get replies to a specific Reddit comment.", + categories={BlockCategory.SOCIAL}, + input_schema=GetRedditCommentRepliesBlock.Input, + output_schema=GetRedditCommentRepliesBlock.Output, + disabled=( + not settings.secrets.reddit_client_id + or not settings.secrets.reddit_client_secret + ), + test_credentials=TEST_CREDENTIALS, + test_input={ + "credentials": TEST_CREDENTIALS_INPUT, + "comment_id": "comment1", + "post_id": "abc123", + "limit": 2, + }, + test_output=[ + ( + "reply", + RedditComment( + comment_id="reply1", + post_id="abc123", + parent_comment_id="comment1", + author="replier1", + body="Reply body 1", + score=3, + created_utc=1234567892.0, + edited=False, + is_submitter=False, + permalink="/r/test/comments/abc123/test/reply1/", + depth=1, + ), + ), + ( + "replies", + [ + RedditComment( + comment_id="reply1", + post_id="abc123", + parent_comment_id="comment1", + author="replier1", + body="Reply body 1", + score=3, + created_utc=1234567892.0, + edited=False, + is_submitter=False, + permalink="/r/test/comments/abc123/test/reply1/", + depth=1, + ), + ], + ), + ("comment_id", "comment1"), + ("post_id", "abc123"), + ], + test_mock={ + "get_replies": lambda creds, comment_id, post_id, limit: [ + MockObject( + id="reply1", + link_id="t3_abc123", + parent_id="t1_comment1", + author="replier1", + body="Reply body 1", + score=3, + created_utc=1234567892.0, + edited=False, + is_submitter=False, + permalink="/r/test/comments/abc123/test/reply1/", + depth=1, + ), + ] + }, + ) + + @staticmethod + def get_replies( + creds: RedditCredentials, comment_id: str, post_id: str, limit: int + ) -> list[Comment]: + client = get_praw(creds) + post_id = strip_reddit_prefix(post_id) + comment_id = strip_reddit_prefix(comment_id) + + # Get the submission and find the comment + submission = client.submission(id=post_id) + submission.comments.replace_more(limit=0) + + # Find the target comment - filter out MoreComments which don't have .id + comment = None + for c in submission.comments.list(): + if isinstance(c, MoreComments): + continue + if c.id == comment_id: + comment = c + break + + if not comment: + return [] + + # Get direct replies - filter out MoreComments objects + replies = [] + # CommentForest supports indexing + for i in range(len(comment.replies)): + reply = comment.replies[i] + if isinstance(reply, MoreComments): + continue + replies.append(reply) + if len(replies) >= min(limit, 50): + break + + return replies + + async def run( + self, input_data: Input, *, credentials: RedditCredentials, **kwargs + ) -> BlockOutput: + try: + replies = self.get_replies( + credentials, + input_data.comment_id, + input_data.post_id, + input_data.limit, + ) + all_replies = [] + for reply in replies: + reply_post_id = strip_reddit_prefix(reply.link_id) + + # parent_comment_id is the parent comment (always present for replies) + parent_comment_id = None + if reply.parent_id.startswith("t1_"): + parent_comment_id = strip_reddit_prefix(reply.parent_id) + + reply_data = RedditComment( + comment_id=reply.id, + post_id=reply_post_id, + parent_comment_id=parent_comment_id, + author=str(reply.author) if reply.author else "[deleted]", + body=reply.body, + score=reply.score, + created_utc=reply.created_utc, + edited=bool(reply.edited), + is_submitter=reply.is_submitter, + permalink=reply.permalink, + depth=reply.depth, + ) + all_replies.append(reply_data) + yield "reply", reply_data + yield "replies", all_replies + yield "comment_id", input_data.comment_id + yield "post_id", input_data.post_id + except Exception as e: + yield "error", str(e) + + +class GetRedditCommentBlock(Block): + """Get details about a specific Reddit comment.""" + + class Input(BlockSchemaInput): + credentials: RedditCredentialsInput = RedditCredentialsField() + comment_id: str = SchemaField( + description="The ID of the comment to fetch", + ) + + class Output(BlockSchemaOutput): + comment: RedditComment = SchemaField(description="The comment details") + error: str = SchemaField( + description="Error message if comment couldn't be fetched" + ) + + def __init__(self): + super().__init__( + id="72cb311a-5998-4e0a-9bc4-f1b67a97284e", + description="Get details about a specific Reddit comment by its ID.", + categories={BlockCategory.SOCIAL}, + input_schema=GetRedditCommentBlock.Input, + output_schema=GetRedditCommentBlock.Output, + disabled=( + not settings.secrets.reddit_client_id + or not settings.secrets.reddit_client_secret + ), + test_credentials=TEST_CREDENTIALS, + test_input={ + "credentials": TEST_CREDENTIALS_INPUT, + "comment_id": "comment1", + }, + test_output=[ + ( + "comment", + RedditComment( + comment_id="comment1", + post_id="abc123", + parent_comment_id=None, + author="user1", + body="Comment body", + score=10, + created_utc=1234567890.0, + edited=False, + is_submitter=False, + permalink="/r/test/comments/abc123/test/comment1/", + depth=0, + ), + ), + ], + test_mock={ + "get_comment": lambda creds, comment_id: MockObject( + id="comment1", + link_id="t3_abc123", + parent_id="t3_abc123", + author="user1", + body="Comment body", + score=10, + created_utc=1234567890.0, + edited=False, + is_submitter=False, + permalink="/r/test/comments/abc123/test/comment1/", + depth=0, + ) + }, + ) + + @staticmethod + def get_comment(creds: RedditCredentials, comment_id: str): + client = get_praw(creds) + comment_id = strip_reddit_prefix(comment_id) + return client.comment(id=comment_id) + + async def run( + self, input_data: Input, *, credentials: RedditCredentials, **kwargs + ) -> BlockOutput: + try: + comment = self.get_comment(credentials, input_data.comment_id) + + post_id = strip_reddit_prefix(comment.link_id) + + # parent_comment_id is None for top-level comments (parent is a post: t3_) + parent_comment_id = None + if comment.parent_id.startswith("t1_"): + parent_comment_id = strip_reddit_prefix(comment.parent_id) + + comment_data = RedditComment( + comment_id=comment.id, + post_id=post_id, + parent_comment_id=parent_comment_id, + author=str(comment.author) if comment.author else "[deleted]", + body=comment.body, + score=comment.score, + created_utc=comment.created_utc, + edited=bool(comment.edited), + is_submitter=comment.is_submitter, + permalink=comment.permalink, + # depth is only available when comments are fetched as part of a tree, + # not when fetched directly by ID + depth=getattr(comment, "depth", 0), + ) + yield "comment", comment_data + except Exception as e: + yield "error", str(e) + + +class ReplyToRedditCommentBlock(Block): + """Reply to a specific Reddit comment.""" + + class Input(BlockSchemaInput): + credentials: RedditCredentialsInput = RedditCredentialsField() + comment_id: str = SchemaField( + description="The ID of the comment to reply to", + ) + reply_text: str = SchemaField( + description="The text content of the reply", + ) + + class Output(BlockSchemaOutput): + comment_id: str = SchemaField(description="ID of the newly created reply") + parent_comment_id: str = SchemaField( + description="The parent comment ID (pass-through for chaining)" + ) + error: str = SchemaField(description="Error message if reply failed") + + def __init__(self): + super().__init__( + id="7635b059-3a9f-4f7d-b499-1b56c4f76f4f", + description="Reply to a specific Reddit comment. Useful for threaded conversations.", + categories={BlockCategory.SOCIAL}, + input_schema=ReplyToRedditCommentBlock.Input, + output_schema=ReplyToRedditCommentBlock.Output, + disabled=( + not settings.secrets.reddit_client_id + or not settings.secrets.reddit_client_secret + ), + test_credentials=TEST_CREDENTIALS, + test_input={ + "credentials": TEST_CREDENTIALS_INPUT, + "comment_id": "parent_comment", + "reply_text": "This is a reply", + }, + test_output=[ + ("comment_id", "new_reply_id"), + ("parent_comment_id", "parent_comment"), + ], + test_mock={ + "reply_to_comment": lambda creds, comment_id, reply_text: "new_reply_id" + }, + ) + + @staticmethod + def reply_to_comment( + creds: RedditCredentials, comment_id: str, reply_text: str + ) -> str: + client = get_praw(creds) + comment_id = strip_reddit_prefix(comment_id) + comment = client.comment(id=comment_id) + reply = comment.reply(reply_text) + if not reply: + raise ValueError("Failed to post reply.") + return reply.id + + async def run( + self, input_data: Input, *, credentials: RedditCredentials, **kwargs + ) -> BlockOutput: + try: + new_comment_id = self.reply_to_comment( + credentials, input_data.comment_id, input_data.reply_text + ) + yield "comment_id", new_comment_id + yield "parent_comment_id", input_data.comment_id + except Exception as e: + yield "error", str(e) + + +class RedditUserProfileSubreddit(BaseModel): + """Information about a user's profile subreddit.""" + + name: str + title: str + public_description: str + subscribers: int + over_18: bool + + +class RedditUserInfo(BaseModel): + """Information about a Reddit user.""" + + username: str + user_id: str + comment_karma: int + link_karma: int + total_karma: int + created_utc: float + is_gold: bool + is_mod: bool + has_verified_email: bool + moderated_subreddits: list[str] + profile_subreddit: RedditUserProfileSubreddit | None + + +class GetRedditUserInfoBlock(Block): + """Get information about a Reddit user.""" + + class Input(BlockSchemaInput): + credentials: RedditCredentialsInput = RedditCredentialsField() + username: str = SchemaField( + description="The Reddit username to look up (without /u/ prefix)", + ) + + class Output(BlockSchemaOutput): + user: RedditUserInfo = SchemaField(description="User information") + username: str = SchemaField( + description="The username (pass-through for chaining)" + ) + error: str = SchemaField(description="Error message if user lookup failed") + + def __init__(self): + super().__init__( + id="1b4c6bd1-4f28-4bad-9ae9-e7034a0f61ff", + description="Get information about a Reddit user including karma, account age, and verification status.", + categories={BlockCategory.SOCIAL}, + input_schema=GetRedditUserInfoBlock.Input, + output_schema=GetRedditUserInfoBlock.Output, + disabled=( + not settings.secrets.reddit_client_id + or not settings.secrets.reddit_client_secret + ), + test_credentials=TEST_CREDENTIALS, + test_input={ + "credentials": TEST_CREDENTIALS_INPUT, + "username": "testuser", + }, + test_output=[ + ( + "user", + RedditUserInfo( + username="testuser", + user_id="abc123", + comment_karma=1000, + link_karma=500, + total_karma=1500, + created_utc=1234567890.0, + is_gold=False, + is_mod=True, + has_verified_email=True, + moderated_subreddits=["python", "learnpython"], + profile_subreddit=RedditUserProfileSubreddit( + name="u_testuser", + title="testuser's profile", + public_description="A test user", + subscribers=100, + over_18=False, + ), + ), + ), + ("username", "testuser"), + ], + test_mock={ + "get_user_info": lambda creds, username: MockObject( + name="testuser", + id="abc123", + comment_karma=1000, + link_karma=500, + total_karma=1500, + created_utc=1234567890.0, + is_gold=False, + is_mod=True, + has_verified_email=True, + subreddit=MockObject( + display_name="u_testuser", + title="testuser's profile", + public_description="A test user", + subscribers=100, + over18=False, + ), + ), + "get_moderated_subreddits": lambda creds, username: [ + MockObject(display_name="python"), + MockObject(display_name="learnpython"), + ], + }, + ) + + @staticmethod + def get_user_info(creds: RedditCredentials, username: str): + client = get_praw(creds) + if username.startswith("u/"): + username = username[2:] + return client.redditor(username) + + @staticmethod + def get_moderated_subreddits(creds: RedditCredentials, username: str) -> list: + client = get_praw(creds) + if username.startswith("u/"): + username = username[2:] + redditor = client.redditor(username) + return list(redditor.moderated()) + + async def run( + self, input_data: Input, *, credentials: RedditCredentials, **kwargs + ) -> BlockOutput: + try: + redditor = self.get_user_info(credentials, input_data.username) + moderated = self.get_moderated_subreddits(credentials, input_data.username) + + # Extract moderated subreddit names + moderated_subreddits = [sub.display_name for sub in moderated] + + # Get profile subreddit info if available + profile_subreddit = None + if hasattr(redditor, "subreddit") and redditor.subreddit: + try: + profile_subreddit = RedditUserProfileSubreddit( + name=redditor.subreddit.display_name, + title=redditor.subreddit.title or "", + public_description=redditor.subreddit.public_description or "", + subscribers=redditor.subreddit.subscribers or 0, + over_18=( + redditor.subreddit.over18 + if hasattr(redditor.subreddit, "over18") + else False + ), + ) + except Exception: + # Profile subreddit may not be accessible + pass + + user_info = RedditUserInfo( + username=redditor.name, + user_id=redditor.id, + comment_karma=redditor.comment_karma, + link_karma=redditor.link_karma, + total_karma=redditor.total_karma, + created_utc=redditor.created_utc, + is_gold=redditor.is_gold, + is_mod=redditor.is_mod, + has_verified_email=redditor.has_verified_email, + moderated_subreddits=moderated_subreddits, + profile_subreddit=profile_subreddit, + ) + yield "user", user_info + yield "username", input_data.username + except Exception as e: + yield "error", str(e) + + +class SendRedditMessageBlock(Block): + """Send a private message to a Reddit user.""" + + class Input(BlockSchemaInput): + credentials: RedditCredentialsInput = RedditCredentialsField() + username: str = SchemaField( + description="The Reddit username to send a message to (without /u/ prefix)", + ) + subject: str = SchemaField( + description="The subject line of the message", + ) + message: str = SchemaField( + description="The body content of the message", + ) + + class Output(BlockSchemaOutput): + success: bool = SchemaField(description="Whether the message was sent") + username: str = SchemaField( + description="The username (pass-through for chaining)" + ) + error: str = SchemaField(description="Error message if sending failed") + + def __init__(self): + super().__init__( + id="7921101a-0537-4259-82ea-bc186ca6b1b6", + description="Send a private message (DM) to a Reddit user.", + categories={BlockCategory.SOCIAL}, + input_schema=SendRedditMessageBlock.Input, + output_schema=SendRedditMessageBlock.Output, + disabled=( + not settings.secrets.reddit_client_id + or not settings.secrets.reddit_client_secret + ), + test_credentials=TEST_CREDENTIALS, + test_input={ + "credentials": TEST_CREDENTIALS_INPUT, + "username": "testuser", + "subject": "Hello", + "message": "This is a test message", + }, + test_output=[ + ("success", True), + ("username", "testuser"), + ], + test_mock={"send_message": lambda creds, username, subject, message: True}, + ) + + @staticmethod + def send_message( + creds: RedditCredentials, username: str, subject: str, message: str + ) -> bool: + client = get_praw(creds) + if username.startswith("u/"): + username = username[2:] + redditor = client.redditor(username) + redditor.message(subject=subject, message=message) + return True + + async def run( + self, input_data: Input, *, credentials: RedditCredentials, **kwargs + ) -> BlockOutput: + try: + success = self.send_message( + credentials, + input_data.username, + input_data.subject, + input_data.message, + ) + yield "success", success + yield "username", input_data.username + except Exception as e: + yield "error", str(e) + + +class RedditInboxItem(BaseModel): + """A Reddit inbox item (message, comment reply, or mention).""" + + item_id: str + item_type: str # "message", "comment_reply", "mention" + subject: str + body: str + author: str + created_utc: float + is_read: bool + context: str | None # permalink for comments, None for messages + + +class GetRedditInboxBlock(Block): + """Get messages and notifications from Reddit inbox.""" + + class Input(BlockSchemaInput): + credentials: RedditCredentialsInput = RedditCredentialsField() + inbox_type: InboxType = SchemaField( + description="Type of inbox items to fetch", + default="unread", + ) + limit: int = SchemaField( + description="Maximum number of items to fetch", + default=25, + ) + mark_read: bool = SchemaField( + description="Whether to mark fetched items as read", + default=False, + ) + + class Output(BlockSchemaOutput): + item: RedditInboxItem = SchemaField(description="An inbox item") + items: list[RedditInboxItem] = SchemaField(description="All fetched items") + error: str = SchemaField(description="Error message if fetch failed") + + def __init__(self): + super().__init__( + id="5a91bb34-7ffe-4b9e-957b-9d4f8fe8dbc9", + description="Get messages, mentions, and comment replies from your Reddit inbox.", + categories={BlockCategory.SOCIAL}, + input_schema=GetRedditInboxBlock.Input, + output_schema=GetRedditInboxBlock.Output, + disabled=( + not settings.secrets.reddit_client_id + or not settings.secrets.reddit_client_secret + ), + test_credentials=TEST_CREDENTIALS, + test_input={ + "credentials": TEST_CREDENTIALS_INPUT, + "inbox_type": "unread", + "limit": 10, + }, + test_output=[ + ( + "item", + RedditInboxItem( + item_id="msg123", + item_type="message", + subject="Hello", + body="Test message body", + author="sender_user", + created_utc=1234567890.0, + is_read=False, + context=None, + ), + ), + ( + "items", + [ + RedditInboxItem( + item_id="msg123", + item_type="message", + subject="Hello", + body="Test message body", + author="sender_user", + created_utc=1234567890.0, + is_read=False, + context=None, + ), + ], + ), + ], + test_mock={ + "get_inbox": lambda creds, inbox_type, limit: [ + MockObject( + id="msg123", + subject="Hello", + body="Test message body", + author="sender_user", + created_utc=1234567890.0, + new=True, + context=None, + was_comment=False, + ), + ] + }, + ) + + @staticmethod + def get_inbox(creds: RedditCredentials, inbox_type: InboxType, limit: int) -> list: + client = get_praw(creds) + inbox = client.inbox + + if inbox_type == "all": + items = inbox.all(limit=limit) + elif inbox_type == "unread": + items = inbox.unread(limit=limit) + elif inbox_type == "messages": + items = inbox.messages(limit=limit) + elif inbox_type == "mentions": + items = inbox.mentions(limit=limit) + elif inbox_type == "comment_replies": + items = inbox.comment_replies(limit=limit) + else: + items = inbox.unread(limit=limit) + + return list(items) + + async def run( + self, input_data: Input, *, credentials: RedditCredentials, **kwargs + ) -> BlockOutput: + try: + raw_items = self.get_inbox( + credentials, input_data.inbox_type, input_data.limit + ) + all_items = [] + + for item in raw_items: + # Determine item type + if hasattr(item, "was_comment") and item.was_comment: + if hasattr(item, "subject") and "mention" in item.subject.lower(): + item_type = "mention" + else: + item_type = "comment_reply" + else: + item_type = "message" + + inbox_item = RedditInboxItem( + item_id=item.id, + item_type=item_type, + subject=item.subject if hasattr(item, "subject") else "", + body=item.body, + author=str(item.author) if item.author else "[deleted]", + created_utc=item.created_utc, + is_read=not item.new, + context=item.context if hasattr(item, "context") else None, + ) + all_items.append(inbox_item) + yield "item", inbox_item + + # Mark as read if requested + if input_data.mark_read and raw_items: + client = get_praw(credentials) + client.inbox.mark_read(raw_items) + + yield "items", all_items + except Exception as e: + yield "error", str(e) + + +class DeleteRedditPostBlock(Block): + """Delete a Reddit post that you own.""" + + class Input(BlockSchemaInput): + credentials: RedditCredentialsInput = RedditCredentialsField() + post_id: str = SchemaField( + description="The ID of the post to delete (must be your own post)", + ) + + class Output(BlockSchemaOutput): + success: bool = SchemaField(description="Whether the deletion was successful") + post_id: str = SchemaField( + description="The post ID (pass-through for chaining)" + ) + error: str = SchemaField(description="Error message if deletion failed") + + def __init__(self): + super().__init__( + id="72e4730a-d66d-4785-8e54-5ab3af450c81", + description="Delete a Reddit post that you own.", + categories={BlockCategory.SOCIAL}, + input_schema=DeleteRedditPostBlock.Input, + output_schema=DeleteRedditPostBlock.Output, + disabled=( + not settings.secrets.reddit_client_id + or not settings.secrets.reddit_client_secret + ), + test_credentials=TEST_CREDENTIALS, + test_input={ + "credentials": TEST_CREDENTIALS_INPUT, + "post_id": "abc123", + }, + test_output=[ + ("success", True), + ("post_id", "abc123"), + ], + test_mock={"delete_post": lambda creds, post_id: True}, + ) + + @staticmethod + def delete_post(creds: RedditCredentials, post_id: str) -> bool: + client = get_praw(creds) + post_id = strip_reddit_prefix(post_id) + submission = client.submission(id=post_id) + submission.delete() + return True + + async def run( + self, input_data: Input, *, credentials: RedditCredentials, **kwargs + ) -> BlockOutput: + try: + success = self.delete_post(credentials, input_data.post_id) + yield "success", success + yield "post_id", input_data.post_id + except Exception as e: + yield "error", str(e) + + +class DeleteRedditCommentBlock(Block): + """Delete a Reddit comment that you own.""" + + class Input(BlockSchemaInput): + credentials: RedditCredentialsInput = RedditCredentialsField() + comment_id: str = SchemaField( + description="The ID of the comment to delete (must be your own comment)", + ) + + class Output(BlockSchemaOutput): + success: bool = SchemaField(description="Whether the deletion was successful") + comment_id: str = SchemaField( + description="The comment ID (pass-through for chaining)" + ) + error: str = SchemaField(description="Error message if deletion failed") + + def __init__(self): + super().__init__( + id="2650584d-434f-46db-81ef-26c8d8d41f81", + description="Delete a Reddit comment that you own.", + categories={BlockCategory.SOCIAL}, + input_schema=DeleteRedditCommentBlock.Input, + output_schema=DeleteRedditCommentBlock.Output, + disabled=( + not settings.secrets.reddit_client_id + or not settings.secrets.reddit_client_secret + ), + test_credentials=TEST_CREDENTIALS, + test_input={ + "credentials": TEST_CREDENTIALS_INPUT, + "comment_id": "xyz789", + }, + test_output=[ + ("success", True), + ("comment_id", "xyz789"), + ], + test_mock={"delete_comment": lambda creds, comment_id: True}, + ) + + @staticmethod + def delete_comment(creds: RedditCredentials, comment_id: str) -> bool: + client = get_praw(creds) + comment_id = strip_reddit_prefix(comment_id) + comment = client.comment(id=comment_id) + comment.delete() + return True + + async def run( + self, input_data: Input, *, credentials: RedditCredentials, **kwargs + ) -> BlockOutput: + try: + success = self.delete_comment(credentials, input_data.comment_id) + yield "success", success + yield "comment_id", input_data.comment_id + except Exception as e: + yield "error", str(e) + + +class SubredditFlair(BaseModel): + """A subreddit link flair template.""" + + flair_id: str + text: str + text_editable: bool + css_class: str = "" # The CSS class for styling (from flair_css_class) + + +class GetSubredditFlairsBlock(Block): + """Get available link flairs for a subreddit.""" + + class Input(BlockSchemaInput): + credentials: RedditCredentialsInput = RedditCredentialsField() + subreddit: str = SchemaField( + description="Subreddit name (without /r/ prefix)", + ) + + class Output(BlockSchemaOutput): + flair: SubredditFlair = SchemaField(description="A flair option") + flairs: list[SubredditFlair] = SchemaField(description="All available flairs") + subreddit: str = SchemaField( + description="The subreddit name (pass-through for chaining)" + ) + error: str = SchemaField(description="Error message if fetch failed") + + def __init__(self): + super().__init__( + id="ada08f34-a7a9-44aa-869f-0638fa4e0a84", + description="Get available link flair options for a subreddit.", + categories={BlockCategory.SOCIAL}, + input_schema=GetSubredditFlairsBlock.Input, + output_schema=GetSubredditFlairsBlock.Output, + disabled=( + not settings.secrets.reddit_client_id + or not settings.secrets.reddit_client_secret + ), + test_credentials=TEST_CREDENTIALS, + test_input={ + "credentials": TEST_CREDENTIALS_INPUT, + "subreddit": "test", + }, + test_output=[ + ( + "flair", + SubredditFlair( + flair_id="abc123", + text="Discussion", + text_editable=False, + css_class="discussion", + ), + ), + ( + "flairs", + [ + SubredditFlair( + flair_id="abc123", + text="Discussion", + text_editable=False, + css_class="discussion", + ), + ], + ), + ("subreddit", "test"), + ], + test_mock={ + "get_flairs": lambda creds, subreddit: [ + { + "flair_template_id": "abc123", + "flair_text": "Discussion", + "flair_text_editable": False, + "flair_css_class": "discussion", + }, + ] + }, + ) + + @staticmethod + def get_flairs(creds: RedditCredentials, subreddit: str) -> list: + client = get_praw(creds) + # Use /r/{subreddit}/api/flairselector endpoint directly with is_newlink=True + # This returns link flairs available for new submissions without requiring mod access + # The link_templates API is moderator-only, so we use flairselector instead + # Path must include the subreddit prefix per Reddit API docs + response = client.post( + f"r/{subreddit}/api/flairselector", + data={"is_newlink": "true"}, + ) + # Response contains 'choices' list with available flairs + choices = response.get("choices", []) + return choices + + async def run( + self, input_data: Input, *, credentials: RedditCredentials, **kwargs + ) -> BlockOutput: + try: + raw_flairs = self.get_flairs(credentials, input_data.subreddit) + all_flairs = [] + + for flair in raw_flairs: + # /api/flairselector returns flairs with flair_template_id, flair_text, etc. + flair_data = SubredditFlair( + flair_id=flair.get("flair_template_id", ""), + text=flair.get("flair_text", ""), + text_editable=flair.get("flair_text_editable", False), + css_class=flair.get("flair_css_class", ""), + ) + all_flairs.append(flair_data) + yield "flair", flair_data + + yield "flairs", all_flairs + yield "subreddit", input_data.subreddit + except Exception as e: + yield "error", str(e) + + +class SubredditRule(BaseModel): + """A subreddit rule.""" + + short_name: str + description: str + kind: str # "all", "link", "comment" + violation_reason: str + priority: int + + +class GetSubredditRulesBlock(Block): + """Get the rules for a subreddit.""" + + class Input(BlockSchemaInput): + credentials: RedditCredentialsInput = RedditCredentialsField() + subreddit: str = SchemaField( + description="Subreddit name (without /r/ prefix)", + ) + + class Output(BlockSchemaOutput): + rule: SubredditRule = SchemaField(description="A subreddit rule") + rules: list[SubredditRule] = SchemaField(description="All subreddit rules") + subreddit: str = SchemaField( + description="The subreddit name (pass-through for chaining)" + ) + error: str = SchemaField(description="Error message if fetch failed") + + def __init__(self): + super().__init__( + id="222aa36c-fa70-4879-8e8a-37d100175f5c", + description="Get the rules for a subreddit to ensure compliance before posting.", + categories={BlockCategory.SOCIAL}, + input_schema=GetSubredditRulesBlock.Input, + output_schema=GetSubredditRulesBlock.Output, + disabled=( + not settings.secrets.reddit_client_id + or not settings.secrets.reddit_client_secret + ), + test_credentials=TEST_CREDENTIALS, + test_input={ + "credentials": TEST_CREDENTIALS_INPUT, + "subreddit": "test", + }, + test_output=[ + ( + "rule", + SubredditRule( + short_name="No spam", + description="Do not post spam or self-promotional content.", + kind="all", + violation_reason="Spam", + priority=0, + ), + ), + ( + "rules", + [ + SubredditRule( + short_name="No spam", + description="Do not post spam or self-promotional content.", + kind="all", + violation_reason="Spam", + priority=0, + ), + ], + ), + ("subreddit", "test"), + ], + test_mock={ + "get_rules": lambda creds, subreddit: [ + MockObject( + short_name="No spam", + description="Do not post spam or self-promotional content.", + kind="all", + violation_reason="Spam", + priority=0, + ), + ] + }, + ) + + @staticmethod + def get_rules(creds: RedditCredentials, subreddit: str) -> list: + client = get_praw(creds) + sub = client.subreddit(subreddit) + return list(sub.rules) + + async def run( + self, input_data: Input, *, credentials: RedditCredentials, **kwargs + ) -> BlockOutput: + try: + raw_rules = self.get_rules(credentials, input_data.subreddit) + all_rules = [] + + for idx, rule in enumerate(raw_rules): + rule_data = SubredditRule( + short_name=rule.short_name, + description=rule.description or "", + kind=rule.kind, + violation_reason=rule.violation_reason or rule.short_name, + priority=idx, + ) + all_rules.append(rule_data) + yield "rule", rule_data + + yield "rules", all_rules + yield "subreddit", input_data.subreddit + except Exception as e: + yield "error", str(e) diff --git a/autogpt_platform/backend/backend/blocks/smart_decision_maker.py b/autogpt_platform/backend/backend/blocks/smart_decision_maker.py index 751f6af37f..ff6042eaab 100644 --- a/autogpt_platform/backend/backend/blocks/smart_decision_maker.py +++ b/autogpt_platform/backend/backend/blocks/smart_decision_maker.py @@ -391,8 +391,12 @@ class SmartDecisionMakerBlock(Block): """ block = sink_node.block + # Use custom name from node metadata if set, otherwise fall back to block.name + custom_name = sink_node.metadata.get("customized_name") + tool_name = custom_name if custom_name else block.name + tool_function: dict[str, Any] = { - "name": SmartDecisionMakerBlock.cleanup(block.name), + "name": SmartDecisionMakerBlock.cleanup(tool_name), "description": block.description, } sink_block_input_schema = block.input_schema @@ -489,14 +493,24 @@ class SmartDecisionMakerBlock(Block): f"Sink graph metadata not found: {graph_id} {graph_version}" ) + # Use custom name from node metadata if set, otherwise fall back to graph name + custom_name = sink_node.metadata.get("customized_name") + tool_name = custom_name if custom_name else sink_graph_meta.name + tool_function: dict[str, Any] = { - "name": SmartDecisionMakerBlock.cleanup(sink_graph_meta.name), + "name": SmartDecisionMakerBlock.cleanup(tool_name), "description": sink_graph_meta.description, } properties = {} + field_mapping = {} for link in links: + field_name = link.sink_name + + clean_field_name = SmartDecisionMakerBlock.cleanup(field_name) + field_mapping[clean_field_name] = field_name + sink_block_input_schema = sink_node.input_default["input_schema"] sink_block_properties = sink_block_input_schema.get("properties", {}).get( link.sink_name, {} @@ -506,7 +520,7 @@ class SmartDecisionMakerBlock(Block): if "description" in sink_block_properties else f"The {link.sink_name} of the tool" ) - properties[link.sink_name] = { + properties[clean_field_name] = { "type": "string", "description": description, "default": json.dumps(sink_block_properties.get("default", None)), @@ -519,7 +533,7 @@ class SmartDecisionMakerBlock(Block): "strict": True, } - # Store node info for later use in output processing + tool_function["_field_mapping"] = field_mapping tool_function["_sink_node_id"] = sink_node.id return {"type": "function", "function": tool_function} @@ -975,10 +989,28 @@ class SmartDecisionMakerBlock(Block): graph_version: int, execution_context: ExecutionContext, execution_processor: "ExecutionProcessor", + nodes_to_skip: set[str] | None = None, **kwargs, ) -> BlockOutput: tool_functions = await self._create_tool_node_signatures(node_id) + original_tool_count = len(tool_functions) + + # Filter out tools for nodes that should be skipped (e.g., missing optional credentials) + if nodes_to_skip: + tool_functions = [ + tf + for tf in tool_functions + if tf.get("function", {}).get("_sink_node_id") not in nodes_to_skip + ] + + # Only raise error if we had tools but they were all filtered out + if original_tool_count > 0 and not tool_functions: + raise ValueError( + "No available tools to execute - all downstream nodes are unavailable " + "(possibly due to missing optional credentials)" + ) + yield "tool_functions", json.dumps(tool_functions) conversation_history = input_data.conversation_history or [] @@ -1129,8 +1161,9 @@ class SmartDecisionMakerBlock(Block): original_field_name = field_mapping.get(clean_arg_name, clean_arg_name) arg_value = tool_args.get(clean_arg_name) - sanitized_arg_name = self.cleanup(original_field_name) - emit_key = f"tools_^_{sink_node_id}_~_{sanitized_arg_name}" + # Use original_field_name directly (not sanitized) to match link sink_name + # The field_mapping already translates from LLM's cleaned names to original names + emit_key = f"tools_^_{sink_node_id}_~_{original_field_name}" logger.debug( "[SmartDecisionMakerBlock|geid:%s|neid:%s] emit %s", diff --git a/autogpt_platform/backend/backend/blocks/test/test_smart_decision_maker.py b/autogpt_platform/backend/backend/blocks/test/test_smart_decision_maker.py index c930fab37e..8266d433ad 100644 --- a/autogpt_platform/backend/backend/blocks/test/test_smart_decision_maker.py +++ b/autogpt_platform/backend/backend/blocks/test/test_smart_decision_maker.py @@ -1057,3 +1057,153 @@ async def test_smart_decision_maker_traditional_mode_default(): ) # Should yield individual tool parameters assert "tools_^_test-sink-node-id_~_max_keyword_difficulty" in outputs assert "conversations" in outputs + + +@pytest.mark.asyncio +async def test_smart_decision_maker_uses_customized_name_for_blocks(): + """Test that SmartDecisionMakerBlock uses customized_name from node metadata for tool names.""" + from unittest.mock import MagicMock + + from backend.blocks.basic import StoreValueBlock + from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock + from backend.data.graph import Link, Node + + # Create a mock node with customized_name in metadata + mock_node = MagicMock(spec=Node) + mock_node.id = "test-node-id" + mock_node.block_id = StoreValueBlock().id + mock_node.metadata = {"customized_name": "My Custom Tool Name"} + mock_node.block = StoreValueBlock() + + # Create a mock link + mock_link = MagicMock(spec=Link) + mock_link.sink_name = "input" + + # Call the function directly + result = await SmartDecisionMakerBlock._create_block_function_signature( + mock_node, [mock_link] + ) + + # Verify the tool name uses the customized name (cleaned up) + assert result["type"] == "function" + assert result["function"]["name"] == "my_custom_tool_name" # Cleaned version + assert result["function"]["_sink_node_id"] == "test-node-id" + + +@pytest.mark.asyncio +async def test_smart_decision_maker_falls_back_to_block_name(): + """Test that SmartDecisionMakerBlock falls back to block.name when no customized_name.""" + from unittest.mock import MagicMock + + from backend.blocks.basic import StoreValueBlock + from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock + from backend.data.graph import Link, Node + + # Create a mock node without customized_name + mock_node = MagicMock(spec=Node) + mock_node.id = "test-node-id" + mock_node.block_id = StoreValueBlock().id + mock_node.metadata = {} # No customized_name + mock_node.block = StoreValueBlock() + + # Create a mock link + mock_link = MagicMock(spec=Link) + mock_link.sink_name = "input" + + # Call the function directly + result = await SmartDecisionMakerBlock._create_block_function_signature( + mock_node, [mock_link] + ) + + # Verify the tool name uses the block's default name + assert result["type"] == "function" + assert result["function"]["name"] == "storevalueblock" # Default block name cleaned + assert result["function"]["_sink_node_id"] == "test-node-id" + + +@pytest.mark.asyncio +async def test_smart_decision_maker_uses_customized_name_for_agents(): + """Test that SmartDecisionMakerBlock uses customized_name from metadata for agent nodes.""" + from unittest.mock import AsyncMock, MagicMock, patch + + from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock + from backend.data.graph import Link, Node + + # Create a mock node with customized_name in metadata + mock_node = MagicMock(spec=Node) + mock_node.id = "test-agent-node-id" + mock_node.metadata = {"customized_name": "My Custom Agent"} + mock_node.input_default = { + "graph_id": "test-graph-id", + "graph_version": 1, + "input_schema": {"properties": {"test_input": {"description": "Test input"}}}, + } + + # Create a mock link + mock_link = MagicMock(spec=Link) + mock_link.sink_name = "test_input" + + # Mock the database client + mock_graph_meta = MagicMock() + mock_graph_meta.name = "Original Agent Name" + mock_graph_meta.description = "Agent description" + + mock_db_client = AsyncMock() + mock_db_client.get_graph_metadata.return_value = mock_graph_meta + + with patch( + "backend.blocks.smart_decision_maker.get_database_manager_async_client", + return_value=mock_db_client, + ): + result = await SmartDecisionMakerBlock._create_agent_function_signature( + mock_node, [mock_link] + ) + + # Verify the tool name uses the customized name (cleaned up) + assert result["type"] == "function" + assert result["function"]["name"] == "my_custom_agent" # Cleaned version + assert result["function"]["_sink_node_id"] == "test-agent-node-id" + + +@pytest.mark.asyncio +async def test_smart_decision_maker_agent_falls_back_to_graph_name(): + """Test that agent node falls back to graph name when no customized_name.""" + from unittest.mock import AsyncMock, MagicMock, patch + + from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock + from backend.data.graph import Link, Node + + # Create a mock node without customized_name + mock_node = MagicMock(spec=Node) + mock_node.id = "test-agent-node-id" + mock_node.metadata = {} # No customized_name + mock_node.input_default = { + "graph_id": "test-graph-id", + "graph_version": 1, + "input_schema": {"properties": {"test_input": {"description": "Test input"}}}, + } + + # Create a mock link + mock_link = MagicMock(spec=Link) + mock_link.sink_name = "test_input" + + # Mock the database client + mock_graph_meta = MagicMock() + mock_graph_meta.name = "Original Agent Name" + mock_graph_meta.description = "Agent description" + + mock_db_client = AsyncMock() + mock_db_client.get_graph_metadata.return_value = mock_graph_meta + + with patch( + "backend.blocks.smart_decision_maker.get_database_manager_async_client", + return_value=mock_db_client, + ): + result = await SmartDecisionMakerBlock._create_agent_function_signature( + mock_node, [mock_link] + ) + + # Verify the tool name uses the graph's default name + assert result["type"] == "function" + assert result["function"]["name"] == "original_agent_name" # Graph name cleaned + assert result["function"]["_sink_node_id"] == "test-agent-node-id" diff --git a/autogpt_platform/backend/backend/blocks/test/test_smart_decision_maker_dict.py b/autogpt_platform/backend/backend/blocks/test/test_smart_decision_maker_dict.py index 839bdc5e15..2087c0b7d6 100644 --- a/autogpt_platform/backend/backend/blocks/test/test_smart_decision_maker_dict.py +++ b/autogpt_platform/backend/backend/blocks/test/test_smart_decision_maker_dict.py @@ -15,6 +15,7 @@ async def test_smart_decision_maker_handles_dynamic_dict_fields(): mock_node.block = CreateDictionaryBlock() mock_node.block_id = CreateDictionaryBlock().id mock_node.input_default = {} + mock_node.metadata = {} # Create mock links with dynamic dictionary fields mock_links = [ @@ -77,6 +78,7 @@ async def test_smart_decision_maker_handles_dynamic_list_fields(): mock_node.block = AddToListBlock() mock_node.block_id = AddToListBlock().id mock_node.input_default = {} + mock_node.metadata = {} # Create mock links with dynamic list fields mock_links = [ diff --git a/autogpt_platform/backend/backend/blocks/test/test_smart_decision_maker_dynamic_fields.py b/autogpt_platform/backend/backend/blocks/test/test_smart_decision_maker_dynamic_fields.py index 6ed830e517..af89a83f86 100644 --- a/autogpt_platform/backend/backend/blocks/test/test_smart_decision_maker_dynamic_fields.py +++ b/autogpt_platform/backend/backend/blocks/test/test_smart_decision_maker_dynamic_fields.py @@ -44,6 +44,7 @@ async def test_create_block_function_signature_with_dict_fields(): mock_node.block = CreateDictionaryBlock() mock_node.block_id = CreateDictionaryBlock().id mock_node.input_default = {} + mock_node.metadata = {} # Create mock links with dynamic dictionary fields (source sanitized, sink original) mock_links = [ @@ -106,6 +107,7 @@ async def test_create_block_function_signature_with_list_fields(): mock_node.block = AddToListBlock() mock_node.block_id = AddToListBlock().id mock_node.input_default = {} + mock_node.metadata = {} # Create mock links with dynamic list fields mock_links = [ @@ -159,6 +161,7 @@ async def test_create_block_function_signature_with_object_fields(): mock_node.block = MatchTextPatternBlock() mock_node.block_id = MatchTextPatternBlock().id mock_node.input_default = {} + mock_node.metadata = {} # Create mock links with dynamic object fields mock_links = [ @@ -208,11 +211,13 @@ async def test_create_tool_node_signatures(): mock_dict_node.block = CreateDictionaryBlock() mock_dict_node.block_id = CreateDictionaryBlock().id mock_dict_node.input_default = {} + mock_dict_node.metadata = {} mock_list_node = Mock() mock_list_node.block = AddToListBlock() mock_list_node.block_id = AddToListBlock().id mock_list_node.input_default = {} + mock_list_node.metadata = {} # Mock links with dynamic fields dict_link1 = Mock( @@ -423,6 +428,7 @@ async def test_mixed_regular_and_dynamic_fields(): mock_node.block.name = "TestBlock" mock_node.block.description = "A test block" mock_node.block.input_schema = Mock() + mock_node.metadata = {} # Mock the get_field_schema to return a proper schema for regular fields def get_field_schema(field_name): diff --git a/autogpt_platform/backend/backend/blocks/wordpress/__init__.py b/autogpt_platform/backend/backend/blocks/wordpress/__init__.py index c7b1e26eea..3eae4a1063 100644 --- a/autogpt_platform/backend/backend/blocks/wordpress/__init__.py +++ b/autogpt_platform/backend/backend/blocks/wordpress/__init__.py @@ -1,3 +1,3 @@ -from .blog import WordPressCreatePostBlock +from .blog import WordPressCreatePostBlock, WordPressGetAllPostsBlock -__all__ = ["WordPressCreatePostBlock"] +__all__ = ["WordPressCreatePostBlock", "WordPressGetAllPostsBlock"] diff --git a/autogpt_platform/backend/backend/blocks/wordpress/_api.py b/autogpt_platform/backend/backend/blocks/wordpress/_api.py index 78f535947b..d21dc3e05d 100644 --- a/autogpt_platform/backend/backend/blocks/wordpress/_api.py +++ b/autogpt_platform/backend/backend/blocks/wordpress/_api.py @@ -161,7 +161,7 @@ async def oauth_exchange_code_for_tokens( grant_type="authorization_code", ).model_dump(exclude_none=True) - response = await Requests().post( + response = await Requests(raise_for_status=False).post( f"{WORDPRESS_BASE_URL}oauth2/token", headers=headers, data=data, @@ -205,7 +205,7 @@ async def oauth_refresh_tokens( grant_type="refresh_token", ).model_dump(exclude_none=True) - response = await Requests().post( + response = await Requests(raise_for_status=False).post( f"{WORDPRESS_BASE_URL}oauth2/token", headers=headers, data=data, @@ -252,7 +252,7 @@ async def validate_token( "token": token, } - response = await Requests().get( + response = await Requests(raise_for_status=False).get( f"{WORDPRESS_BASE_URL}oauth2/token-info", params=params, ) @@ -296,7 +296,7 @@ async def make_api_request( url = f"{WORDPRESS_BASE_URL.rstrip('/')}{endpoint}" - request_method = getattr(Requests(), method.lower()) + request_method = getattr(Requests(raise_for_status=False), method.lower()) response = await request_method( url, headers=headers, @@ -476,6 +476,7 @@ async def create_post( data["tags"] = ",".join(str(t) for t in data["tags"]) # Make the API request + site = normalize_site(site) endpoint = f"/rest/v1.1/sites/{site}/posts/new" headers = { @@ -483,7 +484,7 @@ async def create_post( "Content-Type": "application/x-www-form-urlencoded", } - response = await Requests().post( + response = await Requests(raise_for_status=False).post( f"{WORDPRESS_BASE_URL.rstrip('/')}{endpoint}", headers=headers, data=data, @@ -499,3 +500,132 @@ async def create_post( ) error_message = error_data.get("message", response.text) raise ValueError(f"Failed to create post: {response.status} - {error_message}") + + +class Post(BaseModel): + """Response model for individual posts in a posts list response. + + This is a simplified version compared to PostResponse, as the list endpoint + returns less detailed information than the create/get single post endpoints. + """ + + ID: int + site_ID: int + author: PostAuthor + date: datetime + modified: datetime + title: str + URL: str + short_URL: str + content: str | None = None + excerpt: str | None = None + slug: str + guid: str + status: str + sticky: bool + password: str | None = "" + parent: Union[Dict[str, Any], bool, None] = None + type: str + discussion: Dict[str, Union[str, bool, int]] | None = None + likes_enabled: bool | None = None + sharing_enabled: bool | None = None + like_count: int | None = None + i_like: bool | None = None + is_reblogged: bool | None = None + is_following: bool | None = None + global_ID: str | None = None + featured_image: str | None = None + post_thumbnail: Dict[str, Any] | None = None + format: str | None = None + geo: Union[Dict[str, Any], bool, None] = None + menu_order: int | None = None + page_template: str | None = None + publicize_URLs: List[str] | None = None + terms: Dict[str, Dict[str, Any]] | None = None + tags: Dict[str, Dict[str, Any]] | None = None + categories: Dict[str, Dict[str, Any]] | None = None + attachments: Dict[str, Dict[str, Any]] | None = None + attachment_count: int | None = None + metadata: List[Dict[str, Any]] | None = None + meta: Dict[str, Any] | None = None + capabilities: Dict[str, bool] | None = None + revisions: List[int] | None = None + other_URLs: Dict[str, Any] | None = None + + +class PostsResponse(BaseModel): + """Response model for WordPress posts list.""" + + found: int + posts: List[Post] + meta: Dict[str, Any] + + +def normalize_site(site: str) -> str: + """ + Normalize a site identifier by stripping protocol and trailing slashes. + + Args: + site: Site URL, domain, or ID (e.g., "https://myblog.wordpress.com/", "myblog.wordpress.com", "123456789") + + Returns: + Normalized site identifier (domain or ID only) + """ + site = site.strip() + if site.startswith("https://"): + site = site[8:] + elif site.startswith("http://"): + site = site[7:] + return site.rstrip("/") + + +async def get_posts( + credentials: Credentials, + site: str, + status: PostStatus | None = None, + number: int = 100, + offset: int = 0, +) -> PostsResponse: + """ + Get posts from a WordPress site. + + Args: + credentials: OAuth credentials + site: Site ID or domain (e.g., "myblog.wordpress.com" or "123456789") + status: Filter by post status using PostStatus enum, or None for all + number: Number of posts to retrieve (max 100) + offset: Number of posts to skip (for pagination) + + Returns: + PostsResponse with the list of posts + """ + site = normalize_site(site) + endpoint = f"/rest/v1.1/sites/{site}/posts" + + headers = { + "Authorization": credentials.auth_header(), + } + + params: Dict[str, Any] = { + "number": max(1, min(number, 100)), # 1–100 posts per request + "offset": offset, + } + + if status: + params["status"] = status.value + response = await Requests(raise_for_status=False).get( + f"{WORDPRESS_BASE_URL.rstrip('/')}{endpoint}", + headers=headers, + params=params, + ) + + if response.ok: + return PostsResponse.model_validate(response.json()) + + error_data = ( + response.json() + if response.headers.get("content-type", "").startswith("application/json") + else {} + ) + error_message = error_data.get("message", response.text) + raise ValueError(f"Failed to get posts: {response.status} - {error_message}") diff --git a/autogpt_platform/backend/backend/blocks/wordpress/blog.py b/autogpt_platform/backend/backend/blocks/wordpress/blog.py index c0ad5eca54..22b691480b 100644 --- a/autogpt_platform/backend/backend/blocks/wordpress/blog.py +++ b/autogpt_platform/backend/backend/blocks/wordpress/blog.py @@ -9,7 +9,15 @@ from backend.sdk import ( SchemaField, ) -from ._api import CreatePostRequest, PostResponse, PostStatus, create_post +from ._api import ( + CreatePostRequest, + Post, + PostResponse, + PostsResponse, + PostStatus, + create_post, + get_posts, +) from ._config import wordpress @@ -49,8 +57,15 @@ class WordPressCreatePostBlock(Block): media_urls: list[str] = SchemaField( description="URLs of images to sideload and attach to the post", default=[] ) + publish_as_draft: bool = SchemaField( + description="If True, publishes the post as a draft. If False, publishes it publicly.", + default=False, + ) class Output(BlockSchemaOutput): + site: str = SchemaField( + description="The site ID or domain (pass-through for chaining with other blocks)" + ) post_id: int = SchemaField(description="The ID of the created post") post_url: str = SchemaField(description="The full URL of the created post") short_url: str = SchemaField(description="The shortened wp.me URL") @@ -78,7 +93,9 @@ class WordPressCreatePostBlock(Block): tags=input_data.tags, featured_image=input_data.featured_image, media_urls=input_data.media_urls, - status=PostStatus.PUBLISH, + status=( + PostStatus.DRAFT if input_data.publish_as_draft else PostStatus.PUBLISH + ), ) post_response: PostResponse = await create_post( @@ -87,7 +104,69 @@ class WordPressCreatePostBlock(Block): post_data=post_request, ) + yield "site", input_data.site yield "post_id", post_response.ID yield "post_url", post_response.URL yield "short_url", post_response.short_URL yield "post_data", post_response.model_dump() + + +class WordPressGetAllPostsBlock(Block): + """ + Fetches all posts from a WordPress.com site or Jetpack-enabled site. + Supports filtering by status and pagination. + """ + + class Input(BlockSchemaInput): + credentials: CredentialsMetaInput = wordpress.credentials_field() + site: str = SchemaField( + description="Site ID or domain (e.g., 'myblog.wordpress.com' or '123456789')" + ) + status: PostStatus | None = SchemaField( + description="Filter by post status, or None for all", + default=None, + ) + number: int = SchemaField( + description="Number of posts to retrieve (max 100 per request)", default=20 + ) + offset: int = SchemaField( + description="Number of posts to skip (for pagination)", default=0 + ) + + class Output(BlockSchemaOutput): + site: str = SchemaField( + description="The site ID or domain (pass-through for chaining with other blocks)" + ) + found: int = SchemaField(description="Total number of posts found") + posts: list[Post] = SchemaField( + description="List of post objects with their details" + ) + post: Post = SchemaField( + description="Individual post object (yielded for each post)" + ) + + def __init__(self): + super().__init__( + id="97728fa7-7f6f-4789-ba0c-f2c114119536", + description="Fetch all posts from WordPress.com or Jetpack sites", + categories={BlockCategory.SOCIAL}, + input_schema=self.Input, + output_schema=self.Output, + ) + + async def run( + self, input_data: Input, *, credentials: Credentials, **kwargs + ) -> BlockOutput: + posts_response: PostsResponse = await get_posts( + credentials=credentials, + site=input_data.site, + status=input_data.status, + number=input_data.number, + offset=input_data.offset, + ) + + yield "site", input_data.site + yield "found", posts_response.found + yield "posts", posts_response.posts + for post in posts_response.posts: + yield "post", post diff --git a/autogpt_platform/backend/backend/data/block.py b/autogpt_platform/backend/backend/data/block.py index 727688dcf0..24a68cca03 100644 --- a/autogpt_platform/backend/backend/data/block.py +++ b/autogpt_platform/backend/backend/data/block.py @@ -50,6 +50,8 @@ from .model import ( logger = logging.getLogger(__name__) if TYPE_CHECKING: + from backend.data.execution import ExecutionContext + from .graph import Link app_config = Config() @@ -472,6 +474,7 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]): self.block_type = block_type self.webhook_config = webhook_config self.execution_stats: NodeExecutionStats = NodeExecutionStats() + self.requires_human_review: bool = False if self.webhook_config: if isinstance(self.webhook_config, BlockWebhookConfig): @@ -614,7 +617,77 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]): block_id=self.id, ) from ex + async def is_block_exec_need_review( + self, + input_data: BlockInput, + *, + user_id: str, + node_exec_id: str, + graph_exec_id: str, + graph_id: str, + graph_version: int, + execution_context: "ExecutionContext", + **kwargs, + ) -> tuple[bool, BlockInput]: + """ + Check if this block execution needs human review and handle the review process. + + Returns: + Tuple of (should_pause, input_data_to_use) + - should_pause: True if execution should be paused for review + - input_data_to_use: The input data to use (may be modified by reviewer) + """ + # Skip review if not required or safe mode is disabled + if not self.requires_human_review or not execution_context.safe_mode: + return False, input_data + + from backend.blocks.helpers.review import HITLReviewHelper + + # Handle the review request and get decision + decision = await HITLReviewHelper.handle_review_decision( + input_data=input_data, + user_id=user_id, + node_exec_id=node_exec_id, + graph_exec_id=graph_exec_id, + graph_id=graph_id, + graph_version=graph_version, + execution_context=execution_context, + block_name=self.name, + editable=True, + ) + + if decision is None: + # We're awaiting review - pause execution + return True, input_data + + if not decision.should_proceed: + # Review was rejected, raise an error to stop execution + raise BlockExecutionError( + message=f"Block execution rejected by reviewer: {decision.message}", + block_name=self.name, + block_id=self.id, + ) + + # Review was approved - use the potentially modified data + # ReviewResult.data must be a dict for block inputs + reviewed_data = decision.review_result.data + if not isinstance(reviewed_data, dict): + raise BlockExecutionError( + message=f"Review data must be a dict for block input, got {type(reviewed_data).__name__}", + block_name=self.name, + block_id=self.id, + ) + return False, reviewed_data + async def _execute(self, input_data: BlockInput, **kwargs) -> BlockOutput: + # Check for review requirement and get potentially modified input data + should_pause, input_data = await self.is_block_exec_need_review( + input_data, **kwargs + ) + if should_pause: + return + + # Validate the input data (original or reviewer-modified) once if error := self.input_schema.validate_data(input_data): raise BlockInputError( message=f"Unable to execute block with invalid input data: {error}", @@ -622,6 +695,7 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]): block_id=self.id, ) + # Use the validated input data async for output_name, output_data in self.run( self.input_schema(**{k: v for k, v in input_data.items() if v is not None}), **kwargs, diff --git a/autogpt_platform/backend/backend/data/execution.py b/autogpt_platform/backend/backend/data/execution.py index 020a5a1906..2759dfe179 100644 --- a/autogpt_platform/backend/backend/data/execution.py +++ b/autogpt_platform/backend/backend/data/execution.py @@ -383,6 +383,7 @@ class GraphExecutionWithNodes(GraphExecution): self, execution_context: ExecutionContext, compiled_nodes_input_masks: Optional[NodesInputMasks] = None, + nodes_to_skip: Optional[set[str]] = None, ): return GraphExecutionEntry( user_id=self.user_id, @@ -390,6 +391,7 @@ class GraphExecutionWithNodes(GraphExecution): graph_version=self.graph_version or 0, graph_exec_id=self.id, nodes_input_masks=compiled_nodes_input_masks, + nodes_to_skip=nodes_to_skip or set(), execution_context=execution_context, ) @@ -1145,6 +1147,8 @@ class GraphExecutionEntry(BaseModel): graph_id: str graph_version: int nodes_input_masks: Optional[NodesInputMasks] = None + nodes_to_skip: set[str] = Field(default_factory=set) + """Node IDs that should be skipped due to optional credentials not being configured.""" execution_context: ExecutionContext = Field(default_factory=ExecutionContext) diff --git a/autogpt_platform/backend/backend/data/graph.py b/autogpt_platform/backend/backend/data/graph.py index 0757a86f4a..e9be80892c 100644 --- a/autogpt_platform/backend/backend/data/graph.py +++ b/autogpt_platform/backend/backend/data/graph.py @@ -94,6 +94,15 @@ class Node(BaseDbModel): input_links: list[Link] = [] output_links: list[Link] = [] + @property + def credentials_optional(self) -> bool: + """ + Whether credentials are optional for this node. + When True and credentials are not configured, the node will be skipped + during execution rather than causing a validation error. + """ + return self.metadata.get("credentials_optional", False) + @property def block(self) -> AnyBlockSchema | "_UnknownBlockBase": """Get the block for this node. Returns UnknownBlock if block is deleted/missing.""" @@ -235,7 +244,10 @@ class BaseGraph(BaseDbModel): return any( node.block_id for node in self.nodes - if node.block.block_type == BlockType.HUMAN_IN_THE_LOOP + if ( + node.block.block_type == BlockType.HUMAN_IN_THE_LOOP + or node.block.requires_human_review + ) ) @property @@ -326,7 +338,35 @@ class Graph(BaseGraph): @computed_field @property def credentials_input_schema(self) -> dict[str, Any]: - return self._credentials_input_schema.jsonschema() + schema = self._credentials_input_schema.jsonschema() + + # Determine which credential fields are required based on credentials_optional metadata + graph_credentials_inputs = self.aggregate_credentials_inputs() + required_fields = [] + + # Build a map of node_id -> node for quick lookup + all_nodes = {node.id: node for node in self.nodes} + for sub_graph in self.sub_graphs: + for node in sub_graph.nodes: + all_nodes[node.id] = node + + for field_key, ( + _field_info, + node_field_pairs, + ) in graph_credentials_inputs.items(): + # A field is required if ANY node using it has credentials_optional=False + is_required = False + for node_id, _field_name in node_field_pairs: + node = all_nodes.get(node_id) + if node and not node.credentials_optional: + is_required = True + break + + if is_required: + required_fields.append(field_key) + + schema["required"] = required_fields + return schema @property def _credentials_input_schema(self) -> type[BlockSchema]: diff --git a/autogpt_platform/backend/backend/data/graph_test.py b/autogpt_platform/backend/backend/data/graph_test.py index 044d75e0ca..eea7277eb9 100644 --- a/autogpt_platform/backend/backend/data/graph_test.py +++ b/autogpt_platform/backend/backend/data/graph_test.py @@ -396,3 +396,58 @@ async def test_access_store_listing_graph(server: SpinTestServer): created_graph.id, created_graph.version, "3e53486c-cf57-477e-ba2a-cb02dc828e1b" ) assert got_graph is not None + + +# ============================================================================ +# Tests for Optional Credentials Feature +# ============================================================================ + + +def test_node_credentials_optional_default(): + """Test that credentials_optional defaults to False when not set in metadata.""" + node = Node( + id="test_node", + block_id=StoreValueBlock().id, + input_default={}, + metadata={}, + ) + assert node.credentials_optional is False + + +def test_node_credentials_optional_true(): + """Test that credentials_optional returns True when explicitly set.""" + node = Node( + id="test_node", + block_id=StoreValueBlock().id, + input_default={}, + metadata={"credentials_optional": True}, + ) + assert node.credentials_optional is True + + +def test_node_credentials_optional_false(): + """Test that credentials_optional returns False when explicitly set to False.""" + node = Node( + id="test_node", + block_id=StoreValueBlock().id, + input_default={}, + metadata={"credentials_optional": False}, + ) + assert node.credentials_optional is False + + +def test_node_credentials_optional_with_other_metadata(): + """Test that credentials_optional works correctly with other metadata present.""" + node = Node( + id="test_node", + block_id=StoreValueBlock().id, + input_default={}, + metadata={ + "position": {"x": 100, "y": 200}, + "customized_name": "My Custom Node", + "credentials_optional": True, + }, + ) + assert node.credentials_optional is True + assert node.metadata["position"] == {"x": 100, "y": 200} + assert node.metadata["customized_name"] == "My Custom Node" diff --git a/autogpt_platform/backend/backend/executor/manager.py b/autogpt_platform/backend/backend/executor/manager.py index 75459c5a2a..39d4f984eb 100644 --- a/autogpt_platform/backend/backend/executor/manager.py +++ b/autogpt_platform/backend/backend/executor/manager.py @@ -178,6 +178,7 @@ async def execute_node( execution_processor: "ExecutionProcessor", execution_stats: NodeExecutionStats | None = None, nodes_input_masks: Optional[NodesInputMasks] = None, + nodes_to_skip: Optional[set[str]] = None, ) -> BlockOutput: """ Execute a node in the graph. This will trigger a block execution on a node, @@ -245,6 +246,7 @@ async def execute_node( "user_id": user_id, "execution_context": execution_context, "execution_processor": execution_processor, + "nodes_to_skip": nodes_to_skip or set(), } # Last-minute fetch credentials + acquire a system-wide read-write lock to prevent @@ -542,6 +544,7 @@ class ExecutionProcessor: node_exec_progress: NodeExecutionProgress, nodes_input_masks: Optional[NodesInputMasks], graph_stats_pair: tuple[GraphExecutionStats, threading.Lock], + nodes_to_skip: Optional[set[str]] = None, ) -> NodeExecutionStats: log_metadata = LogMetadata( logger=_logger, @@ -564,6 +567,7 @@ class ExecutionProcessor: db_client=db_client, log_metadata=log_metadata, nodes_input_masks=nodes_input_masks, + nodes_to_skip=nodes_to_skip, ) if isinstance(status, BaseException): raise status @@ -609,6 +613,7 @@ class ExecutionProcessor: db_client: "DatabaseManagerAsyncClient", log_metadata: LogMetadata, nodes_input_masks: Optional[NodesInputMasks] = None, + nodes_to_skip: Optional[set[str]] = None, ) -> ExecutionStatus: status = ExecutionStatus.RUNNING @@ -645,6 +650,7 @@ class ExecutionProcessor: execution_processor=self, execution_stats=stats, nodes_input_masks=nodes_input_masks, + nodes_to_skip=nodes_to_skip, ): await persist_output(output_name, output_data) @@ -956,6 +962,21 @@ class ExecutionProcessor: queued_node_exec = execution_queue.get() + # Check if this node should be skipped due to optional credentials + if queued_node_exec.node_id in graph_exec.nodes_to_skip: + log_metadata.info( + f"Skipping node execution {queued_node_exec.node_exec_id} " + f"for node {queued_node_exec.node_id} - optional credentials not configured" + ) + # Mark the node as completed without executing + # No outputs will be produced, so downstream nodes won't trigger + update_node_execution_status( + db_client=db_client, + exec_id=queued_node_exec.node_exec_id, + status=ExecutionStatus.COMPLETED, + ) + continue + log_metadata.debug( f"Dispatching node execution {queued_node_exec.node_exec_id} " f"for node {queued_node_exec.node_id}", @@ -1016,6 +1037,7 @@ class ExecutionProcessor: execution_stats, execution_stats_lock, ), + nodes_to_skip=graph_exec.nodes_to_skip, ), self.node_execution_loop, ) diff --git a/autogpt_platform/backend/backend/executor/utils.py b/autogpt_platform/backend/backend/executor/utils.py index bcd3dcf3b6..1fb2b9404f 100644 --- a/autogpt_platform/backend/backend/executor/utils.py +++ b/autogpt_platform/backend/backend/executor/utils.py @@ -239,14 +239,19 @@ async def _validate_node_input_credentials( graph: GraphModel, user_id: str, nodes_input_masks: Optional[NodesInputMasks] = None, -) -> dict[str, dict[str, str]]: +) -> tuple[dict[str, dict[str, str]], set[str]]: """ - Checks all credentials for all nodes of the graph and returns structured errors. + Checks all credentials for all nodes of the graph and returns structured errors + and a set of nodes that should be skipped due to optional missing credentials. Returns: - dict[node_id, dict[field_name, error_message]]: Credential validation errors per node + tuple[ + dict[node_id, dict[field_name, error_message]]: Credential validation errors per node, + set[node_id]: Nodes that should be skipped (optional credentials not configured) + ] """ credential_errors: dict[str, dict[str, str]] = defaultdict(dict) + nodes_to_skip: set[str] = set() for node in graph.nodes: block = node.block @@ -256,27 +261,46 @@ async def _validate_node_input_credentials( if not credentials_fields: continue + # Track if any credential field is missing for this node + has_missing_credentials = False + for field_name, credentials_meta_type in credentials_fields.items(): try: + # Check nodes_input_masks first, then input_default + field_value = None if ( nodes_input_masks and (node_input_mask := nodes_input_masks.get(node.id)) and field_name in node_input_mask ): - credentials_meta = credentials_meta_type.model_validate( - node_input_mask[field_name] - ) + field_value = node_input_mask[field_name] elif field_name in node.input_default: - credentials_meta = credentials_meta_type.model_validate( - node.input_default[field_name] - ) - else: - # Missing credentials - credential_errors[node.id][ - field_name - ] = "These credentials are required" - continue + # For optional credentials, don't use input_default - treat as missing + # This prevents stale credential IDs from failing validation + if node.credentials_optional: + field_value = None + else: + field_value = node.input_default[field_name] + + # Check if credentials are missing (None, empty, or not present) + if field_value is None or ( + isinstance(field_value, dict) and not field_value.get("id") + ): + has_missing_credentials = True + # If node has credentials_optional flag, mark for skipping instead of error + if node.credentials_optional: + continue # Don't add error, will be marked for skip after loop + else: + credential_errors[node.id][ + field_name + ] = "These credentials are required" + continue + + credentials_meta = credentials_meta_type.model_validate(field_value) + except ValidationError as e: + # Validation error means credentials were provided but invalid + # This should always be an error, even if optional credential_errors[node.id][field_name] = f"Invalid credentials: {e}" continue @@ -287,6 +311,7 @@ async def _validate_node_input_credentials( ) except Exception as e: # Handle any errors fetching credentials + # If credentials were explicitly configured but unavailable, it's an error credential_errors[node.id][ field_name ] = f"Credentials not available: {e}" @@ -313,7 +338,19 @@ async def _validate_node_input_credentials( ] = "Invalid credentials: type/provider mismatch" continue - return credential_errors + # If node has optional credentials and any are missing, mark for skipping + # But only if there are no other errors for this node + if ( + has_missing_credentials + and node.credentials_optional + and node.id not in credential_errors + ): + nodes_to_skip.add(node.id) + logger.info( + f"Node #{node.id} will be skipped: optional credentials not configured" + ) + + return credential_errors, nodes_to_skip def make_node_credentials_input_map( @@ -355,21 +392,25 @@ async def validate_graph_with_credentials( graph: GraphModel, user_id: str, nodes_input_masks: Optional[NodesInputMasks] = None, -) -> Mapping[str, Mapping[str, str]]: +) -> tuple[Mapping[str, Mapping[str, str]], set[str]]: """ - Validate graph including credentials and return structured errors per node. + Validate graph including credentials and return structured errors per node, + along with a set of nodes that should be skipped due to optional missing credentials. Returns: - dict[node_id, dict[field_name, error_message]]: Validation errors per node + tuple[ + dict[node_id, dict[field_name, error_message]]: Validation errors per node, + set[node_id]: Nodes that should be skipped (optional credentials not configured) + ] """ # Get input validation errors node_input_errors = GraphModel.validate_graph_get_errors( graph, for_run=True, nodes_input_masks=nodes_input_masks ) - # Get credential input/availability/validation errors - node_credential_input_errors = await _validate_node_input_credentials( - graph, user_id, nodes_input_masks + # Get credential input/availability/validation errors and nodes to skip + node_credential_input_errors, nodes_to_skip = ( + await _validate_node_input_credentials(graph, user_id, nodes_input_masks) ) # Merge credential errors with structural errors @@ -378,7 +419,7 @@ async def validate_graph_with_credentials( node_input_errors[node_id] = {} node_input_errors[node_id].update(field_errors) - return node_input_errors + return node_input_errors, nodes_to_skip async def _construct_starting_node_execution_input( @@ -386,7 +427,7 @@ async def _construct_starting_node_execution_input( user_id: str, graph_inputs: BlockInput, nodes_input_masks: Optional[NodesInputMasks] = None, -) -> list[tuple[str, BlockInput]]: +) -> tuple[list[tuple[str, BlockInput]], set[str]]: """ Validates and prepares the input data for executing a graph. This function checks the graph for starting nodes, validates the input data @@ -400,11 +441,14 @@ async def _construct_starting_node_execution_input( node_credentials_map: `dict[node_id, dict[input_name, CredentialsMetaInput]]` Returns: - list[tuple[str, BlockInput]]: A list of tuples, each containing the node ID and - the corresponding input data for that node. + tuple[ + list[tuple[str, BlockInput]]: A list of tuples, each containing the node ID + and the corresponding input data for that node. + set[str]: Node IDs that should be skipped (optional credentials not configured) + ] """ # Use new validation function that includes credentials - validation_errors = await validate_graph_with_credentials( + validation_errors, nodes_to_skip = await validate_graph_with_credentials( graph, user_id, nodes_input_masks ) n_error_nodes = len(validation_errors) @@ -445,7 +489,7 @@ async def _construct_starting_node_execution_input( "No starting nodes found for the graph, make sure an AgentInput or blocks with no inbound links are present as starting nodes." ) - return nodes_input + return nodes_input, nodes_to_skip async def validate_and_construct_node_execution_input( @@ -456,7 +500,7 @@ async def validate_and_construct_node_execution_input( graph_credentials_inputs: Optional[Mapping[str, CredentialsMetaInput]] = None, nodes_input_masks: Optional[NodesInputMasks] = None, is_sub_graph: bool = False, -) -> tuple[GraphModel, list[tuple[str, BlockInput]], NodesInputMasks]: +) -> tuple[GraphModel, list[tuple[str, BlockInput]], NodesInputMasks, set[str]]: """ Public wrapper that handles graph fetching, credential mapping, and validation+construction. This centralizes the logic used by both scheduler validation and actual execution. @@ -473,6 +517,7 @@ async def validate_and_construct_node_execution_input( GraphModel: Full graph object for the given `graph_id`. list[tuple[node_id, BlockInput]]: Starting node IDs with corresponding inputs. dict[str, BlockInput]: Node input masks including all passed-in credentials. + set[str]: Node IDs that should be skipped (optional credentials not configured). Raises: NotFoundError: If the graph is not found. @@ -514,14 +559,16 @@ async def validate_and_construct_node_execution_input( nodes_input_masks or {}, ) - starting_nodes_input = await _construct_starting_node_execution_input( - graph=graph, - user_id=user_id, - graph_inputs=graph_inputs, - nodes_input_masks=nodes_input_masks, + starting_nodes_input, nodes_to_skip = ( + await _construct_starting_node_execution_input( + graph=graph, + user_id=user_id, + graph_inputs=graph_inputs, + nodes_input_masks=nodes_input_masks, + ) ) - return graph, starting_nodes_input, nodes_input_masks + return graph, starting_nodes_input, nodes_input_masks, nodes_to_skip def _merge_nodes_input_masks( @@ -779,6 +826,9 @@ async def add_graph_execution( # Use existing execution's compiled input masks compiled_nodes_input_masks = graph_exec.nodes_input_masks or {} + # For resumed executions, nodes_to_skip was already determined at creation time + # TODO: Consider storing nodes_to_skip in DB if we need to preserve it across resumes + nodes_to_skip: set[str] = set() logger.info(f"Resuming graph execution #{graph_exec.id} for graph #{graph_id}") else: @@ -787,7 +837,7 @@ async def add_graph_execution( ) # Create new execution - graph, starting_nodes_input, compiled_nodes_input_masks = ( + graph, starting_nodes_input, compiled_nodes_input_masks, nodes_to_skip = ( await validate_and_construct_node_execution_input( graph_id=graph_id, user_id=user_id, @@ -836,6 +886,7 @@ async def add_graph_execution( try: graph_exec_entry = graph_exec.to_graph_execution_entry( compiled_nodes_input_masks=compiled_nodes_input_masks, + nodes_to_skip=nodes_to_skip, execution_context=execution_context, ) logger.info(f"Publishing execution {graph_exec.id} to execution queue") diff --git a/autogpt_platform/backend/backend/executor/utils_test.py b/autogpt_platform/backend/backend/executor/utils_test.py index 8854214e14..0e652f9627 100644 --- a/autogpt_platform/backend/backend/executor/utils_test.py +++ b/autogpt_platform/backend/backend/executor/utils_test.py @@ -367,10 +367,13 @@ async def test_add_graph_execution_is_repeatable(mocker: MockerFixture): ) # Setup mock returns + # The function returns (graph, starting_nodes_input, compiled_nodes_input_masks, nodes_to_skip) + nodes_to_skip: set[str] = set() mock_validate.return_value = ( mock_graph, starting_nodes_input, compiled_nodes_input_masks, + nodes_to_skip, ) mock_prisma.is_connected.return_value = True mock_edb.create_graph_execution = mocker.AsyncMock(return_value=mock_graph_exec) @@ -456,3 +459,212 @@ async def test_add_graph_execution_is_repeatable(mocker: MockerFixture): # Both executions should succeed (though they create different objects) assert result1 == mock_graph_exec assert result2 == mock_graph_exec_2 + + +# ============================================================================ +# Tests for Optional Credentials Feature +# ============================================================================ + + +@pytest.mark.asyncio +async def test_validate_node_input_credentials_returns_nodes_to_skip( + mocker: MockerFixture, +): + """ + Test that _validate_node_input_credentials returns nodes_to_skip set + for nodes with credentials_optional=True and missing credentials. + """ + from backend.executor.utils import _validate_node_input_credentials + + # Create a mock node with credentials_optional=True + mock_node = mocker.MagicMock() + mock_node.id = "node-with-optional-creds" + mock_node.credentials_optional = True + mock_node.input_default = {} # No credentials configured + + # Create a mock block with credentials field + mock_block = mocker.MagicMock() + mock_credentials_field_type = mocker.MagicMock() + mock_block.input_schema.get_credentials_fields.return_value = { + "credentials": mock_credentials_field_type + } + mock_node.block = mock_block + + # Create mock graph + mock_graph = mocker.MagicMock() + mock_graph.nodes = [mock_node] + + # Call the function + errors, nodes_to_skip = await _validate_node_input_credentials( + graph=mock_graph, + user_id="test-user-id", + nodes_input_masks=None, + ) + + # Node should be in nodes_to_skip, not in errors + assert mock_node.id in nodes_to_skip + assert mock_node.id not in errors + + +@pytest.mark.asyncio +async def test_validate_node_input_credentials_required_missing_creds_error( + mocker: MockerFixture, +): + """ + Test that _validate_node_input_credentials returns errors + for nodes with credentials_optional=False and missing credentials. + """ + from backend.executor.utils import _validate_node_input_credentials + + # Create a mock node with credentials_optional=False (required) + mock_node = mocker.MagicMock() + mock_node.id = "node-with-required-creds" + mock_node.credentials_optional = False + mock_node.input_default = {} # No credentials configured + + # Create a mock block with credentials field + mock_block = mocker.MagicMock() + mock_credentials_field_type = mocker.MagicMock() + mock_block.input_schema.get_credentials_fields.return_value = { + "credentials": mock_credentials_field_type + } + mock_node.block = mock_block + + # Create mock graph + mock_graph = mocker.MagicMock() + mock_graph.nodes = [mock_node] + + # Call the function + errors, nodes_to_skip = await _validate_node_input_credentials( + graph=mock_graph, + user_id="test-user-id", + nodes_input_masks=None, + ) + + # Node should be in errors, not in nodes_to_skip + assert mock_node.id in errors + assert "credentials" in errors[mock_node.id] + assert "required" in errors[mock_node.id]["credentials"].lower() + assert mock_node.id not in nodes_to_skip + + +@pytest.mark.asyncio +async def test_validate_graph_with_credentials_returns_nodes_to_skip( + mocker: MockerFixture, +): + """ + Test that validate_graph_with_credentials returns nodes_to_skip set + from _validate_node_input_credentials. + """ + from backend.executor.utils import validate_graph_with_credentials + + # Mock _validate_node_input_credentials to return specific values + mock_validate = mocker.patch( + "backend.executor.utils._validate_node_input_credentials" + ) + expected_errors = {"node1": {"field": "error"}} + expected_nodes_to_skip = {"node2", "node3"} + mock_validate.return_value = (expected_errors, expected_nodes_to_skip) + + # Mock GraphModel with validate_graph_get_errors method + mock_graph = mocker.MagicMock() + mock_graph.validate_graph_get_errors.return_value = {} + + # Call the function + errors, nodes_to_skip = await validate_graph_with_credentials( + graph=mock_graph, + user_id="test-user-id", + nodes_input_masks=None, + ) + + # Verify nodes_to_skip is passed through + assert nodes_to_skip == expected_nodes_to_skip + assert "node1" in errors + + +@pytest.mark.asyncio +async def test_add_graph_execution_with_nodes_to_skip(mocker: MockerFixture): + """ + Test that add_graph_execution properly passes nodes_to_skip + to the graph execution entry. + """ + from backend.data.execution import GraphExecutionWithNodes + from backend.executor.utils import add_graph_execution + + # Mock data + graph_id = "test-graph-id" + user_id = "test-user-id" + inputs = {"test_input": "test_value"} + graph_version = 1 + + # Mock the graph object + mock_graph = mocker.MagicMock() + mock_graph.version = graph_version + + # Starting nodes and masks + starting_nodes_input = [("node1", {"input1": "value1"})] + compiled_nodes_input_masks = {} + nodes_to_skip = {"skipped-node-1", "skipped-node-2"} + + # Mock the graph execution object + mock_graph_exec = mocker.MagicMock(spec=GraphExecutionWithNodes) + mock_graph_exec.id = "execution-id-123" + mock_graph_exec.node_executions = [] + + # Track what's passed to to_graph_execution_entry + captured_kwargs = {} + + def capture_to_entry(**kwargs): + captured_kwargs.update(kwargs) + return mocker.MagicMock() + + mock_graph_exec.to_graph_execution_entry.side_effect = capture_to_entry + + # Setup mocks + mock_validate = mocker.patch( + "backend.executor.utils.validate_and_construct_node_execution_input" + ) + mock_edb = mocker.patch("backend.executor.utils.execution_db") + mock_prisma = mocker.patch("backend.executor.utils.prisma") + mock_udb = mocker.patch("backend.executor.utils.user_db") + mock_gdb = mocker.patch("backend.executor.utils.graph_db") + mock_get_queue = mocker.patch("backend.executor.utils.get_async_execution_queue") + mock_get_event_bus = mocker.patch( + "backend.executor.utils.get_async_execution_event_bus" + ) + + # Setup returns - include nodes_to_skip in the tuple + mock_validate.return_value = ( + mock_graph, + starting_nodes_input, + compiled_nodes_input_masks, + nodes_to_skip, # This should be passed through + ) + mock_prisma.is_connected.return_value = True + mock_edb.create_graph_execution = mocker.AsyncMock(return_value=mock_graph_exec) + mock_edb.update_graph_execution_stats = mocker.AsyncMock( + return_value=mock_graph_exec + ) + mock_edb.update_node_execution_status_batch = mocker.AsyncMock() + + mock_user = mocker.MagicMock() + mock_user.timezone = "UTC" + mock_settings = mocker.MagicMock() + mock_settings.human_in_the_loop_safe_mode = True + + mock_udb.get_user_by_id = mocker.AsyncMock(return_value=mock_user) + mock_gdb.get_graph_settings = mocker.AsyncMock(return_value=mock_settings) + mock_get_queue.return_value = mocker.AsyncMock() + mock_get_event_bus.return_value = mocker.MagicMock(publish=mocker.AsyncMock()) + + # Call the function + await add_graph_execution( + graph_id=graph_id, + user_id=user_id, + inputs=inputs, + graph_version=graph_version, + ) + + # Verify nodes_to_skip was passed to to_graph_execution_entry + assert "nodes_to_skip" in captured_kwargs + assert captured_kwargs["nodes_to_skip"] == nodes_to_skip diff --git a/autogpt_platform/backend/backend/integrations/oauth/__init__.py b/autogpt_platform/backend/backend/integrations/oauth/__init__.py index 137b9eadfd..1ea7a23ee7 100644 --- a/autogpt_platform/backend/backend/integrations/oauth/__init__.py +++ b/autogpt_platform/backend/backend/integrations/oauth/__init__.py @@ -8,6 +8,7 @@ from .discord import DiscordOAuthHandler from .github import GitHubOAuthHandler from .google import GoogleOAuthHandler from .notion import NotionOAuthHandler +from .reddit import RedditOAuthHandler from .twitter import TwitterOAuthHandler if TYPE_CHECKING: @@ -20,6 +21,7 @@ _ORIGINAL_HANDLERS = [ GitHubOAuthHandler, GoogleOAuthHandler, NotionOAuthHandler, + RedditOAuthHandler, TwitterOAuthHandler, TodoistOAuthHandler, ] diff --git a/autogpt_platform/backend/backend/integrations/oauth/reddit.py b/autogpt_platform/backend/backend/integrations/oauth/reddit.py new file mode 100644 index 0000000000..a69e1e62c7 --- /dev/null +++ b/autogpt_platform/backend/backend/integrations/oauth/reddit.py @@ -0,0 +1,208 @@ +import time +import urllib.parse +from typing import ClassVar, Optional + +from pydantic import SecretStr + +from backend.data.model import OAuth2Credentials +from backend.integrations.oauth.base import BaseOAuthHandler +from backend.integrations.providers import ProviderName +from backend.util.request import Requests +from backend.util.settings import Settings + +settings = Settings() + + +class RedditOAuthHandler(BaseOAuthHandler): + """ + Reddit OAuth 2.0 handler. + + Based on the documentation at: + - https://github.com/reddit-archive/reddit/wiki/OAuth2 + + Notes: + - Reddit requires `duration=permanent` to get refresh tokens + - Access tokens expire after 1 hour (3600 seconds) + - Reddit requires HTTP Basic Auth for token requests + - Reddit requires a unique User-Agent header + """ + + PROVIDER_NAME = ProviderName.REDDIT + DEFAULT_SCOPES: ClassVar[list[str]] = [ + "identity", # Get username, verify auth + "read", # Access posts and comments + "submit", # Submit new posts and comments + "edit", # Edit own posts and comments + "history", # Access user's post history + "privatemessages", # Access inbox and send private messages + "flair", # Access and set flair on posts/subreddits + ] + + AUTHORIZE_URL = "https://www.reddit.com/api/v1/authorize" + TOKEN_URL = "https://www.reddit.com/api/v1/access_token" + USERNAME_URL = "https://oauth.reddit.com/api/v1/me" + REVOKE_URL = "https://www.reddit.com/api/v1/revoke_token" + + def __init__(self, client_id: str, client_secret: str, redirect_uri: str): + self.client_id = client_id + self.client_secret = client_secret + self.redirect_uri = redirect_uri + + def get_login_url( + self, scopes: list[str], state: str, code_challenge: Optional[str] + ) -> str: + """Generate Reddit OAuth 2.0 authorization URL""" + scopes = self.handle_default_scopes(scopes) + + params = { + "response_type": "code", + "client_id": self.client_id, + "redirect_uri": self.redirect_uri, + "scope": " ".join(scopes), + "state": state, + "duration": "permanent", # Required for refresh tokens + } + + return f"{self.AUTHORIZE_URL}?{urllib.parse.urlencode(params)}" + + async def exchange_code_for_tokens( + self, code: str, scopes: list[str], code_verifier: Optional[str] + ) -> OAuth2Credentials: + """Exchange authorization code for access tokens""" + scopes = self.handle_default_scopes(scopes) + + headers = { + "Content-Type": "application/x-www-form-urlencoded", + "User-Agent": settings.config.reddit_user_agent, + } + + data = { + "grant_type": "authorization_code", + "code": code, + "redirect_uri": self.redirect_uri, + } + + # Reddit requires HTTP Basic Auth for token requests + auth = (self.client_id, self.client_secret) + + response = await Requests().post( + self.TOKEN_URL, headers=headers, data=data, auth=auth + ) + + if not response.ok: + error_text = response.text() + raise ValueError( + f"Reddit token exchange failed: {response.status} - {error_text}" + ) + + tokens = response.json() + + if "error" in tokens: + raise ValueError(f"Reddit OAuth error: {tokens.get('error')}") + + username = await self._get_username(tokens["access_token"]) + + return OAuth2Credentials( + provider=self.PROVIDER_NAME, + title=None, + username=username, + access_token=tokens["access_token"], + refresh_token=tokens.get("refresh_token"), + access_token_expires_at=int(time.time()) + tokens.get("expires_in", 3600), + refresh_token_expires_at=None, # Reddit refresh tokens don't expire + scopes=scopes, + ) + + async def _get_username(self, access_token: str) -> str: + """Get the username from the access token""" + headers = { + "Authorization": f"Bearer {access_token}", + "User-Agent": settings.config.reddit_user_agent, + } + + response = await Requests().get(self.USERNAME_URL, headers=headers) + + if not response.ok: + raise ValueError(f"Failed to get Reddit username: {response.status}") + + data = response.json() + return data.get("name", "unknown") + + async def _refresh_tokens( + self, credentials: OAuth2Credentials + ) -> OAuth2Credentials: + """Refresh access tokens using refresh token""" + if not credentials.refresh_token: + raise ValueError("No refresh token available") + + headers = { + "Content-Type": "application/x-www-form-urlencoded", + "User-Agent": settings.config.reddit_user_agent, + } + + data = { + "grant_type": "refresh_token", + "refresh_token": credentials.refresh_token.get_secret_value(), + } + + auth = (self.client_id, self.client_secret) + + response = await Requests().post( + self.TOKEN_URL, headers=headers, data=data, auth=auth + ) + + if not response.ok: + error_text = response.text() + raise ValueError( + f"Reddit token refresh failed: {response.status} - {error_text}" + ) + + tokens = response.json() + + if "error" in tokens: + raise ValueError(f"Reddit OAuth error: {tokens.get('error')}") + + username = await self._get_username(tokens["access_token"]) + + # Reddit may or may not return a new refresh token + new_refresh_token = tokens.get("refresh_token") + if new_refresh_token: + refresh_token: SecretStr | None = SecretStr(new_refresh_token) + elif credentials.refresh_token: + # Keep the existing refresh token + refresh_token = credentials.refresh_token + else: + refresh_token = None + + return OAuth2Credentials( + id=credentials.id, + provider=self.PROVIDER_NAME, + title=credentials.title, + username=username, + access_token=tokens["access_token"], + refresh_token=refresh_token, + access_token_expires_at=int(time.time()) + tokens.get("expires_in", 3600), + refresh_token_expires_at=None, + scopes=credentials.scopes, + ) + + async def revoke_tokens(self, credentials: OAuth2Credentials) -> bool: + """Revoke the access token""" + headers = { + "Content-Type": "application/x-www-form-urlencoded", + "User-Agent": settings.config.reddit_user_agent, + } + + data = { + "token": credentials.access_token.get_secret_value(), + "token_type_hint": "access_token", + } + + auth = (self.client_id, self.client_secret) + + response = await Requests().post( + self.REVOKE_URL, headers=headers, data=data, auth=auth + ) + + # Reddit returns 204 No Content on successful revocation + return response.ok diff --git a/autogpt_platform/backend/backend/util/settings.py b/autogpt_platform/backend/backend/util/settings.py index 0f17b1215c..7a51200eaf 100644 --- a/autogpt_platform/backend/backend/util/settings.py +++ b/autogpt_platform/backend/backend/util/settings.py @@ -264,7 +264,7 @@ class Config(UpdateTrackingModel["Config"], BaseSettings): ) reddit_user_agent: str = Field( - default="AutoGPT:1.0 (by /u/autogpt)", + default="web:AutoGPT:v0.6.0 (by /u/autogpt)", description="The user agent for the Reddit API", ) diff --git a/autogpt_platform/backend/gen_prisma_types_stub.py b/autogpt_platform/backend/gen_prisma_types_stub.py new file mode 100644 index 0000000000..3f6073b2ff --- /dev/null +++ b/autogpt_platform/backend/gen_prisma_types_stub.py @@ -0,0 +1,227 @@ +#!/usr/bin/env python3 +""" +Generate a lightweight stub for prisma/types.py that collapses all exported +symbols to Any. This prevents Pyright from spending time/budget on Prisma's +query DSL types while keeping runtime behavior unchanged. + +Usage: + poetry run gen-prisma-stub + +This script automatically finds the prisma package location and generates +the types.pyi stub file in the same directory as types.py. +""" + +from __future__ import annotations + +import ast +import importlib.util +import sys +from pathlib import Path +from typing import Iterable, Set + + +def _iter_assigned_names(target: ast.expr) -> Iterable[str]: + """Extract names from assignment targets (handles tuple unpacking).""" + if isinstance(target, ast.Name): + yield target.id + elif isinstance(target, (ast.Tuple, ast.List)): + for elt in target.elts: + yield from _iter_assigned_names(elt) + + +def _is_private(name: str) -> bool: + """Check if a name is private (starts with _ but not __).""" + return name.startswith("_") and not name.startswith("__") + + +def _is_safe_type_alias(node: ast.Assign) -> bool: + """Check if an assignment is a safe type alias that shouldn't be stubbed. + + Safe types are: + - Literal types (don't cause type budget issues) + - Simple type references (SortMode, SortOrder, etc.) + - TypeVar definitions + """ + if not node.value: + return False + + # Check if it's a Subscript (like Literal[...], Union[...], TypeVar[...]) + if isinstance(node.value, ast.Subscript): + # Get the base type name + if isinstance(node.value.value, ast.Name): + base_name = node.value.value.id + # Literal types are safe + if base_name == "Literal": + return True + # TypeVar is safe + if base_name == "TypeVar": + return True + elif isinstance(node.value.value, ast.Attribute): + # Handle typing_extensions.Literal etc. + if node.value.value.attr == "Literal": + return True + + # Check if it's a simple Name reference (like SortMode = _types.SortMode) + if isinstance(node.value, ast.Attribute): + return True + + # Check if it's a Call (like TypeVar(...)) + if isinstance(node.value, ast.Call): + if isinstance(node.value.func, ast.Name): + if node.value.func.id == "TypeVar": + return True + + return False + + +def collect_top_level_symbols( + tree: ast.Module, source_lines: list[str] +) -> tuple[Set[str], Set[str], list[str], Set[str]]: + """Collect all top-level symbols from an AST module. + + Returns: + Tuple of (class_names, function_names, safe_variable_sources, unsafe_variable_names) + safe_variable_sources contains the actual source code lines for safe variables + """ + classes: Set[str] = set() + functions: Set[str] = set() + safe_variable_sources: list[str] = [] + unsafe_variables: Set[str] = set() + + for node in tree.body: + if isinstance(node, ast.ClassDef): + if not _is_private(node.name): + classes.add(node.name) + elif isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): + if not _is_private(node.name): + functions.add(node.name) + elif isinstance(node, ast.Assign): + is_safe = _is_safe_type_alias(node) + names = [] + for t in node.targets: + for n in _iter_assigned_names(t): + if not _is_private(n): + names.append(n) + if names: + if is_safe: + # Extract the source code for this assignment + start_line = node.lineno - 1 # 0-indexed + end_line = node.end_lineno if node.end_lineno else node.lineno + source = "\n".join(source_lines[start_line:end_line]) + safe_variable_sources.append(source) + else: + unsafe_variables.update(names) + elif isinstance(node, ast.AnnAssign) and node.target: + # Annotated assignments are always stubbed + for n in _iter_assigned_names(node.target): + if not _is_private(n): + unsafe_variables.add(n) + + return classes, functions, safe_variable_sources, unsafe_variables + + +def find_prisma_types_path() -> Path: + """Find the prisma types.py file in the installed package.""" + spec = importlib.util.find_spec("prisma") + if spec is None or spec.origin is None: + raise RuntimeError("Could not find prisma package. Is it installed?") + + prisma_dir = Path(spec.origin).parent + types_path = prisma_dir / "types.py" + + if not types_path.exists(): + raise RuntimeError(f"prisma/types.py not found at {types_path}") + + return types_path + + +def generate_stub(src_path: Path, stub_path: Path) -> int: + """Generate the .pyi stub file from the source types.py.""" + code = src_path.read_text(encoding="utf-8", errors="ignore") + source_lines = code.splitlines() + tree = ast.parse(code, filename=str(src_path)) + classes, functions, safe_variable_sources, unsafe_variables = ( + collect_top_level_symbols(tree, source_lines) + ) + + header = """\ +# -*- coding: utf-8 -*- +# Auto-generated stub file - DO NOT EDIT +# Generated by gen_prisma_types_stub.py +# +# This stub intentionally collapses complex Prisma query DSL types to Any. +# Prisma's generated types can explode Pyright's type inference budgets +# on large schemas. We collapse them to Any so the rest of the codebase +# can remain strongly typed while keeping runtime behavior unchanged. +# +# Safe types (Literal, TypeVar, simple references) are preserved from the +# original types.py to maintain proper type checking where possible. + +from __future__ import annotations +from typing import Any +from typing_extensions import Literal + +# Re-export commonly used typing constructs that may be imported from this module +from typing import TYPE_CHECKING, TypeVar, Generic, Union, Optional, List, Dict + +# Base type alias for stubbed Prisma types - allows any dict structure +_PrismaDict = dict[str, Any] + +""" + + lines = [header] + + # Include safe variable definitions (Literal types, TypeVars, etc.) + lines.append("# Safe type definitions preserved from original types.py") + for source in safe_variable_sources: + lines.append(source) + lines.append("") + + # Stub all classes and unsafe variables uniformly as dict[str, Any] aliases + # This allows: + # 1. Use in type annotations: x: SomeType + # 2. Constructor calls: SomeType(...) + # 3. Dict literal assignments: x: SomeType = {...} + lines.append( + "# Stubbed types (collapsed to dict[str, Any] to prevent type budget exhaustion)" + ) + all_stubbed = sorted(classes | unsafe_variables) + for name in all_stubbed: + lines.append(f"{name} = _PrismaDict") + + lines.append("") + + # Stub functions + for name in sorted(functions): + lines.append(f"def {name}(*args: Any, **kwargs: Any) -> Any: ...") + + lines.append("") + + stub_path.write_text("\n".join(lines), encoding="utf-8") + return ( + len(classes) + + len(functions) + + len(safe_variable_sources) + + len(unsafe_variables) + ) + + +def main() -> None: + """Main entry point.""" + try: + types_path = find_prisma_types_path() + stub_path = types_path.with_suffix(".pyi") + + print(f"Found prisma types.py at: {types_path}") + print(f"Generating stub at: {stub_path}") + + num_symbols = generate_stub(types_path, stub_path) + print(f"Generated {stub_path.name} with {num_symbols} Any-typed symbols") + + except Exception as e: + print(f"Error: {e}", file=sys.stderr) + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/autogpt_platform/backend/linter.py b/autogpt_platform/backend/linter.py index a86e6761f7..599aae4580 100644 --- a/autogpt_platform/backend/linter.py +++ b/autogpt_platform/backend/linter.py @@ -25,6 +25,9 @@ def run(*command: str) -> None: def lint(): + # Generate Prisma types stub before running pyright to prevent type budget exhaustion + run("gen-prisma-stub") + lint_step_args: list[list[str]] = [ ["ruff", "check", *TARGET_DIRS, "--exit-zero"], ["ruff", "format", "--diff", "--check", LIBS_DIR], @@ -49,4 +52,6 @@ def format(): run("ruff", "format", LIBS_DIR) run("isort", "--profile", "black", BACKEND_DIR) run("black", BACKEND_DIR) + # Generate Prisma types stub before running pyright to prevent type budget exhaustion + run("gen-prisma-stub") run("pyright", *TARGET_DIRS) diff --git a/autogpt_platform/backend/pyproject.toml b/autogpt_platform/backend/pyproject.toml index e8b8fd0ba5..21bf15e776 100644 --- a/autogpt_platform/backend/pyproject.toml +++ b/autogpt_platform/backend/pyproject.toml @@ -117,6 +117,7 @@ lint = "linter:lint" test = "run_tests:test" load-store-agents = "test.load_store_agents:run" export-api-schema = "backend.cli.generate_openapi_json:main" +gen-prisma-stub = "gen_prisma_types_stub:main" oauth-tool = "backend.cli.oauth_tool:cli" [tool.isort] @@ -134,6 +135,9 @@ ignore_patterns = [] [tool.pytest.ini_options] asyncio_mode = "auto" asyncio_default_fixture_loop_scope = "session" +# Disable syrupy plugin to avoid conflict with pytest-snapshot +# Both provide --snapshot-update argument causing ArgumentError +addopts = "-p no:syrupy" filterwarnings = [ "ignore:'audioop' is deprecated:DeprecationWarning:discord.player", "ignore:invalid escape sequence:DeprecationWarning:tweepy.api", diff --git a/autogpt_platform/backend/snapshots/grph_single b/autogpt_platform/backend/snapshots/grph_single index 7ba26f6171..7ce8695e6b 100644 --- a/autogpt_platform/backend/snapshots/grph_single +++ b/autogpt_platform/backend/snapshots/grph_single @@ -2,6 +2,7 @@ "created_at": "2025-09-04T13:37:00", "credentials_input_schema": { "properties": {}, + "required": [], "title": "TestGraphCredentialsInputSchema", "type": "object" }, diff --git a/autogpt_platform/backend/snapshots/grphs_all b/autogpt_platform/backend/snapshots/grphs_all index d54df2bc18..f69b45a6de 100644 --- a/autogpt_platform/backend/snapshots/grphs_all +++ b/autogpt_platform/backend/snapshots/grphs_all @@ -2,6 +2,7 @@ { "credentials_input_schema": { "properties": {}, + "required": [], "title": "TestGraphCredentialsInputSchema", "type": "object" }, diff --git a/autogpt_platform/backend/snapshots/lib_agts_search b/autogpt_platform/backend/snapshots/lib_agts_search index d1feb7d16d..c8e3cc73a6 100644 --- a/autogpt_platform/backend/snapshots/lib_agts_search +++ b/autogpt_platform/backend/snapshots/lib_agts_search @@ -4,6 +4,7 @@ "id": "test-agent-1", "graph_id": "test-agent-1", "graph_version": 1, + "owner_user_id": "3e53486c-cf57-477e-ba2a-cb02dc828e1a", "image_url": null, "creator_name": "Test Creator", "creator_image_url": "", @@ -41,6 +42,7 @@ "id": "test-agent-2", "graph_id": "test-agent-2", "graph_version": 1, + "owner_user_id": "3e53486c-cf57-477e-ba2a-cb02dc828e1a", "image_url": null, "creator_name": "Test Creator", "creator_image_url": "", diff --git a/autogpt_platform/backend/snapshots/sub_success b/autogpt_platform/backend/snapshots/sub_success index 13e2ec570d..268d577745 100644 --- a/autogpt_platform/backend/snapshots/sub_success +++ b/autogpt_platform/backend/snapshots/sub_success @@ -1,6 +1,7 @@ { "submissions": [ { + "listing_id": "test-listing-id", "agent_id": "test-agent-id", "agent_version": 1, "name": "Test Agent", diff --git a/autogpt_platform/docker-compose.platform.yml b/autogpt_platform/docker-compose.platform.yml index b2df626029..de6ecfd612 100644 --- a/autogpt_platform/docker-compose.platform.yml +++ b/autogpt_platform/docker-compose.platform.yml @@ -37,7 +37,7 @@ services: context: ../ dockerfile: autogpt_platform/backend/Dockerfile target: migrate - command: ["sh", "-c", "poetry run prisma generate && poetry run prisma migrate deploy"] + command: ["sh", "-c", "poetry run prisma generate && poetry run gen-prisma-stub && poetry run prisma migrate deploy"] develop: watch: - path: ./ diff --git a/autogpt_platform/frontend/public/integrations/webshare_proxy.png b/autogpt_platform/frontend/public/integrations/webshare_proxy.png new file mode 100644 index 0000000000..2b07ef8415 Binary files /dev/null and b/autogpt_platform/frontend/public/integrations/webshare_proxy.png differ diff --git a/autogpt_platform/frontend/public/integrations/wordpress.png b/autogpt_platform/frontend/public/integrations/wordpress.png new file mode 100644 index 0000000000..b8ba8bd3ff Binary files /dev/null and b/autogpt_platform/frontend/public/integrations/wordpress.png differ diff --git a/autogpt_platform/frontend/src/app/(platform)/build/components/BuilderActions/components/RunInputDialog/RunInputDialog.tsx b/autogpt_platform/frontend/src/app/(platform)/build/components/BuilderActions/components/RunInputDialog/RunInputDialog.tsx index 431feeaade..bd08aa8ee0 100644 --- a/autogpt_platform/frontend/src/app/(platform)/build/components/BuilderActions/components/RunInputDialog/RunInputDialog.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/build/components/BuilderActions/components/RunInputDialog/RunInputDialog.tsx @@ -66,6 +66,7 @@ export const RunInputDialog = ({ formContext={{ showHandles: false, size: "large", + showOptionalToggle: false, }} /> diff --git a/autogpt_platform/frontend/src/app/(platform)/build/components/BuilderActions/components/RunInputDialog/useRunInputDialog.ts b/autogpt_platform/frontend/src/app/(platform)/build/components/BuilderActions/components/RunInputDialog/useRunInputDialog.ts index a71ad0bd07..ddd77bae48 100644 --- a/autogpt_platform/frontend/src/app/(platform)/build/components/BuilderActions/components/RunInputDialog/useRunInputDialog.ts +++ b/autogpt_platform/frontend/src/app/(platform)/build/components/BuilderActions/components/RunInputDialog/useRunInputDialog.ts @@ -66,7 +66,7 @@ export const useRunInputDialog = ({ if (isCredentialFieldSchema(fieldSchema)) { dynamicUiSchema[fieldName] = { ...dynamicUiSchema[fieldName], - "ui:field": "credentials", + "ui:field": "custom/credential_field", }; } }); @@ -76,12 +76,18 @@ export const useRunInputDialog = ({ }, [credentialsSchema]); const handleManualRun = async () => { + // Filter out incomplete credentials (those without a valid id) + // RJSF auto-populates const values (provider, type) but not id field + const validCredentials = Object.fromEntries( + Object.entries(credentialValues).filter(([_, cred]) => cred && cred.id), + ); + await executeGraph({ graphId: flowID ?? "", graphVersion: flowVersion || null, data: { inputs: inputValues, - credentials_inputs: credentialValues, + credentials_inputs: validCredentials, source: "builder", }, }); diff --git a/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/components/NodeHeader.tsx b/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/components/NodeHeader.tsx index 5943986d30..e13aa37a31 100644 --- a/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/components/NodeHeader.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/components/NodeHeader.tsx @@ -68,7 +68,10 @@ export const NodeHeader = ({ data, nodeId }: Props) => {
- + {beautifyString(title).replace("Block", "").trim()}
diff --git a/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/components/NodeOutput/components/NodeDataViewer/NodeDataViewer.tsx b/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/components/NodeOutput/components/NodeDataViewer/NodeDataViewer.tsx index c505282e7b..31b89315d6 100644 --- a/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/components/NodeOutput/components/NodeDataViewer/NodeDataViewer.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/components/NodeOutput/components/NodeDataViewer/NodeDataViewer.tsx @@ -151,7 +151,7 @@ export const NodeDataViewer: FC = ({
- {outputItems.length > 0 && ( + {outputItems.length > 1 && ( ({ value: item.value, diff --git a/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/helpers.ts b/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/helpers.ts index 39384485f5..46032a67ea 100644 --- a/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/helpers.ts +++ b/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/helpers.ts @@ -89,6 +89,18 @@ export function extractOptions( // get display type and color for schema types [need for type display next to field name] export const getTypeDisplayInfo = (schema: any) => { + if ( + schema?.type === "array" && + "format" in schema && + schema.format === "table" + ) { + return { + displayType: "table", + colorClass: "!text-indigo-500", + hexColor: "#6366f1", + }; + } + if (schema?.type === "string" && schema?.format) { const formatMap: Record< string, diff --git a/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/uiSchema.ts b/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/uiSchema.ts index ad1fab7c95..065e697828 100644 --- a/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/uiSchema.ts +++ b/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/uiSchema.ts @@ -1,6 +1,6 @@ export const uiSchema = { credentials: { - "ui:field": "credentials", + "ui:field": "custom/credential_field", provider: { "ui:widget": "hidden" }, type: { "ui:widget": "hidden" }, id: { "ui:autofocus": true }, diff --git a/autogpt_platform/frontend/src/app/(platform)/build/stores/nodeStore.ts b/autogpt_platform/frontend/src/app/(platform)/build/stores/nodeStore.ts index 96478c5b6f..c151f90faa 100644 --- a/autogpt_platform/frontend/src/app/(platform)/build/stores/nodeStore.ts +++ b/autogpt_platform/frontend/src/app/(platform)/build/stores/nodeStore.ts @@ -68,6 +68,9 @@ type NodeStore = { clearAllNodeErrors: () => void; // Add this syncHardcodedValuesWithHandleIds: (nodeId: string) => void; + + // Credentials optional helpers + setCredentialsOptional: (nodeId: string, optional: boolean) => void; }; export const useNodeStore = create((set, get) => ({ @@ -226,6 +229,9 @@ export const useNodeStore = create((set, get) => ({ ...(node.data.metadata?.customized_name !== undefined && { customized_name: node.data.metadata.customized_name, }), + ...(node.data.metadata?.credentials_optional !== undefined && { + credentials_optional: node.data.metadata.credentials_optional, + }), }, }; }, @@ -342,4 +348,30 @@ export const useNodeStore = create((set, get) => ({ })); } }, + + setCredentialsOptional: (nodeId: string, optional: boolean) => { + set((state) => ({ + nodes: state.nodes.map((n) => + n.id === nodeId + ? { + ...n, + data: { + ...n.data, + metadata: { + ...n.data.metadata, + credentials_optional: optional, + }, + }, + } + : n, + ), + })); + + const newState = { + nodes: get().nodes, + edges: useEdgeStore.getState().edges, + }; + + useHistoryStore.getState().pushState(newState); + }, })); diff --git a/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/CredentialsInputs/CredentialsInputs.tsx b/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/CredentialsInputs/CredentialsInputs.tsx index 60d61fab57..a0f9376aa2 100644 --- a/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/CredentialsInputs/CredentialsInputs.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/CredentialsInputs/CredentialsInputs.tsx @@ -34,7 +34,9 @@ type Props = { onSelectCredentials: (newValue?: CredentialsMetaInput) => void; onLoaded?: (loaded: boolean) => void; readOnly?: boolean; + isOptional?: boolean; showTitle?: boolean; + variant?: "default" | "node"; }; export function CredentialsInput({ @@ -45,7 +47,9 @@ export function CredentialsInput({ siblingInputs, onLoaded, readOnly = false, + isOptional = false, showTitle = true, + variant = "default", }: Props) { const hookData = useCredentialsInput({ schema, @@ -54,6 +58,7 @@ export function CredentialsInput({ siblingInputs, onLoaded, readOnly, + isOptional, }); if (!isLoaded(hookData)) { @@ -94,7 +99,14 @@ export function CredentialsInput({
{showTitle && (
- {displayName} credentials + + {displayName} credentials + {isOptional && ( + + (optional) + + )} + {schema.description && ( )} @@ -103,14 +115,17 @@ export function CredentialsInput({ {hasCredentialsToShow ? ( <> - {credentialsToShow.length > 1 && !readOnly ? ( + {(credentialsToShow.length > 1 || isOptional) && !readOnly ? ( onSelectCredential(undefined)} readOnly={readOnly} + allowNone={isOptional} + variant={variant} /> ) : (
diff --git a/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/CredentialsInputs/components/CredentialRow/CredentialRow.tsx b/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/CredentialsInputs/components/CredentialRow/CredentialRow.tsx index 21ec1200e4..2d0358aacb 100644 --- a/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/CredentialsInputs/components/CredentialRow/CredentialRow.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/CredentialsInputs/components/CredentialRow/CredentialRow.tsx @@ -30,6 +30,8 @@ type CredentialRowProps = { readOnly?: boolean; showCaret?: boolean; asSelectTrigger?: boolean; + /** When "node", applies compact styling for node context */ + variant?: "default" | "node"; }; export function CredentialRow({ @@ -41,14 +43,22 @@ export function CredentialRow({ readOnly = false, showCaret = false, asSelectTrigger = false, + variant = "default", }: CredentialRowProps) { const ProviderIcon = providerIcons[provider] || fallbackIcon; + const isNodeVariant = variant === "node"; return (
-
+
{getCredentialDisplayName(credential, displayName)} - - {"*".repeat(MASKED_KEY_LENGTH)} - + {!(asSelectTrigger && isNodeVariant) && ( + + {"*".repeat(MASKED_KEY_LENGTH)} + + )}
{showCaret && !asSelectTrigger && ( diff --git a/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/CredentialsInputs/components/CredentialsSelect/CredentialsSelect.tsx b/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/CredentialsInputs/components/CredentialsSelect/CredentialsSelect.tsx index 7adfa5772b..6e1ec2afb1 100644 --- a/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/CredentialsInputs/components/CredentialsSelect/CredentialsSelect.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/CredentialsInputs/components/CredentialsSelect/CredentialsSelect.tsx @@ -7,6 +7,7 @@ import { } from "@/components/__legacy__/ui/select"; import { Text } from "@/components/atoms/Text/Text"; import { CredentialsMetaInput } from "@/lib/autogpt-server-api/types"; +import { cn } from "@/lib/utils"; import { useEffect } from "react"; import { getCredentialDisplayName } from "../../helpers"; import { CredentialRow } from "../CredentialRow/CredentialRow"; @@ -23,7 +24,11 @@ interface Props { displayName: string; selectedCredentials?: CredentialsMetaInput; onSelectCredential: (credentialId: string) => void; + onClearCredential?: () => void; readOnly?: boolean; + allowNone?: boolean; + /** When "node", applies compact styling for node context */ + variant?: "default" | "node"; } export function CredentialsSelect({ @@ -32,22 +37,38 @@ export function CredentialsSelect({ displayName, selectedCredentials, onSelectCredential, + onClearCredential, readOnly = false, + allowNone = true, + variant = "default", }: Props) { - // Auto-select first credential if none is selected + // Auto-select first credential if none is selected (only if allowNone is false) useEffect(() => { - if (!selectedCredentials && credentials.length > 0) { + if (!allowNone && !selectedCredentials && credentials.length > 0) { onSelectCredential(credentials[0].id); } - }, [selectedCredentials, credentials, onSelectCredential]); + }, [allowNone, selectedCredentials, credentials, onSelectCredential]); + + const handleValueChange = (value: string) => { + if (value === "__none__") { + onClearCredential?.(); + } else { + onSelectCredential(value); + } + }; return (
void; + onClose: () => void; } export const useEditAgentForm = ({ submission, onSuccess, + onClose, }: useEditAgentFormProps) => { const editAgentSchema = z.object({ title: z @@ -45,7 +47,7 @@ export const useEditAgentForm = ({ changes_summary: z .string() .min(1, "Changes summary is required") - .max(200, "Changes summary must be less than 200 characters"), + .max(500, "Changes summary must be less than 500 characters"), agentOutputDemo: z .string() .refine(validateYouTubeUrl, "Please enter a valid YouTube URL"), @@ -54,19 +56,11 @@ export const useEditAgentForm = ({ type EditAgentFormData = z.infer; const [images, setImages] = React.useState( - submission.image_urls || [], + Array.from(new Set(submission.image_urls || [])), // Remove duplicates ); const [isSubmitting, setIsSubmitting] = React.useState(false); - const { mutateAsync: editSubmission } = usePutV2EditStoreSubmission({ - mutation: { - onSuccess: () => { - queryClient.invalidateQueries({ - queryKey: getGetV2ListMySubmissionsQueryKey(), - }); - }, - }, - }); + const { mutateAsync: editSubmission } = usePutV2EditStoreSubmission(); const queryClient = useQueryClient(); const { toast } = useToast(); @@ -132,7 +126,20 @@ export const useEditAgentForm = ({ // Extract the StoreSubmission from the response if (response.status === 200 && response.data) { + toast({ + title: "Agent Updated", + description: "Your agent submission has been updated successfully.", + duration: 3000, + variant: "default", + }); + + queryClient.invalidateQueries({ + queryKey: getGetV2ListMySubmissionsQueryKey(), + }); + + // Call onSuccess and explicitly close the modal onSuccess(response.data); + onClose(); } else { throw new Error("Failed to update submission"); } diff --git a/autogpt_platform/frontend/src/components/contextual/PublishAgentModal/components/AgentInfoStep/useAgentInfoStep.ts b/autogpt_platform/frontend/src/components/contextual/PublishAgentModal/components/AgentInfoStep/useAgentInfoStep.ts index f3dcfa1f21..448d65b195 100644 --- a/autogpt_platform/frontend/src/components/contextual/PublishAgentModal/components/AgentInfoStep/useAgentInfoStep.ts +++ b/autogpt_platform/frontend/src/components/contextual/PublishAgentModal/components/AgentInfoStep/useAgentInfoStep.ts @@ -1,10 +1,8 @@ import { useEffect, useCallback, useState } from "react"; import { useForm } from "react-hook-form"; import { zodResolver } from "@hookform/resolvers/zod"; -import { useQueryClient } from "@tanstack/react-query"; import { useToast } from "@/components/molecules/Toast/use-toast"; import { useBackendAPI } from "@/lib/autogpt-server-api/context"; -import { getGetV2ListMySubmissionsQueryKey } from "@/app/api/__generated__/endpoints/store/store"; import * as Sentry from "@sentry/nextjs"; import { PublishAgentFormData, @@ -33,7 +31,6 @@ export function useAgentInfoStep({ const [images, setImages] = useState([]); const [isSubmitting, setIsSubmitting] = useState(false); - const queryClient = useQueryClient(); const { toast } = useToast(); const api = useBackendAPI(); @@ -54,23 +51,26 @@ export function useAgentInfoStep({ }); useEffect(() => { - if (initialData) { + if (initialData?.agent_id) { setAgentId(initialData.agent_id); - const initialImages = [ - ...(initialData?.thumbnailSrc ? [initialData.thumbnailSrc] : []), - ...(initialData.additionalImages || []), - ]; - setImages(initialImages); - - // Update form with initial data + setImages( + Array.from( + new Set([ + ...(initialData?.thumbnailSrc ? [initialData.thumbnailSrc] : []), + ...(initialData.additionalImages || []), + ]), + ), + ); form.reset({ - changesSummary: initialData.changesSummary || "", + changesSummary: isMarketplaceUpdate + ? "" + : initialData.changesSummary || "", title: initialData.title, subheader: initialData.subheader, slug: initialData.slug.toLocaleLowerCase().trim(), youtubeLink: initialData.youtubeLink, category: initialData.category, - description: initialData.description, + description: isMarketplaceUpdate ? "" : initialData.description, recommendedScheduleCron: initialData.recommendedScheduleCron || "", instructions: initialData.instructions || "", agentOutputDemo: initialData.agentOutputDemo || "", @@ -78,6 +78,13 @@ export function useAgentInfoStep({ } }, [initialData, form]); + // Ensure agentId is set from selectedAgentId if initialData doesn't have it + useEffect(() => { + if (selectedAgentId && !agentId) { + setAgentId(selectedAgentId); + } + }, [selectedAgentId, agentId]); + const handleImagesChange = useCallback((newImages: string[]) => { setImages(newImages); }, []); @@ -92,6 +99,16 @@ export function useAgentInfoStep({ return; } + // Validate that an agent is selected before submission + if (!selectedAgentId || !selectedAgentVersion) { + toast({ + title: "Agent Selection Required", + description: "Please select an agent before submitting to the store.", + variant: "destructive", + }); + return; + } + const categories = data.category ? [data.category] : []; const filteredCategories = categories.filter(Boolean); @@ -106,18 +123,14 @@ export function useAgentInfoStep({ image_urls: images, video_url: data.youtubeLink || "", agent_output_demo_url: data.agentOutputDemo || "", - agent_id: selectedAgentId || "", - agent_version: selectedAgentVersion || 0, + agent_id: selectedAgentId, + agent_version: selectedAgentVersion, slug: (data.slug || "").replace(/\s+/g, "-"), categories: filteredCategories, recommended_schedule_cron: data.recommendedScheduleCron || null, changes_summary: data.changesSummary || null, } as any); - await queryClient.invalidateQueries({ - queryKey: getGetV2ListMySubmissionsQueryKey(), - }); - onSuccess(response); } catch (error) { Sentry.captureException(error); @@ -139,12 +152,7 @@ export function useAgentInfoStep({ agentId, images, isSubmitting, - initialImages: initialData - ? [ - ...(initialData?.thumbnailSrc ? [initialData.thumbnailSrc] : []), - ...(initialData.additionalImages || []), - ] - : [], + initialImages: images, initialSelectedImage: initialData?.thumbnailSrc || null, handleImagesChange, handleSubmit: form.handleSubmit(handleFormSubmit), diff --git a/autogpt_platform/frontend/src/components/contextual/PublishAgentModal/usePublishAgentModal.ts b/autogpt_platform/frontend/src/components/contextual/PublishAgentModal/usePublishAgentModal.ts index 0f8a819c6e..69bbb6c866 100644 --- a/autogpt_platform/frontend/src/components/contextual/PublishAgentModal/usePublishAgentModal.ts +++ b/autogpt_platform/frontend/src/components/contextual/PublishAgentModal/usePublishAgentModal.ts @@ -6,9 +6,11 @@ import { emptyModalState } from "./helpers"; import { useGetV2GetMyAgents, useGetV2ListMySubmissions, + getGetV2ListMySubmissionsQueryKey, } from "@/app/api/__generated__/endpoints/store/store"; import { okData } from "@/app/api/helpers"; import type { MyAgent } from "@/app/api/__generated__/models/myAgent"; +import { useQueryClient } from "@tanstack/react-query"; const defaultTargetState: PublishState = { isOpen: false, @@ -65,6 +67,7 @@ export function usePublishAgentModal({ >(preSelectedAgentVersion || null); const router = useRouter(); + const queryClient = useQueryClient(); // Fetch agent data for pre-populating form when agent is pre-selected const { data: myAgents } = useGetV2GetMyAgents(); @@ -77,14 +80,18 @@ export function usePublishAgentModal({ } }, [targetState]); - // Reset internal state when modal opens + // Reset internal state when modal opens (only on initial open, not on every targetState change) + const [hasOpened, setHasOpened] = useState(false); useEffect(() => { if (!targetState) return; - if (targetState.isOpen) { + if (targetState.isOpen && !hasOpened) { setSelectedAgent(null); setSelectedAgentId(preSelectedAgentId || null); setSelectedAgentVersion(preSelectedAgentVersion || null); setInitialData(emptyModalState); + setHasOpened(true); + } else if (!targetState.isOpen && hasOpened) { + setHasOpened(false); } }, [targetState, preSelectedAgentId, preSelectedAgentVersion]); @@ -172,6 +179,11 @@ export function usePublishAgentModal({ setSelectedAgentVersion(null); setInitialData(emptyModalState); + // Invalidate submissions query to refresh the data after modal closes + queryClient.invalidateQueries({ + queryKey: getGetV2ListMySubmissionsQueryKey(), + }); + // Update parent with clean closed state const newState = { isOpen: false, diff --git a/autogpt_platform/frontend/src/components/molecules/Table/Table.stories.tsx b/autogpt_platform/frontend/src/components/molecules/Table/Table.stories.tsx new file mode 100644 index 0000000000..6dfb0b378f --- /dev/null +++ b/autogpt_platform/frontend/src/components/molecules/Table/Table.stories.tsx @@ -0,0 +1,116 @@ +import type { Meta, StoryObj } from "@storybook/nextjs"; +import { TooltipProvider } from "@/components/atoms/Tooltip/BaseTooltip"; +import { Table } from "./Table"; + +const meta = { + title: "Molecules/Table", + component: Table, + decorators: [ + (Story) => ( + + + + ), + ], + parameters: { + layout: "centered", + }, + tags: ["autodocs"], + argTypes: { + allowAddRow: { + control: "boolean", + description: "Whether to show the Add row button", + }, + allowDeleteRow: { + control: "boolean", + description: "Whether to show delete buttons for each row", + }, + readOnly: { + control: "boolean", + description: + "Whether the table is read-only (renders text instead of inputs)", + }, + addRowLabel: { + control: "text", + description: "Label for the Add row button", + }, + }, +} satisfies Meta; + +export default meta; +type Story = StoryObj; + +export const Default: Story = { + args: { + columns: ["name", "email", "role"], + allowAddRow: true, + allowDeleteRow: true, + }, +}; + +export const WithDefaultValues: Story = { + args: { + columns: ["name", "email", "role"], + defaultValues: [ + { name: "John Doe", email: "john@example.com", role: "Admin" }, + { name: "Jane Smith", email: "jane@example.com", role: "User" }, + { name: "Bob Wilson", email: "bob@example.com", role: "Editor" }, + ], + allowAddRow: true, + allowDeleteRow: true, + }, +}; + +export const ReadOnly: Story = { + args: { + columns: ["name", "email"], + defaultValues: [ + { name: "John Doe", email: "john@example.com" }, + { name: "Jane Smith", email: "jane@example.com" }, + ], + readOnly: true, + }, +}; + +export const NoAddOrDelete: Story = { + args: { + columns: ["name", "email"], + defaultValues: [ + { name: "John Doe", email: "john@example.com" }, + { name: "Jane Smith", email: "jane@example.com" }, + ], + allowAddRow: false, + allowDeleteRow: false, + }, +}; + +export const SingleColumn: Story = { + args: { + columns: ["item"], + allowAddRow: true, + allowDeleteRow: true, + addRowLabel: "Add item", + }, +}; + +export const CustomAddLabel: Story = { + args: { + columns: ["key", "value"], + allowAddRow: true, + allowDeleteRow: true, + addRowLabel: "Add new entry", + }, +}; + +export const KeyValuePairs: Story = { + args: { + columns: ["key", "value"], + defaultValues: [ + { key: "API_KEY", value: "sk-..." }, + { key: "DATABASE_URL", value: "postgres://..." }, + ], + allowAddRow: true, + allowDeleteRow: true, + addRowLabel: "Add variable", + }, +}; diff --git a/autogpt_platform/frontend/src/components/molecules/Table/Table.tsx b/autogpt_platform/frontend/src/components/molecules/Table/Table.tsx new file mode 100644 index 0000000000..a09a8344a5 --- /dev/null +++ b/autogpt_platform/frontend/src/components/molecules/Table/Table.tsx @@ -0,0 +1,133 @@ +import * as React from "react"; +import { + Table as BaseTable, + TableBody, + TableCell, + TableHead, + TableHeader, + TableRow, +} from "@/components/__legacy__/ui/table"; +import { Button } from "@/components/atoms/Button/Button"; +import { Input } from "@/components/atoms/Input/Input"; +import { Text } from "@/components/atoms/Text/Text"; +import { Plus, Trash2 } from "lucide-react"; +import { cn } from "@/lib/utils"; +import { useTable, RowData } from "./useTable"; +import { formatColumnTitle, formatPlaceholder } from "./helpers"; + +export interface TableProps { + columns: string[]; + defaultValues?: RowData[]; + onChange?: (rows: RowData[]) => void; + allowAddRow?: boolean; + allowDeleteRow?: boolean; + addRowLabel?: string; + className?: string; + readOnly?: boolean; +} + +export function Table({ + columns, + defaultValues, + onChange, + allowAddRow = true, + allowDeleteRow = true, + addRowLabel = "Add row", + className, + readOnly = false, +}: TableProps) { + const { rows, handleAddRow, handleDeleteRow, handleCellChange } = useTable({ + columns, + defaultValues, + onChange, + }); + + const showDeleteColumn = allowDeleteRow && !readOnly; + const showAddButton = allowAddRow && !readOnly; + + return ( +
+
+ + + + {columns.map((column) => ( + + {formatColumnTitle(column)} + + ))} + {showDeleteColumn && } + + + + {rows.map((row, rowIndex) => ( + + {columns.map((column) => ( + + {readOnly ? ( + + {row[column] || "-"} + + ) : ( + + handleCellChange(rowIndex, column, e.target.value) + } + placeholder={formatPlaceholder(column)} + size="small" + wrapperClassName="mb-0" + /> + )} + + ))} + {showDeleteColumn && ( + + + + )} + + ))} + {showAddButton && ( + + + + + + )} + + +
+
+ ); +} + +export { type RowData } from "./useTable"; diff --git a/autogpt_platform/frontend/src/components/molecules/Table/helpers.ts b/autogpt_platform/frontend/src/components/molecules/Table/helpers.ts new file mode 100644 index 0000000000..3ea116095a --- /dev/null +++ b/autogpt_platform/frontend/src/components/molecules/Table/helpers.ts @@ -0,0 +1,7 @@ +export const formatColumnTitle = (key: string): string => { + return key.charAt(0).toUpperCase() + key.slice(1); +}; + +export const formatPlaceholder = (key: string): string => { + return `Enter ${key.toLowerCase()}`; +}; diff --git a/autogpt_platform/frontend/src/components/molecules/Table/useTable.ts b/autogpt_platform/frontend/src/components/molecules/Table/useTable.ts new file mode 100644 index 0000000000..085c18aa74 --- /dev/null +++ b/autogpt_platform/frontend/src/components/molecules/Table/useTable.ts @@ -0,0 +1,81 @@ +import { useState, useEffect } from "react"; + +export type RowData = Record; + +interface UseTableOptions { + columns: string[]; + defaultValues?: RowData[]; + onChange?: (rows: RowData[]) => void; +} + +export function useTable({ + columns, + defaultValues, + onChange, +}: UseTableOptions) { + const createEmptyRow = (): RowData => { + const emptyRow: RowData = {}; + columns.forEach((column) => { + emptyRow[column] = ""; + }); + return emptyRow; + }; + + const [rows, setRows] = useState(() => { + if (defaultValues && defaultValues.length > 0) { + return defaultValues; + } + return []; + }); + + useEffect(() => { + if (defaultValues !== undefined) { + setRows(defaultValues); + } + }, [defaultValues]); + + const updateRows = (newRows: RowData[]) => { + setRows(newRows); + onChange?.(newRows); + }; + + const handleAddRow = () => { + const newRows = [...rows, createEmptyRow()]; + updateRows(newRows); + }; + + const handleDeleteRow = (rowIndex: number) => { + const newRows = rows.filter((_, index) => index !== rowIndex); + updateRows(newRows); + }; + + const handleCellChange = ( + rowIndex: number, + columnKey: string, + value: string, + ) => { + const newRows = rows.map((row, index) => { + if (index === rowIndex) { + return { + ...row, + [columnKey]: value, + }; + } + return row; + }); + updateRows(newRows); + }; + + const clearAll = () => { + updateRows([]); + }; + + return { + rows, + handleAddRow, + handleDeleteRow, + handleCellChange, + clearAll, + createEmptyRow, + }; +} diff --git a/autogpt_platform/frontend/src/components/molecules/Toast/styles.module.css b/autogpt_platform/frontend/src/components/molecules/Toast/styles.module.css index ec8271958f..d1daaaa88b 100644 --- a/autogpt_platform/frontend/src/components/molecules/Toast/styles.module.css +++ b/autogpt_platform/frontend/src/components/molecules/Toast/styles.module.css @@ -37,11 +37,3 @@ html body .toastDescription { font-size: 0.75rem !important; line-height: 1.25rem !important; } - -/* Position close button on the right */ -/* stylelint-disable-next-line selector-pseudo-class-no-unknown */ -#root [data-sonner-toast] [data-close-button="true"] { - left: unset !important; - right: -18px !important; - top: -3px !important; -} diff --git a/autogpt_platform/frontend/src/components/molecules/Toast/toaster.tsx b/autogpt_platform/frontend/src/components/molecules/Toast/toaster.tsx index adb88831af..80b228f03b 100644 --- a/autogpt_platform/frontend/src/components/molecules/Toast/toaster.tsx +++ b/autogpt_platform/frontend/src/components/molecules/Toast/toaster.tsx @@ -1,6 +1,5 @@ "use client"; -import { CheckCircle, Info, Warning, XCircle } from "@phosphor-icons/react"; import { Toaster as SonnerToaster } from "sonner"; import styles from "./styles.module.css"; @@ -23,10 +22,10 @@ export function Toaster() { }} className="custom__toast" icons={{ - success: , - error: , - warning: , - info: , + success: null, + error: null, + warning: null, + info: null, }} /> ); diff --git a/autogpt_platform/frontend/src/components/renderers/InputRenderer/FormRenderer.tsx b/autogpt_platform/frontend/src/components/renderers/InputRenderer/FormRenderer.tsx index f784b64516..da0e3d6683 100644 --- a/autogpt_platform/frontend/src/components/renderers/InputRenderer/FormRenderer.tsx +++ b/autogpt_platform/frontend/src/components/renderers/InputRenderer/FormRenderer.tsx @@ -30,6 +30,8 @@ export const FormRenderer = ({ return generateUiSchemaForCustomFields(preprocessedSchema, uiSchema); }, [preprocessedSchema, uiSchema]); + console.log("preprocessedSchema", preprocessedSchema); + return (
{ const { registry, schema } = props; const { fields } = registry; const { SchemaField: _SchemaField } = fields; const { nodeId } = registry.formContext; - const { isInputConnected } = useEdgeStore(); - - const uiOptions = getUiOptions(props.uiSchema, props.globalUiOptions); - - const Widget = getWidget({ type: "string" }, "select", registry.widgets); - const { handleOptionChange, enumOptions, @@ -26,6 +21,15 @@ export const AnyOfField = (props: FieldProps) => { field_id, } = useAnyOfField(props); + const parentCustomFieldId = findCustomFieldId(schema); + if (parentCustomFieldId) { + return null; + } + + const uiOptions = getUiOptions(props.uiSchema, props.globalUiOptions); + + const Widget = getWidget({ type: "string" }, "select", registry.widgets); + const handleId = getHandleId({ uiOptions, id: field_id + ANY_OF_FLAG, @@ -40,12 +44,21 @@ export const AnyOfField = (props: FieldProps) => { const isHandleConnected = isInputConnected(nodeId, handleId); + // Now anyOf can render - custom fields if the option schema matches a custom field + const optionCustomFieldId = optionSchema + ? findCustomFieldId(optionSchema) + : null; + + const optionUiSchema = optionCustomFieldId + ? { ...updatedUiSchema, "ui:field": optionCustomFieldId } + : updatedUiSchema; + const optionsSchemaField = (optionSchema && optionSchema.type !== "null" && ( <_SchemaField {...props} schema={optionSchema} - uiSchema={updatedUiSchema} + uiSchema={optionUiSchema} /> )) || null; diff --git a/autogpt_platform/frontend/src/components/renderers/InputRenderer/base/standard/widgets/TextInput/TextInputExpanderModal.tsx b/autogpt_platform/frontend/src/components/renderers/InputRenderer/base/standard/widgets/TextInput/TextInputExpanderModal.tsx index 5b19874bfb..a28b460ea5 100644 --- a/autogpt_platform/frontend/src/components/renderers/InputRenderer/base/standard/widgets/TextInput/TextInputExpanderModal.tsx +++ b/autogpt_platform/frontend/src/components/renderers/InputRenderer/base/standard/widgets/TextInput/TextInputExpanderModal.tsx @@ -17,6 +17,7 @@ interface InputExpanderModalProps { defaultValue: string; description?: string; placeholder?: string; + inputType?: "text" | "json"; } export const InputExpanderModal: FC = ({ @@ -27,6 +28,7 @@ export const InputExpanderModal: FC = ({ defaultValue, description, placeholder, + inputType = "text", }) => { const [tempValue, setTempValue] = useState(defaultValue); const [isCopied, setIsCopied] = useState(false); @@ -78,7 +80,10 @@ export const InputExpanderModal: FC = ({ hideLabel id="input-expander-modal" value={tempValue} - className="!min-h-[300px] rounded-2xlarge" + className={cn( + "!min-h-[300px] rounded-2xlarge", + inputType === "json" && "font-mono text-sm", + )} onChange={(e) => setTempValue(e.target.value)} placeholder={placeholder || "Enter text..."} autoFocus diff --git a/autogpt_platform/frontend/src/components/renderers/InputRenderer/custom/CredentialField/CredentialField.tsx b/autogpt_platform/frontend/src/components/renderers/InputRenderer/custom/CredentialField/CredentialField.tsx index f814fba93f..707b48f9d9 100644 --- a/autogpt_platform/frontend/src/components/renderers/InputRenderer/custom/CredentialField/CredentialField.tsx +++ b/autogpt_platform/frontend/src/components/renderers/InputRenderer/custom/CredentialField/CredentialField.tsx @@ -8,19 +8,34 @@ import { CredentialsInput } from "@/app/(platform)/library/agents/[id]/component import { useNodeStore } from "@/app/(platform)/build/stores/nodeStore"; import { useShallow } from "zustand/react/shallow"; import { CredentialFieldTitle } from "./components/CredentialFieldTitle"; +import { Switch } from "@/components/atoms/Switch/Switch"; export const CredentialsField = (props: FieldProps) => { - const { formData, onChange, schema, registry, fieldPathId } = props; + const { formData, onChange, schema, registry, fieldPathId, required } = props; const formContext = registry.formContext; const uiOptions = getUiOptions(props.uiSchema); const nodeId = formContext?.nodeId; - // Get sibling inputs (hardcoded values) from the node store - const hardcodedValues = useNodeStore( - useShallow((state) => (nodeId ? state.getHardCodedValues(nodeId) : {})), + // Get sibling inputs (hardcoded values) and credentials optional state from the node store + // Note: We select the node data directly instead of using getter functions to avoid + // creating new object references that would cause infinite re-render loops with useShallow + const { node, setCredentialsOptional } = useNodeStore( + useShallow((state) => ({ + node: nodeId ? state.nodes.find((n) => n.id === nodeId) : undefined, + setCredentialsOptional: state.setCredentialsOptional, + })), ); + const hardcodedValues = useMemo( + () => node?.data?.hardcodedValues || {}, + [node?.data?.hardcodedValues], + ); + const credentialsOptional = useMemo(() => { + const value = node?.data?.metadata?.credentials_optional; + return typeof value === "boolean" ? value : false; + }, [node?.data?.metadata?.credentials_optional]); + const handleChange = (newValue: any) => { onChange(newValue, fieldPathId?.path); }; @@ -52,6 +67,10 @@ export const CredentialsField = (props: FieldProps) => { [formData?.id, formData?.provider, formData?.title, formData?.type], ); + // In builder canvas (nodeId exists): show star based on credentialsOptional toggle + // In run dialogs (no nodeId): show star based on schema's required array + const isRequired = nodeId ? !credentialsOptional : required; + return (
{ registry={registry} uiOptions={uiOptions} schema={schema} + required={isRequired} /> { siblingInputs={hardcodedValues} showTitle={false} readOnly={formContext?.readOnly} + isOptional={!isRequired} + className="w-full" + variant="node" /> + + {/* Optional credentials toggle - only show in builder canvas, not run dialogs */} + {nodeId && + !formContext?.readOnly && + formContext?.showOptionalToggle !== false && ( +
+ + setCredentialsOptional(nodeId, checked) + } + /> + +
+ )}
); }; diff --git a/autogpt_platform/frontend/src/components/renderers/InputRenderer/custom/CredentialField/components/CredentialFieldTitle.tsx b/autogpt_platform/frontend/src/components/renderers/InputRenderer/custom/CredentialField/components/CredentialFieldTitle.tsx index ca14c8a4ce..347f4e089a 100644 --- a/autogpt_platform/frontend/src/components/renderers/InputRenderer/custom/CredentialField/components/CredentialFieldTitle.tsx +++ b/autogpt_platform/frontend/src/components/renderers/InputRenderer/custom/CredentialField/components/CredentialFieldTitle.tsx @@ -18,8 +18,9 @@ export const CredentialFieldTitle = (props: { uiOptions: UiSchema; schema: RJSFSchema; fieldPathId: FieldPathId; + required?: boolean; }) => { - const { registry, uiOptions, schema, fieldPathId } = props; + const { registry, uiOptions, schema, fieldPathId, required = false } = props; const { nodeId } = registry.formContext; const TitleFieldTemplate = getTemplate( @@ -50,7 +51,7 @@ export const CredentialFieldTitle = (props: { { + const { + formData, + onChange, + schema, + registry, + uiSchema, + required, + name, + fieldPathId, + } = props; + + const uiOptions = getUiOptions(uiSchema); + + const TitleFieldTemplate = getTemplate( + "TitleFieldTemplate", + registry, + uiOptions, + ); + + const fieldId = fieldPathId?.$id ?? props.id ?? "json-field"; + + const handleId = getHandleId({ + uiOptions, + id: fieldId, + schema: schema, + }); + + const updatedUiSchema = updateUiOption(uiSchema, { + handleId: handleId, + }); + + const { + textValue, + isModalOpen, + handleChange, + handleModalOpen, + handleModalClose, + handleModalSave, + } = useJsonTextField({ + formData, + onChange, + path: fieldPathId?.path, + }); + + const placeholder = getPlaceholder(schema); + const title = schema.title || name || "JSON Value"; + + return ( +
+ +
+ + + + + + + Expand input + +
+ {schema.description && ( + {schema.description} + )} + + +
+ ); +}; + +export default JsonTextField; diff --git a/autogpt_platform/frontend/src/components/renderers/InputRenderer/custom/JsonTextField/helpers.ts b/autogpt_platform/frontend/src/components/renderers/InputRenderer/custom/JsonTextField/helpers.ts new file mode 100644 index 0000000000..fea0f20dbc --- /dev/null +++ b/autogpt_platform/frontend/src/components/renderers/InputRenderer/custom/JsonTextField/helpers.ts @@ -0,0 +1,67 @@ +import { RJSFSchema } from "@rjsf/utils"; + +/** + * Converts form data to a JSON string for display + * @param formData - The data to stringify + * @returns JSON string or empty string if data is null/undefined + */ +export function stringifyFormData(formData: unknown): string { + if (formData === undefined || formData === null) { + return ""; + } + try { + return JSON.stringify(formData, null, 2); + } catch { + return ""; + } +} + +/** + * Parses a JSON string into an object/array + * @param value - The JSON string to parse + * @returns Parsed value or undefined if parsing fails or empty + */ +export function parseJsonValue(value: string): unknown | undefined { + const trimmed = value.trim(); + if (trimmed === "") { + return undefined; + } + + try { + return JSON.parse(trimmed); + } catch { + return undefined; + } +} + +/** + * Gets the appropriate placeholder text based on schema type + * @param schema - The JSON schema + * @returns Placeholder string + */ +export function getPlaceholder(schema: RJSFSchema): string { + if (schema.type === "array") { + return '["item1", "item2"] or [{"key": "value"}]'; + } + if (schema.type === "object") { + return '{"key": "value"}'; + } + return "Enter JSON value..."; +} + +/** + * Checks if a JSON string is valid + * @param value - The JSON string to validate + * @returns true if valid JSON, false otherwise + */ +export function isValidJson(value: string): boolean { + if (value.trim() === "") { + return true; // Empty is considered valid (will be undefined) + } + try { + JSON.parse(value); + return true; + } catch { + return false; + } +} diff --git a/autogpt_platform/frontend/src/components/renderers/InputRenderer/custom/JsonTextField/useJsonTextField.ts b/autogpt_platform/frontend/src/components/renderers/InputRenderer/custom/JsonTextField/useJsonTextField.ts new file mode 100644 index 0000000000..85dc69cfd3 --- /dev/null +++ b/autogpt_platform/frontend/src/components/renderers/InputRenderer/custom/JsonTextField/useJsonTextField.ts @@ -0,0 +1,107 @@ +import { useState, useEffect, useCallback } from "react"; +import { FieldProps } from "@rjsf/utils"; +import { stringifyFormData, parseJsonValue, isValidJson } from "./helpers"; + +type FieldOnChange = FieldProps["onChange"]; +type FieldPathId = FieldProps["fieldPathId"]; + +interface UseJsonTextFieldOptions { + formData: unknown; + onChange: FieldOnChange; + path?: FieldPathId["path"]; +} + +interface UseJsonTextFieldReturn { + textValue: string; + isModalOpen: boolean; + hasError: boolean; + handleChange: ( + e: React.ChangeEvent, + ) => void; + handleModalOpen: () => void; + handleModalClose: () => void; + handleModalSave: (value: string) => void; +} + +/** + * Custom hook for managing JSON text field state and handlers + */ +export function useJsonTextField({ + formData, + onChange, + path, +}: UseJsonTextFieldOptions): UseJsonTextFieldReturn { + const [textValue, setTextValue] = useState(() => stringifyFormData(formData)); + const [isModalOpen, setIsModalOpen] = useState(false); + const [hasError, setHasError] = useState(false); + + // Update text value when formData changes externally + useEffect(() => { + const newValue = stringifyFormData(formData); + setTextValue(newValue); + setHasError(false); + }, [formData]); + + const handleChange = useCallback( + (e: React.ChangeEvent) => { + const value = e.target.value; + setTextValue(value); + + // Validate JSON and update error state + const valid = isValidJson(value); + setHasError(!valid); + + // Try to parse and update formData + if (value.trim() === "") { + onChange(undefined, path ?? []); + return; + } + + const parsed = parseJsonValue(value); + if (parsed !== undefined) { + onChange(parsed, path ?? []); + } + }, + [onChange, path], + ); + + const handleModalOpen = useCallback(() => { + setIsModalOpen(true); + }, []); + + const handleModalClose = useCallback(() => { + setIsModalOpen(false); + }, []); + + const handleModalSave = useCallback( + (value: string) => { + setTextValue(value); + setIsModalOpen(false); + + // Validate and update + const valid = isValidJson(value); + setHasError(!valid); + + if (value.trim() === "") { + onChange(undefined, path ?? []); + return; + } + + const parsed = parseJsonValue(value); + if (parsed !== undefined) { + onChange(parsed, path ?? []); + } + }, + [onChange, path], + ); + + return { + textValue, + isModalOpen, + hasError, + handleChange, + handleModalOpen, + handleModalClose, + handleModalSave, + }; +} diff --git a/autogpt_platform/frontend/src/components/renderers/InputRenderer/custom/MultiSelectField/MultiSelectField.tsx b/autogpt_platform/frontend/src/components/renderers/InputRenderer/custom/MultiSelectField/MultiSelectField.tsx new file mode 100644 index 0000000000..dcae2f3bed --- /dev/null +++ b/autogpt_platform/frontend/src/components/renderers/InputRenderer/custom/MultiSelectField/MultiSelectField.tsx @@ -0,0 +1,57 @@ +import React from "react"; +import { FieldProps, getUiOptions } from "@rjsf/utils"; +import { BlockIOObjectSubSchema } from "@/lib/autogpt-server-api/types"; +import { + MultiSelector, + MultiSelectorContent, + MultiSelectorInput, + MultiSelectorItem, + MultiSelectorList, + MultiSelectorTrigger, +} from "@/components/__legacy__/ui/multiselect"; +import { cn } from "@/lib/utils"; +import { useMultiSelectField } from "./useMultiSelectField"; + +export const MultiSelectField = (props: FieldProps) => { + const { schema, formData, onChange, fieldPathId } = props; + const uiOptions = getUiOptions(props.uiSchema); + + const { optionSchema, options, selection, createChangeHandler } = + useMultiSelectField({ + schema: schema as BlockIOObjectSubSchema, + formData, + }); + + const handleValuesChange = createChangeHandler(onChange, fieldPathId); + + const displayName = schema.title || "options"; + + return ( +
+ + + + + + + {options + .map((key) => ({ ...optionSchema[key], key })) + .map(({ key, title, description }) => ( + + {title ?? key} + + ))} + + + +
+ ); +}; diff --git a/autogpt_platform/frontend/src/components/renderers/InputRenderer/custom/MultiSelectField/index.ts b/autogpt_platform/frontend/src/components/renderers/InputRenderer/custom/MultiSelectField/index.ts new file mode 100644 index 0000000000..4d49ec7dbb --- /dev/null +++ b/autogpt_platform/frontend/src/components/renderers/InputRenderer/custom/MultiSelectField/index.ts @@ -0,0 +1 @@ +export { MultiSelectField } from "./MultiSelectField"; diff --git a/autogpt_platform/frontend/src/components/renderers/InputRenderer/custom/MultiSelectField/useMultiSelectField.ts b/autogpt_platform/frontend/src/components/renderers/InputRenderer/custom/MultiSelectField/useMultiSelectField.ts new file mode 100644 index 0000000000..c04173dcfe --- /dev/null +++ b/autogpt_platform/frontend/src/components/renderers/InputRenderer/custom/MultiSelectField/useMultiSelectField.ts @@ -0,0 +1,65 @@ +import { FieldProps } from "@rjsf/utils"; +import { BlockIOObjectSubSchema } from "@/lib/autogpt-server-api/types"; + +type FormData = Record | null | undefined; + +interface UseMultiSelectFieldOptions { + schema: BlockIOObjectSubSchema; + formData: FormData; +} + +export function useMultiSelectField({ + schema, + formData, +}: UseMultiSelectFieldOptions) { + const getOptionSchema = (): Record => { + if (schema.properties) { + return schema.properties as Record; + } + if ( + "anyOf" in schema && + Array.isArray(schema.anyOf) && + schema.anyOf.length > 0 && + "properties" in schema.anyOf[0] + ) { + return (schema.anyOf[0] as BlockIOObjectSubSchema).properties as Record< + string, + BlockIOObjectSubSchema + >; + } + return {}; + }; + + const optionSchema = getOptionSchema(); + const options = Object.keys(optionSchema); + + const getSelection = (): string[] => { + if (!formData || typeof formData !== "object") { + return []; + } + return Object.entries(formData) + .filter(([_, value]) => value === true) + .map(([key]) => key); + }; + + const selection = getSelection(); + + const createChangeHandler = + ( + onChange: FieldProps["onChange"], + fieldPathId: FieldProps["fieldPathId"], + ) => + (values: string[]) => { + const newValue = Object.fromEntries( + options.map((opt) => [opt, values.includes(opt)]), + ); + onChange(newValue, fieldPathId?.path); + }; + + return { + optionSchema, + options, + selection, + createChangeHandler, + }; +} diff --git a/autogpt_platform/frontend/src/components/renderers/InputRenderer/custom/TableField/TableField.tsx b/autogpt_platform/frontend/src/components/renderers/InputRenderer/custom/TableField/TableField.tsx new file mode 100644 index 0000000000..b48eca3238 --- /dev/null +++ b/autogpt_platform/frontend/src/components/renderers/InputRenderer/custom/TableField/TableField.tsx @@ -0,0 +1,52 @@ +import { descriptionId, FieldProps, getTemplate, titleId } from "@rjsf/utils"; +import { Table, RowData } from "@/components/molecules/Table/Table"; +import { useMemo } from "react"; + +export const TableField = (props: FieldProps) => { + const { schema, formData, onChange, fieldPathId, registry, uiSchema } = props; + + const itemSchema = schema.items as any; + const properties = itemSchema?.properties || {}; + + const columns: string[] = useMemo(() => { + return Object.keys(properties); + }, [properties]); + + const handleChange = (rows: RowData[]) => { + onChange(rows, fieldPathId?.path.slice(0, -1)); + }; + + const TitleFieldTemplate = getTemplate("TitleFieldTemplate", registry); + const DescriptionFieldTemplate = getTemplate( + "DescriptionFieldTemplate", + registry, + ); + + return ( +
+ + + + + + ); +}; diff --git a/autogpt_platform/frontend/src/components/renderers/InputRenderer/custom/custom-registry.ts b/autogpt_platform/frontend/src/components/renderers/InputRenderer/custom/custom-registry.ts index 91850e3f10..30d2c27a5a 100644 --- a/autogpt_platform/frontend/src/components/renderers/InputRenderer/custom/custom-registry.ts +++ b/autogpt_platform/frontend/src/components/renderers/InputRenderer/custom/custom-registry.ts @@ -1,6 +1,10 @@ import { FieldProps, RJSFSchema, RegistryFieldsType } from "@rjsf/utils"; import { CredentialsField } from "./CredentialField/CredentialField"; import { GoogleDrivePickerField } from "./GoogleDrivePickerField/GoogleDrivePickerField"; +import { JsonTextField } from "./JsonTextField/JsonTextField"; +import { MultiSelectField } from "./MultiSelectField/MultiSelectField"; +import { isMultiSelectSchema } from "../utils/schema-utils"; +import { TableField } from "./TableField/TableField"; export interface CustomFieldDefinition { id: string; @@ -8,6 +12,9 @@ export interface CustomFieldDefinition { component: (props: FieldProps) => JSX.Element | null; } +/** Field ID for JsonTextField - used to render nested complex types as text input */ +export const JSON_TEXT_FIELD_ID = "custom/json_text_field"; + export const CUSTOM_FIELDS: CustomFieldDefinition[] = [ { id: "custom/credential_field", @@ -30,6 +37,28 @@ export const CUSTOM_FIELDS: CustomFieldDefinition[] = [ }, component: GoogleDrivePickerField, }, + { + id: "custom/json_text_field", + // Not matched by schema - assigned via uiSchema for nested complex types + matcher: () => false, + component: JsonTextField, + }, + { + id: "custom/multi_select_field", + matcher: isMultiSelectSchema, + component: MultiSelectField, + }, + { + id: "custom/table_field", + matcher: (schema: any) => { + return ( + schema.type === "array" && + "format" in schema && + schema.format === "table" + ); + }, + component: TableField, + }, ]; export function findCustomFieldId(schema: any): string | null { diff --git a/autogpt_platform/frontend/src/components/renderers/InputRenderer/types.ts b/autogpt_platform/frontend/src/components/renderers/InputRenderer/types.ts index af2e8b7866..e667e27e2d 100644 --- a/autogpt_platform/frontend/src/components/renderers/InputRenderer/types.ts +++ b/autogpt_platform/frontend/src/components/renderers/InputRenderer/types.ts @@ -6,6 +6,7 @@ export interface ExtendedFormContextType extends FormContextType { uiType?: BlockUIType; showHandles?: boolean; size?: "small" | "medium" | "large"; + showOptionalToggle?: boolean; } export type PathSegment = { diff --git a/autogpt_platform/frontend/src/components/renderers/InputRenderer/utils/generate-ui-schema.ts b/autogpt_platform/frontend/src/components/renderers/InputRenderer/utils/generate-ui-schema.ts index 4a2f4fc44a..4012c39068 100644 --- a/autogpt_platform/frontend/src/components/renderers/InputRenderer/utils/generate-ui-schema.ts +++ b/autogpt_platform/frontend/src/components/renderers/InputRenderer/utils/generate-ui-schema.ts @@ -1,19 +1,46 @@ import { RJSFSchema, UiSchema } from "@rjsf/utils"; -import { findCustomFieldId } from "../custom/custom-registry"; +import { + findCustomFieldId, + JSON_TEXT_FIELD_ID, +} from "../custom/custom-registry"; + +function isComplexType(schema: RJSFSchema): boolean { + return schema.type === "object" || schema.type === "array"; +} + +function hasComplexAnyOfOptions(schema: RJSFSchema): boolean { + const options = schema.anyOf || schema.oneOf; + if (!Array.isArray(options)) return false; + return options.some( + (opt: any) => + opt && + typeof opt === "object" && + (opt.type === "object" || opt.type === "array"), + ); +} /** * Generates uiSchema with ui:field settings for custom fields based on schema matchers. * This is the standard RJSF way to route fields to custom components. + * + * Nested complex types (arrays/objects inside arrays/objects) are rendered as JsonTextField + * to avoid deeply nested form UIs. Users can enter raw JSON for these fields. + * + * @param schema - The JSON schema + * @param existingUiSchema - Existing uiSchema to merge with + * @param insideComplexType - Whether we're already inside a complex type (object/array) */ export function generateUiSchemaForCustomFields( schema: RJSFSchema, existingUiSchema: UiSchema = {}, + insideComplexType: boolean = false, ): UiSchema { const uiSchema: UiSchema = { ...existingUiSchema }; if (schema.properties) { for (const [key, propSchema] of Object.entries(schema.properties)) { if (propSchema && typeof propSchema === "object") { + // First check for custom field matchers (credentials, google drive, etc.) const customFieldId = findCustomFieldId(propSchema); if (customFieldId) { @@ -21,8 +48,33 @@ export function generateUiSchemaForCustomFields( ...(uiSchema[key] as object), "ui:field": customFieldId, }; + // Skip further processing for custom fields + continue; } + // Handle nested complex types - render as JsonTextField + if (insideComplexType && isComplexType(propSchema as RJSFSchema)) { + uiSchema[key] = { + ...(uiSchema[key] as object), + "ui:field": JSON_TEXT_FIELD_ID, + }; + // Don't recurse further - this field is now a text input + continue; + } + + // Handle anyOf/oneOf inside complex types + if ( + insideComplexType && + hasComplexAnyOfOptions(propSchema as RJSFSchema) + ) { + uiSchema[key] = { + ...(uiSchema[key] as object), + "ui:field": JSON_TEXT_FIELD_ID, + }; + continue; + } + + // Recurse into object properties if ( propSchema.type === "object" && propSchema.properties && @@ -31,6 +83,7 @@ export function generateUiSchemaForCustomFields( const nestedUiSchema = generateUiSchemaForCustomFields( propSchema as RJSFSchema, (uiSchema[key] as UiSchema) || {}, + true, // Now inside a complex type ); uiSchema[key] = { ...(uiSchema[key] as object), @@ -38,9 +91,11 @@ export function generateUiSchemaForCustomFields( }; } + // Handle arrays if (propSchema.type === "array" && propSchema.items) { const itemsSchema = propSchema.items as RJSFSchema; if (itemsSchema && typeof itemsSchema === "object") { + // Check for custom field on array items const itemsCustomFieldId = findCustomFieldId(itemsSchema); if (itemsCustomFieldId) { uiSchema[key] = { @@ -49,10 +104,28 @@ export function generateUiSchemaForCustomFields( "ui:field": itemsCustomFieldId, }, }; + } else if (isComplexType(itemsSchema)) { + // Array items that are complex types become JsonTextField + uiSchema[key] = { + ...(uiSchema[key] as object), + items: { + "ui:field": JSON_TEXT_FIELD_ID, + }, + }; + } else if (hasComplexAnyOfOptions(itemsSchema)) { + // Array items with anyOf containing complex types become JsonTextField + uiSchema[key] = { + ...(uiSchema[key] as object), + items: { + "ui:field": JSON_TEXT_FIELD_ID, + }, + }; } else if (itemsSchema.properties) { + // Recurse into object items (but they're now inside a complex type) const itemsUiSchema = generateUiSchemaForCustomFields( itemsSchema, ((uiSchema[key] as UiSchema)?.items as UiSchema) || {}, + true, // Inside complex type (array) ); if (Object.keys(itemsUiSchema).length > 0) { uiSchema[key] = { @@ -63,6 +136,61 @@ export function generateUiSchemaForCustomFields( } } } + + // Handle anyOf/oneOf at root level - process complex options + if (!insideComplexType) { + const anyOfOptions = propSchema.anyOf || propSchema.oneOf; + + if (Array.isArray(anyOfOptions)) { + for (let i = 0; i < anyOfOptions.length; i++) { + const option = anyOfOptions[i] as RJSFSchema; + if (option && typeof option === "object") { + // Handle anyOf array options with complex items + if (option.type === "array" && option.items) { + const itemsSchema = option.items as RJSFSchema; + if (itemsSchema && typeof itemsSchema === "object") { + // Array items that are complex types become JsonTextField + if (isComplexType(itemsSchema)) { + uiSchema[key] = { + ...(uiSchema[key] as object), + items: { + "ui:field": JSON_TEXT_FIELD_ID, + }, + }; + } else if (hasComplexAnyOfOptions(itemsSchema)) { + uiSchema[key] = { + ...(uiSchema[key] as object), + items: { + "ui:field": JSON_TEXT_FIELD_ID, + }, + }; + } + } + } + + // Recurse into anyOf object options with properties + if ( + option.type === "object" && + option.properties && + typeof option.properties === "object" + ) { + const optionUiSchema = generateUiSchemaForCustomFields( + option, + {}, + true, // Inside complex type (anyOf object option) + ); + if (Object.keys(optionUiSchema).length > 0) { + // Store under the property key - RJSF will apply it + uiSchema[key] = { + ...(uiSchema[key] as object), + ...optionUiSchema, + }; + } + } + } + } + } + } } } } diff --git a/autogpt_platform/frontend/src/components/renderers/InputRenderer/utils/schema-utils.ts b/autogpt_platform/frontend/src/components/renderers/InputRenderer/utils/schema-utils.ts index b1cfd37967..fecf2d77d1 100644 --- a/autogpt_platform/frontend/src/components/renderers/InputRenderer/utils/schema-utils.ts +++ b/autogpt_platform/frontend/src/components/renderers/InputRenderer/utils/schema-utils.ts @@ -1,7 +1,11 @@ import { getUiOptions, RJSFSchema, UiSchema } from "@rjsf/utils"; export function isAnyOfSchema(schema: RJSFSchema | undefined): boolean { - return Array.isArray(schema?.anyOf) && schema!.anyOf.length > 0; + return ( + Array.isArray(schema?.anyOf) && + schema!.anyOf.length > 0 && + schema?.enum === undefined + ); } export const isAnyOfChild = ( @@ -33,3 +37,21 @@ export function isOptionalType(schema: RJSFSchema | undefined): { export function isAnyOfSelector(name: string) { return name.includes("anyof_select"); } + +export function isMultiSelectSchema(schema: RJSFSchema | undefined): boolean { + if (typeof schema !== "object" || schema === null) { + return false; + } + + if ("anyOf" in schema || "oneOf" in schema) { + return false; + } + + return !!( + schema.type === "object" && + schema.properties && + Object.values(schema.properties).every( + (prop: any) => prop.type === "boolean", + ) + ); +} diff --git a/autogpt_platform/frontend/src/tests/agent-dashboard.spec.ts b/autogpt_platform/frontend/src/tests/agent-dashboard.spec.ts index 87dfff10ca..ea1629f929 100644 --- a/autogpt_platform/frontend/src/tests/agent-dashboard.spec.ts +++ b/autogpt_platform/frontend/src/tests/agent-dashboard.spec.ts @@ -83,7 +83,7 @@ test("agent table delete action works correctly", async ({ page }) => { const rows = agentTable.getByTestId("agent-table-row"); - // Delete button testing — delete the first agent in the list + // Delete button testing — only works for PENDING submissions const beforeCount = await rows.count(); if (beforeCount === 0) { @@ -91,11 +91,18 @@ test("agent table delete action works correctly", async ({ page }) => { return; } - const firstRow = rows.first(); - const deletedSubmissionId = await firstRow.getAttribute("data-submission-id"); - await firstRow.scrollIntoViewIfNeeded(); + // Find a PENDING submission to delete + const pendingRow = rows.filter({ hasText: "Pending" }).first(); + if (!(await pendingRow.count())) { + console.log("No pending agents available; skipping delete flow."); + return; + } - const delActionsButton = firstRow.getByTestId("agent-table-row-actions"); + const deletedSubmissionId = + await pendingRow.getAttribute("data-submission-id"); + await pendingRow.scrollIntoViewIfNeeded(); + + const delActionsButton = pendingRow.getByTestId("agent-table-row-actions"); await delActionsButton.waitFor({ state: "visible", timeout: 10000 }); await delActionsButton.scrollIntoViewIfNeeded(); await delActionsButton.click(); @@ -108,7 +115,7 @@ test("agent table delete action works correctly", async ({ page }) => { await isHidden(page.locator(`[data-submission-id="${deletedSubmissionId}"]`)); }); -test("edit action is unavailable for rejected agents (view only)", async ({ +test("edit and delete actions are unavailable for non-pending submissions", async ({ page, }) => { await page.goto("/profile/dashboard"); @@ -118,27 +125,39 @@ test("edit action is unavailable for rejected agents (view only)", async ({ const rows = agentTable.getByTestId("agent-table-row"); + // Test with rejected submissions (view only) const rejectedRow = rows.filter({ hasText: "Rejected" }).first(); - if (!(await rejectedRow.count())) { - console.log("No rejected agents available; skipping rejected edit test."); - return; + if (await rejectedRow.count()) { + await rejectedRow.scrollIntoViewIfNeeded(); + const actionsButton = rejectedRow.getByTestId("agent-table-row-actions"); + await actionsButton.waitFor({ state: "visible", timeout: 10000 }); + await actionsButton.scrollIntoViewIfNeeded(); + await actionsButton.click(); + + await expect(page.getByRole("menuitem", { name: "View" })).toBeVisible(); + await expect(page.getByRole("menuitem", { name: "Edit" })).toHaveCount(0); + await expect(page.getByRole("menuitem", { name: "Delete" })).toHaveCount(0); + + // Close the menu + await page.keyboard.press("Escape"); } - await rejectedRow.scrollIntoViewIfNeeded(); + // Test with approved submissions (view only) + const approvedRow = rows.filter({ hasText: "Approved" }).first(); + if (await approvedRow.count()) { + await approvedRow.scrollIntoViewIfNeeded(); + const actionsButton = approvedRow.getByTestId("agent-table-row-actions"); + await actionsButton.waitFor({ state: "visible", timeout: 10000 }); + await actionsButton.scrollIntoViewIfNeeded(); + await actionsButton.click(); - const actionsButton = rejectedRow.getByTestId("agent-table-row-actions"); - await actionsButton.waitFor({ state: "visible", timeout: 10000 }); - await actionsButton.scrollIntoViewIfNeeded(); - await actionsButton.click(); - - // Rejected should not show Edit, only View - await expect(page.getByRole("menuitem", { name: "View" })).toBeVisible(); - await expect(page.getByRole("menuitem", { name: "Edit" })).toHaveCount(0); + await expect(page.getByRole("menuitem", { name: "View" })).toBeVisible(); + await expect(page.getByRole("menuitem", { name: "Edit" })).toHaveCount(0); + await expect(page.getByRole("menuitem", { name: "Delete" })).toHaveCount(0); + } }); -test("editing an approved agent creates a new pending submission", async ({ - page, -}) => { +test("editing a pending submission works correctly", async ({ page }) => { await page.goto("/profile/dashboard"); const agentTable = page.getByTestId("agent-table"); @@ -146,16 +165,17 @@ test("editing an approved agent creates a new pending submission", async ({ const rows = agentTable.getByTestId("agent-table-row"); - const approvedRow = rows.filter({ hasText: "Approved" }).first(); - if (!(await approvedRow.count())) { - console.log("No approved agents available; skipping approved edit test."); + // Find a PENDING submission to edit (only PENDING submissions can be edited) + const pendingRow = rows.filter({ hasText: "Pending" }).first(); + if (!(await pendingRow.count())) { + console.log("No pending agents available; skipping edit test."); return; } const beforeCount = await rows.count(); - await approvedRow.scrollIntoViewIfNeeded(); - const actionsButton = approvedRow.getByTestId("agent-table-row-actions"); + await pendingRow.scrollIntoViewIfNeeded(); + const actionsButton = pendingRow.getByTestId("agent-table-row-actions"); await actionsButton.waitFor({ state: "visible", timeout: 10000 }); await actionsButton.scrollIntoViewIfNeeded(); await actionsButton.click(); @@ -167,11 +187,11 @@ test("editing an approved agent creates a new pending submission", async ({ const editModal = page.getByTestId("edit-agent-modal"); await expect(editModal).toBeVisible(); - const newTitle = `E2E Edit Approved ${Date.now()}`; + const newTitle = `E2E Edit Pending ${Date.now()}`; await page.getByRole("textbox", { name: "Title" }).fill(newTitle); await page .getByRole("textbox", { name: "Changes Summary" }) - .fill("E2E change - approved -> new pending submission"); + .fill("E2E change - updating pending submission"); await page.getByRole("button", { name: "Update submission" }).click(); await expect(editModal).not.toBeVisible(); diff --git a/docs/_javascript/mathjax.js b/docs/_javascript/mathjax.js deleted file mode 100644 index a80ddbff75..0000000000 --- a/docs/_javascript/mathjax.js +++ /dev/null @@ -1,16 +0,0 @@ -window.MathJax = { - tex: { - inlineMath: [["\\(", "\\)"]], - displayMath: [["\\[", "\\]"]], - processEscapes: true, - processEnvironments: true - }, - options: { - ignoreHtmlClass: ".*|", - processHtmlClass: "arithmatex" - } -}; - -document$.subscribe(() => { - MathJax.typesetPromise() -}) \ No newline at end of file diff --git a/docs/_javascript/tablesort.js b/docs/_javascript/tablesort.js deleted file mode 100644 index ee04e90082..0000000000 --- a/docs/_javascript/tablesort.js +++ /dev/null @@ -1,6 +0,0 @@ -document$.subscribe(function () { - var tables = document.querySelectorAll("article table:not([class])") - tables.forEach(function (table) { - new Tablesort(table) - }) -}) \ No newline at end of file diff --git a/docs/content/contribute/index.md b/docs/content/contribute/index.md index 4cce72aacf..3b458a04df 100644 --- a/docs/content/contribute/index.md +++ b/docs/content/contribute/index.md @@ -1,48 +1,35 @@ # Contributing to the Docs -We welcome contributions to our documentation! If you would like to contribute, please follow the steps below. +We welcome contributions to our documentation! Our docs are hosted on GitBook and synced with GitHub. -## Setting up the Docs +## How It Works -1. Clone the repository: +- Documentation lives in the `docs/` directory on the `gitbook` branch +- GitBook automatically syncs changes from GitHub +- You can edit docs directly on GitHub or locally + +## Editing Docs Locally + +1. Clone the repository and switch to the gitbook branch: ```shell - git clone github.com/Significant-Gravitas/AutoGPT.git + git clone https://github.com/Significant-Gravitas/AutoGPT.git + cd AutoGPT + git checkout gitbook ``` -1. Install the dependencies: +2. Make your changes to markdown files in `docs/` - ```shell - python -m pip install -r docs/requirements.txt - ``` +3. Preview changes: + - Push to a branch and create a PR - GitBook will generate a preview + - Or use any markdown preview tool locally - or +## Adding a New Page - ```shell - python3 -m pip install -r docs/requirements.txt - ``` - -1. Start iterating using mkdocs' live server: - - ```shell - mkdocs serve - ``` - -1. Open your browser and navigate to `http://127.0.0.1:8000`. - -1. The server will automatically reload the docs when you save your changes. - -## Adding a new page - -1. Create a new markdown file in the `docs/content` directory. -1. Add the new page to the `nav` section in the `mkdocs.yml` file. -1. Add the content to the new markdown file. -1. Run `mkdocs serve` to see your changes. - -## Checking links - -To check for broken links in the documentation, run `mkdocs build` and look for warnings in the console output. +1. Create a new markdown file in the appropriate `docs/` subdirectory +2. Add the new page to the relevant `SUMMARY.md` file to include it in the navigation +3. Submit a pull request to the `gitbook` branch ## Submitting a Pull Request -When you're ready to submit your changes, please create a pull request. We will review your changes and merge them if they are appropriate. +When you're ready to submit your changes, create a pull request targeting the `gitbook` branch. We will review your changes and merge them if appropriate. diff --git a/docs/mkdocs.yml b/docs/mkdocs.yml deleted file mode 100644 index 876467633e..0000000000 --- a/docs/mkdocs.yml +++ /dev/null @@ -1,194 +0,0 @@ -site_name: AutoGPT Documentation -site_url: https://docs.agpt.co/ -repo_url: https://github.com/Significant-Gravitas/AutoGPT -repo_name: AutoGPT -edit_uri: edit/master/docs/content -docs_dir: content -nav: - - Home: index.md - - - The AutoGPT Platform 🆕: - - Getting Started: - - Setup AutoGPT (Local-Host): platform/getting-started.md - - Edit an Agent: platform/edit-agent.md - - Delete an Agent: platform/delete-agent.md - - Download & Import and Agent: platform/download-agent-from-marketplace-local.md - - Create a Basic Agent: platform/create-basic-agent.md - - Submit an Agent to the Marketplace: platform/submit-agent-to-marketplace.md - - Advanced Setup: platform/advanced_setup.md - - Agent Blocks: platform/agent-blocks.md - - Build your own Blocks: platform/new_blocks.md - - Block SDK Guide: platform/block-sdk-guide.md - - Using Ollama: platform/ollama.md - - Using AI/ML API: platform/aimlapi.md - - Using D-ID: platform/d_id.md - - Blocks: platform/blocks/blocks.md - - API: - - Introduction: platform/integrating/api-guide.md - - OAuth & SSO: platform/integrating/oauth-guide.md - - Contributing: - - Tests: platform/contributing/tests.md - - OAuth Flows: platform/contributing/oauth-integration-flow.md - - AutoGPT Classic: - - Introduction: classic/index.md - - Setup: - - Setting up AutoGPT: classic/setup/index.md - - Set up with Docker: classic/setup/docker.md - - For Developers: classic/setup/for-developers.md - - Configuration: - - Options: classic/configuration/options.md - - Search: classic/configuration/search.md - - Voice: classic/configuration/voice.md - - Usage: classic/usage.md - - Help us improve AutoGPT: - - Share your debug logs with us: classic/share-your-logs.md - - Contribution guide: contributing.md - - Running tests: classic/testing.md - - Code of Conduct: code-of-conduct.md - - Benchmark: - - Readme: https://github.com/Significant-Gravitas/AutoGPT/blob/master/classic/benchmark/README.md - - Forge: - - Introduction: forge/get-started.md - - Components: - - Introduction: forge/components/introduction.md - - Agents: forge/components/agents.md - - Components: forge/components/components.md - - Protocols: forge/components/protocols.md - - Commands: forge/components/commands.md - - Built in Components: forge/components/built-in-components.md - - Creating Components: forge/components/creating-components.md - - Frontend: - - Readme: https://github.com/Significant-Gravitas/AutoGPT/blob/master/classic/frontend/README.md - - - Contribute: - - Introduction: contribute/index.md - - Testing: ../../autogpt_platform/backend/TESTING.md - - # - Challenges: - # - Introduction: challenges/introduction.md - # - List of Challenges: - # - Memory: - # - Introduction: challenges/memory/introduction.md - # - Memory Challenge A: challenges/memory/challenge_a.md - # - Memory Challenge B: challenges/memory/challenge_b.md - # - Memory Challenge C: challenges/memory/challenge_c.md - # - Memory Challenge D: challenges/memory/challenge_d.md - # - Information retrieval: - # - Introduction: challenges/information_retrieval/introduction.md - # - Information Retrieval Challenge A: challenges/information_retrieval/challenge_a.md - # - Information Retrieval Challenge B: challenges/information_retrieval/challenge_b.md - # - Submit a Challenge: challenges/submit.md - # - Beat a Challenge: challenges/beat.md - - - License: https://github.com/Significant-Gravitas/AutoGPT/blob/master/LICENSE - -theme: - name: material - custom_dir: overrides - language: en - icon: - repo: fontawesome/brands/github - logo: material/book-open-variant - edit: material/pencil - view: material/eye - favicon: assets/favicon.png - features: - - navigation.sections - - navigation.footer - - navigation.top - - navigation.tracking - - navigation.tabs - # - navigation.path - - toc.follow - - toc.integrate - - content.action.edit - - content.action.view - - content.code.copy - - content.code.annotate - - content.tabs.link - palette: - # Palette toggle for light mode - - media: "(prefers-color-scheme: light)" - scheme: default - toggle: - icon: material/weather-night - name: Switch to dark mode - - # Palette toggle for dark mode - - media: "(prefers-color-scheme: dark)" - scheme: slate - toggle: - icon: material/weather-sunny - name: Switch to light mode - -markdown_extensions: - # Python Markdown - - abbr - - admonition - - attr_list - - def_list - - footnotes - - md_in_html - - toc: - permalink: true - - tables - - # Python Markdown Extensions - - pymdownx.arithmatex: - generic: true - - pymdownx.betterem: - smart_enable: all - - pymdownx.critic - - pymdownx.caret - - pymdownx.details - - pymdownx.emoji: - emoji_index: !!python/name:material.extensions.emoji.twemoji - emoji_generator: !!python/name:material.extensions.emoji.to_svg - - pymdownx.highlight - - pymdownx.inlinehilite - - pymdownx.keys - - pymdownx.mark - - pymdownx.smartsymbols - - pymdownx.snippets: - base_path: ['.','../'] - check_paths: true - dedent_subsections: true - - pymdownx.superfences: - custom_fences: - - name: mermaid - class: mermaid - format: !!python/name:pymdownx.superfences.fence_code_format - - pymdownx.tabbed: - alternate_style: true - - pymdownx.tasklist: - custom_checkbox: true - - pymdownx.tilde - -plugins: - - table-reader - - search - - git-revision-date-localized: - enable_creation_date: true - -extra: - social: - - icon: fontawesome/brands/github - link: https://github.com/Significant-Gravitas/AutoGPT - - icon: fontawesome/brands/x-twitter - link: https://x.com/Auto_GPT - - icon: fontawesome/brands/instagram - link: https://www.instagram.com/autogpt/ - - icon: fontawesome/brands/discord - link: https://discord.gg/autogpt - -extra: - analytics: - provider: google - property: G-XKPNKB9XG6 - -extra_javascript: - - https://unpkg.com/tablesort@5.3.0/dist/tablesort.min.js - - _javascript/tablesort.js - - _javascript/mathjax.js - - https://cdnjs.cloudflare.com/polyfill/v3/polyfill.min.js?features=es6 - - https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js diff --git a/docs/netlify.toml b/docs/netlify.toml deleted file mode 100644 index c07b733a57..0000000000 --- a/docs/netlify.toml +++ /dev/null @@ -1,5 +0,0 @@ -# Netlify config for AutoGPT docs - -[build] - publish = "public/" - command = "mkdocs build -d public" diff --git a/docs/overrides/assets/favicon.png b/docs/overrides/assets/favicon.png deleted file mode 100644 index 69bded6abf..0000000000 Binary files a/docs/overrides/assets/favicon.png and /dev/null differ diff --git a/docs/overrides/main.html b/docs/overrides/main.html deleted file mode 100644 index b376724b8b..0000000000 --- a/docs/overrides/main.html +++ /dev/null @@ -1,61 +0,0 @@ -{% extends "base.html" %} - - -{% block extrahead %} - - - - - -{% raw %} - -{% endraw %} -{% endblock %} \ No newline at end of file diff --git a/docs/platform/SUMMARY.md b/docs/platform/SUMMARY.md new file mode 100644 index 0000000000..b71f6dc929 --- /dev/null +++ b/docs/platform/SUMMARY.md @@ -0,0 +1,34 @@ +# Table of contents + +* [What is the AutoGPT Platform?](what-is-autogpt-platform.md) + +## Getting Started + +* [Setting Up Auto-GPT (Local Host)](getting-started.md) +* [AutoGPT Platform Installer](installer.md) +* [Edit an Agent](edit-agent.md) +* [Delete an Agent](delete-agent.md) +* [Download & Import an Agent](download-agent-from-marketplace-local.md) +* [Create a Basic Agent](create-basic-agent.md) +* [Submit an Agent to the Marketplace](submit-agent-to-marketplace.md) + +## Advanced Setup + +* [Advanced Setup](advanced_setup.md) + +## Building Blocks + +* [Agent Blocks Overview](agent-blocks.md) +* [Build your own Blocks](new_blocks.md) +* [Block SDK Guide](block-sdk-guide.md) + +## Using AI Services + +* [Using Ollama](ollama.md) +* [Using AI/ML API](aimlapi.md) +* [Using D-ID](d_id.md) + +## API & Integrations + +* [API Introduction](integrating/api-guide.md) +* [OAuth & SSO](integrating/oauth-guide.md) diff --git a/docs/content/platform/advanced_setup.md b/docs/platform/advanced_setup.md similarity index 100% rename from docs/content/platform/advanced_setup.md rename to docs/platform/advanced_setup.md diff --git a/docs/content/platform/agent-blocks.md b/docs/platform/agent-blocks.md similarity index 100% rename from docs/content/platform/agent-blocks.md rename to docs/platform/agent-blocks.md diff --git a/docs/content/platform/aimlapi.md b/docs/platform/aimlapi.md similarity index 100% rename from docs/content/platform/aimlapi.md rename to docs/platform/aimlapi.md diff --git a/docs/content/platform/block-sdk-guide.md b/docs/platform/block-sdk-guide.md similarity index 96% rename from docs/content/platform/block-sdk-guide.md rename to docs/platform/block-sdk-guide.md index 1f8fc545e9..5b3eda5184 100644 --- a/docs/content/platform/block-sdk-guide.md +++ b/docs/platform/block-sdk-guide.md @@ -227,7 +227,7 @@ backend/blocks/my_provider/ ## Best Practices -1. **Error Handling**: Error output pin is already defined on BlockSchemaOutput +1. **Error Handling**: Use `BlockInputError` for validation failures and `BlockExecutionError` for runtime errors (import from `backend.util.exceptions`). These inherit from `ValueError` so the executor treats them as user-fixable. See [Error Handling in new_blocks.md](new_blocks.md#error-handling) for details. 2. **Credentials**: Use the provider's `credentials_field()` method 3. **Validation**: Use SchemaField constraints (ge, le, min_length, etc.) 4. **Categories**: Choose appropriate categories for discoverability diff --git a/docs/content/platform/blocks/ai_condition.md b/docs/platform/blocks/ai_condition.md similarity index 100% rename from docs/content/platform/blocks/ai_condition.md rename to docs/platform/blocks/ai_condition.md diff --git a/docs/content/platform/blocks/ai_shortform_video_block.md b/docs/platform/blocks/ai_shortform_video_block.md similarity index 100% rename from docs/content/platform/blocks/ai_shortform_video_block.md rename to docs/platform/blocks/ai_shortform_video_block.md diff --git a/docs/content/platform/blocks/basic.md b/docs/platform/blocks/basic.md similarity index 100% rename from docs/content/platform/blocks/basic.md rename to docs/platform/blocks/basic.md diff --git a/docs/content/platform/blocks/blocks.md b/docs/platform/blocks/blocks.md similarity index 100% rename from docs/content/platform/blocks/blocks.md rename to docs/platform/blocks/blocks.md diff --git a/docs/content/platform/blocks/branching.md b/docs/platform/blocks/branching.md similarity index 100% rename from docs/content/platform/blocks/branching.md rename to docs/platform/blocks/branching.md diff --git a/docs/content/platform/blocks/csv.md b/docs/platform/blocks/csv.md similarity index 100% rename from docs/content/platform/blocks/csv.md rename to docs/platform/blocks/csv.md diff --git a/docs/content/platform/blocks/decoder_block.md b/docs/platform/blocks/decoder_block.md similarity index 100% rename from docs/content/platform/blocks/decoder_block.md rename to docs/platform/blocks/decoder_block.md diff --git a/docs/content/platform/blocks/discord.md b/docs/platform/blocks/discord.md similarity index 100% rename from docs/content/platform/blocks/discord.md rename to docs/platform/blocks/discord.md diff --git a/docs/content/platform/blocks/email_block.md b/docs/platform/blocks/email_block.md similarity index 100% rename from docs/content/platform/blocks/email_block.md rename to docs/platform/blocks/email_block.md diff --git a/docs/content/platform/blocks/flux_kontext.md b/docs/platform/blocks/flux_kontext.md similarity index 100% rename from docs/content/platform/blocks/flux_kontext.md rename to docs/platform/blocks/flux_kontext.md diff --git a/docs/content/platform/blocks/github/issues.md b/docs/platform/blocks/github/issues.md similarity index 100% rename from docs/content/platform/blocks/github/issues.md rename to docs/platform/blocks/github/issues.md diff --git a/docs/content/platform/blocks/github/pull_requests.md b/docs/platform/blocks/github/pull_requests.md similarity index 100% rename from docs/content/platform/blocks/github/pull_requests.md rename to docs/platform/blocks/github/pull_requests.md diff --git a/docs/content/platform/blocks/github/repo.md b/docs/platform/blocks/github/repo.md similarity index 100% rename from docs/content/platform/blocks/github/repo.md rename to docs/platform/blocks/github/repo.md diff --git a/docs/content/platform/blocks/google/gmail.md b/docs/platform/blocks/google/gmail.md similarity index 100% rename from docs/content/platform/blocks/google/gmail.md rename to docs/platform/blocks/google/gmail.md diff --git a/docs/content/platform/blocks/google/sheet.md b/docs/platform/blocks/google/sheet.md similarity index 100% rename from docs/content/platform/blocks/google/sheet.md rename to docs/platform/blocks/google/sheet.md diff --git a/docs/content/platform/blocks/google_maps.md b/docs/platform/blocks/google_maps.md similarity index 100% rename from docs/content/platform/blocks/google_maps.md rename to docs/platform/blocks/google_maps.md diff --git a/docs/content/platform/blocks/http.md b/docs/platform/blocks/http.md similarity index 100% rename from docs/content/platform/blocks/http.md rename to docs/platform/blocks/http.md diff --git a/docs/content/platform/blocks/ideogram.md b/docs/platform/blocks/ideogram.md similarity index 100% rename from docs/content/platform/blocks/ideogram.md rename to docs/platform/blocks/ideogram.md diff --git a/docs/content/platform/blocks/iteration.md b/docs/platform/blocks/iteration.md similarity index 100% rename from docs/content/platform/blocks/iteration.md rename to docs/platform/blocks/iteration.md diff --git a/docs/content/platform/blocks/llm.md b/docs/platform/blocks/llm.md similarity index 100% rename from docs/content/platform/blocks/llm.md rename to docs/platform/blocks/llm.md diff --git a/docs/content/platform/blocks/maths.md b/docs/platform/blocks/maths.md similarity index 100% rename from docs/content/platform/blocks/maths.md rename to docs/platform/blocks/maths.md diff --git a/docs/content/platform/blocks/medium.md b/docs/platform/blocks/medium.md similarity index 100% rename from docs/content/platform/blocks/medium.md rename to docs/platform/blocks/medium.md diff --git a/docs/content/platform/blocks/reddit.md b/docs/platform/blocks/reddit.md similarity index 100% rename from docs/content/platform/blocks/reddit.md rename to docs/platform/blocks/reddit.md diff --git a/docs/content/platform/blocks/replicate_flux_advanced.md b/docs/platform/blocks/replicate_flux_advanced.md similarity index 100% rename from docs/content/platform/blocks/replicate_flux_advanced.md rename to docs/platform/blocks/replicate_flux_advanced.md diff --git a/docs/content/platform/blocks/rss.md b/docs/platform/blocks/rss.md similarity index 100% rename from docs/content/platform/blocks/rss.md rename to docs/platform/blocks/rss.md diff --git a/docs/content/platform/blocks/sampling.md b/docs/platform/blocks/sampling.md similarity index 100% rename from docs/content/platform/blocks/sampling.md rename to docs/platform/blocks/sampling.md diff --git a/docs/content/platform/blocks/search.md b/docs/platform/blocks/search.md similarity index 100% rename from docs/content/platform/blocks/search.md rename to docs/platform/blocks/search.md diff --git a/docs/content/platform/blocks/talking_head.md b/docs/platform/blocks/talking_head.md similarity index 100% rename from docs/content/platform/blocks/talking_head.md rename to docs/platform/blocks/talking_head.md diff --git a/docs/content/platform/blocks/text.md b/docs/platform/blocks/text.md similarity index 100% rename from docs/content/platform/blocks/text.md rename to docs/platform/blocks/text.md diff --git a/docs/content/platform/blocks/text_to_speech_block.md b/docs/platform/blocks/text_to_speech_block.md similarity index 100% rename from docs/content/platform/blocks/text_to_speech_block.md rename to docs/platform/blocks/text_to_speech_block.md diff --git a/docs/content/platform/blocks/time_blocks.md b/docs/platform/blocks/time_blocks.md similarity index 100% rename from docs/content/platform/blocks/time_blocks.md rename to docs/platform/blocks/time_blocks.md diff --git a/docs/content/platform/blocks/todoist.md b/docs/platform/blocks/todoist.md similarity index 100% rename from docs/content/platform/blocks/todoist.md rename to docs/platform/blocks/todoist.md diff --git a/docs/content/platform/blocks/twitter/twitter.md b/docs/platform/blocks/twitter/twitter.md similarity index 100% rename from docs/content/platform/blocks/twitter/twitter.md rename to docs/platform/blocks/twitter/twitter.md diff --git a/docs/content/platform/blocks/youtube.md b/docs/platform/blocks/youtube.md similarity index 100% rename from docs/content/platform/blocks/youtube.md rename to docs/platform/blocks/youtube.md diff --git a/docs/content/platform/contributing/oauth-integration-flow.md b/docs/platform/contributing/oauth-integration-flow.md similarity index 100% rename from docs/content/platform/contributing/oauth-integration-flow.md rename to docs/platform/contributing/oauth-integration-flow.md diff --git a/docs/content/platform/contributing/tests.md b/docs/platform/contributing/tests.md similarity index 100% rename from docs/content/platform/contributing/tests.md rename to docs/platform/contributing/tests.md diff --git a/docs/content/platform/create-basic-agent.md b/docs/platform/create-basic-agent.md similarity index 100% rename from docs/content/platform/create-basic-agent.md rename to docs/platform/create-basic-agent.md diff --git a/docs/content/platform/d_id.md b/docs/platform/d_id.md similarity index 100% rename from docs/content/platform/d_id.md rename to docs/platform/d_id.md diff --git a/docs/content/platform/delete-agent.md b/docs/platform/delete-agent.md similarity index 100% rename from docs/content/platform/delete-agent.md rename to docs/platform/delete-agent.md diff --git a/docs/content/platform/download-agent-from-marketplace-local.md b/docs/platform/download-agent-from-marketplace-local.md similarity index 100% rename from docs/content/platform/download-agent-from-marketplace-local.md rename to docs/platform/download-agent-from-marketplace-local.md diff --git a/docs/content/platform/edit-agent.md b/docs/platform/edit-agent.md similarity index 100% rename from docs/content/platform/edit-agent.md rename to docs/platform/edit-agent.md diff --git a/docs/content/platform/getting-started.md b/docs/platform/getting-started.md similarity index 100% rename from docs/content/platform/getting-started.md rename to docs/platform/getting-started.md diff --git a/docs/content/platform/installer.md b/docs/platform/installer.md similarity index 100% rename from docs/content/platform/installer.md rename to docs/platform/installer.md diff --git a/docs/content/platform/integrating/api-guide.md b/docs/platform/integrating/api-guide.md similarity index 100% rename from docs/content/platform/integrating/api-guide.md rename to docs/platform/integrating/api-guide.md diff --git a/docs/content/platform/integrating/oauth-guide.md b/docs/platform/integrating/oauth-guide.md similarity index 100% rename from docs/content/platform/integrating/oauth-guide.md rename to docs/platform/integrating/oauth-guide.md diff --git a/docs/content/platform/new_blocks.md b/docs/platform/new_blocks.md similarity index 93% rename from docs/content/platform/new_blocks.md rename to docs/platform/new_blocks.md index 8e6324b937..d9d329ff51 100644 --- a/docs/content/platform/new_blocks.md +++ b/docs/platform/new_blocks.md @@ -616,7 +616,59 @@ custom_requests = Requests( ### Error Handling -All blocks should have an error output that catches all reasonable errors that a user can handle, wrap them in a ValueError, and re-raise. Don't catch things the system admin would need to fix like being out of money or unreachable addresses. +Blocks should raise appropriate exceptions for errors that users can fix. The executor classifies errors based on whether they inherit from `ValueError` - these are treated as "expected failures" (user-fixable) rather than system errors. + +#### Block Exception Classes + +Import from `backend.util.exceptions`: + +```python +from backend.util.exceptions import BlockInputError, BlockExecutionError +``` + +| Exception | Use Case | Example | +|-----------|----------|---------| +| `BlockInputError` | Invalid user input, validation failures, missing required fields | Bad API key format, invalid URL, missing credentials | +| `BlockExecutionError` | Runtime failures the user can address | API errors, auth failures, resource not found, rate limits | +| `ValueError` | Simple cases (auto-wrapped to `BlockExecutionError`) | Basic validation errors | + +#### Raising Exceptions + +```python +from backend.util.exceptions import BlockInputError, BlockExecutionError + +class MyBlock(Block): + async def run(self, input_data: Input, **kwargs) -> BlockOutput: + # Input validation - use BlockInputError + if not input_data.api_key: + raise BlockInputError( + message="API key is required", + block_name=self.name, + block_id=self.id, + ) + + try: + result = await self.call_api(input_data) + yield "result", result + except AuthenticationError as e: + # API/runtime errors - use BlockExecutionError + raise BlockExecutionError( + message=f"Authentication failed: {e}", + block_name=self.name, + block_id=self.id, + ) from e +``` + +#### What NOT to Catch + +Don't catch errors that require system admin intervention: + +- Out of money/credits +- Unreachable infrastructure +- Database connection failures +- Internal server errors from your own services + +Let these propagate as unexpected errors so they get proper attention. ### Data Models diff --git a/docs/content/platform/ollama.md b/docs/platform/ollama.md similarity index 100% rename from docs/content/platform/ollama.md rename to docs/platform/ollama.md diff --git a/docs/content/platform/submit-agent-to-marketplace.md b/docs/platform/submit-agent-to-marketplace.md similarity index 100% rename from docs/content/platform/submit-agent-to-marketplace.md rename to docs/platform/submit-agent-to-marketplace.md diff --git a/docs/platform/what-is-autogpt-platform.md b/docs/platform/what-is-autogpt-platform.md new file mode 100644 index 0000000000..b0d1ff0f74 --- /dev/null +++ b/docs/platform/what-is-autogpt-platform.md @@ -0,0 +1,82 @@ +# What is the AutoGPT Platform? + +The AutoGPT Platform is a groundbreaking system that revolutionizes AI utilization for businesses and individuals. It enables the creation, deployment, and management of continuous agents that work tirelessly on your behalf, bringing unprecedented efficiency and innovation to your workflows. + +## Key Features + +* **Seamless Integration and Low-Code Workflows**: Rapidly create complex workflows without extensive coding knowledge. +* **Autonomous Operation and Continuous Agents**: Deploy cloud-based assistants that run indefinitely, activating on relevant triggers. +* **Intelligent Automation and Maximum Efficiency**: Streamline workflows by automating repetitive processes. +* **Reliable Performance and Predictable Execution**: Enjoy consistent and dependable long-running processes. + +## Platform Architecture + +The AutoGPT Platform consists of two main components: + +### 1. AutoGPT Server + +The powerhouse of our platform, containing: + +* **Source Code**: Core logic driving agents and automation processes. +* **Infrastructure**: Robust systems ensuring reliable and scalable performance. +* **Marketplace**: A comprehensive marketplace for pre-built agents. + +### 2. AutoGPT Frontend + +The user interface where you interact with the platform: + +* **Agent Builder**: Design and configure your own AI agents. +* **Workflow Management**: Build, modify, and optimize automation workflows. +* **Deployment Controls**: Manage the lifecycle of your agents. +* **Ready-to-Use Agents**: Select from pre-configured agents. +* **Agent Interaction**: Run and interact with agents through a user-friendly interface. +* **Monitoring and Analytics**: Track agent performance and gain insights. + +## Platform Components + +### Agents and Workflows + +In the platform, you can create highly customized workflows to build agents. An agent is essentially an automated workflow that you design to perform specific tasks or processes. Create customized workflows to build agents for various tasks, including: + +* Data processing and analysis +* Task scheduling and management +* Communication and notification systems +* Integration between different software tools +* AI-powered decision making and content generation + +### Blocks as Integrations + +Blocks represent actions and are the building blocks of your workflows, including: + +* Connections to external services +* Data processing tools +* AI models for various tasks +* Custom scripts or functions +* Conditional logic and decision-making components + +You can learn more under: [Build your own Blocks](new_blocks.md) + +## Available Language Models + +The platform comes pre-integrated with cutting-edge LLM providers: + +* OpenAI - +* Anthropic - +* Groq - +* Llama - +* AI/ML API - + * AI/ML API provides 300+ AI models including Deepseek, Gemini, ChatGPT. The models run at enterprise-grade rate limits and uptimes. + +## License Overview + +We've adopted a dual-license approach to balance open collaboration with sustainable development: + +* **MIT License**: The majority of the AutoGPT repository remains under this license. +* **Polyform Shield License**: Applies to the new `autogpt_platform` folder. + +This strategy allows us to share previously closed-source components, fostering a vibrant ecosystem of developers and users. + +## Ready to Get Started? + +* Read the [Getting Started docs](getting-started.md) to self-host. +* [Join the waitlist](https://agpt.co/waitlist) for the cloud-hosted beta. diff --git a/docs/requirements.txt b/docs/requirements.txt deleted file mode 100644 index ea0eab8c2a..0000000000 --- a/docs/requirements.txt +++ /dev/null @@ -1,8 +0,0 @@ -mkdocs -mkdocs-material -mkdocs-table-reader-plugin -pymdown-extensions -mkdocs-git-revision-date-localized-plugin -zipp>=3.19.1 # not directly required, pinned by Snyk to avoid a vulnerability -urllib3>=2.2.2 # not directly required, pinned by Snyk to avoid a vulnerability -requests>=2.32.4 # not directly required, pinned by Snyk to avoid a vulnerability