diff --git a/.branchlet.json b/.branchlet.json new file mode 100644 index 0000000000..cc13ff9f74 --- /dev/null +++ b/.branchlet.json @@ -0,0 +1,37 @@ +{ + "worktreeCopyPatterns": [ + ".env*", + ".vscode/**", + ".auth/**", + ".claude/**", + "autogpt_platform/.env*", + "autogpt_platform/backend/.env*", + "autogpt_platform/frontend/.env*", + "autogpt_platform/frontend/.auth/**", + "autogpt_platform/db/docker/.env*" + ], + "worktreeCopyIgnores": [ + "**/node_modules/**", + "**/dist/**", + "**/.git/**", + "**/Thumbs.db", + "**/.DS_Store", + "**/.next/**", + "**/__pycache__/**", + "**/.ruff_cache/**", + "**/.pytest_cache/**", + "**/*.pyc", + "**/playwright-report/**", + "**/logs/**", + "**/site/**" + ], + "worktreePathTemplate": "$BASE_PATH.worktree", + "postCreateCmd": [ + "cd autogpt_platform/autogpt_libs && poetry install", + "cd autogpt_platform/backend && poetry install && poetry run prisma generate", + "cd autogpt_platform/frontend && pnpm install", + "cd docs && pip install -r requirements.txt" + ], + "terminalCommand": "code .", + "deleteBranchWithWorktree": false +} diff --git a/.dockerignore b/.dockerignore index 94bf1742f1..c9524ce700 100644 --- a/.dockerignore +++ b/.dockerignore @@ -16,6 +16,7 @@ !autogpt_platform/backend/poetry.lock !autogpt_platform/backend/README.md !autogpt_platform/backend/.env +!autogpt_platform/backend/gen_prisma_types_stub.py # Platform - Market !autogpt_platform/market/market/ diff --git a/.github/workflows/claude-dependabot.yml b/.github/workflows/claude-dependabot.yml index 20b6f1d28e..1fd0da3d8e 100644 --- a/.github/workflows/claude-dependabot.yml +++ b/.github/workflows/claude-dependabot.yml @@ -74,7 +74,7 @@ jobs: - name: Generate Prisma Client working-directory: autogpt_platform/backend - run: poetry run prisma generate + run: poetry run prisma generate && poetry run gen-prisma-stub # Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml) - name: Set up Node.js diff --git a/.github/workflows/claude.yml b/.github/workflows/claude.yml index 3f5e8c22ec..71c6ef49c2 100644 --- a/.github/workflows/claude.yml +++ b/.github/workflows/claude.yml @@ -90,7 +90,7 @@ jobs: - name: Generate Prisma Client working-directory: autogpt_platform/backend - run: poetry run prisma generate + run: poetry run prisma generate && poetry run gen-prisma-stub # Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml) - name: Set up Node.js diff --git a/.github/workflows/copilot-setup-steps.yml b/.github/workflows/copilot-setup-steps.yml index 13ef01cc44..aac8befee0 100644 --- a/.github/workflows/copilot-setup-steps.yml +++ b/.github/workflows/copilot-setup-steps.yml @@ -72,7 +72,7 @@ jobs: - name: Generate Prisma Client working-directory: autogpt_platform/backend - run: poetry run prisma generate + run: poetry run prisma generate && poetry run gen-prisma-stub # Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml) - name: Set up Node.js @@ -108,6 +108,16 @@ jobs: # run: pnpm playwright install --with-deps chromium # Docker setup for development environment + - name: Free up disk space + run: | + # Remove large unused tools to free disk space for Docker builds + sudo rm -rf /usr/share/dotnet + sudo rm -rf /usr/local/lib/android + sudo rm -rf /opt/ghc + sudo rm -rf /opt/hostedtoolcache/CodeQL + sudo docker system prune -af + df -h + - name: Set up Docker Buildx uses: docker/setup-buildx-action@v3 diff --git a/.github/workflows/platform-backend-ci.yml b/.github/workflows/platform-backend-ci.yml index f962382fa5..da5ab83c1c 100644 --- a/.github/workflows/platform-backend-ci.yml +++ b/.github/workflows/platform-backend-ci.yml @@ -134,7 +134,7 @@ jobs: run: poetry install - name: Generate Prisma Client - run: poetry run prisma generate + run: poetry run prisma generate && poetry run gen-prisma-stub - id: supabase name: Start Supabase diff --git a/autogpt_platform/Makefile b/autogpt_platform/Makefile index d99fee49d7..2ff454e392 100644 --- a/autogpt_platform/Makefile +++ b/autogpt_platform/Makefile @@ -12,6 +12,7 @@ reset-db: rm -rf db/docker/volumes/db/data cd backend && poetry run prisma migrate deploy cd backend && poetry run prisma generate + cd backend && poetry run gen-prisma-stub # View logs for core services logs-core: @@ -33,6 +34,7 @@ init-env: migrate: cd backend && poetry run prisma migrate deploy cd backend && poetry run prisma generate + cd backend && poetry run gen-prisma-stub run-backend: cd backend && poetry run app diff --git a/autogpt_platform/backend/Dockerfile b/autogpt_platform/backend/Dockerfile index 7f51bad3a1..b3389d1787 100644 --- a/autogpt_platform/backend/Dockerfile +++ b/autogpt_platform/backend/Dockerfile @@ -48,7 +48,8 @@ RUN poetry install --no-ansi --no-root # Generate Prisma client COPY autogpt_platform/backend/schema.prisma ./ COPY autogpt_platform/backend/backend/data/partial_types.py ./backend/data/partial_types.py -RUN poetry run prisma generate +COPY autogpt_platform/backend/gen_prisma_types_stub.py ./ +RUN poetry run prisma generate && poetry run gen-prisma-stub FROM debian:13-slim AS server_dependencies diff --git a/autogpt_platform/backend/backend/api/features/library/db.py b/autogpt_platform/backend/backend/api/features/library/db.py index 69ed0d2730..1c17e7b36c 100644 --- a/autogpt_platform/backend/backend/api/features/library/db.py +++ b/autogpt_platform/backend/backend/api/features/library/db.py @@ -489,7 +489,7 @@ async def update_agent_version_in_library( agent_graph_version: int, ) -> library_model.LibraryAgent: """ - Updates the agent version in the library if useGraphIsActiveVersion is True. + Updates the agent version in the library for any agent owned by the user. Args: user_id: Owner of the LibraryAgent. @@ -498,20 +498,31 @@ async def update_agent_version_in_library( Raises: DatabaseError: If there's an error with the update. + NotFoundError: If no library agent is found for this user and agent. """ logger.debug( f"Updating agent version in library for user #{user_id}, " f"agent #{agent_graph_id} v{agent_graph_version}" ) - try: - library_agent = await prisma.models.LibraryAgent.prisma().find_first_or_raise( + async with transaction() as tx: + library_agent = await prisma.models.LibraryAgent.prisma(tx).find_first_or_raise( where={ "userId": user_id, "agentGraphId": agent_graph_id, - "useGraphIsActiveVersion": True, }, ) - lib = await prisma.models.LibraryAgent.prisma().update( + + # Delete any conflicting LibraryAgent for the target version + await prisma.models.LibraryAgent.prisma(tx).delete_many( + where={ + "userId": user_id, + "agentGraphId": agent_graph_id, + "agentGraphVersion": agent_graph_version, + "id": {"not": library_agent.id}, + } + ) + + lib = await prisma.models.LibraryAgent.prisma(tx).update( where={"id": library_agent.id}, data={ "AgentGraph": { @@ -525,13 +536,13 @@ async def update_agent_version_in_library( }, include={"AgentGraph": True}, ) - if lib is None: - raise NotFoundError(f"Library agent {library_agent.id} not found") - return library_model.LibraryAgent.from_db(lib) - except prisma.errors.PrismaError as e: - logger.error(f"Database error updating agent version in library: {e}") - raise DatabaseError("Failed to update agent version in library") from e + if lib is None: + raise NotFoundError( + f"Failed to update library agent for {agent_graph_id} v{agent_graph_version}" + ) + + return library_model.LibraryAgent.from_db(lib) async def update_library_agent( @@ -825,6 +836,7 @@ async def add_store_agent_to_library( } }, "isCreatedByUser": False, + "useGraphIsActiveVersion": False, "settings": SafeJson( _initialize_graph_settings(graph_model).model_dump() ), diff --git a/autogpt_platform/backend/backend/api/features/library/model.py b/autogpt_platform/backend/backend/api/features/library/model.py index c20f82afae..56fad7bfd3 100644 --- a/autogpt_platform/backend/backend/api/features/library/model.py +++ b/autogpt_platform/backend/backend/api/features/library/model.py @@ -48,6 +48,7 @@ class LibraryAgent(pydantic.BaseModel): id: str graph_id: str graph_version: int + owner_user_id: str # ID of user who owns/created this agent graph image_url: str | None @@ -163,6 +164,7 @@ class LibraryAgent(pydantic.BaseModel): id=agent.id, graph_id=agent.agentGraphId, graph_version=agent.agentGraphVersion, + owner_user_id=agent.userId, image_url=agent.imageUrl, creator_name=creator_name, creator_image_url=creator_image_url, diff --git a/autogpt_platform/backend/backend/api/features/library/routes_test.py b/autogpt_platform/backend/backend/api/features/library/routes_test.py index ad28b5b6bd..0f05240a7f 100644 --- a/autogpt_platform/backend/backend/api/features/library/routes_test.py +++ b/autogpt_platform/backend/backend/api/features/library/routes_test.py @@ -42,6 +42,7 @@ async def test_get_library_agents_success( id="test-agent-1", graph_id="test-agent-1", graph_version=1, + owner_user_id=test_user_id, name="Test Agent 1", description="Test Description 1", image_url=None, @@ -64,6 +65,7 @@ async def test_get_library_agents_success( id="test-agent-2", graph_id="test-agent-2", graph_version=1, + owner_user_id=test_user_id, name="Test Agent 2", description="Test Description 2", image_url=None, @@ -138,6 +140,7 @@ async def test_get_favorite_library_agents_success( id="test-agent-1", graph_id="test-agent-1", graph_version=1, + owner_user_id=test_user_id, name="Favorite Agent 1", description="Test Favorite Description 1", image_url=None, @@ -205,6 +208,7 @@ def test_add_agent_to_library_success( id="test-library-agent-id", graph_id="test-agent-1", graph_version=1, + owner_user_id=test_user_id, name="Test Agent 1", description="Test Description 1", image_url=None, diff --git a/autogpt_platform/backend/backend/api/features/store/db.py b/autogpt_platform/backend/backend/api/features/store/db.py index 8e5a39df89..8e4310ee02 100644 --- a/autogpt_platform/backend/backend/api/features/store/db.py +++ b/autogpt_platform/backend/backend/api/features/store/db.py @@ -614,6 +614,7 @@ async def get_store_submissions( submission_models = [] for sub in submissions: submission_model = store_model.StoreSubmission( + listing_id=sub.listing_id, agent_id=sub.agent_id, agent_version=sub.agent_version, name=sub.name, @@ -667,35 +668,48 @@ async def delete_store_submission( submission_id: str, ) -> bool: """ - Delete a store listing submission as the submitting user. + Delete a store submission version as the submitting user. Args: user_id: ID of the authenticated user - submission_id: ID of the submission to be deleted + submission_id: StoreListingVersion ID to delete Returns: - bool: True if the submission was successfully deleted, False otherwise + bool: True if successfully deleted """ - logger.debug(f"Deleting store submission {submission_id} for user {user_id}") - try: - # Verify the submission belongs to this user - submission = await prisma.models.StoreListing.prisma().find_first( - where={"agentGraphId": submission_id, "owningUserId": user_id} + # Find the submission version with ownership check + version = await prisma.models.StoreListingVersion.prisma().find_first( + where={"id": submission_id}, include={"StoreListing": True} ) - if not submission: - logger.warning(f"Submission not found for user {user_id}: {submission_id}") - raise store_exceptions.SubmissionNotFoundError( - f"Submission not found for this user. User ID: {user_id}, Submission ID: {submission_id}" + if ( + not version + or not version.StoreListing + or version.StoreListing.owningUserId != user_id + ): + raise store_exceptions.SubmissionNotFoundError("Submission not found") + + # Prevent deletion of approved submissions + if version.submissionStatus == prisma.enums.SubmissionStatus.APPROVED: + raise store_exceptions.InvalidOperationError( + "Cannot delete approved submissions" ) - # Delete the submission - await prisma.models.StoreListing.prisma().delete(where={"id": submission.id}) - - logger.debug( - f"Successfully deleted submission {submission_id} for user {user_id}" + # Delete the version + await prisma.models.StoreListingVersion.prisma().delete( + where={"id": version.id} ) + + # Clean up empty listing if this was the last version + remaining = await prisma.models.StoreListingVersion.prisma().count( + where={"storeListingId": version.storeListingId} + ) + if remaining == 0: + await prisma.models.StoreListing.prisma().delete( + where={"id": version.storeListingId} + ) + return True except Exception as e: @@ -759,9 +773,15 @@ async def create_store_submission( logger.warning( f"Agent not found for user {user_id}: {agent_id} v{agent_version}" ) - raise store_exceptions.AgentNotFoundError( - f"Agent not found for this user. User ID: {user_id}, Agent ID: {agent_id}, Version: {agent_version}" - ) + # Provide more user-friendly error message when agent_id is empty + if not agent_id or agent_id.strip() == "": + raise store_exceptions.AgentNotFoundError( + "No agent selected. Please select an agent before submitting to the store." + ) + else: + raise store_exceptions.AgentNotFoundError( + f"Agent not found for this user. User ID: {user_id}, Agent ID: {agent_id}, Version: {agent_version}" + ) # Check if listing already exists for this agent existing_listing = await prisma.models.StoreListing.prisma().find_first( @@ -833,6 +853,7 @@ async def create_store_submission( logger.debug(f"Created store listing for agent {agent_id}") # Return submission details return store_model.StoreSubmission( + listing_id=listing.id, agent_id=agent_id, agent_version=agent_version, name=name, @@ -944,81 +965,56 @@ async def edit_store_submission( # Currently we are not allowing user to update the agent associated with a submission # If we allow it in future, then we need a check here to verify the agent belongs to this user. - # Check if we can edit this submission - if current_version.submissionStatus == prisma.enums.SubmissionStatus.REJECTED: + # Only allow editing of PENDING submissions + if current_version.submissionStatus != prisma.enums.SubmissionStatus.PENDING: raise store_exceptions.InvalidOperationError( - "Cannot edit a rejected submission" - ) - - # For APPROVED submissions, we need to create a new version - if current_version.submissionStatus == prisma.enums.SubmissionStatus.APPROVED: - # Create a new version for the existing listing - return await create_store_version( - user_id=user_id, - agent_id=current_version.agentGraphId, - agent_version=current_version.agentGraphVersion, - store_listing_id=current_version.storeListingId, - name=name, - video_url=video_url, - agent_output_demo_url=agent_output_demo_url, - image_urls=image_urls, - description=description, - sub_heading=sub_heading, - categories=categories, - changes_summary=changes_summary, - recommended_schedule_cron=recommended_schedule_cron, - instructions=instructions, + f"Cannot edit a {current_version.submissionStatus.value.lower()} submission. Only pending submissions can be edited." ) # For PENDING submissions, we can update the existing version - elif current_version.submissionStatus == prisma.enums.SubmissionStatus.PENDING: - # Update the existing version - updated_version = await prisma.models.StoreListingVersion.prisma().update( - where={"id": store_listing_version_id}, - data=prisma.types.StoreListingVersionUpdateInput( - name=name, - videoUrl=video_url, - agentOutputDemoUrl=agent_output_demo_url, - imageUrls=image_urls, - description=description, - categories=categories, - subHeading=sub_heading, - changesSummary=changes_summary, - recommendedScheduleCron=recommended_schedule_cron, - instructions=instructions, - ), - ) - - logger.debug( - f"Updated existing version {store_listing_version_id} for agent {current_version.agentGraphId}" - ) - - if not updated_version: - raise DatabaseError("Failed to update store listing version") - return store_model.StoreSubmission( - agent_id=current_version.agentGraphId, - agent_version=current_version.agentGraphVersion, + # Update the existing version + updated_version = await prisma.models.StoreListingVersion.prisma().update( + where={"id": store_listing_version_id}, + data=prisma.types.StoreListingVersionUpdateInput( name=name, - sub_heading=sub_heading, - slug=current_version.StoreListing.slug, + videoUrl=video_url, + agentOutputDemoUrl=agent_output_demo_url, + imageUrls=image_urls, description=description, - instructions=instructions, - image_urls=image_urls, - date_submitted=updated_version.submittedAt or updated_version.createdAt, - status=updated_version.submissionStatus, - runs=0, - rating=0.0, - store_listing_version_id=updated_version.id, - changes_summary=changes_summary, - video_url=video_url, categories=categories, - version=updated_version.version, - ) + subHeading=sub_heading, + changesSummary=changes_summary, + recommendedScheduleCron=recommended_schedule_cron, + instructions=instructions, + ), + ) - else: - raise store_exceptions.InvalidOperationError( - f"Cannot edit submission with status: {current_version.submissionStatus}" - ) + logger.debug( + f"Updated existing version {store_listing_version_id} for agent {current_version.agentGraphId}" + ) + + if not updated_version: + raise DatabaseError("Failed to update store listing version") + return store_model.StoreSubmission( + listing_id=current_version.StoreListing.id, + agent_id=current_version.agentGraphId, + agent_version=current_version.agentGraphVersion, + name=name, + sub_heading=sub_heading, + slug=current_version.StoreListing.slug, + description=description, + instructions=instructions, + image_urls=image_urls, + date_submitted=updated_version.submittedAt or updated_version.createdAt, + status=updated_version.submissionStatus, + runs=0, + rating=0.0, + store_listing_version_id=updated_version.id, + changes_summary=changes_summary, + video_url=video_url, + categories=categories, + version=updated_version.version, + ) except ( store_exceptions.SubmissionNotFoundError, @@ -1097,38 +1093,78 @@ async def create_store_version( f"Agent not found for this user. User ID: {user_id}, Agent ID: {agent_id}, Version: {agent_version}" ) - # Get the latest version number - latest_version = listing.Versions[0] if listing.Versions else None - - next_version = (latest_version.version + 1) if latest_version else 1 - - # Create a new version for the existing listing - new_version = await prisma.models.StoreListingVersion.prisma().create( - data=prisma.types.StoreListingVersionCreateInput( - version=next_version, - agentGraphId=agent_id, - agentGraphVersion=agent_version, - name=name, - videoUrl=video_url, - agentOutputDemoUrl=agent_output_demo_url, - imageUrls=image_urls, - description=description, - instructions=instructions, - categories=categories, - subHeading=sub_heading, - submissionStatus=prisma.enums.SubmissionStatus.PENDING, - submittedAt=datetime.now(), - changesSummary=changes_summary, - recommendedScheduleCron=recommended_schedule_cron, - storeListingId=store_listing_id, + # Check if there's already a PENDING submission for this agent (any version) + existing_pending_submission = ( + await prisma.models.StoreListingVersion.prisma().find_first( + where=prisma.types.StoreListingVersionWhereInput( + storeListingId=store_listing_id, + agentGraphId=agent_id, + submissionStatus=prisma.enums.SubmissionStatus.PENDING, + isDeleted=False, + ) ) ) + # Handle existing pending submission and create new one atomically + async with transaction() as tx: + # Get the latest version number first + latest_listing = await prisma.models.StoreListing.prisma(tx).find_first( + where=prisma.types.StoreListingWhereInput( + id=store_listing_id, owningUserId=user_id + ), + include={"Versions": {"order_by": {"version": "desc"}, "take": 1}}, + ) + + if not latest_listing: + raise store_exceptions.ListingNotFoundError( + f"Store listing not found. User ID: {user_id}, Listing ID: {store_listing_id}" + ) + + latest_version = ( + latest_listing.Versions[0] if latest_listing.Versions else None + ) + next_version = (latest_version.version + 1) if latest_version else 1 + + # If there's an existing pending submission, delete it atomically before creating new one + if existing_pending_submission: + logger.info( + f"Found existing PENDING submission for agent {agent_id} (was v{existing_pending_submission.agentGraphVersion}, now v{agent_version}), replacing existing submission instead of creating duplicate" + ) + await prisma.models.StoreListingVersion.prisma(tx).delete( + where={"id": existing_pending_submission.id} + ) + logger.debug( + f"Deleted existing pending submission {existing_pending_submission.id}" + ) + + # Create a new version for the existing listing + new_version = await prisma.models.StoreListingVersion.prisma(tx).create( + data=prisma.types.StoreListingVersionCreateInput( + version=next_version, + agentGraphId=agent_id, + agentGraphVersion=agent_version, + name=name, + videoUrl=video_url, + agentOutputDemoUrl=agent_output_demo_url, + imageUrls=image_urls, + description=description, + instructions=instructions, + categories=categories, + subHeading=sub_heading, + submissionStatus=prisma.enums.SubmissionStatus.PENDING, + submittedAt=datetime.now(), + changesSummary=changes_summary, + recommendedScheduleCron=recommended_schedule_cron, + storeListingId=store_listing_id, + ) + ) + logger.debug( f"Created new version for listing {store_listing_id} of agent {agent_id}" ) # Return submission details return store_model.StoreSubmission( + listing_id=listing.id, agent_id=agent_id, agent_version=agent_version, name=name, @@ -1708,15 +1744,12 @@ async def review_store_submission( # Convert to Pydantic model for consistency return store_model.StoreSubmission( + listing_id=(submission.StoreListing.id if submission.StoreListing else ""), agent_id=submission.agentGraphId, agent_version=submission.agentGraphVersion, name=submission.name, sub_heading=submission.subHeading, - slug=( - submission.StoreListing.slug - if hasattr(submission, "storeListing") and submission.StoreListing - else "" - ), + slug=(submission.StoreListing.slug if submission.StoreListing else ""), description=submission.description, instructions=submission.instructions, image_urls=submission.imageUrls or [], @@ -1818,9 +1851,7 @@ async def get_admin_listings_with_versions( where = prisma.types.StoreListingWhereInput(**where_dict) include = prisma.types.StoreListingInclude( Versions=prisma.types.FindManyStoreListingVersionArgsFromStoreListing( - order_by=prisma.types._StoreListingVersion_version_OrderByInput( - version="desc" - ) + order_by={"version": "desc"} ), OwningUser=True, ) @@ -1845,6 +1876,7 @@ async def get_admin_listings_with_versions( # If we have versions, turn them into StoreSubmission models for version in listing.Versions or []: version_model = store_model.StoreSubmission( + listing_id=listing.id, agent_id=version.agentGraphId, agent_version=version.agentGraphVersion, name=version.name, diff --git a/autogpt_platform/backend/backend/api/features/store/model.py b/autogpt_platform/backend/backend/api/features/store/model.py index 972898b296..077135217a 100644 --- a/autogpt_platform/backend/backend/api/features/store/model.py +++ b/autogpt_platform/backend/backend/api/features/store/model.py @@ -110,6 +110,7 @@ class Profile(pydantic.BaseModel): class StoreSubmission(pydantic.BaseModel): + listing_id: str agent_id: str agent_version: int name: str @@ -164,8 +165,12 @@ class StoreListingsWithVersionsResponse(pydantic.BaseModel): class StoreSubmissionRequest(pydantic.BaseModel): - agent_id: str - agent_version: int + agent_id: str = pydantic.Field( + ..., min_length=1, description="Agent ID cannot be empty" + ) + agent_version: int = pydantic.Field( + ..., gt=0, description="Agent version must be greater than 0" + ) slug: str name: str sub_heading: str diff --git a/autogpt_platform/backend/backend/api/features/store/model_test.py b/autogpt_platform/backend/backend/api/features/store/model_test.py index a37966601b..fd09a0cf77 100644 --- a/autogpt_platform/backend/backend/api/features/store/model_test.py +++ b/autogpt_platform/backend/backend/api/features/store/model_test.py @@ -138,6 +138,7 @@ def test_creator_details(): def test_store_submission(): submission = store_model.StoreSubmission( + listing_id="listing123", agent_id="agent123", agent_version=1, sub_heading="Test subheading", @@ -159,6 +160,7 @@ def test_store_submissions_response(): response = store_model.StoreSubmissionsResponse( submissions=[ store_model.StoreSubmission( + listing_id="listing123", agent_id="agent123", agent_version=1, sub_heading="Test subheading", diff --git a/autogpt_platform/backend/backend/api/features/store/routes_test.py b/autogpt_platform/backend/backend/api/features/store/routes_test.py index 7fdc0b9ebb..36431c20ec 100644 --- a/autogpt_platform/backend/backend/api/features/store/routes_test.py +++ b/autogpt_platform/backend/backend/api/features/store/routes_test.py @@ -521,6 +521,7 @@ def test_get_submissions_success( mocked_value = store_model.StoreSubmissionsResponse( submissions=[ store_model.StoreSubmission( + listing_id="test-listing-id", name="Test Agent", description="Test agent description", image_urls=["test.jpg"], diff --git a/autogpt_platform/backend/backend/blocks/airtable/_webhook.py b/autogpt_platform/backend/backend/blocks/airtable/_webhook.py index 58e6f95d0c..452630953e 100644 --- a/autogpt_platform/backend/backend/blocks/airtable/_webhook.py +++ b/autogpt_platform/backend/backend/blocks/airtable/_webhook.py @@ -6,6 +6,9 @@ import hashlib import hmac import logging from enum import Enum +from typing import cast + +from prisma.types import Serializable from backend.sdk import ( BaseWebhooksManager, @@ -84,7 +87,9 @@ class AirtableWebhookManager(BaseWebhooksManager): # update webhook config await update_webhook( webhook.id, - config={"base_id": base_id, "cursor": response.cursor}, + config=cast( + dict[str, Serializable], {"base_id": base_id, "cursor": response.cursor} + ), ) event_type = "notification" diff --git a/autogpt_platform/backend/backend/blocks/helpers/review.py b/autogpt_platform/backend/backend/blocks/helpers/review.py new file mode 100644 index 0000000000..f35397e6aa --- /dev/null +++ b/autogpt_platform/backend/backend/blocks/helpers/review.py @@ -0,0 +1,184 @@ +""" +Shared helpers for Human-In-The-Loop (HITL) review functionality. +Used by both the dedicated HumanInTheLoopBlock and blocks that require human review. +""" + +import logging +from typing import Any, Optional + +from prisma.enums import ReviewStatus +from pydantic import BaseModel + +from backend.data.execution import ExecutionContext, ExecutionStatus +from backend.data.human_review import ReviewResult +from backend.executor.manager import async_update_node_execution_status +from backend.util.clients import get_database_manager_async_client + +logger = logging.getLogger(__name__) + + +class ReviewDecision(BaseModel): + """Result of a review decision.""" + + should_proceed: bool + message: str + review_result: ReviewResult + + +class HITLReviewHelper: + """Helper class for Human-In-The-Loop review operations.""" + + @staticmethod + async def get_or_create_human_review(**kwargs) -> Optional[ReviewResult]: + """Create or retrieve a human review from the database.""" + return await get_database_manager_async_client().get_or_create_human_review( + **kwargs + ) + + @staticmethod + async def update_node_execution_status(**kwargs) -> None: + """Update the execution status of a node.""" + await async_update_node_execution_status( + db_client=get_database_manager_async_client(), **kwargs + ) + + @staticmethod + async def update_review_processed_status( + node_exec_id: str, processed: bool + ) -> None: + """Update the processed status of a review.""" + return await get_database_manager_async_client().update_review_processed_status( + node_exec_id, processed + ) + + @staticmethod + async def _handle_review_request( + input_data: Any, + user_id: str, + node_exec_id: str, + graph_exec_id: str, + graph_id: str, + graph_version: int, + execution_context: ExecutionContext, + block_name: str = "Block", + editable: bool = False, + ) -> Optional[ReviewResult]: + """ + Handle a review request for a block that requires human review. + + Args: + input_data: The input data to be reviewed + user_id: ID of the user requesting the review + node_exec_id: ID of the node execution + graph_exec_id: ID of the graph execution + graph_id: ID of the graph + graph_version: Version of the graph + execution_context: Current execution context + block_name: Name of the block requesting review + editable: Whether the reviewer can edit the data + + Returns: + ReviewResult if review is complete, None if waiting for human input + + Raises: + Exception: If review creation or status update fails + """ + # Skip review if safe mode is disabled - return auto-approved result + if not execution_context.safe_mode: + logger.info( + f"Block {block_name} skipping review for node {node_exec_id} - safe mode disabled" + ) + return ReviewResult( + data=input_data, + status=ReviewStatus.APPROVED, + message="Auto-approved (safe mode disabled)", + processed=True, + node_exec_id=node_exec_id, + ) + + result = await HITLReviewHelper.get_or_create_human_review( + user_id=user_id, + node_exec_id=node_exec_id, + graph_exec_id=graph_exec_id, + graph_id=graph_id, + graph_version=graph_version, + input_data=input_data, + message=f"Review required for {block_name} execution", + editable=editable, + ) + + if result is None: + logger.info( + f"Block {block_name} pausing execution for node {node_exec_id} - awaiting human review" + ) + await HITLReviewHelper.update_node_execution_status( + exec_id=node_exec_id, + status=ExecutionStatus.REVIEW, + ) + return None # Signal that execution should pause + + # Mark review as processed if not already done + if not result.processed: + await HITLReviewHelper.update_review_processed_status( + node_exec_id=node_exec_id, processed=True + ) + + return result + + @staticmethod + async def handle_review_decision( + input_data: Any, + user_id: str, + node_exec_id: str, + graph_exec_id: str, + graph_id: str, + graph_version: int, + execution_context: ExecutionContext, + block_name: str = "Block", + editable: bool = False, + ) -> Optional[ReviewDecision]: + """ + Handle a review request and return the decision in a single call. + + Args: + input_data: The input data to be reviewed + user_id: ID of the user requesting the review + node_exec_id: ID of the node execution + graph_exec_id: ID of the graph execution + graph_id: ID of the graph + graph_version: Version of the graph + execution_context: Current execution context + block_name: Name of the block requesting review + editable: Whether the reviewer can edit the data + + Returns: + ReviewDecision if review is complete (approved/rejected), + None if execution should pause (awaiting review) + """ + review_result = await HITLReviewHelper._handle_review_request( + input_data=input_data, + user_id=user_id, + node_exec_id=node_exec_id, + graph_exec_id=graph_exec_id, + graph_id=graph_id, + graph_version=graph_version, + execution_context=execution_context, + block_name=block_name, + editable=editable, + ) + + if review_result is None: + # Still awaiting review - return None to pause execution + return None + + # Review is complete, determine outcome + should_proceed = review_result.status == ReviewStatus.APPROVED + message = review_result.message or ( + "Execution approved by reviewer" + if should_proceed + else "Execution rejected by reviewer" + ) + + return ReviewDecision( + should_proceed=should_proceed, message=message, review_result=review_result + ) diff --git a/autogpt_platform/backend/backend/blocks/human_in_the_loop.py b/autogpt_platform/backend/backend/blocks/human_in_the_loop.py index 13c9fb31db..1e338816c8 100644 --- a/autogpt_platform/backend/backend/blocks/human_in_the_loop.py +++ b/autogpt_platform/backend/backend/blocks/human_in_the_loop.py @@ -3,6 +3,7 @@ from typing import Any from prisma.enums import ReviewStatus +from backend.blocks.helpers.review import HITLReviewHelper from backend.data.block import ( Block, BlockCategory, @@ -11,11 +12,9 @@ from backend.data.block import ( BlockSchemaOutput, BlockType, ) -from backend.data.execution import ExecutionContext, ExecutionStatus +from backend.data.execution import ExecutionContext from backend.data.human_review import ReviewResult from backend.data.model import SchemaField -from backend.executor.manager import async_update_node_execution_status -from backend.util.clients import get_database_manager_async_client logger = logging.getLogger(__name__) @@ -72,32 +71,26 @@ class HumanInTheLoopBlock(Block): ("approved_data", {"name": "John Doe", "age": 30}), ], test_mock={ - "get_or_create_human_review": lambda *_args, **_kwargs: ReviewResult( - data={"name": "John Doe", "age": 30}, - status=ReviewStatus.APPROVED, - message="", - processed=False, - node_exec_id="test-node-exec-id", - ), - "update_node_execution_status": lambda *_args, **_kwargs: None, - "update_review_processed_status": lambda *_args, **_kwargs: None, + "handle_review_decision": lambda **kwargs: type( + "ReviewDecision", + (), + { + "should_proceed": True, + "message": "Test approval message", + "review_result": ReviewResult( + data={"name": "John Doe", "age": 30}, + status=ReviewStatus.APPROVED, + message="", + processed=False, + node_exec_id="test-node-exec-id", + ), + }, + )(), }, ) - async def get_or_create_human_review(self, **kwargs): - return await get_database_manager_async_client().get_or_create_human_review( - **kwargs - ) - - async def update_node_execution_status(self, **kwargs): - return await async_update_node_execution_status( - db_client=get_database_manager_async_client(), **kwargs - ) - - async def update_review_processed_status(self, node_exec_id: str, processed: bool): - return await get_database_manager_async_client().update_review_processed_status( - node_exec_id, processed - ) + async def handle_review_decision(self, **kwargs): + return await HITLReviewHelper.handle_review_decision(**kwargs) async def run( self, @@ -109,7 +102,7 @@ class HumanInTheLoopBlock(Block): graph_id: str, graph_version: int, execution_context: ExecutionContext, - **kwargs, + **_kwargs, ) -> BlockOutput: if not execution_context.safe_mode: logger.info( @@ -119,48 +112,28 @@ class HumanInTheLoopBlock(Block): yield "review_message", "Auto-approved (safe mode disabled)" return - try: - result = await self.get_or_create_human_review( - user_id=user_id, - node_exec_id=node_exec_id, - graph_exec_id=graph_exec_id, - graph_id=graph_id, - graph_version=graph_version, - input_data=input_data.data, - message=input_data.name, - editable=input_data.editable, - ) - except Exception as e: - logger.error(f"Error in HITL block for node {node_exec_id}: {str(e)}") - raise + decision = await self.handle_review_decision( + input_data=input_data.data, + user_id=user_id, + node_exec_id=node_exec_id, + graph_exec_id=graph_exec_id, + graph_id=graph_id, + graph_version=graph_version, + execution_context=execution_context, + block_name=self.name, + editable=input_data.editable, + ) - if result is None: - logger.info( - f"HITL block pausing execution for node {node_exec_id} - awaiting human review" - ) - try: - await self.update_node_execution_status( - exec_id=node_exec_id, - status=ExecutionStatus.REVIEW, - ) - return - except Exception as e: - logger.error( - f"Failed to update node status for HITL block {node_exec_id}: {str(e)}" - ) - raise + if decision is None: + return - if not result.processed: - await self.update_review_processed_status( - node_exec_id=node_exec_id, processed=True - ) + status = decision.review_result.status + if status == ReviewStatus.APPROVED: + yield "approved_data", decision.review_result.data + elif status == ReviewStatus.REJECTED: + yield "rejected_data", decision.review_result.data + else: + raise RuntimeError(f"Unexpected review status: {status}") - if result.status == ReviewStatus.APPROVED: - yield "approved_data", result.data - if result.message: - yield "review_message", result.message - - elif result.status == ReviewStatus.REJECTED: - yield "rejected_data", result.data - if result.message: - yield "review_message", result.message + if decision.message: + yield "review_message", decision.message diff --git a/autogpt_platform/backend/backend/blocks/reddit.py b/autogpt_platform/backend/backend/blocks/reddit.py index 231e7affef..1109d568db 100644 --- a/autogpt_platform/backend/backend/blocks/reddit.py +++ b/autogpt_platform/backend/backend/blocks/reddit.py @@ -3,6 +3,7 @@ from datetime import datetime, timezone from typing import Iterator, Literal import praw +from praw.models import Comment, MoreComments, Submission from pydantic import BaseModel, SecretStr from backend.data.block import ( @@ -15,33 +16,51 @@ from backend.data.block import ( from backend.data.model import ( CredentialsField, CredentialsMetaInput, + OAuth2Credentials, SchemaField, - UserPasswordCredentials, ) from backend.integrations.providers import ProviderName from backend.util.mock import MockObject from backend.util.settings import Settings -RedditCredentials = UserPasswordCredentials +# Type aliases for Reddit API options +UserPostSort = Literal["new", "hot", "top", "controversial"] +SearchSort = Literal["relevance", "hot", "top", "new", "comments"] +TimeFilter = Literal["all", "day", "hour", "month", "week", "year"] +CommentSort = Literal["best", "top", "new", "controversial", "old", "qa"] +InboxType = Literal["all", "unread", "messages", "mentions", "comment_replies"] + +RedditCredentials = OAuth2Credentials RedditCredentialsInput = CredentialsMetaInput[ Literal[ProviderName.REDDIT], - Literal["user_password"], + Literal["oauth2"], ] def RedditCredentialsField() -> RedditCredentialsInput: """Creates a Reddit credentials input on a block.""" return CredentialsField( - description="The Reddit integration requires a username and password.", + description="Connect your Reddit account to access Reddit features.", ) -TEST_CREDENTIALS = UserPasswordCredentials( +TEST_CREDENTIALS = OAuth2Credentials( id="01234567-89ab-cdef-0123-456789abcdef", provider="reddit", - username=SecretStr("mock-reddit-username"), - password=SecretStr("mock-reddit-password"), + access_token=SecretStr("mock-reddit-access-token"), + refresh_token=SecretStr("mock-reddit-refresh-token"), + access_token_expires_at=9999999999, + scopes=[ + "identity", + "read", + "submit", + "edit", + "history", + "privatemessages", + "flair", + ], title="Mock Reddit credentials", + username="mock-reddit-username", ) TEST_CREDENTIALS_INPUT = { @@ -53,27 +72,29 @@ TEST_CREDENTIALS_INPUT = { class RedditPost(BaseModel): - id: str + post_id: str subreddit: str title: str body: str -class RedditComment(BaseModel): - post_id: str - comment: str - - settings = Settings() logger = logging.getLogger(__name__) def get_praw(creds: RedditCredentials) -> praw.Reddit: + """ + Create a PRAW Reddit client using OAuth2 credentials. + + Uses the refresh_token for authentication, which allows the client + to automatically refresh the access token when needed. + """ client = praw.Reddit( client_id=settings.secrets.reddit_client_id, client_secret=settings.secrets.reddit_client_secret, - username=creds.username.get_secret_value(), - password=creds.password.get_secret_value(), + refresh_token=( + creds.refresh_token.get_secret_value() if creds.refresh_token else None + ), user_agent=settings.config.reddit_user_agent, ) me = client.user.me() @@ -83,11 +104,36 @@ def get_praw(creds: RedditCredentials) -> praw.Reddit: return client +def strip_reddit_prefix(id_str: str) -> str: + """ + Strip Reddit type prefix (t1_, t3_, etc.) from an ID if present. + + Reddit uses type prefixes like t1_ for comments, t3_ for posts, etc. + This helper normalizes IDs by removing these prefixes when present, + allowing blocks to accept both 'abc123' and 't3_abc123' formats. + + Args: + id_str: The ID string that may have a Reddit type prefix. + + Returns: + The ID without the type prefix. + """ + if ( + len(id_str) > 3 + and id_str[0] == "t" + and id_str[1].isdigit() + and id_str[2] == "_" + ): + return id_str[3:] + return id_str + + class GetRedditPostsBlock(Block): class Input(BlockSchemaInput): subreddit: str = SchemaField( description="Subreddit name, excluding the /r/ prefix", default="writingprompts", + advanced=False, ) credentials: RedditCredentialsInput = RedditCredentialsField() last_minutes: int | None = SchemaField( @@ -128,26 +174,32 @@ class GetRedditPostsBlock(Block): ( "post", RedditPost( - id="id1", subreddit="subreddit", title="title1", body="body1" + post_id="id1", + subreddit="subreddit", + title="title1", + body="body1", ), ), ( "post", RedditPost( - id="id2", subreddit="subreddit", title="title2", body="body2" + post_id="id2", + subreddit="subreddit", + title="title2", + body="body2", ), ), ( "posts", [ RedditPost( - id="id1", + post_id="id1", subreddit="subreddit", title="title1", body="body1", ), RedditPost( - id="id2", + post_id="id2", subreddit="subreddit", title="title2", body="body2", @@ -184,13 +236,14 @@ class GetRedditPostsBlock(Block): ) time_difference = current_time - post_datetime if time_difference.total_seconds() / 60 > input_data.last_minutes: - continue + # Posts are ordered newest-first, so all subsequent posts will also be older + break if input_data.last_post and post.id == input_data.last_post: break reddit_post = RedditPost( - id=post.id, + post_id=post.id, subreddit=input_data.subreddit, title=post.title, body=post.selftext, @@ -204,10 +257,18 @@ class GetRedditPostsBlock(Block): class PostRedditCommentBlock(Block): class Input(BlockSchemaInput): credentials: RedditCredentialsInput = RedditCredentialsField() - data: RedditComment = SchemaField(description="Reddit comment") + post_id: str = SchemaField( + description="The ID of the post to comment on", + ) + comment: str = SchemaField( + description="The content of the comment to post", + ) class Output(BlockSchemaOutput): comment_id: str = SchemaField(description="Posted comment ID") + post_id: str = SchemaField( + description="The post ID (pass-through for chaining)" + ) def __init__(self): super().__init__( @@ -223,17 +284,24 @@ class PostRedditCommentBlock(Block): test_credentials=TEST_CREDENTIALS, test_input={ "credentials": TEST_CREDENTIALS_INPUT, - "data": {"post_id": "id", "comment": "comment"}, + "post_id": "test_post_id", + "comment": "comment", + }, + test_output=[ + ("comment_id", "dummy_comment_id"), + ("post_id", "test_post_id"), + ], + test_mock={ + "reply_post": lambda creds, post_id, comment: "dummy_comment_id" }, - test_output=[("comment_id", "dummy_comment_id")], - test_mock={"reply_post": lambda creds, comment: "dummy_comment_id"}, ) @staticmethod - def reply_post(creds: RedditCredentials, comment: RedditComment) -> str: + def reply_post(creds: RedditCredentials, post_id: str, comment: str) -> str: client = get_praw(creds) - submission = client.submission(id=comment.post_id) - new_comment = submission.reply(comment.comment) + post_id = strip_reddit_prefix(post_id) + submission = client.submission(id=post_id) + new_comment = submission.reply(comment) if not new_comment: raise ValueError("Failed to post comment.") return new_comment.id @@ -241,4 +309,2230 @@ class PostRedditCommentBlock(Block): async def run( self, input_data: Input, *, credentials: RedditCredentials, **kwargs ) -> BlockOutput: - yield "comment_id", self.reply_post(credentials, input_data.data) + yield "comment_id", self.reply_post( + credentials, + post_id=input_data.post_id, + comment=input_data.comment, + ) + yield "post_id", input_data.post_id + + +class CreateRedditPostBlock(Block): + class Input(BlockSchemaInput): + credentials: RedditCredentialsInput = RedditCredentialsField() + subreddit: str = SchemaField( + description="Subreddit to post to, excluding the /r/ prefix", + ) + title: str = SchemaField( + description="Title of the post", + ) + content: str = SchemaField( + description="Body text of the post (for text posts)", + default="", + ) + url: str | None = SchemaField( + description="URL to submit (for link posts). If provided, content is ignored.", + default=None, + ) + flair_id: str | None = SchemaField( + description="Flair template ID to apply to the post (from GetSubredditFlairsBlock)", + default=None, + ) + flair_text: str | None = SchemaField( + description="Custom flair text (only used if the flair template allows editing)", + default=None, + ) + + class Output(BlockSchemaOutput): + post_id: str = SchemaField(description="ID of the created post") + post_url: str = SchemaField(description="URL of the created post") + subreddit: str = SchemaField( + description="The subreddit name (pass-through for chaining)" + ) + + def __init__(self): + super().__init__( + id="f3a2b1c0-8d7e-4f6a-9b5c-1234567890ab", + description="Create a new post on a subreddit. Can create text posts or link posts.", + categories={BlockCategory.SOCIAL}, + input_schema=CreateRedditPostBlock.Input, + output_schema=CreateRedditPostBlock.Output, + disabled=( + not settings.secrets.reddit_client_id + or not settings.secrets.reddit_client_secret + ), + test_credentials=TEST_CREDENTIALS, + test_input={ + "credentials": TEST_CREDENTIALS_INPUT, + "subreddit": "test", + "title": "Test Post", + "content": "This is a test post body.", + }, + test_output=[ + ("post_id", "abc123"), + ("post_url", "https://reddit.com/r/test/comments/abc123/test_post/"), + ("subreddit", "test"), + ], + test_mock={ + "create_post": lambda creds, subreddit, title, content, url, flair_id, flair_text: ( + "abc123", + "https://reddit.com/r/test/comments/abc123/test_post/", + ) + }, + ) + + @staticmethod + def create_post( + creds: RedditCredentials, + subreddit: str, + title: str, + content: str = "", + url: str | None = None, + flair_id: str | None = None, + flair_text: str | None = None, + ) -> tuple[str, str]: + """ + Create a new post on a subreddit. + + Args: + creds: Reddit OAuth2 credentials + subreddit: Subreddit name (without /r/ prefix) + title: Post title + content: Post body text (for text posts) + url: URL to submit (for link posts, overrides content) + flair_id: Optional flair template ID to apply + flair_text: Optional custom flair text (for editable flairs) + + Returns: + Tuple of (post_id, post_url) + """ + client = get_praw(creds) + sub = client.subreddit(subreddit) + + if url: + submission = sub.submit( + title=title, url=url, flair_id=flair_id, flair_text=flair_text + ) + else: + submission = sub.submit( + title=title, selftext=content, flair_id=flair_id, flair_text=flair_text + ) + + return submission.id, f"https://reddit.com{submission.permalink}" + + async def run( + self, input_data: Input, *, credentials: RedditCredentials, **kwargs + ) -> BlockOutput: + post_id, post_url = self.create_post( + credentials, + input_data.subreddit, + input_data.title, + input_data.content, + input_data.url, + input_data.flair_id, + input_data.flair_text, + ) + yield "post_id", post_id + yield "post_url", post_url + yield "subreddit", input_data.subreddit + + +class RedditPostDetails(BaseModel): + """Detailed information about a Reddit post.""" + + id: str + subreddit: str + title: str + body: str + author: str + score: int + upvote_ratio: float + num_comments: int + created_utc: float + url: str + permalink: str + is_self: bool + over_18: bool + + +class GetRedditPostBlock(Block): + """Get detailed information about a specific Reddit post.""" + + class Input(BlockSchemaInput): + credentials: RedditCredentialsInput = RedditCredentialsField() + post_id: str = SchemaField( + description="The ID of the post to fetch (e.g., 'abc123' or full ID 't3_abc123')", + ) + + class Output(BlockSchemaOutput): + post: RedditPostDetails = SchemaField(description="Detailed post information") + error: str = SchemaField( + description="Error message if the post couldn't be fetched" + ) + + def __init__(self): + super().__init__( + id="36e6a259-168c-4032-83ec-b2935d0e4584", + description="Get detailed information about a specific Reddit post by its ID.", + categories={BlockCategory.SOCIAL}, + input_schema=GetRedditPostBlock.Input, + output_schema=GetRedditPostBlock.Output, + disabled=( + not settings.secrets.reddit_client_id + or not settings.secrets.reddit_client_secret + ), + test_credentials=TEST_CREDENTIALS, + test_input={ + "credentials": TEST_CREDENTIALS_INPUT, + "post_id": "abc123", + }, + test_output=[ + ( + "post", + RedditPostDetails( + id="abc123", + subreddit="test", + title="Test Post", + body="Test body", + author="testuser", + score=100, + upvote_ratio=0.95, + num_comments=10, + created_utc=1234567890.0, + url="https://reddit.com/r/test/comments/abc123/test_post/", + permalink="/r/test/comments/abc123/test_post/", + is_self=True, + over_18=False, + ), + ), + ], + test_mock={ + "get_post": lambda creds, post_id: RedditPostDetails( + id="abc123", + subreddit="test", + title="Test Post", + body="Test body", + author="testuser", + score=100, + upvote_ratio=0.95, + num_comments=10, + created_utc=1234567890.0, + url="https://reddit.com/r/test/comments/abc123/test_post/", + permalink="/r/test/comments/abc123/test_post/", + is_self=True, + over_18=False, + ) + }, + ) + + @staticmethod + def get_post(creds: RedditCredentials, post_id: str) -> RedditPostDetails: + client = get_praw(creds) + post_id = strip_reddit_prefix(post_id) + submission = client.submission(id=post_id) + + return RedditPostDetails( + id=submission.id, + subreddit=submission.subreddit.display_name, + title=submission.title, + body=submission.selftext, + author=str(submission.author) if submission.author else "[deleted]", + score=submission.score, + upvote_ratio=submission.upvote_ratio, + num_comments=submission.num_comments, + created_utc=submission.created_utc, + url=submission.url, + permalink=submission.permalink, + is_self=submission.is_self, + over_18=submission.over_18, + ) + + async def run( + self, input_data: Input, *, credentials: RedditCredentials, **kwargs + ) -> BlockOutput: + try: + post = self.get_post(credentials, input_data.post_id) + yield "post", post + except Exception as e: + yield "error", str(e) + + +class GetUserPostsBlock(Block): + """Get posts by a specific Reddit user.""" + + class Input(BlockSchemaInput): + credentials: RedditCredentialsInput = RedditCredentialsField() + username: str = SchemaField( + description="Reddit username to fetch posts from (without /u/ prefix)", + ) + post_limit: int = SchemaField( + description="Maximum number of posts to fetch", + default=10, + ) + sort: UserPostSort = SchemaField( + description="Sort order for user posts", + default="new", + ) + + class Output(BlockSchemaOutput): + post: RedditPost = SchemaField(description="A post by the user") + posts: list[RedditPost] = SchemaField(description="All posts by the user") + error: str = SchemaField( + description="Error message if posts couldn't be fetched" + ) + + def __init__(self): + super().__init__( + id="6fbe6329-d13e-4d2e-bd4d-b4d921b56161", + description="Fetch posts by a specific Reddit user.", + categories={BlockCategory.SOCIAL}, + input_schema=GetUserPostsBlock.Input, + output_schema=GetUserPostsBlock.Output, + disabled=( + not settings.secrets.reddit_client_id + or not settings.secrets.reddit_client_secret + ), + test_credentials=TEST_CREDENTIALS, + test_input={ + "credentials": TEST_CREDENTIALS_INPUT, + "username": "testuser", + "post_limit": 2, + }, + test_output=[ + ( + "post", + RedditPost( + post_id="id1", subreddit="sub1", title="title1", body="body1" + ), + ), + ( + "post", + RedditPost( + post_id="id2", subreddit="sub2", title="title2", body="body2" + ), + ), + ( + "posts", + [ + RedditPost( + post_id="id1", + subreddit="sub1", + title="title1", + body="body1", + ), + RedditPost( + post_id="id2", + subreddit="sub2", + title="title2", + body="body2", + ), + ], + ), + ], + test_mock={ + "get_user_posts": lambda creds, username, limit, sort: [ + MockObject( + id="id1", + subreddit=MockObject(display_name="sub1"), + title="title1", + selftext="body1", + ), + MockObject( + id="id2", + subreddit=MockObject(display_name="sub2"), + title="title2", + selftext="body2", + ), + ] + }, + ) + + @staticmethod + def get_user_posts( + creds: RedditCredentials, username: str, limit: int, sort: UserPostSort + ) -> list[Submission]: + client = get_praw(creds) + redditor = client.redditor(username) + + if sort == "new": + submissions = redditor.submissions.new(limit=limit) + elif sort == "hot": + submissions = redditor.submissions.hot(limit=limit) + elif sort == "top": + submissions = redditor.submissions.top(limit=limit) + elif sort == "controversial": + submissions = redditor.submissions.controversial(limit=limit) + else: + submissions = redditor.submissions.new(limit=limit) + + return list(submissions) + + async def run( + self, input_data: Input, *, credentials: RedditCredentials, **kwargs + ) -> BlockOutput: + try: + submissions = self.get_user_posts( + credentials, + input_data.username, + input_data.post_limit, + input_data.sort, + ) + all_posts = [] + for submission in submissions: + post = RedditPost( + post_id=submission.id, + subreddit=submission.subreddit.display_name, + title=submission.title, + body=submission.selftext, + ) + all_posts.append(post) + yield "post", post + yield "posts", all_posts + except Exception as e: + yield "error", str(e) + + +class RedditGetMyPostsBlock(Block): + """Get posts by the authenticated Reddit user.""" + + class Input(BlockSchemaInput): + credentials: RedditCredentialsInput = RedditCredentialsField() + post_limit: int = SchemaField( + description="Maximum number of posts to fetch", + default=10, + ) + sort: UserPostSort = SchemaField( + description="Sort order for posts", + default="new", + ) + + class Output(BlockSchemaOutput): + post: RedditPost = SchemaField(description="A post by you") + posts: list[RedditPost] = SchemaField(description="All your posts") + error: str = SchemaField( + description="Error message if posts couldn't be fetched" + ) + + def __init__(self): + super().__init__( + id="4ab3381b-0c07-4201-89b3-fa2ec264f154", + description="Fetch posts created by the authenticated Reddit user (you).", + categories={BlockCategory.SOCIAL}, + input_schema=RedditGetMyPostsBlock.Input, + output_schema=RedditGetMyPostsBlock.Output, + disabled=( + not settings.secrets.reddit_client_id + or not settings.secrets.reddit_client_secret + ), + test_credentials=TEST_CREDENTIALS, + test_input={ + "credentials": TEST_CREDENTIALS_INPUT, + "post_limit": 2, + }, + test_output=[ + ( + "post", + RedditPost( + post_id="id1", subreddit="sub1", title="title1", body="body1" + ), + ), + ( + "post", + RedditPost( + post_id="id2", subreddit="sub2", title="title2", body="body2" + ), + ), + ( + "posts", + [ + RedditPost( + post_id="id1", + subreddit="sub1", + title="title1", + body="body1", + ), + RedditPost( + post_id="id2", + subreddit="sub2", + title="title2", + body="body2", + ), + ], + ), + ], + test_mock={ + "get_my_posts": lambda creds, limit, sort: [ + MockObject( + id="id1", + subreddit=MockObject(display_name="sub1"), + title="title1", + selftext="body1", + ), + MockObject( + id="id2", + subreddit=MockObject(display_name="sub2"), + title="title2", + selftext="body2", + ), + ] + }, + ) + + @staticmethod + def get_my_posts( + creds: RedditCredentials, limit: int, sort: UserPostSort + ) -> list[Submission]: + client = get_praw(creds) + me = client.user.me() + if not me: + raise ValueError("Could not get authenticated user.") + + if sort == "new": + submissions = me.submissions.new(limit=limit) + elif sort == "hot": + submissions = me.submissions.hot(limit=limit) + elif sort == "top": + submissions = me.submissions.top(limit=limit) + elif sort == "controversial": + submissions = me.submissions.controversial(limit=limit) + else: + submissions = me.submissions.new(limit=limit) + + return list(submissions) + + async def run( + self, input_data: Input, *, credentials: RedditCredentials, **kwargs + ) -> BlockOutput: + try: + submissions = self.get_my_posts( + credentials, + input_data.post_limit, + input_data.sort, + ) + all_posts = [] + for submission in submissions: + post = RedditPost( + post_id=submission.id, + subreddit=submission.subreddit.display_name, + title=submission.title, + body=submission.selftext, + ) + all_posts.append(post) + yield "post", post + yield "posts", all_posts + except Exception as e: + yield "error", str(e) + + +class RedditSearchResult(BaseModel): + """A search result from Reddit.""" + + id: str + subreddit: str + title: str + body: str + author: str + score: int + num_comments: int + created_utc: float + permalink: str + + +class SearchRedditBlock(Block): + """Search Reddit for posts matching a query.""" + + class Input(BlockSchemaInput): + credentials: RedditCredentialsInput = RedditCredentialsField() + query: str = SchemaField( + description="Search query string", + ) + subreddit: str | None = SchemaField( + description="Limit search to a specific subreddit (without /r/ prefix)", + default=None, + ) + sort: SearchSort = SchemaField( + description="Sort order for search results", + default="relevance", + ) + time_filter: TimeFilter = SchemaField( + description="Time filter for search results", + default="all", + ) + limit: int = SchemaField( + description="Maximum number of results to return", + default=10, + ) + + class Output(BlockSchemaOutput): + result: RedditSearchResult = SchemaField(description="A search result") + results: list[RedditSearchResult] = SchemaField( + description="All search results" + ) + error: str = SchemaField(description="Error message if search failed") + + def __init__(self): + super().__init__( + id="4a0c975e-807b-4d5e-83c9-1619864a4b1a", + description="Search Reddit for posts matching a query. Can search all of Reddit or a specific subreddit.", + categories={BlockCategory.SOCIAL}, + input_schema=SearchRedditBlock.Input, + output_schema=SearchRedditBlock.Output, + disabled=( + not settings.secrets.reddit_client_id + or not settings.secrets.reddit_client_secret + ), + test_credentials=TEST_CREDENTIALS, + test_input={ + "credentials": TEST_CREDENTIALS_INPUT, + "query": "test query", + "limit": 2, + }, + test_output=[ + ( + "result", + RedditSearchResult( + id="id1", + subreddit="sub1", + title="title1", + body="body1", + author="author1", + score=100, + num_comments=10, + created_utc=1234567890.0, + permalink="/r/sub1/comments/id1/title1/", + ), + ), + ( + "result", + RedditSearchResult( + id="id2", + subreddit="sub2", + title="title2", + body="body2", + author="author2", + score=50, + num_comments=5, + created_utc=1234567891.0, + permalink="/r/sub2/comments/id2/title2/", + ), + ), + ( + "results", + [ + RedditSearchResult( + id="id1", + subreddit="sub1", + title="title1", + body="body1", + author="author1", + score=100, + num_comments=10, + created_utc=1234567890.0, + permalink="/r/sub1/comments/id1/title1/", + ), + RedditSearchResult( + id="id2", + subreddit="sub2", + title="title2", + body="body2", + author="author2", + score=50, + num_comments=5, + created_utc=1234567891.0, + permalink="/r/sub2/comments/id2/title2/", + ), + ], + ), + ], + test_mock={ + "search_reddit": lambda creds, query, subreddit, sort, time_filter, limit: [ + MockObject( + id="id1", + subreddit=MockObject(display_name="sub1"), + title="title1", + selftext="body1", + author="author1", + score=100, + num_comments=10, + created_utc=1234567890.0, + permalink="/r/sub1/comments/id1/title1/", + ), + MockObject( + id="id2", + subreddit=MockObject(display_name="sub2"), + title="title2", + selftext="body2", + author="author2", + score=50, + num_comments=5, + created_utc=1234567891.0, + permalink="/r/sub2/comments/id2/title2/", + ), + ] + }, + ) + + @staticmethod + def search_reddit( + creds: RedditCredentials, + query: str, + subreddit: str | None, + sort: SearchSort, + time_filter: TimeFilter, + limit: int, + ) -> list[Submission]: + client = get_praw(creds) + + if subreddit: + sub = client.subreddit(subreddit) + results = sub.search(query, sort=sort, time_filter=time_filter, limit=limit) + else: + results = client.subreddit("all").search( + query, sort=sort, time_filter=time_filter, limit=limit + ) + + return list(results) + + async def run( + self, input_data: Input, *, credentials: RedditCredentials, **kwargs + ) -> BlockOutput: + try: + submissions = self.search_reddit( + credentials, + input_data.query, + input_data.subreddit, + input_data.sort, + input_data.time_filter, + input_data.limit, + ) + all_results = [] + for submission in submissions: + result = RedditSearchResult( + id=submission.id, + subreddit=submission.subreddit.display_name, + title=submission.title, + body=submission.selftext, + author=str(submission.author) if submission.author else "[deleted]", + score=submission.score, + num_comments=submission.num_comments, + created_utc=submission.created_utc, + permalink=submission.permalink, + ) + all_results.append(result) + yield "result", result + yield "results", all_results + except Exception as e: + yield "error", str(e) + + +class EditRedditPostBlock(Block): + """Edit an existing Reddit post that you own.""" + + class Input(BlockSchemaInput): + credentials: RedditCredentialsInput = RedditCredentialsField() + post_id: str = SchemaField( + description="The ID of the post to edit (must be your own post)", + ) + new_content: str = SchemaField( + description="The new body text for the post", + ) + + class Output(BlockSchemaOutput): + success: bool = SchemaField(description="Whether the edit was successful") + post_id: str = SchemaField( + description="The post ID (pass-through for chaining)" + ) + post_url: str = SchemaField(description="URL of the edited post") + error: str = SchemaField(description="Error message if the edit failed") + + def __init__(self): + super().__init__( + id="cdb9df0f-8b1d-433e-873a-ededc1b6479d", + description="Edit the body text of an existing Reddit post that you own. Only works for self/text posts.", + categories={BlockCategory.SOCIAL}, + input_schema=EditRedditPostBlock.Input, + output_schema=EditRedditPostBlock.Output, + disabled=( + not settings.secrets.reddit_client_id + or not settings.secrets.reddit_client_secret + ), + test_credentials=TEST_CREDENTIALS, + test_input={ + "credentials": TEST_CREDENTIALS_INPUT, + "post_id": "abc123", + "new_content": "Updated post content", + }, + test_output=[ + ("success", True), + ("post_id", "abc123"), + ("post_url", "https://reddit.com/r/test/comments/abc123/test_post/"), + ], + test_mock={ + "edit_post": lambda creds, post_id, new_content: ( + True, + "https://reddit.com/r/test/comments/abc123/test_post/", + ) + }, + ) + + @staticmethod + def edit_post( + creds: RedditCredentials, post_id: str, new_content: str + ) -> tuple[bool, str]: + client = get_praw(creds) + post_id = strip_reddit_prefix(post_id) + submission = client.submission(id=post_id) + submission.edit(new_content) + return True, f"https://reddit.com{submission.permalink}" + + async def run( + self, input_data: Input, *, credentials: RedditCredentials, **kwargs + ) -> BlockOutput: + try: + success, post_url = self.edit_post( + credentials, input_data.post_id, input_data.new_content + ) + yield "success", success + yield "post_id", input_data.post_id + yield "post_url", post_url + except Exception as e: + error_msg = str(e) + if "403" in error_msg: + error_msg = ( + "Permission denied (403): You can only edit your own posts. " + "Make sure the post belongs to the authenticated Reddit account." + ) + yield "error", error_msg + + +class SubredditInfo(BaseModel): + """Information about a subreddit.""" + + name: str + display_name: str + title: str + description: str + public_description: str + subscribers: int + created_utc: float + over_18: bool + url: str + + +class GetSubredditInfoBlock(Block): + """Get information about a subreddit.""" + + class Input(BlockSchemaInput): + credentials: RedditCredentialsInput = RedditCredentialsField() + subreddit: str = SchemaField( + description="Subreddit name (without /r/ prefix)", + ) + + class Output(BlockSchemaOutput): + info: SubredditInfo = SchemaField(description="Subreddit information") + subreddit: str = SchemaField( + description="The subreddit name (pass-through for chaining)" + ) + error: str = SchemaField( + description="Error message if the subreddit couldn't be fetched" + ) + + def __init__(self): + super().__init__( + id="5a2d1f0c-01fb-43ea-bad7-2260d269c930", + description="Get information about a subreddit including subscriber count, description, and rules.", + categories={BlockCategory.SOCIAL}, + input_schema=GetSubredditInfoBlock.Input, + output_schema=GetSubredditInfoBlock.Output, + disabled=( + not settings.secrets.reddit_client_id + or not settings.secrets.reddit_client_secret + ), + test_credentials=TEST_CREDENTIALS, + test_input={ + "credentials": TEST_CREDENTIALS_INPUT, + "subreddit": "python", + }, + test_output=[ + ( + "info", + SubredditInfo( + name="t5_2qh0y", + display_name="python", + title="Python", + description="News about the Python programming language", + public_description="News about Python", + subscribers=1000000, + created_utc=1234567890.0, + over_18=False, + url="/r/python/", + ), + ), + ("subreddit", "python"), + ], + test_mock={ + "get_subreddit_info": lambda creds, subreddit: SubredditInfo( + name="t5_2qh0y", + display_name="python", + title="Python", + description="News about the Python programming language", + public_description="News about Python", + subscribers=1000000, + created_utc=1234567890.0, + over_18=False, + url="/r/python/", + ) + }, + ) + + @staticmethod + def get_subreddit_info(creds: RedditCredentials, subreddit: str) -> SubredditInfo: + client = get_praw(creds) + sub = client.subreddit(subreddit) + + return SubredditInfo( + name=sub.name, + display_name=sub.display_name, + title=sub.title, + description=sub.description, + public_description=sub.public_description, + subscribers=sub.subscribers, + created_utc=sub.created_utc, + over_18=sub.over18, + url=sub.url, + ) + + async def run( + self, input_data: Input, *, credentials: RedditCredentials, **kwargs + ) -> BlockOutput: + try: + info = self.get_subreddit_info(credentials, input_data.subreddit) + yield "info", info + yield "subreddit", input_data.subreddit + except Exception as e: + yield "error", str(e) + + +class RedditComment(BaseModel): + """A Reddit comment.""" + + comment_id: str + post_id: str + parent_comment_id: str | None + author: str + body: str + score: int + created_utc: float + edited: bool + is_submitter: bool + permalink: str + depth: int + + +class GetRedditPostCommentsBlock(Block): + """Get comments on a Reddit post.""" + + class Input(BlockSchemaInput): + credentials: RedditCredentialsInput = RedditCredentialsField() + post_id: str = SchemaField( + description="The ID of the post to get comments from", + ) + limit: int = SchemaField( + description="Maximum number of top-level comments to fetch (max 100)", + default=25, + ) + sort: CommentSort = SchemaField( + description="Sort order for comments", + default="best", + ) + + class Output(BlockSchemaOutput): + comment: RedditComment = SchemaField(description="A comment on the post") + comments: list[RedditComment] = SchemaField(description="All fetched comments") + post_id: str = SchemaField( + description="The post ID (pass-through for chaining)" + ) + error: str = SchemaField( + description="Error message if comments couldn't be fetched" + ) + + def __init__(self): + super().__init__( + id="98422b2c-c3b0-4d70-871f-56bd966f46da", + description="Get top-level comments on a Reddit post.", + categories={BlockCategory.SOCIAL}, + input_schema=GetRedditPostCommentsBlock.Input, + output_schema=GetRedditPostCommentsBlock.Output, + disabled=( + not settings.secrets.reddit_client_id + or not settings.secrets.reddit_client_secret + ), + test_credentials=TEST_CREDENTIALS, + test_input={ + "credentials": TEST_CREDENTIALS_INPUT, + "post_id": "abc123", + "limit": 2, + }, + test_output=[ + ( + "comment", + RedditComment( + comment_id="comment1", + post_id="abc123", + parent_comment_id=None, + author="user1", + body="Comment body 1", + score=10, + created_utc=1234567890.0, + edited=False, + is_submitter=False, + permalink="/r/test/comments/abc123/test/comment1/", + depth=0, + ), + ), + ( + "comment", + RedditComment( + comment_id="comment2", + post_id="abc123", + parent_comment_id=None, + author="user2", + body="Comment body 2", + score=5, + created_utc=1234567891.0, + edited=False, + is_submitter=True, + permalink="/r/test/comments/abc123/test/comment2/", + depth=0, + ), + ), + ( + "comments", + [ + RedditComment( + comment_id="comment1", + post_id="abc123", + parent_comment_id=None, + author="user1", + body="Comment body 1", + score=10, + created_utc=1234567890.0, + edited=False, + is_submitter=False, + permalink="/r/test/comments/abc123/test/comment1/", + depth=0, + ), + RedditComment( + comment_id="comment2", + post_id="abc123", + parent_comment_id=None, + author="user2", + body="Comment body 2", + score=5, + created_utc=1234567891.0, + edited=False, + is_submitter=True, + permalink="/r/test/comments/abc123/test/comment2/", + depth=0, + ), + ], + ), + ("post_id", "abc123"), + ], + test_mock={ + "get_comments": lambda creds, post_id, limit, sort: [ + MockObject( + id="comment1", + link_id="t3_abc123", + parent_id="t3_abc123", + author="user1", + body="Comment body 1", + score=10, + created_utc=1234567890.0, + edited=False, + is_submitter=False, + permalink="/r/test/comments/abc123/test/comment1/", + depth=0, + ), + MockObject( + id="comment2", + link_id="t3_abc123", + parent_id="t3_abc123", + author="user2", + body="Comment body 2", + score=5, + created_utc=1234567891.0, + edited=False, + is_submitter=True, + permalink="/r/test/comments/abc123/test/comment2/", + depth=0, + ), + ] + }, + ) + + @staticmethod + def get_comments( + creds: RedditCredentials, post_id: str, limit: int, sort: CommentSort + ) -> list[Comment]: + client = get_praw(creds) + post_id = strip_reddit_prefix(post_id) + submission = client.submission(id=post_id) + submission.comment_sort = sort + # Replace MoreComments with actual comments up to limit + submission.comments.replace_more(limit=0) + # Return only top-level comments (depth=0), limited + # CommentForest supports indexing, so use slicing directly + max_comments = min(limit, 100) + return [ + submission.comments[i] + for i in range(min(len(submission.comments), max_comments)) + ] + + async def run( + self, input_data: Input, *, credentials: RedditCredentials, **kwargs + ) -> BlockOutput: + try: + comments = self.get_comments( + credentials, + input_data.post_id, + input_data.limit, + input_data.sort, + ) + all_comments = [] + for comment in comments: + # Extract post_id from link_id (format: t3_xxxxx) + comment_post_id = strip_reddit_prefix(comment.link_id) + + # parent_comment_id is None for top-level comments (parent is a post: t3_) + # For replies, extract the comment ID from t1_xxxxx + parent_comment_id = None + if comment.parent_id.startswith("t1_"): + parent_comment_id = strip_reddit_prefix(comment.parent_id) + + comment_data = RedditComment( + comment_id=comment.id, + post_id=comment_post_id, + parent_comment_id=parent_comment_id, + author=str(comment.author) if comment.author else "[deleted]", + body=comment.body, + score=comment.score, + created_utc=comment.created_utc, + edited=bool(comment.edited), + is_submitter=comment.is_submitter, + permalink=comment.permalink, + depth=comment.depth, + ) + all_comments.append(comment_data) + yield "comment", comment_data + yield "comments", all_comments + yield "post_id", input_data.post_id + except Exception as e: + yield "error", str(e) + + +class GetRedditCommentRepliesBlock(Block): + """Get replies to a specific Reddit comment.""" + + class Input(BlockSchemaInput): + credentials: RedditCredentialsInput = RedditCredentialsField() + comment_id: str = SchemaField( + description="The ID of the comment to get replies from", + ) + post_id: str = SchemaField( + description="The ID of the post containing the comment", + ) + limit: int = SchemaField( + description="Maximum number of replies to fetch (max 50)", + default=10, + ) + + class Output(BlockSchemaOutput): + reply: RedditComment = SchemaField(description="A reply to the comment") + replies: list[RedditComment] = SchemaField(description="All replies") + comment_id: str = SchemaField( + description="The parent comment ID (pass-through for chaining)" + ) + post_id: str = SchemaField( + description="The post ID (pass-through for chaining)" + ) + error: str = SchemaField( + description="Error message if replies couldn't be fetched" + ) + + def __init__(self): + super().__init__( + id="7fa83965-7289-432f-98a9-1575f5bcc8f1", + description="Get replies to a specific Reddit comment.", + categories={BlockCategory.SOCIAL}, + input_schema=GetRedditCommentRepliesBlock.Input, + output_schema=GetRedditCommentRepliesBlock.Output, + disabled=( + not settings.secrets.reddit_client_id + or not settings.secrets.reddit_client_secret + ), + test_credentials=TEST_CREDENTIALS, + test_input={ + "credentials": TEST_CREDENTIALS_INPUT, + "comment_id": "comment1", + "post_id": "abc123", + "limit": 2, + }, + test_output=[ + ( + "reply", + RedditComment( + comment_id="reply1", + post_id="abc123", + parent_comment_id="comment1", + author="replier1", + body="Reply body 1", + score=3, + created_utc=1234567892.0, + edited=False, + is_submitter=False, + permalink="/r/test/comments/abc123/test/reply1/", + depth=1, + ), + ), + ( + "replies", + [ + RedditComment( + comment_id="reply1", + post_id="abc123", + parent_comment_id="comment1", + author="replier1", + body="Reply body 1", + score=3, + created_utc=1234567892.0, + edited=False, + is_submitter=False, + permalink="/r/test/comments/abc123/test/reply1/", + depth=1, + ), + ], + ), + ("comment_id", "comment1"), + ("post_id", "abc123"), + ], + test_mock={ + "get_replies": lambda creds, comment_id, post_id, limit: [ + MockObject( + id="reply1", + link_id="t3_abc123", + parent_id="t1_comment1", + author="replier1", + body="Reply body 1", + score=3, + created_utc=1234567892.0, + edited=False, + is_submitter=False, + permalink="/r/test/comments/abc123/test/reply1/", + depth=1, + ), + ] + }, + ) + + @staticmethod + def get_replies( + creds: RedditCredentials, comment_id: str, post_id: str, limit: int + ) -> list[Comment]: + client = get_praw(creds) + post_id = strip_reddit_prefix(post_id) + comment_id = strip_reddit_prefix(comment_id) + + # Get the submission and find the comment + submission = client.submission(id=post_id) + submission.comments.replace_more(limit=0) + + # Find the target comment - filter out MoreComments which don't have .id + comment = None + for c in submission.comments.list(): + if isinstance(c, MoreComments): + continue + if c.id == comment_id: + comment = c + break + + if not comment: + return [] + + # Get direct replies - filter out MoreComments objects + replies = [] + # CommentForest supports indexing + for i in range(len(comment.replies)): + reply = comment.replies[i] + if isinstance(reply, MoreComments): + continue + replies.append(reply) + if len(replies) >= min(limit, 50): + break + + return replies + + async def run( + self, input_data: Input, *, credentials: RedditCredentials, **kwargs + ) -> BlockOutput: + try: + replies = self.get_replies( + credentials, + input_data.comment_id, + input_data.post_id, + input_data.limit, + ) + all_replies = [] + for reply in replies: + reply_post_id = strip_reddit_prefix(reply.link_id) + + # parent_comment_id is the parent comment (always present for replies) + parent_comment_id = None + if reply.parent_id.startswith("t1_"): + parent_comment_id = strip_reddit_prefix(reply.parent_id) + + reply_data = RedditComment( + comment_id=reply.id, + post_id=reply_post_id, + parent_comment_id=parent_comment_id, + author=str(reply.author) if reply.author else "[deleted]", + body=reply.body, + score=reply.score, + created_utc=reply.created_utc, + edited=bool(reply.edited), + is_submitter=reply.is_submitter, + permalink=reply.permalink, + depth=reply.depth, + ) + all_replies.append(reply_data) + yield "reply", reply_data + yield "replies", all_replies + yield "comment_id", input_data.comment_id + yield "post_id", input_data.post_id + except Exception as e: + yield "error", str(e) + + +class GetRedditCommentBlock(Block): + """Get details about a specific Reddit comment.""" + + class Input(BlockSchemaInput): + credentials: RedditCredentialsInput = RedditCredentialsField() + comment_id: str = SchemaField( + description="The ID of the comment to fetch", + ) + + class Output(BlockSchemaOutput): + comment: RedditComment = SchemaField(description="The comment details") + error: str = SchemaField( + description="Error message if comment couldn't be fetched" + ) + + def __init__(self): + super().__init__( + id="72cb311a-5998-4e0a-9bc4-f1b67a97284e", + description="Get details about a specific Reddit comment by its ID.", + categories={BlockCategory.SOCIAL}, + input_schema=GetRedditCommentBlock.Input, + output_schema=GetRedditCommentBlock.Output, + disabled=( + not settings.secrets.reddit_client_id + or not settings.secrets.reddit_client_secret + ), + test_credentials=TEST_CREDENTIALS, + test_input={ + "credentials": TEST_CREDENTIALS_INPUT, + "comment_id": "comment1", + }, + test_output=[ + ( + "comment", + RedditComment( + comment_id="comment1", + post_id="abc123", + parent_comment_id=None, + author="user1", + body="Comment body", + score=10, + created_utc=1234567890.0, + edited=False, + is_submitter=False, + permalink="/r/test/comments/abc123/test/comment1/", + depth=0, + ), + ), + ], + test_mock={ + "get_comment": lambda creds, comment_id: MockObject( + id="comment1", + link_id="t3_abc123", + parent_id="t3_abc123", + author="user1", + body="Comment body", + score=10, + created_utc=1234567890.0, + edited=False, + is_submitter=False, + permalink="/r/test/comments/abc123/test/comment1/", + depth=0, + ) + }, + ) + + @staticmethod + def get_comment(creds: RedditCredentials, comment_id: str): + client = get_praw(creds) + comment_id = strip_reddit_prefix(comment_id) + return client.comment(id=comment_id) + + async def run( + self, input_data: Input, *, credentials: RedditCredentials, **kwargs + ) -> BlockOutput: + try: + comment = self.get_comment(credentials, input_data.comment_id) + + post_id = strip_reddit_prefix(comment.link_id) + + # parent_comment_id is None for top-level comments (parent is a post: t3_) + parent_comment_id = None + if comment.parent_id.startswith("t1_"): + parent_comment_id = strip_reddit_prefix(comment.parent_id) + + comment_data = RedditComment( + comment_id=comment.id, + post_id=post_id, + parent_comment_id=parent_comment_id, + author=str(comment.author) if comment.author else "[deleted]", + body=comment.body, + score=comment.score, + created_utc=comment.created_utc, + edited=bool(comment.edited), + is_submitter=comment.is_submitter, + permalink=comment.permalink, + # depth is only available when comments are fetched as part of a tree, + # not when fetched directly by ID + depth=getattr(comment, "depth", 0), + ) + yield "comment", comment_data + except Exception as e: + yield "error", str(e) + + +class ReplyToRedditCommentBlock(Block): + """Reply to a specific Reddit comment.""" + + class Input(BlockSchemaInput): + credentials: RedditCredentialsInput = RedditCredentialsField() + comment_id: str = SchemaField( + description="The ID of the comment to reply to", + ) + reply_text: str = SchemaField( + description="The text content of the reply", + ) + + class Output(BlockSchemaOutput): + comment_id: str = SchemaField(description="ID of the newly created reply") + parent_comment_id: str = SchemaField( + description="The parent comment ID (pass-through for chaining)" + ) + error: str = SchemaField(description="Error message if reply failed") + + def __init__(self): + super().__init__( + id="7635b059-3a9f-4f7d-b499-1b56c4f76f4f", + description="Reply to a specific Reddit comment. Useful for threaded conversations.", + categories={BlockCategory.SOCIAL}, + input_schema=ReplyToRedditCommentBlock.Input, + output_schema=ReplyToRedditCommentBlock.Output, + disabled=( + not settings.secrets.reddit_client_id + or not settings.secrets.reddit_client_secret + ), + test_credentials=TEST_CREDENTIALS, + test_input={ + "credentials": TEST_CREDENTIALS_INPUT, + "comment_id": "parent_comment", + "reply_text": "This is a reply", + }, + test_output=[ + ("comment_id", "new_reply_id"), + ("parent_comment_id", "parent_comment"), + ], + test_mock={ + "reply_to_comment": lambda creds, comment_id, reply_text: "new_reply_id" + }, + ) + + @staticmethod + def reply_to_comment( + creds: RedditCredentials, comment_id: str, reply_text: str + ) -> str: + client = get_praw(creds) + comment_id = strip_reddit_prefix(comment_id) + comment = client.comment(id=comment_id) + reply = comment.reply(reply_text) + if not reply: + raise ValueError("Failed to post reply.") + return reply.id + + async def run( + self, input_data: Input, *, credentials: RedditCredentials, **kwargs + ) -> BlockOutput: + try: + new_comment_id = self.reply_to_comment( + credentials, input_data.comment_id, input_data.reply_text + ) + yield "comment_id", new_comment_id + yield "parent_comment_id", input_data.comment_id + except Exception as e: + yield "error", str(e) + + +class RedditUserProfileSubreddit(BaseModel): + """Information about a user's profile subreddit.""" + + name: str + title: str + public_description: str + subscribers: int + over_18: bool + + +class RedditUserInfo(BaseModel): + """Information about a Reddit user.""" + + username: str + user_id: str + comment_karma: int + link_karma: int + total_karma: int + created_utc: float + is_gold: bool + is_mod: bool + has_verified_email: bool + moderated_subreddits: list[str] + profile_subreddit: RedditUserProfileSubreddit | None + + +class GetRedditUserInfoBlock(Block): + """Get information about a Reddit user.""" + + class Input(BlockSchemaInput): + credentials: RedditCredentialsInput = RedditCredentialsField() + username: str = SchemaField( + description="The Reddit username to look up (without /u/ prefix)", + ) + + class Output(BlockSchemaOutput): + user: RedditUserInfo = SchemaField(description="User information") + username: str = SchemaField( + description="The username (pass-through for chaining)" + ) + error: str = SchemaField(description="Error message if user lookup failed") + + def __init__(self): + super().__init__( + id="1b4c6bd1-4f28-4bad-9ae9-e7034a0f61ff", + description="Get information about a Reddit user including karma, account age, and verification status.", + categories={BlockCategory.SOCIAL}, + input_schema=GetRedditUserInfoBlock.Input, + output_schema=GetRedditUserInfoBlock.Output, + disabled=( + not settings.secrets.reddit_client_id + or not settings.secrets.reddit_client_secret + ), + test_credentials=TEST_CREDENTIALS, + test_input={ + "credentials": TEST_CREDENTIALS_INPUT, + "username": "testuser", + }, + test_output=[ + ( + "user", + RedditUserInfo( + username="testuser", + user_id="abc123", + comment_karma=1000, + link_karma=500, + total_karma=1500, + created_utc=1234567890.0, + is_gold=False, + is_mod=True, + has_verified_email=True, + moderated_subreddits=["python", "learnpython"], + profile_subreddit=RedditUserProfileSubreddit( + name="u_testuser", + title="testuser's profile", + public_description="A test user", + subscribers=100, + over_18=False, + ), + ), + ), + ("username", "testuser"), + ], + test_mock={ + "get_user_info": lambda creds, username: MockObject( + name="testuser", + id="abc123", + comment_karma=1000, + link_karma=500, + total_karma=1500, + created_utc=1234567890.0, + is_gold=False, + is_mod=True, + has_verified_email=True, + subreddit=MockObject( + display_name="u_testuser", + title="testuser's profile", + public_description="A test user", + subscribers=100, + over18=False, + ), + ), + "get_moderated_subreddits": lambda creds, username: [ + MockObject(display_name="python"), + MockObject(display_name="learnpython"), + ], + }, + ) + + @staticmethod + def get_user_info(creds: RedditCredentials, username: str): + client = get_praw(creds) + if username.startswith("u/"): + username = username[2:] + return client.redditor(username) + + @staticmethod + def get_moderated_subreddits(creds: RedditCredentials, username: str) -> list: + client = get_praw(creds) + if username.startswith("u/"): + username = username[2:] + redditor = client.redditor(username) + return list(redditor.moderated()) + + async def run( + self, input_data: Input, *, credentials: RedditCredentials, **kwargs + ) -> BlockOutput: + try: + redditor = self.get_user_info(credentials, input_data.username) + moderated = self.get_moderated_subreddits(credentials, input_data.username) + + # Extract moderated subreddit names + moderated_subreddits = [sub.display_name for sub in moderated] + + # Get profile subreddit info if available + profile_subreddit = None + if hasattr(redditor, "subreddit") and redditor.subreddit: + try: + profile_subreddit = RedditUserProfileSubreddit( + name=redditor.subreddit.display_name, + title=redditor.subreddit.title or "", + public_description=redditor.subreddit.public_description or "", + subscribers=redditor.subreddit.subscribers or 0, + over_18=( + redditor.subreddit.over18 + if hasattr(redditor.subreddit, "over18") + else False + ), + ) + except Exception: + # Profile subreddit may not be accessible + pass + + user_info = RedditUserInfo( + username=redditor.name, + user_id=redditor.id, + comment_karma=redditor.comment_karma, + link_karma=redditor.link_karma, + total_karma=redditor.total_karma, + created_utc=redditor.created_utc, + is_gold=redditor.is_gold, + is_mod=redditor.is_mod, + has_verified_email=redditor.has_verified_email, + moderated_subreddits=moderated_subreddits, + profile_subreddit=profile_subreddit, + ) + yield "user", user_info + yield "username", input_data.username + except Exception as e: + yield "error", str(e) + + +class SendRedditMessageBlock(Block): + """Send a private message to a Reddit user.""" + + class Input(BlockSchemaInput): + credentials: RedditCredentialsInput = RedditCredentialsField() + username: str = SchemaField( + description="The Reddit username to send a message to (without /u/ prefix)", + ) + subject: str = SchemaField( + description="The subject line of the message", + ) + message: str = SchemaField( + description="The body content of the message", + ) + + class Output(BlockSchemaOutput): + success: bool = SchemaField(description="Whether the message was sent") + username: str = SchemaField( + description="The username (pass-through for chaining)" + ) + error: str = SchemaField(description="Error message if sending failed") + + def __init__(self): + super().__init__( + id="7921101a-0537-4259-82ea-bc186ca6b1b6", + description="Send a private message (DM) to a Reddit user.", + categories={BlockCategory.SOCIAL}, + input_schema=SendRedditMessageBlock.Input, + output_schema=SendRedditMessageBlock.Output, + disabled=( + not settings.secrets.reddit_client_id + or not settings.secrets.reddit_client_secret + ), + test_credentials=TEST_CREDENTIALS, + test_input={ + "credentials": TEST_CREDENTIALS_INPUT, + "username": "testuser", + "subject": "Hello", + "message": "This is a test message", + }, + test_output=[ + ("success", True), + ("username", "testuser"), + ], + test_mock={"send_message": lambda creds, username, subject, message: True}, + ) + + @staticmethod + def send_message( + creds: RedditCredentials, username: str, subject: str, message: str + ) -> bool: + client = get_praw(creds) + if username.startswith("u/"): + username = username[2:] + redditor = client.redditor(username) + redditor.message(subject=subject, message=message) + return True + + async def run( + self, input_data: Input, *, credentials: RedditCredentials, **kwargs + ) -> BlockOutput: + try: + success = self.send_message( + credentials, + input_data.username, + input_data.subject, + input_data.message, + ) + yield "success", success + yield "username", input_data.username + except Exception as e: + yield "error", str(e) + + +class RedditInboxItem(BaseModel): + """A Reddit inbox item (message, comment reply, or mention).""" + + item_id: str + item_type: str # "message", "comment_reply", "mention" + subject: str + body: str + author: str + created_utc: float + is_read: bool + context: str | None # permalink for comments, None for messages + + +class GetRedditInboxBlock(Block): + """Get messages and notifications from Reddit inbox.""" + + class Input(BlockSchemaInput): + credentials: RedditCredentialsInput = RedditCredentialsField() + inbox_type: InboxType = SchemaField( + description="Type of inbox items to fetch", + default="unread", + ) + limit: int = SchemaField( + description="Maximum number of items to fetch", + default=25, + ) + mark_read: bool = SchemaField( + description="Whether to mark fetched items as read", + default=False, + ) + + class Output(BlockSchemaOutput): + item: RedditInboxItem = SchemaField(description="An inbox item") + items: list[RedditInboxItem] = SchemaField(description="All fetched items") + error: str = SchemaField(description="Error message if fetch failed") + + def __init__(self): + super().__init__( + id="5a91bb34-7ffe-4b9e-957b-9d4f8fe8dbc9", + description="Get messages, mentions, and comment replies from your Reddit inbox.", + categories={BlockCategory.SOCIAL}, + input_schema=GetRedditInboxBlock.Input, + output_schema=GetRedditInboxBlock.Output, + disabled=( + not settings.secrets.reddit_client_id + or not settings.secrets.reddit_client_secret + ), + test_credentials=TEST_CREDENTIALS, + test_input={ + "credentials": TEST_CREDENTIALS_INPUT, + "inbox_type": "unread", + "limit": 10, + }, + test_output=[ + ( + "item", + RedditInboxItem( + item_id="msg123", + item_type="message", + subject="Hello", + body="Test message body", + author="sender_user", + created_utc=1234567890.0, + is_read=False, + context=None, + ), + ), + ( + "items", + [ + RedditInboxItem( + item_id="msg123", + item_type="message", + subject="Hello", + body="Test message body", + author="sender_user", + created_utc=1234567890.0, + is_read=False, + context=None, + ), + ], + ), + ], + test_mock={ + "get_inbox": lambda creds, inbox_type, limit: [ + MockObject( + id="msg123", + subject="Hello", + body="Test message body", + author="sender_user", + created_utc=1234567890.0, + new=True, + context=None, + was_comment=False, + ), + ] + }, + ) + + @staticmethod + def get_inbox(creds: RedditCredentials, inbox_type: InboxType, limit: int) -> list: + client = get_praw(creds) + inbox = client.inbox + + if inbox_type == "all": + items = inbox.all(limit=limit) + elif inbox_type == "unread": + items = inbox.unread(limit=limit) + elif inbox_type == "messages": + items = inbox.messages(limit=limit) + elif inbox_type == "mentions": + items = inbox.mentions(limit=limit) + elif inbox_type == "comment_replies": + items = inbox.comment_replies(limit=limit) + else: + items = inbox.unread(limit=limit) + + return list(items) + + async def run( + self, input_data: Input, *, credentials: RedditCredentials, **kwargs + ) -> BlockOutput: + try: + raw_items = self.get_inbox( + credentials, input_data.inbox_type, input_data.limit + ) + all_items = [] + + for item in raw_items: + # Determine item type + if hasattr(item, "was_comment") and item.was_comment: + if hasattr(item, "subject") and "mention" in item.subject.lower(): + item_type = "mention" + else: + item_type = "comment_reply" + else: + item_type = "message" + + inbox_item = RedditInboxItem( + item_id=item.id, + item_type=item_type, + subject=item.subject if hasattr(item, "subject") else "", + body=item.body, + author=str(item.author) if item.author else "[deleted]", + created_utc=item.created_utc, + is_read=not item.new, + context=item.context if hasattr(item, "context") else None, + ) + all_items.append(inbox_item) + yield "item", inbox_item + + # Mark as read if requested + if input_data.mark_read and raw_items: + client = get_praw(credentials) + client.inbox.mark_read(raw_items) + + yield "items", all_items + except Exception as e: + yield "error", str(e) + + +class DeleteRedditPostBlock(Block): + """Delete a Reddit post that you own.""" + + class Input(BlockSchemaInput): + credentials: RedditCredentialsInput = RedditCredentialsField() + post_id: str = SchemaField( + description="The ID of the post to delete (must be your own post)", + ) + + class Output(BlockSchemaOutput): + success: bool = SchemaField(description="Whether the deletion was successful") + post_id: str = SchemaField( + description="The post ID (pass-through for chaining)" + ) + error: str = SchemaField(description="Error message if deletion failed") + + def __init__(self): + super().__init__( + id="72e4730a-d66d-4785-8e54-5ab3af450c81", + description="Delete a Reddit post that you own.", + categories={BlockCategory.SOCIAL}, + input_schema=DeleteRedditPostBlock.Input, + output_schema=DeleteRedditPostBlock.Output, + disabled=( + not settings.secrets.reddit_client_id + or not settings.secrets.reddit_client_secret + ), + test_credentials=TEST_CREDENTIALS, + test_input={ + "credentials": TEST_CREDENTIALS_INPUT, + "post_id": "abc123", + }, + test_output=[ + ("success", True), + ("post_id", "abc123"), + ], + test_mock={"delete_post": lambda creds, post_id: True}, + ) + + @staticmethod + def delete_post(creds: RedditCredentials, post_id: str) -> bool: + client = get_praw(creds) + post_id = strip_reddit_prefix(post_id) + submission = client.submission(id=post_id) + submission.delete() + return True + + async def run( + self, input_data: Input, *, credentials: RedditCredentials, **kwargs + ) -> BlockOutput: + try: + success = self.delete_post(credentials, input_data.post_id) + yield "success", success + yield "post_id", input_data.post_id + except Exception as e: + yield "error", str(e) + + +class DeleteRedditCommentBlock(Block): + """Delete a Reddit comment that you own.""" + + class Input(BlockSchemaInput): + credentials: RedditCredentialsInput = RedditCredentialsField() + comment_id: str = SchemaField( + description="The ID of the comment to delete (must be your own comment)", + ) + + class Output(BlockSchemaOutput): + success: bool = SchemaField(description="Whether the deletion was successful") + comment_id: str = SchemaField( + description="The comment ID (pass-through for chaining)" + ) + error: str = SchemaField(description="Error message if deletion failed") + + def __init__(self): + super().__init__( + id="2650584d-434f-46db-81ef-26c8d8d41f81", + description="Delete a Reddit comment that you own.", + categories={BlockCategory.SOCIAL}, + input_schema=DeleteRedditCommentBlock.Input, + output_schema=DeleteRedditCommentBlock.Output, + disabled=( + not settings.secrets.reddit_client_id + or not settings.secrets.reddit_client_secret + ), + test_credentials=TEST_CREDENTIALS, + test_input={ + "credentials": TEST_CREDENTIALS_INPUT, + "comment_id": "xyz789", + }, + test_output=[ + ("success", True), + ("comment_id", "xyz789"), + ], + test_mock={"delete_comment": lambda creds, comment_id: True}, + ) + + @staticmethod + def delete_comment(creds: RedditCredentials, comment_id: str) -> bool: + client = get_praw(creds) + comment_id = strip_reddit_prefix(comment_id) + comment = client.comment(id=comment_id) + comment.delete() + return True + + async def run( + self, input_data: Input, *, credentials: RedditCredentials, **kwargs + ) -> BlockOutput: + try: + success = self.delete_comment(credentials, input_data.comment_id) + yield "success", success + yield "comment_id", input_data.comment_id + except Exception as e: + yield "error", str(e) + + +class SubredditFlair(BaseModel): + """A subreddit link flair template.""" + + flair_id: str + text: str + text_editable: bool + css_class: str = "" # The CSS class for styling (from flair_css_class) + + +class GetSubredditFlairsBlock(Block): + """Get available link flairs for a subreddit.""" + + class Input(BlockSchemaInput): + credentials: RedditCredentialsInput = RedditCredentialsField() + subreddit: str = SchemaField( + description="Subreddit name (without /r/ prefix)", + ) + + class Output(BlockSchemaOutput): + flair: SubredditFlair = SchemaField(description="A flair option") + flairs: list[SubredditFlair] = SchemaField(description="All available flairs") + subreddit: str = SchemaField( + description="The subreddit name (pass-through for chaining)" + ) + error: str = SchemaField(description="Error message if fetch failed") + + def __init__(self): + super().__init__( + id="ada08f34-a7a9-44aa-869f-0638fa4e0a84", + description="Get available link flair options for a subreddit.", + categories={BlockCategory.SOCIAL}, + input_schema=GetSubredditFlairsBlock.Input, + output_schema=GetSubredditFlairsBlock.Output, + disabled=( + not settings.secrets.reddit_client_id + or not settings.secrets.reddit_client_secret + ), + test_credentials=TEST_CREDENTIALS, + test_input={ + "credentials": TEST_CREDENTIALS_INPUT, + "subreddit": "test", + }, + test_output=[ + ( + "flair", + SubredditFlair( + flair_id="abc123", + text="Discussion", + text_editable=False, + css_class="discussion", + ), + ), + ( + "flairs", + [ + SubredditFlair( + flair_id="abc123", + text="Discussion", + text_editable=False, + css_class="discussion", + ), + ], + ), + ("subreddit", "test"), + ], + test_mock={ + "get_flairs": lambda creds, subreddit: [ + { + "flair_template_id": "abc123", + "flair_text": "Discussion", + "flair_text_editable": False, + "flair_css_class": "discussion", + }, + ] + }, + ) + + @staticmethod + def get_flairs(creds: RedditCredentials, subreddit: str) -> list: + client = get_praw(creds) + # Use /r/{subreddit}/api/flairselector endpoint directly with is_newlink=True + # This returns link flairs available for new submissions without requiring mod access + # The link_templates API is moderator-only, so we use flairselector instead + # Path must include the subreddit prefix per Reddit API docs + response = client.post( + f"r/{subreddit}/api/flairselector", + data={"is_newlink": "true"}, + ) + # Response contains 'choices' list with available flairs + choices = response.get("choices", []) + return choices + + async def run( + self, input_data: Input, *, credentials: RedditCredentials, **kwargs + ) -> BlockOutput: + try: + raw_flairs = self.get_flairs(credentials, input_data.subreddit) + all_flairs = [] + + for flair in raw_flairs: + # /api/flairselector returns flairs with flair_template_id, flair_text, etc. + flair_data = SubredditFlair( + flair_id=flair.get("flair_template_id", ""), + text=flair.get("flair_text", ""), + text_editable=flair.get("flair_text_editable", False), + css_class=flair.get("flair_css_class", ""), + ) + all_flairs.append(flair_data) + yield "flair", flair_data + + yield "flairs", all_flairs + yield "subreddit", input_data.subreddit + except Exception as e: + yield "error", str(e) + + +class SubredditRule(BaseModel): + """A subreddit rule.""" + + short_name: str + description: str + kind: str # "all", "link", "comment" + violation_reason: str + priority: int + + +class GetSubredditRulesBlock(Block): + """Get the rules for a subreddit.""" + + class Input(BlockSchemaInput): + credentials: RedditCredentialsInput = RedditCredentialsField() + subreddit: str = SchemaField( + description="Subreddit name (without /r/ prefix)", + ) + + class Output(BlockSchemaOutput): + rule: SubredditRule = SchemaField(description="A subreddit rule") + rules: list[SubredditRule] = SchemaField(description="All subreddit rules") + subreddit: str = SchemaField( + description="The subreddit name (pass-through for chaining)" + ) + error: str = SchemaField(description="Error message if fetch failed") + + def __init__(self): + super().__init__( + id="222aa36c-fa70-4879-8e8a-37d100175f5c", + description="Get the rules for a subreddit to ensure compliance before posting.", + categories={BlockCategory.SOCIAL}, + input_schema=GetSubredditRulesBlock.Input, + output_schema=GetSubredditRulesBlock.Output, + disabled=( + not settings.secrets.reddit_client_id + or not settings.secrets.reddit_client_secret + ), + test_credentials=TEST_CREDENTIALS, + test_input={ + "credentials": TEST_CREDENTIALS_INPUT, + "subreddit": "test", + }, + test_output=[ + ( + "rule", + SubredditRule( + short_name="No spam", + description="Do not post spam or self-promotional content.", + kind="all", + violation_reason="Spam", + priority=0, + ), + ), + ( + "rules", + [ + SubredditRule( + short_name="No spam", + description="Do not post spam or self-promotional content.", + kind="all", + violation_reason="Spam", + priority=0, + ), + ], + ), + ("subreddit", "test"), + ], + test_mock={ + "get_rules": lambda creds, subreddit: [ + MockObject( + short_name="No spam", + description="Do not post spam or self-promotional content.", + kind="all", + violation_reason="Spam", + priority=0, + ), + ] + }, + ) + + @staticmethod + def get_rules(creds: RedditCredentials, subreddit: str) -> list: + client = get_praw(creds) + sub = client.subreddit(subreddit) + return list(sub.rules) + + async def run( + self, input_data: Input, *, credentials: RedditCredentials, **kwargs + ) -> BlockOutput: + try: + raw_rules = self.get_rules(credentials, input_data.subreddit) + all_rules = [] + + for idx, rule in enumerate(raw_rules): + rule_data = SubredditRule( + short_name=rule.short_name, + description=rule.description or "", + kind=rule.kind, + violation_reason=rule.violation_reason or rule.short_name, + priority=idx, + ) + all_rules.append(rule_data) + yield "rule", rule_data + + yield "rules", all_rules + yield "subreddit", input_data.subreddit + except Exception as e: + yield "error", str(e) diff --git a/autogpt_platform/backend/backend/blocks/smart_decision_maker.py b/autogpt_platform/backend/backend/blocks/smart_decision_maker.py index 751f6af37f..ff6042eaab 100644 --- a/autogpt_platform/backend/backend/blocks/smart_decision_maker.py +++ b/autogpt_platform/backend/backend/blocks/smart_decision_maker.py @@ -391,8 +391,12 @@ class SmartDecisionMakerBlock(Block): """ block = sink_node.block + # Use custom name from node metadata if set, otherwise fall back to block.name + custom_name = sink_node.metadata.get("customized_name") + tool_name = custom_name if custom_name else block.name + tool_function: dict[str, Any] = { - "name": SmartDecisionMakerBlock.cleanup(block.name), + "name": SmartDecisionMakerBlock.cleanup(tool_name), "description": block.description, } sink_block_input_schema = block.input_schema @@ -489,14 +493,24 @@ class SmartDecisionMakerBlock(Block): f"Sink graph metadata not found: {graph_id} {graph_version}" ) + # Use custom name from node metadata if set, otherwise fall back to graph name + custom_name = sink_node.metadata.get("customized_name") + tool_name = custom_name if custom_name else sink_graph_meta.name + tool_function: dict[str, Any] = { - "name": SmartDecisionMakerBlock.cleanup(sink_graph_meta.name), + "name": SmartDecisionMakerBlock.cleanup(tool_name), "description": sink_graph_meta.description, } properties = {} + field_mapping = {} for link in links: + field_name = link.sink_name + + clean_field_name = SmartDecisionMakerBlock.cleanup(field_name) + field_mapping[clean_field_name] = field_name + sink_block_input_schema = sink_node.input_default["input_schema"] sink_block_properties = sink_block_input_schema.get("properties", {}).get( link.sink_name, {} @@ -506,7 +520,7 @@ class SmartDecisionMakerBlock(Block): if "description" in sink_block_properties else f"The {link.sink_name} of the tool" ) - properties[link.sink_name] = { + properties[clean_field_name] = { "type": "string", "description": description, "default": json.dumps(sink_block_properties.get("default", None)), @@ -519,7 +533,7 @@ class SmartDecisionMakerBlock(Block): "strict": True, } - # Store node info for later use in output processing + tool_function["_field_mapping"] = field_mapping tool_function["_sink_node_id"] = sink_node.id return {"type": "function", "function": tool_function} @@ -975,10 +989,28 @@ class SmartDecisionMakerBlock(Block): graph_version: int, execution_context: ExecutionContext, execution_processor: "ExecutionProcessor", + nodes_to_skip: set[str] | None = None, **kwargs, ) -> BlockOutput: tool_functions = await self._create_tool_node_signatures(node_id) + original_tool_count = len(tool_functions) + + # Filter out tools for nodes that should be skipped (e.g., missing optional credentials) + if nodes_to_skip: + tool_functions = [ + tf + for tf in tool_functions + if tf.get("function", {}).get("_sink_node_id") not in nodes_to_skip + ] + + # Only raise error if we had tools but they were all filtered out + if original_tool_count > 0 and not tool_functions: + raise ValueError( + "No available tools to execute - all downstream nodes are unavailable " + "(possibly due to missing optional credentials)" + ) + yield "tool_functions", json.dumps(tool_functions) conversation_history = input_data.conversation_history or [] @@ -1129,8 +1161,9 @@ class SmartDecisionMakerBlock(Block): original_field_name = field_mapping.get(clean_arg_name, clean_arg_name) arg_value = tool_args.get(clean_arg_name) - sanitized_arg_name = self.cleanup(original_field_name) - emit_key = f"tools_^_{sink_node_id}_~_{sanitized_arg_name}" + # Use original_field_name directly (not sanitized) to match link sink_name + # The field_mapping already translates from LLM's cleaned names to original names + emit_key = f"tools_^_{sink_node_id}_~_{original_field_name}" logger.debug( "[SmartDecisionMakerBlock|geid:%s|neid:%s] emit %s", diff --git a/autogpt_platform/backend/backend/blocks/test/test_smart_decision_maker.py b/autogpt_platform/backend/backend/blocks/test/test_smart_decision_maker.py index c930fab37e..8266d433ad 100644 --- a/autogpt_platform/backend/backend/blocks/test/test_smart_decision_maker.py +++ b/autogpt_platform/backend/backend/blocks/test/test_smart_decision_maker.py @@ -1057,3 +1057,153 @@ async def test_smart_decision_maker_traditional_mode_default(): ) # Should yield individual tool parameters assert "tools_^_test-sink-node-id_~_max_keyword_difficulty" in outputs assert "conversations" in outputs + + +@pytest.mark.asyncio +async def test_smart_decision_maker_uses_customized_name_for_blocks(): + """Test that SmartDecisionMakerBlock uses customized_name from node metadata for tool names.""" + from unittest.mock import MagicMock + + from backend.blocks.basic import StoreValueBlock + from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock + from backend.data.graph import Link, Node + + # Create a mock node with customized_name in metadata + mock_node = MagicMock(spec=Node) + mock_node.id = "test-node-id" + mock_node.block_id = StoreValueBlock().id + mock_node.metadata = {"customized_name": "My Custom Tool Name"} + mock_node.block = StoreValueBlock() + + # Create a mock link + mock_link = MagicMock(spec=Link) + mock_link.sink_name = "input" + + # Call the function directly + result = await SmartDecisionMakerBlock._create_block_function_signature( + mock_node, [mock_link] + ) + + # Verify the tool name uses the customized name (cleaned up) + assert result["type"] == "function" + assert result["function"]["name"] == "my_custom_tool_name" # Cleaned version + assert result["function"]["_sink_node_id"] == "test-node-id" + + +@pytest.mark.asyncio +async def test_smart_decision_maker_falls_back_to_block_name(): + """Test that SmartDecisionMakerBlock falls back to block.name when no customized_name.""" + from unittest.mock import MagicMock + + from backend.blocks.basic import StoreValueBlock + from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock + from backend.data.graph import Link, Node + + # Create a mock node without customized_name + mock_node = MagicMock(spec=Node) + mock_node.id = "test-node-id" + mock_node.block_id = StoreValueBlock().id + mock_node.metadata = {} # No customized_name + mock_node.block = StoreValueBlock() + + # Create a mock link + mock_link = MagicMock(spec=Link) + mock_link.sink_name = "input" + + # Call the function directly + result = await SmartDecisionMakerBlock._create_block_function_signature( + mock_node, [mock_link] + ) + + # Verify the tool name uses the block's default name + assert result["type"] == "function" + assert result["function"]["name"] == "storevalueblock" # Default block name cleaned + assert result["function"]["_sink_node_id"] == "test-node-id" + + +@pytest.mark.asyncio +async def test_smart_decision_maker_uses_customized_name_for_agents(): + """Test that SmartDecisionMakerBlock uses customized_name from metadata for agent nodes.""" + from unittest.mock import AsyncMock, MagicMock, patch + + from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock + from backend.data.graph import Link, Node + + # Create a mock node with customized_name in metadata + mock_node = MagicMock(spec=Node) + mock_node.id = "test-agent-node-id" + mock_node.metadata = {"customized_name": "My Custom Agent"} + mock_node.input_default = { + "graph_id": "test-graph-id", + "graph_version": 1, + "input_schema": {"properties": {"test_input": {"description": "Test input"}}}, + } + + # Create a mock link + mock_link = MagicMock(spec=Link) + mock_link.sink_name = "test_input" + + # Mock the database client + mock_graph_meta = MagicMock() + mock_graph_meta.name = "Original Agent Name" + mock_graph_meta.description = "Agent description" + + mock_db_client = AsyncMock() + mock_db_client.get_graph_metadata.return_value = mock_graph_meta + + with patch( + "backend.blocks.smart_decision_maker.get_database_manager_async_client", + return_value=mock_db_client, + ): + result = await SmartDecisionMakerBlock._create_agent_function_signature( + mock_node, [mock_link] + ) + + # Verify the tool name uses the customized name (cleaned up) + assert result["type"] == "function" + assert result["function"]["name"] == "my_custom_agent" # Cleaned version + assert result["function"]["_sink_node_id"] == "test-agent-node-id" + + +@pytest.mark.asyncio +async def test_smart_decision_maker_agent_falls_back_to_graph_name(): + """Test that agent node falls back to graph name when no customized_name.""" + from unittest.mock import AsyncMock, MagicMock, patch + + from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock + from backend.data.graph import Link, Node + + # Create a mock node without customized_name + mock_node = MagicMock(spec=Node) + mock_node.id = "test-agent-node-id" + mock_node.metadata = {} # No customized_name + mock_node.input_default = { + "graph_id": "test-graph-id", + "graph_version": 1, + "input_schema": {"properties": {"test_input": {"description": "Test input"}}}, + } + + # Create a mock link + mock_link = MagicMock(spec=Link) + mock_link.sink_name = "test_input" + + # Mock the database client + mock_graph_meta = MagicMock() + mock_graph_meta.name = "Original Agent Name" + mock_graph_meta.description = "Agent description" + + mock_db_client = AsyncMock() + mock_db_client.get_graph_metadata.return_value = mock_graph_meta + + with patch( + "backend.blocks.smart_decision_maker.get_database_manager_async_client", + return_value=mock_db_client, + ): + result = await SmartDecisionMakerBlock._create_agent_function_signature( + mock_node, [mock_link] + ) + + # Verify the tool name uses the graph's default name + assert result["type"] == "function" + assert result["function"]["name"] == "original_agent_name" # Graph name cleaned + assert result["function"]["_sink_node_id"] == "test-agent-node-id" diff --git a/autogpt_platform/backend/backend/blocks/test/test_smart_decision_maker_dict.py b/autogpt_platform/backend/backend/blocks/test/test_smart_decision_maker_dict.py index 839bdc5e15..2087c0b7d6 100644 --- a/autogpt_platform/backend/backend/blocks/test/test_smart_decision_maker_dict.py +++ b/autogpt_platform/backend/backend/blocks/test/test_smart_decision_maker_dict.py @@ -15,6 +15,7 @@ async def test_smart_decision_maker_handles_dynamic_dict_fields(): mock_node.block = CreateDictionaryBlock() mock_node.block_id = CreateDictionaryBlock().id mock_node.input_default = {} + mock_node.metadata = {} # Create mock links with dynamic dictionary fields mock_links = [ @@ -77,6 +78,7 @@ async def test_smart_decision_maker_handles_dynamic_list_fields(): mock_node.block = AddToListBlock() mock_node.block_id = AddToListBlock().id mock_node.input_default = {} + mock_node.metadata = {} # Create mock links with dynamic list fields mock_links = [ diff --git a/autogpt_platform/backend/backend/blocks/test/test_smart_decision_maker_dynamic_fields.py b/autogpt_platform/backend/backend/blocks/test/test_smart_decision_maker_dynamic_fields.py index 6ed830e517..af89a83f86 100644 --- a/autogpt_platform/backend/backend/blocks/test/test_smart_decision_maker_dynamic_fields.py +++ b/autogpt_platform/backend/backend/blocks/test/test_smart_decision_maker_dynamic_fields.py @@ -44,6 +44,7 @@ async def test_create_block_function_signature_with_dict_fields(): mock_node.block = CreateDictionaryBlock() mock_node.block_id = CreateDictionaryBlock().id mock_node.input_default = {} + mock_node.metadata = {} # Create mock links with dynamic dictionary fields (source sanitized, sink original) mock_links = [ @@ -106,6 +107,7 @@ async def test_create_block_function_signature_with_list_fields(): mock_node.block = AddToListBlock() mock_node.block_id = AddToListBlock().id mock_node.input_default = {} + mock_node.metadata = {} # Create mock links with dynamic list fields mock_links = [ @@ -159,6 +161,7 @@ async def test_create_block_function_signature_with_object_fields(): mock_node.block = MatchTextPatternBlock() mock_node.block_id = MatchTextPatternBlock().id mock_node.input_default = {} + mock_node.metadata = {} # Create mock links with dynamic object fields mock_links = [ @@ -208,11 +211,13 @@ async def test_create_tool_node_signatures(): mock_dict_node.block = CreateDictionaryBlock() mock_dict_node.block_id = CreateDictionaryBlock().id mock_dict_node.input_default = {} + mock_dict_node.metadata = {} mock_list_node = Mock() mock_list_node.block = AddToListBlock() mock_list_node.block_id = AddToListBlock().id mock_list_node.input_default = {} + mock_list_node.metadata = {} # Mock links with dynamic fields dict_link1 = Mock( @@ -423,6 +428,7 @@ async def test_mixed_regular_and_dynamic_fields(): mock_node.block.name = "TestBlock" mock_node.block.description = "A test block" mock_node.block.input_schema = Mock() + mock_node.metadata = {} # Mock the get_field_schema to return a proper schema for regular fields def get_field_schema(field_name): diff --git a/autogpt_platform/backend/backend/blocks/wordpress/__init__.py b/autogpt_platform/backend/backend/blocks/wordpress/__init__.py index c7b1e26eea..3eae4a1063 100644 --- a/autogpt_platform/backend/backend/blocks/wordpress/__init__.py +++ b/autogpt_platform/backend/backend/blocks/wordpress/__init__.py @@ -1,3 +1,3 @@ -from .blog import WordPressCreatePostBlock +from .blog import WordPressCreatePostBlock, WordPressGetAllPostsBlock -__all__ = ["WordPressCreatePostBlock"] +__all__ = ["WordPressCreatePostBlock", "WordPressGetAllPostsBlock"] diff --git a/autogpt_platform/backend/backend/blocks/wordpress/_api.py b/autogpt_platform/backend/backend/blocks/wordpress/_api.py index 78f535947b..d21dc3e05d 100644 --- a/autogpt_platform/backend/backend/blocks/wordpress/_api.py +++ b/autogpt_platform/backend/backend/blocks/wordpress/_api.py @@ -161,7 +161,7 @@ async def oauth_exchange_code_for_tokens( grant_type="authorization_code", ).model_dump(exclude_none=True) - response = await Requests().post( + response = await Requests(raise_for_status=False).post( f"{WORDPRESS_BASE_URL}oauth2/token", headers=headers, data=data, @@ -205,7 +205,7 @@ async def oauth_refresh_tokens( grant_type="refresh_token", ).model_dump(exclude_none=True) - response = await Requests().post( + response = await Requests(raise_for_status=False).post( f"{WORDPRESS_BASE_URL}oauth2/token", headers=headers, data=data, @@ -252,7 +252,7 @@ async def validate_token( "token": token, } - response = await Requests().get( + response = await Requests(raise_for_status=False).get( f"{WORDPRESS_BASE_URL}oauth2/token-info", params=params, ) @@ -296,7 +296,7 @@ async def make_api_request( url = f"{WORDPRESS_BASE_URL.rstrip('/')}{endpoint}" - request_method = getattr(Requests(), method.lower()) + request_method = getattr(Requests(raise_for_status=False), method.lower()) response = await request_method( url, headers=headers, @@ -476,6 +476,7 @@ async def create_post( data["tags"] = ",".join(str(t) for t in data["tags"]) # Make the API request + site = normalize_site(site) endpoint = f"/rest/v1.1/sites/{site}/posts/new" headers = { @@ -483,7 +484,7 @@ async def create_post( "Content-Type": "application/x-www-form-urlencoded", } - response = await Requests().post( + response = await Requests(raise_for_status=False).post( f"{WORDPRESS_BASE_URL.rstrip('/')}{endpoint}", headers=headers, data=data, @@ -499,3 +500,132 @@ async def create_post( ) error_message = error_data.get("message", response.text) raise ValueError(f"Failed to create post: {response.status} - {error_message}") + + +class Post(BaseModel): + """Response model for individual posts in a posts list response. + + This is a simplified version compared to PostResponse, as the list endpoint + returns less detailed information than the create/get single post endpoints. + """ + + ID: int + site_ID: int + author: PostAuthor + date: datetime + modified: datetime + title: str + URL: str + short_URL: str + content: str | None = None + excerpt: str | None = None + slug: str + guid: str + status: str + sticky: bool + password: str | None = "" + parent: Union[Dict[str, Any], bool, None] = None + type: str + discussion: Dict[str, Union[str, bool, int]] | None = None + likes_enabled: bool | None = None + sharing_enabled: bool | None = None + like_count: int | None = None + i_like: bool | None = None + is_reblogged: bool | None = None + is_following: bool | None = None + global_ID: str | None = None + featured_image: str | None = None + post_thumbnail: Dict[str, Any] | None = None + format: str | None = None + geo: Union[Dict[str, Any], bool, None] = None + menu_order: int | None = None + page_template: str | None = None + publicize_URLs: List[str] | None = None + terms: Dict[str, Dict[str, Any]] | None = None + tags: Dict[str, Dict[str, Any]] | None = None + categories: Dict[str, Dict[str, Any]] | None = None + attachments: Dict[str, Dict[str, Any]] | None = None + attachment_count: int | None = None + metadata: List[Dict[str, Any]] | None = None + meta: Dict[str, Any] | None = None + capabilities: Dict[str, bool] | None = None + revisions: List[int] | None = None + other_URLs: Dict[str, Any] | None = None + + +class PostsResponse(BaseModel): + """Response model for WordPress posts list.""" + + found: int + posts: List[Post] + meta: Dict[str, Any] + + +def normalize_site(site: str) -> str: + """ + Normalize a site identifier by stripping protocol and trailing slashes. + + Args: + site: Site URL, domain, or ID (e.g., "https://myblog.wordpress.com/", "myblog.wordpress.com", "123456789") + + Returns: + Normalized site identifier (domain or ID only) + """ + site = site.strip() + if site.startswith("https://"): + site = site[8:] + elif site.startswith("http://"): + site = site[7:] + return site.rstrip("/") + + +async def get_posts( + credentials: Credentials, + site: str, + status: PostStatus | None = None, + number: int = 100, + offset: int = 0, +) -> PostsResponse: + """ + Get posts from a WordPress site. + + Args: + credentials: OAuth credentials + site: Site ID or domain (e.g., "myblog.wordpress.com" or "123456789") + status: Filter by post status using PostStatus enum, or None for all + number: Number of posts to retrieve (max 100) + offset: Number of posts to skip (for pagination) + + Returns: + PostsResponse with the list of posts + """ + site = normalize_site(site) + endpoint = f"/rest/v1.1/sites/{site}/posts" + + headers = { + "Authorization": credentials.auth_header(), + } + + params: Dict[str, Any] = { + "number": max(1, min(number, 100)), # 1–100 posts per request + "offset": offset, + } + + if status: + params["status"] = status.value + response = await Requests(raise_for_status=False).get( + f"{WORDPRESS_BASE_URL.rstrip('/')}{endpoint}", + headers=headers, + params=params, + ) + + if response.ok: + return PostsResponse.model_validate(response.json()) + + error_data = ( + response.json() + if response.headers.get("content-type", "").startswith("application/json") + else {} + ) + error_message = error_data.get("message", response.text) + raise ValueError(f"Failed to get posts: {response.status} - {error_message}") diff --git a/autogpt_platform/backend/backend/blocks/wordpress/blog.py b/autogpt_platform/backend/backend/blocks/wordpress/blog.py index c0ad5eca54..22b691480b 100644 --- a/autogpt_platform/backend/backend/blocks/wordpress/blog.py +++ b/autogpt_platform/backend/backend/blocks/wordpress/blog.py @@ -9,7 +9,15 @@ from backend.sdk import ( SchemaField, ) -from ._api import CreatePostRequest, PostResponse, PostStatus, create_post +from ._api import ( + CreatePostRequest, + Post, + PostResponse, + PostsResponse, + PostStatus, + create_post, + get_posts, +) from ._config import wordpress @@ -49,8 +57,15 @@ class WordPressCreatePostBlock(Block): media_urls: list[str] = SchemaField( description="URLs of images to sideload and attach to the post", default=[] ) + publish_as_draft: bool = SchemaField( + description="If True, publishes the post as a draft. If False, publishes it publicly.", + default=False, + ) class Output(BlockSchemaOutput): + site: str = SchemaField( + description="The site ID or domain (pass-through for chaining with other blocks)" + ) post_id: int = SchemaField(description="The ID of the created post") post_url: str = SchemaField(description="The full URL of the created post") short_url: str = SchemaField(description="The shortened wp.me URL") @@ -78,7 +93,9 @@ class WordPressCreatePostBlock(Block): tags=input_data.tags, featured_image=input_data.featured_image, media_urls=input_data.media_urls, - status=PostStatus.PUBLISH, + status=( + PostStatus.DRAFT if input_data.publish_as_draft else PostStatus.PUBLISH + ), ) post_response: PostResponse = await create_post( @@ -87,7 +104,69 @@ class WordPressCreatePostBlock(Block): post_data=post_request, ) + yield "site", input_data.site yield "post_id", post_response.ID yield "post_url", post_response.URL yield "short_url", post_response.short_URL yield "post_data", post_response.model_dump() + + +class WordPressGetAllPostsBlock(Block): + """ + Fetches all posts from a WordPress.com site or Jetpack-enabled site. + Supports filtering by status and pagination. + """ + + class Input(BlockSchemaInput): + credentials: CredentialsMetaInput = wordpress.credentials_field() + site: str = SchemaField( + description="Site ID or domain (e.g., 'myblog.wordpress.com' or '123456789')" + ) + status: PostStatus | None = SchemaField( + description="Filter by post status, or None for all", + default=None, + ) + number: int = SchemaField( + description="Number of posts to retrieve (max 100 per request)", default=20 + ) + offset: int = SchemaField( + description="Number of posts to skip (for pagination)", default=0 + ) + + class Output(BlockSchemaOutput): + site: str = SchemaField( + description="The site ID or domain (pass-through for chaining with other blocks)" + ) + found: int = SchemaField(description="Total number of posts found") + posts: list[Post] = SchemaField( + description="List of post objects with their details" + ) + post: Post = SchemaField( + description="Individual post object (yielded for each post)" + ) + + def __init__(self): + super().__init__( + id="97728fa7-7f6f-4789-ba0c-f2c114119536", + description="Fetch all posts from WordPress.com or Jetpack sites", + categories={BlockCategory.SOCIAL}, + input_schema=self.Input, + output_schema=self.Output, + ) + + async def run( + self, input_data: Input, *, credentials: Credentials, **kwargs + ) -> BlockOutput: + posts_response: PostsResponse = await get_posts( + credentials=credentials, + site=input_data.site, + status=input_data.status, + number=input_data.number, + offset=input_data.offset, + ) + + yield "site", input_data.site + yield "found", posts_response.found + yield "posts", posts_response.posts + for post in posts_response.posts: + yield "post", post diff --git a/autogpt_platform/backend/backend/data/block.py b/autogpt_platform/backend/backend/data/block.py index 727688dcf0..24a68cca03 100644 --- a/autogpt_platform/backend/backend/data/block.py +++ b/autogpt_platform/backend/backend/data/block.py @@ -50,6 +50,8 @@ from .model import ( logger = logging.getLogger(__name__) if TYPE_CHECKING: + from backend.data.execution import ExecutionContext + from .graph import Link app_config = Config() @@ -472,6 +474,7 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]): self.block_type = block_type self.webhook_config = webhook_config self.execution_stats: NodeExecutionStats = NodeExecutionStats() + self.requires_human_review: bool = False if self.webhook_config: if isinstance(self.webhook_config, BlockWebhookConfig): @@ -614,7 +617,77 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]): block_id=self.id, ) from ex + async def is_block_exec_need_review( + self, + input_data: BlockInput, + *, + user_id: str, + node_exec_id: str, + graph_exec_id: str, + graph_id: str, + graph_version: int, + execution_context: "ExecutionContext", + **kwargs, + ) -> tuple[bool, BlockInput]: + """ + Check if this block execution needs human review and handle the review process. + + Returns: + Tuple of (should_pause, input_data_to_use) + - should_pause: True if execution should be paused for review + - input_data_to_use: The input data to use (may be modified by reviewer) + """ + # Skip review if not required or safe mode is disabled + if not self.requires_human_review or not execution_context.safe_mode: + return False, input_data + + from backend.blocks.helpers.review import HITLReviewHelper + + # Handle the review request and get decision + decision = await HITLReviewHelper.handle_review_decision( + input_data=input_data, + user_id=user_id, + node_exec_id=node_exec_id, + graph_exec_id=graph_exec_id, + graph_id=graph_id, + graph_version=graph_version, + execution_context=execution_context, + block_name=self.name, + editable=True, + ) + + if decision is None: + # We're awaiting review - pause execution + return True, input_data + + if not decision.should_proceed: + # Review was rejected, raise an error to stop execution + raise BlockExecutionError( + message=f"Block execution rejected by reviewer: {decision.message}", + block_name=self.name, + block_id=self.id, + ) + + # Review was approved - use the potentially modified data + # ReviewResult.data must be a dict for block inputs + reviewed_data = decision.review_result.data + if not isinstance(reviewed_data, dict): + raise BlockExecutionError( + message=f"Review data must be a dict for block input, got {type(reviewed_data).__name__}", + block_name=self.name, + block_id=self.id, + ) + return False, reviewed_data + async def _execute(self, input_data: BlockInput, **kwargs) -> BlockOutput: + # Check for review requirement and get potentially modified input data + should_pause, input_data = await self.is_block_exec_need_review( + input_data, **kwargs + ) + if should_pause: + return + + # Validate the input data (original or reviewer-modified) once if error := self.input_schema.validate_data(input_data): raise BlockInputError( message=f"Unable to execute block with invalid input data: {error}", @@ -622,6 +695,7 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]): block_id=self.id, ) + # Use the validated input data async for output_name, output_data in self.run( self.input_schema(**{k: v for k, v in input_data.items() if v is not None}), **kwargs, diff --git a/autogpt_platform/backend/backend/data/execution.py b/autogpt_platform/backend/backend/data/execution.py index 020a5a1906..2759dfe179 100644 --- a/autogpt_platform/backend/backend/data/execution.py +++ b/autogpt_platform/backend/backend/data/execution.py @@ -383,6 +383,7 @@ class GraphExecutionWithNodes(GraphExecution): self, execution_context: ExecutionContext, compiled_nodes_input_masks: Optional[NodesInputMasks] = None, + nodes_to_skip: Optional[set[str]] = None, ): return GraphExecutionEntry( user_id=self.user_id, @@ -390,6 +391,7 @@ class GraphExecutionWithNodes(GraphExecution): graph_version=self.graph_version or 0, graph_exec_id=self.id, nodes_input_masks=compiled_nodes_input_masks, + nodes_to_skip=nodes_to_skip or set(), execution_context=execution_context, ) @@ -1145,6 +1147,8 @@ class GraphExecutionEntry(BaseModel): graph_id: str graph_version: int nodes_input_masks: Optional[NodesInputMasks] = None + nodes_to_skip: set[str] = Field(default_factory=set) + """Node IDs that should be skipped due to optional credentials not being configured.""" execution_context: ExecutionContext = Field(default_factory=ExecutionContext) diff --git a/autogpt_platform/backend/backend/data/graph.py b/autogpt_platform/backend/backend/data/graph.py index 0757a86f4a..e9be80892c 100644 --- a/autogpt_platform/backend/backend/data/graph.py +++ b/autogpt_platform/backend/backend/data/graph.py @@ -94,6 +94,15 @@ class Node(BaseDbModel): input_links: list[Link] = [] output_links: list[Link] = [] + @property + def credentials_optional(self) -> bool: + """ + Whether credentials are optional for this node. + When True and credentials are not configured, the node will be skipped + during execution rather than causing a validation error. + """ + return self.metadata.get("credentials_optional", False) + @property def block(self) -> AnyBlockSchema | "_UnknownBlockBase": """Get the block for this node. Returns UnknownBlock if block is deleted/missing.""" @@ -235,7 +244,10 @@ class BaseGraph(BaseDbModel): return any( node.block_id for node in self.nodes - if node.block.block_type == BlockType.HUMAN_IN_THE_LOOP + if ( + node.block.block_type == BlockType.HUMAN_IN_THE_LOOP + or node.block.requires_human_review + ) ) @property @@ -326,7 +338,35 @@ class Graph(BaseGraph): @computed_field @property def credentials_input_schema(self) -> dict[str, Any]: - return self._credentials_input_schema.jsonschema() + schema = self._credentials_input_schema.jsonschema() + + # Determine which credential fields are required based on credentials_optional metadata + graph_credentials_inputs = self.aggregate_credentials_inputs() + required_fields = [] + + # Build a map of node_id -> node for quick lookup + all_nodes = {node.id: node for node in self.nodes} + for sub_graph in self.sub_graphs: + for node in sub_graph.nodes: + all_nodes[node.id] = node + + for field_key, ( + _field_info, + node_field_pairs, + ) in graph_credentials_inputs.items(): + # A field is required if ANY node using it has credentials_optional=False + is_required = False + for node_id, _field_name in node_field_pairs: + node = all_nodes.get(node_id) + if node and not node.credentials_optional: + is_required = True + break + + if is_required: + required_fields.append(field_key) + + schema["required"] = required_fields + return schema @property def _credentials_input_schema(self) -> type[BlockSchema]: diff --git a/autogpt_platform/backend/backend/data/graph_test.py b/autogpt_platform/backend/backend/data/graph_test.py index 044d75e0ca..eea7277eb9 100644 --- a/autogpt_platform/backend/backend/data/graph_test.py +++ b/autogpt_platform/backend/backend/data/graph_test.py @@ -396,3 +396,58 @@ async def test_access_store_listing_graph(server: SpinTestServer): created_graph.id, created_graph.version, "3e53486c-cf57-477e-ba2a-cb02dc828e1b" ) assert got_graph is not None + + +# ============================================================================ +# Tests for Optional Credentials Feature +# ============================================================================ + + +def test_node_credentials_optional_default(): + """Test that credentials_optional defaults to False when not set in metadata.""" + node = Node( + id="test_node", + block_id=StoreValueBlock().id, + input_default={}, + metadata={}, + ) + assert node.credentials_optional is False + + +def test_node_credentials_optional_true(): + """Test that credentials_optional returns True when explicitly set.""" + node = Node( + id="test_node", + block_id=StoreValueBlock().id, + input_default={}, + metadata={"credentials_optional": True}, + ) + assert node.credentials_optional is True + + +def test_node_credentials_optional_false(): + """Test that credentials_optional returns False when explicitly set to False.""" + node = Node( + id="test_node", + block_id=StoreValueBlock().id, + input_default={}, + metadata={"credentials_optional": False}, + ) + assert node.credentials_optional is False + + +def test_node_credentials_optional_with_other_metadata(): + """Test that credentials_optional works correctly with other metadata present.""" + node = Node( + id="test_node", + block_id=StoreValueBlock().id, + input_default={}, + metadata={ + "position": {"x": 100, "y": 200}, + "customized_name": "My Custom Node", + "credentials_optional": True, + }, + ) + assert node.credentials_optional is True + assert node.metadata["position"] == {"x": 100, "y": 200} + assert node.metadata["customized_name"] == "My Custom Node" diff --git a/autogpt_platform/backend/backend/executor/manager.py b/autogpt_platform/backend/backend/executor/manager.py index 75459c5a2a..39d4f984eb 100644 --- a/autogpt_platform/backend/backend/executor/manager.py +++ b/autogpt_platform/backend/backend/executor/manager.py @@ -178,6 +178,7 @@ async def execute_node( execution_processor: "ExecutionProcessor", execution_stats: NodeExecutionStats | None = None, nodes_input_masks: Optional[NodesInputMasks] = None, + nodes_to_skip: Optional[set[str]] = None, ) -> BlockOutput: """ Execute a node in the graph. This will trigger a block execution on a node, @@ -245,6 +246,7 @@ async def execute_node( "user_id": user_id, "execution_context": execution_context, "execution_processor": execution_processor, + "nodes_to_skip": nodes_to_skip or set(), } # Last-minute fetch credentials + acquire a system-wide read-write lock to prevent @@ -542,6 +544,7 @@ class ExecutionProcessor: node_exec_progress: NodeExecutionProgress, nodes_input_masks: Optional[NodesInputMasks], graph_stats_pair: tuple[GraphExecutionStats, threading.Lock], + nodes_to_skip: Optional[set[str]] = None, ) -> NodeExecutionStats: log_metadata = LogMetadata( logger=_logger, @@ -564,6 +567,7 @@ class ExecutionProcessor: db_client=db_client, log_metadata=log_metadata, nodes_input_masks=nodes_input_masks, + nodes_to_skip=nodes_to_skip, ) if isinstance(status, BaseException): raise status @@ -609,6 +613,7 @@ class ExecutionProcessor: db_client: "DatabaseManagerAsyncClient", log_metadata: LogMetadata, nodes_input_masks: Optional[NodesInputMasks] = None, + nodes_to_skip: Optional[set[str]] = None, ) -> ExecutionStatus: status = ExecutionStatus.RUNNING @@ -645,6 +650,7 @@ class ExecutionProcessor: execution_processor=self, execution_stats=stats, nodes_input_masks=nodes_input_masks, + nodes_to_skip=nodes_to_skip, ): await persist_output(output_name, output_data) @@ -956,6 +962,21 @@ class ExecutionProcessor: queued_node_exec = execution_queue.get() + # Check if this node should be skipped due to optional credentials + if queued_node_exec.node_id in graph_exec.nodes_to_skip: + log_metadata.info( + f"Skipping node execution {queued_node_exec.node_exec_id} " + f"for node {queued_node_exec.node_id} - optional credentials not configured" + ) + # Mark the node as completed without executing + # No outputs will be produced, so downstream nodes won't trigger + update_node_execution_status( + db_client=db_client, + exec_id=queued_node_exec.node_exec_id, + status=ExecutionStatus.COMPLETED, + ) + continue + log_metadata.debug( f"Dispatching node execution {queued_node_exec.node_exec_id} " f"for node {queued_node_exec.node_id}", @@ -1016,6 +1037,7 @@ class ExecutionProcessor: execution_stats, execution_stats_lock, ), + nodes_to_skip=graph_exec.nodes_to_skip, ), self.node_execution_loop, ) diff --git a/autogpt_platform/backend/backend/executor/utils.py b/autogpt_platform/backend/backend/executor/utils.py index bcd3dcf3b6..1fb2b9404f 100644 --- a/autogpt_platform/backend/backend/executor/utils.py +++ b/autogpt_platform/backend/backend/executor/utils.py @@ -239,14 +239,19 @@ async def _validate_node_input_credentials( graph: GraphModel, user_id: str, nodes_input_masks: Optional[NodesInputMasks] = None, -) -> dict[str, dict[str, str]]: +) -> tuple[dict[str, dict[str, str]], set[str]]: """ - Checks all credentials for all nodes of the graph and returns structured errors. + Checks all credentials for all nodes of the graph and returns structured errors + and a set of nodes that should be skipped due to optional missing credentials. Returns: - dict[node_id, dict[field_name, error_message]]: Credential validation errors per node + tuple[ + dict[node_id, dict[field_name, error_message]]: Credential validation errors per node, + set[node_id]: Nodes that should be skipped (optional credentials not configured) + ] """ credential_errors: dict[str, dict[str, str]] = defaultdict(dict) + nodes_to_skip: set[str] = set() for node in graph.nodes: block = node.block @@ -256,27 +261,46 @@ async def _validate_node_input_credentials( if not credentials_fields: continue + # Track if any credential field is missing for this node + has_missing_credentials = False + for field_name, credentials_meta_type in credentials_fields.items(): try: + # Check nodes_input_masks first, then input_default + field_value = None if ( nodes_input_masks and (node_input_mask := nodes_input_masks.get(node.id)) and field_name in node_input_mask ): - credentials_meta = credentials_meta_type.model_validate( - node_input_mask[field_name] - ) + field_value = node_input_mask[field_name] elif field_name in node.input_default: - credentials_meta = credentials_meta_type.model_validate( - node.input_default[field_name] - ) - else: - # Missing credentials - credential_errors[node.id][ - field_name - ] = "These credentials are required" - continue + # For optional credentials, don't use input_default - treat as missing + # This prevents stale credential IDs from failing validation + if node.credentials_optional: + field_value = None + else: + field_value = node.input_default[field_name] + + # Check if credentials are missing (None, empty, or not present) + if field_value is None or ( + isinstance(field_value, dict) and not field_value.get("id") + ): + has_missing_credentials = True + # If node has credentials_optional flag, mark for skipping instead of error + if node.credentials_optional: + continue # Don't add error, will be marked for skip after loop + else: + credential_errors[node.id][ + field_name + ] = "These credentials are required" + continue + + credentials_meta = credentials_meta_type.model_validate(field_value) + except ValidationError as e: + # Validation error means credentials were provided but invalid + # This should always be an error, even if optional credential_errors[node.id][field_name] = f"Invalid credentials: {e}" continue @@ -287,6 +311,7 @@ async def _validate_node_input_credentials( ) except Exception as e: # Handle any errors fetching credentials + # If credentials were explicitly configured but unavailable, it's an error credential_errors[node.id][ field_name ] = f"Credentials not available: {e}" @@ -313,7 +338,19 @@ async def _validate_node_input_credentials( ] = "Invalid credentials: type/provider mismatch" continue - return credential_errors + # If node has optional credentials and any are missing, mark for skipping + # But only if there are no other errors for this node + if ( + has_missing_credentials + and node.credentials_optional + and node.id not in credential_errors + ): + nodes_to_skip.add(node.id) + logger.info( + f"Node #{node.id} will be skipped: optional credentials not configured" + ) + + return credential_errors, nodes_to_skip def make_node_credentials_input_map( @@ -355,21 +392,25 @@ async def validate_graph_with_credentials( graph: GraphModel, user_id: str, nodes_input_masks: Optional[NodesInputMasks] = None, -) -> Mapping[str, Mapping[str, str]]: +) -> tuple[Mapping[str, Mapping[str, str]], set[str]]: """ - Validate graph including credentials and return structured errors per node. + Validate graph including credentials and return structured errors per node, + along with a set of nodes that should be skipped due to optional missing credentials. Returns: - dict[node_id, dict[field_name, error_message]]: Validation errors per node + tuple[ + dict[node_id, dict[field_name, error_message]]: Validation errors per node, + set[node_id]: Nodes that should be skipped (optional credentials not configured) + ] """ # Get input validation errors node_input_errors = GraphModel.validate_graph_get_errors( graph, for_run=True, nodes_input_masks=nodes_input_masks ) - # Get credential input/availability/validation errors - node_credential_input_errors = await _validate_node_input_credentials( - graph, user_id, nodes_input_masks + # Get credential input/availability/validation errors and nodes to skip + node_credential_input_errors, nodes_to_skip = ( + await _validate_node_input_credentials(graph, user_id, nodes_input_masks) ) # Merge credential errors with structural errors @@ -378,7 +419,7 @@ async def validate_graph_with_credentials( node_input_errors[node_id] = {} node_input_errors[node_id].update(field_errors) - return node_input_errors + return node_input_errors, nodes_to_skip async def _construct_starting_node_execution_input( @@ -386,7 +427,7 @@ async def _construct_starting_node_execution_input( user_id: str, graph_inputs: BlockInput, nodes_input_masks: Optional[NodesInputMasks] = None, -) -> list[tuple[str, BlockInput]]: +) -> tuple[list[tuple[str, BlockInput]], set[str]]: """ Validates and prepares the input data for executing a graph. This function checks the graph for starting nodes, validates the input data @@ -400,11 +441,14 @@ async def _construct_starting_node_execution_input( node_credentials_map: `dict[node_id, dict[input_name, CredentialsMetaInput]]` Returns: - list[tuple[str, BlockInput]]: A list of tuples, each containing the node ID and - the corresponding input data for that node. + tuple[ + list[tuple[str, BlockInput]]: A list of tuples, each containing the node ID + and the corresponding input data for that node. + set[str]: Node IDs that should be skipped (optional credentials not configured) + ] """ # Use new validation function that includes credentials - validation_errors = await validate_graph_with_credentials( + validation_errors, nodes_to_skip = await validate_graph_with_credentials( graph, user_id, nodes_input_masks ) n_error_nodes = len(validation_errors) @@ -445,7 +489,7 @@ async def _construct_starting_node_execution_input( "No starting nodes found for the graph, make sure an AgentInput or blocks with no inbound links are present as starting nodes." ) - return nodes_input + return nodes_input, nodes_to_skip async def validate_and_construct_node_execution_input( @@ -456,7 +500,7 @@ async def validate_and_construct_node_execution_input( graph_credentials_inputs: Optional[Mapping[str, CredentialsMetaInput]] = None, nodes_input_masks: Optional[NodesInputMasks] = None, is_sub_graph: bool = False, -) -> tuple[GraphModel, list[tuple[str, BlockInput]], NodesInputMasks]: +) -> tuple[GraphModel, list[tuple[str, BlockInput]], NodesInputMasks, set[str]]: """ Public wrapper that handles graph fetching, credential mapping, and validation+construction. This centralizes the logic used by both scheduler validation and actual execution. @@ -473,6 +517,7 @@ async def validate_and_construct_node_execution_input( GraphModel: Full graph object for the given `graph_id`. list[tuple[node_id, BlockInput]]: Starting node IDs with corresponding inputs. dict[str, BlockInput]: Node input masks including all passed-in credentials. + set[str]: Node IDs that should be skipped (optional credentials not configured). Raises: NotFoundError: If the graph is not found. @@ -514,14 +559,16 @@ async def validate_and_construct_node_execution_input( nodes_input_masks or {}, ) - starting_nodes_input = await _construct_starting_node_execution_input( - graph=graph, - user_id=user_id, - graph_inputs=graph_inputs, - nodes_input_masks=nodes_input_masks, + starting_nodes_input, nodes_to_skip = ( + await _construct_starting_node_execution_input( + graph=graph, + user_id=user_id, + graph_inputs=graph_inputs, + nodes_input_masks=nodes_input_masks, + ) ) - return graph, starting_nodes_input, nodes_input_masks + return graph, starting_nodes_input, nodes_input_masks, nodes_to_skip def _merge_nodes_input_masks( @@ -779,6 +826,9 @@ async def add_graph_execution( # Use existing execution's compiled input masks compiled_nodes_input_masks = graph_exec.nodes_input_masks or {} + # For resumed executions, nodes_to_skip was already determined at creation time + # TODO: Consider storing nodes_to_skip in DB if we need to preserve it across resumes + nodes_to_skip: set[str] = set() logger.info(f"Resuming graph execution #{graph_exec.id} for graph #{graph_id}") else: @@ -787,7 +837,7 @@ async def add_graph_execution( ) # Create new execution - graph, starting_nodes_input, compiled_nodes_input_masks = ( + graph, starting_nodes_input, compiled_nodes_input_masks, nodes_to_skip = ( await validate_and_construct_node_execution_input( graph_id=graph_id, user_id=user_id, @@ -836,6 +886,7 @@ async def add_graph_execution( try: graph_exec_entry = graph_exec.to_graph_execution_entry( compiled_nodes_input_masks=compiled_nodes_input_masks, + nodes_to_skip=nodes_to_skip, execution_context=execution_context, ) logger.info(f"Publishing execution {graph_exec.id} to execution queue") diff --git a/autogpt_platform/backend/backend/executor/utils_test.py b/autogpt_platform/backend/backend/executor/utils_test.py index 8854214e14..0e652f9627 100644 --- a/autogpt_platform/backend/backend/executor/utils_test.py +++ b/autogpt_platform/backend/backend/executor/utils_test.py @@ -367,10 +367,13 @@ async def test_add_graph_execution_is_repeatable(mocker: MockerFixture): ) # Setup mock returns + # The function returns (graph, starting_nodes_input, compiled_nodes_input_masks, nodes_to_skip) + nodes_to_skip: set[str] = set() mock_validate.return_value = ( mock_graph, starting_nodes_input, compiled_nodes_input_masks, + nodes_to_skip, ) mock_prisma.is_connected.return_value = True mock_edb.create_graph_execution = mocker.AsyncMock(return_value=mock_graph_exec) @@ -456,3 +459,212 @@ async def test_add_graph_execution_is_repeatable(mocker: MockerFixture): # Both executions should succeed (though they create different objects) assert result1 == mock_graph_exec assert result2 == mock_graph_exec_2 + + +# ============================================================================ +# Tests for Optional Credentials Feature +# ============================================================================ + + +@pytest.mark.asyncio +async def test_validate_node_input_credentials_returns_nodes_to_skip( + mocker: MockerFixture, +): + """ + Test that _validate_node_input_credentials returns nodes_to_skip set + for nodes with credentials_optional=True and missing credentials. + """ + from backend.executor.utils import _validate_node_input_credentials + + # Create a mock node with credentials_optional=True + mock_node = mocker.MagicMock() + mock_node.id = "node-with-optional-creds" + mock_node.credentials_optional = True + mock_node.input_default = {} # No credentials configured + + # Create a mock block with credentials field + mock_block = mocker.MagicMock() + mock_credentials_field_type = mocker.MagicMock() + mock_block.input_schema.get_credentials_fields.return_value = { + "credentials": mock_credentials_field_type + } + mock_node.block = mock_block + + # Create mock graph + mock_graph = mocker.MagicMock() + mock_graph.nodes = [mock_node] + + # Call the function + errors, nodes_to_skip = await _validate_node_input_credentials( + graph=mock_graph, + user_id="test-user-id", + nodes_input_masks=None, + ) + + # Node should be in nodes_to_skip, not in errors + assert mock_node.id in nodes_to_skip + assert mock_node.id not in errors + + +@pytest.mark.asyncio +async def test_validate_node_input_credentials_required_missing_creds_error( + mocker: MockerFixture, +): + """ + Test that _validate_node_input_credentials returns errors + for nodes with credentials_optional=False and missing credentials. + """ + from backend.executor.utils import _validate_node_input_credentials + + # Create a mock node with credentials_optional=False (required) + mock_node = mocker.MagicMock() + mock_node.id = "node-with-required-creds" + mock_node.credentials_optional = False + mock_node.input_default = {} # No credentials configured + + # Create a mock block with credentials field + mock_block = mocker.MagicMock() + mock_credentials_field_type = mocker.MagicMock() + mock_block.input_schema.get_credentials_fields.return_value = { + "credentials": mock_credentials_field_type + } + mock_node.block = mock_block + + # Create mock graph + mock_graph = mocker.MagicMock() + mock_graph.nodes = [mock_node] + + # Call the function + errors, nodes_to_skip = await _validate_node_input_credentials( + graph=mock_graph, + user_id="test-user-id", + nodes_input_masks=None, + ) + + # Node should be in errors, not in nodes_to_skip + assert mock_node.id in errors + assert "credentials" in errors[mock_node.id] + assert "required" in errors[mock_node.id]["credentials"].lower() + assert mock_node.id not in nodes_to_skip + + +@pytest.mark.asyncio +async def test_validate_graph_with_credentials_returns_nodes_to_skip( + mocker: MockerFixture, +): + """ + Test that validate_graph_with_credentials returns nodes_to_skip set + from _validate_node_input_credentials. + """ + from backend.executor.utils import validate_graph_with_credentials + + # Mock _validate_node_input_credentials to return specific values + mock_validate = mocker.patch( + "backend.executor.utils._validate_node_input_credentials" + ) + expected_errors = {"node1": {"field": "error"}} + expected_nodes_to_skip = {"node2", "node3"} + mock_validate.return_value = (expected_errors, expected_nodes_to_skip) + + # Mock GraphModel with validate_graph_get_errors method + mock_graph = mocker.MagicMock() + mock_graph.validate_graph_get_errors.return_value = {} + + # Call the function + errors, nodes_to_skip = await validate_graph_with_credentials( + graph=mock_graph, + user_id="test-user-id", + nodes_input_masks=None, + ) + + # Verify nodes_to_skip is passed through + assert nodes_to_skip == expected_nodes_to_skip + assert "node1" in errors + + +@pytest.mark.asyncio +async def test_add_graph_execution_with_nodes_to_skip(mocker: MockerFixture): + """ + Test that add_graph_execution properly passes nodes_to_skip + to the graph execution entry. + """ + from backend.data.execution import GraphExecutionWithNodes + from backend.executor.utils import add_graph_execution + + # Mock data + graph_id = "test-graph-id" + user_id = "test-user-id" + inputs = {"test_input": "test_value"} + graph_version = 1 + + # Mock the graph object + mock_graph = mocker.MagicMock() + mock_graph.version = graph_version + + # Starting nodes and masks + starting_nodes_input = [("node1", {"input1": "value1"})] + compiled_nodes_input_masks = {} + nodes_to_skip = {"skipped-node-1", "skipped-node-2"} + + # Mock the graph execution object + mock_graph_exec = mocker.MagicMock(spec=GraphExecutionWithNodes) + mock_graph_exec.id = "execution-id-123" + mock_graph_exec.node_executions = [] + + # Track what's passed to to_graph_execution_entry + captured_kwargs = {} + + def capture_to_entry(**kwargs): + captured_kwargs.update(kwargs) + return mocker.MagicMock() + + mock_graph_exec.to_graph_execution_entry.side_effect = capture_to_entry + + # Setup mocks + mock_validate = mocker.patch( + "backend.executor.utils.validate_and_construct_node_execution_input" + ) + mock_edb = mocker.patch("backend.executor.utils.execution_db") + mock_prisma = mocker.patch("backend.executor.utils.prisma") + mock_udb = mocker.patch("backend.executor.utils.user_db") + mock_gdb = mocker.patch("backend.executor.utils.graph_db") + mock_get_queue = mocker.patch("backend.executor.utils.get_async_execution_queue") + mock_get_event_bus = mocker.patch( + "backend.executor.utils.get_async_execution_event_bus" + ) + + # Setup returns - include nodes_to_skip in the tuple + mock_validate.return_value = ( + mock_graph, + starting_nodes_input, + compiled_nodes_input_masks, + nodes_to_skip, # This should be passed through + ) + mock_prisma.is_connected.return_value = True + mock_edb.create_graph_execution = mocker.AsyncMock(return_value=mock_graph_exec) + mock_edb.update_graph_execution_stats = mocker.AsyncMock( + return_value=mock_graph_exec + ) + mock_edb.update_node_execution_status_batch = mocker.AsyncMock() + + mock_user = mocker.MagicMock() + mock_user.timezone = "UTC" + mock_settings = mocker.MagicMock() + mock_settings.human_in_the_loop_safe_mode = True + + mock_udb.get_user_by_id = mocker.AsyncMock(return_value=mock_user) + mock_gdb.get_graph_settings = mocker.AsyncMock(return_value=mock_settings) + mock_get_queue.return_value = mocker.AsyncMock() + mock_get_event_bus.return_value = mocker.MagicMock(publish=mocker.AsyncMock()) + + # Call the function + await add_graph_execution( + graph_id=graph_id, + user_id=user_id, + inputs=inputs, + graph_version=graph_version, + ) + + # Verify nodes_to_skip was passed to to_graph_execution_entry + assert "nodes_to_skip" in captured_kwargs + assert captured_kwargs["nodes_to_skip"] == nodes_to_skip diff --git a/autogpt_platform/backend/backend/integrations/oauth/__init__.py b/autogpt_platform/backend/backend/integrations/oauth/__init__.py index 137b9eadfd..1ea7a23ee7 100644 --- a/autogpt_platform/backend/backend/integrations/oauth/__init__.py +++ b/autogpt_platform/backend/backend/integrations/oauth/__init__.py @@ -8,6 +8,7 @@ from .discord import DiscordOAuthHandler from .github import GitHubOAuthHandler from .google import GoogleOAuthHandler from .notion import NotionOAuthHandler +from .reddit import RedditOAuthHandler from .twitter import TwitterOAuthHandler if TYPE_CHECKING: @@ -20,6 +21,7 @@ _ORIGINAL_HANDLERS = [ GitHubOAuthHandler, GoogleOAuthHandler, NotionOAuthHandler, + RedditOAuthHandler, TwitterOAuthHandler, TodoistOAuthHandler, ] diff --git a/autogpt_platform/backend/backend/integrations/oauth/reddit.py b/autogpt_platform/backend/backend/integrations/oauth/reddit.py new file mode 100644 index 0000000000..a69e1e62c7 --- /dev/null +++ b/autogpt_platform/backend/backend/integrations/oauth/reddit.py @@ -0,0 +1,208 @@ +import time +import urllib.parse +from typing import ClassVar, Optional + +from pydantic import SecretStr + +from backend.data.model import OAuth2Credentials +from backend.integrations.oauth.base import BaseOAuthHandler +from backend.integrations.providers import ProviderName +from backend.util.request import Requests +from backend.util.settings import Settings + +settings = Settings() + + +class RedditOAuthHandler(BaseOAuthHandler): + """ + Reddit OAuth 2.0 handler. + + Based on the documentation at: + - https://github.com/reddit-archive/reddit/wiki/OAuth2 + + Notes: + - Reddit requires `duration=permanent` to get refresh tokens + - Access tokens expire after 1 hour (3600 seconds) + - Reddit requires HTTP Basic Auth for token requests + - Reddit requires a unique User-Agent header + """ + + PROVIDER_NAME = ProviderName.REDDIT + DEFAULT_SCOPES: ClassVar[list[str]] = [ + "identity", # Get username, verify auth + "read", # Access posts and comments + "submit", # Submit new posts and comments + "edit", # Edit own posts and comments + "history", # Access user's post history + "privatemessages", # Access inbox and send private messages + "flair", # Access and set flair on posts/subreddits + ] + + AUTHORIZE_URL = "https://www.reddit.com/api/v1/authorize" + TOKEN_URL = "https://www.reddit.com/api/v1/access_token" + USERNAME_URL = "https://oauth.reddit.com/api/v1/me" + REVOKE_URL = "https://www.reddit.com/api/v1/revoke_token" + + def __init__(self, client_id: str, client_secret: str, redirect_uri: str): + self.client_id = client_id + self.client_secret = client_secret + self.redirect_uri = redirect_uri + + def get_login_url( + self, scopes: list[str], state: str, code_challenge: Optional[str] + ) -> str: + """Generate Reddit OAuth 2.0 authorization URL""" + scopes = self.handle_default_scopes(scopes) + + params = { + "response_type": "code", + "client_id": self.client_id, + "redirect_uri": self.redirect_uri, + "scope": " ".join(scopes), + "state": state, + "duration": "permanent", # Required for refresh tokens + } + + return f"{self.AUTHORIZE_URL}?{urllib.parse.urlencode(params)}" + + async def exchange_code_for_tokens( + self, code: str, scopes: list[str], code_verifier: Optional[str] + ) -> OAuth2Credentials: + """Exchange authorization code for access tokens""" + scopes = self.handle_default_scopes(scopes) + + headers = { + "Content-Type": "application/x-www-form-urlencoded", + "User-Agent": settings.config.reddit_user_agent, + } + + data = { + "grant_type": "authorization_code", + "code": code, + "redirect_uri": self.redirect_uri, + } + + # Reddit requires HTTP Basic Auth for token requests + auth = (self.client_id, self.client_secret) + + response = await Requests().post( + self.TOKEN_URL, headers=headers, data=data, auth=auth + ) + + if not response.ok: + error_text = response.text() + raise ValueError( + f"Reddit token exchange failed: {response.status} - {error_text}" + ) + + tokens = response.json() + + if "error" in tokens: + raise ValueError(f"Reddit OAuth error: {tokens.get('error')}") + + username = await self._get_username(tokens["access_token"]) + + return OAuth2Credentials( + provider=self.PROVIDER_NAME, + title=None, + username=username, + access_token=tokens["access_token"], + refresh_token=tokens.get("refresh_token"), + access_token_expires_at=int(time.time()) + tokens.get("expires_in", 3600), + refresh_token_expires_at=None, # Reddit refresh tokens don't expire + scopes=scopes, + ) + + async def _get_username(self, access_token: str) -> str: + """Get the username from the access token""" + headers = { + "Authorization": f"Bearer {access_token}", + "User-Agent": settings.config.reddit_user_agent, + } + + response = await Requests().get(self.USERNAME_URL, headers=headers) + + if not response.ok: + raise ValueError(f"Failed to get Reddit username: {response.status}") + + data = response.json() + return data.get("name", "unknown") + + async def _refresh_tokens( + self, credentials: OAuth2Credentials + ) -> OAuth2Credentials: + """Refresh access tokens using refresh token""" + if not credentials.refresh_token: + raise ValueError("No refresh token available") + + headers = { + "Content-Type": "application/x-www-form-urlencoded", + "User-Agent": settings.config.reddit_user_agent, + } + + data = { + "grant_type": "refresh_token", + "refresh_token": credentials.refresh_token.get_secret_value(), + } + + auth = (self.client_id, self.client_secret) + + response = await Requests().post( + self.TOKEN_URL, headers=headers, data=data, auth=auth + ) + + if not response.ok: + error_text = response.text() + raise ValueError( + f"Reddit token refresh failed: {response.status} - {error_text}" + ) + + tokens = response.json() + + if "error" in tokens: + raise ValueError(f"Reddit OAuth error: {tokens.get('error')}") + + username = await self._get_username(tokens["access_token"]) + + # Reddit may or may not return a new refresh token + new_refresh_token = tokens.get("refresh_token") + if new_refresh_token: + refresh_token: SecretStr | None = SecretStr(new_refresh_token) + elif credentials.refresh_token: + # Keep the existing refresh token + refresh_token = credentials.refresh_token + else: + refresh_token = None + + return OAuth2Credentials( + id=credentials.id, + provider=self.PROVIDER_NAME, + title=credentials.title, + username=username, + access_token=tokens["access_token"], + refresh_token=refresh_token, + access_token_expires_at=int(time.time()) + tokens.get("expires_in", 3600), + refresh_token_expires_at=None, + scopes=credentials.scopes, + ) + + async def revoke_tokens(self, credentials: OAuth2Credentials) -> bool: + """Revoke the access token""" + headers = { + "Content-Type": "application/x-www-form-urlencoded", + "User-Agent": settings.config.reddit_user_agent, + } + + data = { + "token": credentials.access_token.get_secret_value(), + "token_type_hint": "access_token", + } + + auth = (self.client_id, self.client_secret) + + response = await Requests().post( + self.REVOKE_URL, headers=headers, data=data, auth=auth + ) + + # Reddit returns 204 No Content on successful revocation + return response.ok diff --git a/autogpt_platform/backend/backend/util/settings.py b/autogpt_platform/backend/backend/util/settings.py index 0f17b1215c..7a51200eaf 100644 --- a/autogpt_platform/backend/backend/util/settings.py +++ b/autogpt_platform/backend/backend/util/settings.py @@ -264,7 +264,7 @@ class Config(UpdateTrackingModel["Config"], BaseSettings): ) reddit_user_agent: str = Field( - default="AutoGPT:1.0 (by /u/autogpt)", + default="web:AutoGPT:v0.6.0 (by /u/autogpt)", description="The user agent for the Reddit API", ) diff --git a/autogpt_platform/backend/gen_prisma_types_stub.py b/autogpt_platform/backend/gen_prisma_types_stub.py new file mode 100644 index 0000000000..3f6073b2ff --- /dev/null +++ b/autogpt_platform/backend/gen_prisma_types_stub.py @@ -0,0 +1,227 @@ +#!/usr/bin/env python3 +""" +Generate a lightweight stub for prisma/types.py that collapses all exported +symbols to Any. This prevents Pyright from spending time/budget on Prisma's +query DSL types while keeping runtime behavior unchanged. + +Usage: + poetry run gen-prisma-stub + +This script automatically finds the prisma package location and generates +the types.pyi stub file in the same directory as types.py. +""" + +from __future__ import annotations + +import ast +import importlib.util +import sys +from pathlib import Path +from typing import Iterable, Set + + +def _iter_assigned_names(target: ast.expr) -> Iterable[str]: + """Extract names from assignment targets (handles tuple unpacking).""" + if isinstance(target, ast.Name): + yield target.id + elif isinstance(target, (ast.Tuple, ast.List)): + for elt in target.elts: + yield from _iter_assigned_names(elt) + + +def _is_private(name: str) -> bool: + """Check if a name is private (starts with _ but not __).""" + return name.startswith("_") and not name.startswith("__") + + +def _is_safe_type_alias(node: ast.Assign) -> bool: + """Check if an assignment is a safe type alias that shouldn't be stubbed. + + Safe types are: + - Literal types (don't cause type budget issues) + - Simple type references (SortMode, SortOrder, etc.) + - TypeVar definitions + """ + if not node.value: + return False + + # Check if it's a Subscript (like Literal[...], Union[...], TypeVar[...]) + if isinstance(node.value, ast.Subscript): + # Get the base type name + if isinstance(node.value.value, ast.Name): + base_name = node.value.value.id + # Literal types are safe + if base_name == "Literal": + return True + # TypeVar is safe + if base_name == "TypeVar": + return True + elif isinstance(node.value.value, ast.Attribute): + # Handle typing_extensions.Literal etc. + if node.value.value.attr == "Literal": + return True + + # Check if it's a simple Name reference (like SortMode = _types.SortMode) + if isinstance(node.value, ast.Attribute): + return True + + # Check if it's a Call (like TypeVar(...)) + if isinstance(node.value, ast.Call): + if isinstance(node.value.func, ast.Name): + if node.value.func.id == "TypeVar": + return True + + return False + + +def collect_top_level_symbols( + tree: ast.Module, source_lines: list[str] +) -> tuple[Set[str], Set[str], list[str], Set[str]]: + """Collect all top-level symbols from an AST module. + + Returns: + Tuple of (class_names, function_names, safe_variable_sources, unsafe_variable_names) + safe_variable_sources contains the actual source code lines for safe variables + """ + classes: Set[str] = set() + functions: Set[str] = set() + safe_variable_sources: list[str] = [] + unsafe_variables: Set[str] = set() + + for node in tree.body: + if isinstance(node, ast.ClassDef): + if not _is_private(node.name): + classes.add(node.name) + elif isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): + if not _is_private(node.name): + functions.add(node.name) + elif isinstance(node, ast.Assign): + is_safe = _is_safe_type_alias(node) + names = [] + for t in node.targets: + for n in _iter_assigned_names(t): + if not _is_private(n): + names.append(n) + if names: + if is_safe: + # Extract the source code for this assignment + start_line = node.lineno - 1 # 0-indexed + end_line = node.end_lineno if node.end_lineno else node.lineno + source = "\n".join(source_lines[start_line:end_line]) + safe_variable_sources.append(source) + else: + unsafe_variables.update(names) + elif isinstance(node, ast.AnnAssign) and node.target: + # Annotated assignments are always stubbed + for n in _iter_assigned_names(node.target): + if not _is_private(n): + unsafe_variables.add(n) + + return classes, functions, safe_variable_sources, unsafe_variables + + +def find_prisma_types_path() -> Path: + """Find the prisma types.py file in the installed package.""" + spec = importlib.util.find_spec("prisma") + if spec is None or spec.origin is None: + raise RuntimeError("Could not find prisma package. Is it installed?") + + prisma_dir = Path(spec.origin).parent + types_path = prisma_dir / "types.py" + + if not types_path.exists(): + raise RuntimeError(f"prisma/types.py not found at {types_path}") + + return types_path + + +def generate_stub(src_path: Path, stub_path: Path) -> int: + """Generate the .pyi stub file from the source types.py.""" + code = src_path.read_text(encoding="utf-8", errors="ignore") + source_lines = code.splitlines() + tree = ast.parse(code, filename=str(src_path)) + classes, functions, safe_variable_sources, unsafe_variables = ( + collect_top_level_symbols(tree, source_lines) + ) + + header = """\ +# -*- coding: utf-8 -*- +# Auto-generated stub file - DO NOT EDIT +# Generated by gen_prisma_types_stub.py +# +# This stub intentionally collapses complex Prisma query DSL types to Any. +# Prisma's generated types can explode Pyright's type inference budgets +# on large schemas. We collapse them to Any so the rest of the codebase +# can remain strongly typed while keeping runtime behavior unchanged. +# +# Safe types (Literal, TypeVar, simple references) are preserved from the +# original types.py to maintain proper type checking where possible. + +from __future__ import annotations +from typing import Any +from typing_extensions import Literal + +# Re-export commonly used typing constructs that may be imported from this module +from typing import TYPE_CHECKING, TypeVar, Generic, Union, Optional, List, Dict + +# Base type alias for stubbed Prisma types - allows any dict structure +_PrismaDict = dict[str, Any] + +""" + + lines = [header] + + # Include safe variable definitions (Literal types, TypeVars, etc.) + lines.append("# Safe type definitions preserved from original types.py") + for source in safe_variable_sources: + lines.append(source) + lines.append("") + + # Stub all classes and unsafe variables uniformly as dict[str, Any] aliases + # This allows: + # 1. Use in type annotations: x: SomeType + # 2. Constructor calls: SomeType(...) + # 3. Dict literal assignments: x: SomeType = {...} + lines.append( + "# Stubbed types (collapsed to dict[str, Any] to prevent type budget exhaustion)" + ) + all_stubbed = sorted(classes | unsafe_variables) + for name in all_stubbed: + lines.append(f"{name} = _PrismaDict") + + lines.append("") + + # Stub functions + for name in sorted(functions): + lines.append(f"def {name}(*args: Any, **kwargs: Any) -> Any: ...") + + lines.append("") + + stub_path.write_text("\n".join(lines), encoding="utf-8") + return ( + len(classes) + + len(functions) + + len(safe_variable_sources) + + len(unsafe_variables) + ) + + +def main() -> None: + """Main entry point.""" + try: + types_path = find_prisma_types_path() + stub_path = types_path.with_suffix(".pyi") + + print(f"Found prisma types.py at: {types_path}") + print(f"Generating stub at: {stub_path}") + + num_symbols = generate_stub(types_path, stub_path) + print(f"Generated {stub_path.name} with {num_symbols} Any-typed symbols") + + except Exception as e: + print(f"Error: {e}", file=sys.stderr) + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/autogpt_platform/backend/linter.py b/autogpt_platform/backend/linter.py index a86e6761f7..599aae4580 100644 --- a/autogpt_platform/backend/linter.py +++ b/autogpt_platform/backend/linter.py @@ -25,6 +25,9 @@ def run(*command: str) -> None: def lint(): + # Generate Prisma types stub before running pyright to prevent type budget exhaustion + run("gen-prisma-stub") + lint_step_args: list[list[str]] = [ ["ruff", "check", *TARGET_DIRS, "--exit-zero"], ["ruff", "format", "--diff", "--check", LIBS_DIR], @@ -49,4 +52,6 @@ def format(): run("ruff", "format", LIBS_DIR) run("isort", "--profile", "black", BACKEND_DIR) run("black", BACKEND_DIR) + # Generate Prisma types stub before running pyright to prevent type budget exhaustion + run("gen-prisma-stub") run("pyright", *TARGET_DIRS) diff --git a/autogpt_platform/backend/pyproject.toml b/autogpt_platform/backend/pyproject.toml index e8b8fd0ba5..21bf15e776 100644 --- a/autogpt_platform/backend/pyproject.toml +++ b/autogpt_platform/backend/pyproject.toml @@ -117,6 +117,7 @@ lint = "linter:lint" test = "run_tests:test" load-store-agents = "test.load_store_agents:run" export-api-schema = "backend.cli.generate_openapi_json:main" +gen-prisma-stub = "gen_prisma_types_stub:main" oauth-tool = "backend.cli.oauth_tool:cli" [tool.isort] @@ -134,6 +135,9 @@ ignore_patterns = [] [tool.pytest.ini_options] asyncio_mode = "auto" asyncio_default_fixture_loop_scope = "session" +# Disable syrupy plugin to avoid conflict with pytest-snapshot +# Both provide --snapshot-update argument causing ArgumentError +addopts = "-p no:syrupy" filterwarnings = [ "ignore:'audioop' is deprecated:DeprecationWarning:discord.player", "ignore:invalid escape sequence:DeprecationWarning:tweepy.api", diff --git a/autogpt_platform/backend/snapshots/grph_single b/autogpt_platform/backend/snapshots/grph_single index 7ba26f6171..7ce8695e6b 100644 --- a/autogpt_platform/backend/snapshots/grph_single +++ b/autogpt_platform/backend/snapshots/grph_single @@ -2,6 +2,7 @@ "created_at": "2025-09-04T13:37:00", "credentials_input_schema": { "properties": {}, + "required": [], "title": "TestGraphCredentialsInputSchema", "type": "object" }, diff --git a/autogpt_platform/backend/snapshots/grphs_all b/autogpt_platform/backend/snapshots/grphs_all index d54df2bc18..f69b45a6de 100644 --- a/autogpt_platform/backend/snapshots/grphs_all +++ b/autogpt_platform/backend/snapshots/grphs_all @@ -2,6 +2,7 @@ { "credentials_input_schema": { "properties": {}, + "required": [], "title": "TestGraphCredentialsInputSchema", "type": "object" }, diff --git a/autogpt_platform/backend/snapshots/lib_agts_search b/autogpt_platform/backend/snapshots/lib_agts_search index d1feb7d16d..c8e3cc73a6 100644 --- a/autogpt_platform/backend/snapshots/lib_agts_search +++ b/autogpt_platform/backend/snapshots/lib_agts_search @@ -4,6 +4,7 @@ "id": "test-agent-1", "graph_id": "test-agent-1", "graph_version": 1, + "owner_user_id": "3e53486c-cf57-477e-ba2a-cb02dc828e1a", "image_url": null, "creator_name": "Test Creator", "creator_image_url": "", @@ -41,6 +42,7 @@ "id": "test-agent-2", "graph_id": "test-agent-2", "graph_version": 1, + "owner_user_id": "3e53486c-cf57-477e-ba2a-cb02dc828e1a", "image_url": null, "creator_name": "Test Creator", "creator_image_url": "", diff --git a/autogpt_platform/backend/snapshots/sub_success b/autogpt_platform/backend/snapshots/sub_success index 13e2ec570d..268d577745 100644 --- a/autogpt_platform/backend/snapshots/sub_success +++ b/autogpt_platform/backend/snapshots/sub_success @@ -1,6 +1,7 @@ { "submissions": [ { + "listing_id": "test-listing-id", "agent_id": "test-agent-id", "agent_version": 1, "name": "Test Agent", diff --git a/autogpt_platform/docker-compose.platform.yml b/autogpt_platform/docker-compose.platform.yml index b2df626029..de6ecfd612 100644 --- a/autogpt_platform/docker-compose.platform.yml +++ b/autogpt_platform/docker-compose.platform.yml @@ -37,7 +37,7 @@ services: context: ../ dockerfile: autogpt_platform/backend/Dockerfile target: migrate - command: ["sh", "-c", "poetry run prisma generate && poetry run prisma migrate deploy"] + command: ["sh", "-c", "poetry run prisma generate && poetry run gen-prisma-stub && poetry run prisma migrate deploy"] develop: watch: - path: ./ diff --git a/autogpt_platform/frontend/package.json b/autogpt_platform/frontend/package.json index fb8856a30f..f881ebaf5b 100644 --- a/autogpt_platform/frontend/package.json +++ b/autogpt_platform/frontend/package.json @@ -92,7 +92,6 @@ "react-currency-input-field": "4.0.3", "react-day-picker": "9.11.1", "react-dom": "18.3.1", - "react-drag-drop-files": "2.4.0", "react-hook-form": "7.66.0", "react-icons": "5.5.0", "react-markdown": "9.0.3", diff --git a/autogpt_platform/frontend/pnpm-lock.yaml b/autogpt_platform/frontend/pnpm-lock.yaml index 6b3e5e2ffd..4240d0d155 100644 --- a/autogpt_platform/frontend/pnpm-lock.yaml +++ b/autogpt_platform/frontend/pnpm-lock.yaml @@ -200,9 +200,6 @@ importers: react-dom: specifier: 18.3.1 version: 18.3.1(react@18.3.1) - react-drag-drop-files: - specifier: 2.4.0 - version: 2.4.0(react-dom@18.3.1(react@18.3.1))(react@18.3.1) react-hook-form: specifier: 7.66.0 version: 7.66.0(react@18.3.1) @@ -1004,9 +1001,6 @@ packages: '@emotion/memoize@0.8.1': resolution: {integrity: sha512-W2P2c/VRW1/1tLox0mVUalvnWXxavmv/Oum2aPsRcoDJuob75FC3Y8FbpfLwUegRcxINtGUMPq0tFCvYNTBXNA==} - '@emotion/unitless@0.8.1': - resolution: {integrity: sha512-KOEGMu6dmJZtpadb476IsZBclKvILjopjUii3V+7MnXIQCYh8W3NgNcgwo21n9LXZX6EDIKvqfjYxXebDwxKmQ==} - '@epic-web/invariant@1.0.0': resolution: {integrity: sha512-lrTPqgvfFQtR/eY/qkIzp98OGdNJu0m5ji3q/nJI8v3SXkRKEnWiOxMmbvcSoAIzv/cGiuvRy57k4suKQSAdwA==} @@ -3122,9 +3116,6 @@ packages: '@types/statuses@2.0.6': resolution: {integrity: sha512-xMAgYwceFhRA2zY+XbEA7mxYbA093wdiW8Vu6gZPGWy9cmOyU9XesH1tNcEWsKFd5Vzrqx5T3D38PWx1FIIXkA==} - '@types/stylis@4.2.7': - resolution: {integrity: sha512-VgDNokpBoKF+wrdvhAAfS55OMQpL6QRglwTwNC3kIgBrzZxA4WsFj+2eLfEA/uMUDzBcEhYmjSbwQakn/i3ajA==} - '@types/tedious@4.0.14': resolution: {integrity: sha512-KHPsfX/FoVbUGbyYvk1q9MMQHLPeRZhRJZdO45Q4YjvFkv4hMNghCWTvy7rdKessBsmtz4euWCWAB6/tVpI1Iw==} @@ -3781,9 +3772,6 @@ packages: resolution: {integrity: sha512-QOSvevhslijgYwRx6Rv7zKdMF8lbRmx+uQGx2+vDc+KI/eBnsy9kit5aj23AgGu3pa4t9AgwbnXWqS+iOY+2aA==} engines: {node: '>= 6'} - camelize@1.0.1: - resolution: {integrity: sha512-dU+Tx2fsypxTgtLoE36npi3UqcjSSMNYfkqgmoEhtZrraP5VWq0K7FkWVTYa8eMPtnU/G2txVsfdCJTn9uzpuQ==} - caniuse-lite@1.0.30001762: resolution: {integrity: sha512-PxZwGNvH7Ak8WX5iXzoK1KPZttBXNPuaOvI2ZYU7NrlM+d9Ov+TUvlLOBNGzVXAntMSMMlJPd+jY6ovrVjSmUw==} @@ -3997,10 +3985,6 @@ packages: resolution: {integrity: sha512-r4ESw/IlusD17lgQi1O20Fa3qNnsckR126TdUuBgAu7GBYSIPvdNyONd3Zrxh0xCwA4+6w/TDArBPsMvhur+KQ==} engines: {node: '>= 0.10'} - css-color-keywords@1.0.0: - resolution: {integrity: sha512-FyyrDHZKEjXDpNJYvVsV960FiqQyXc/LlYmsxl2BcdMb2WPx0OGRVgTg55rPSyLSNMqP52R9r8geSp7apN3Ofg==} - engines: {node: '>=4'} - css-loader@6.11.0: resolution: {integrity: sha512-CTJ+AEQJjq5NzLga5pE39qdiSV56F8ywCIsqNIRF0r7BDgWsN25aazToqAFg7ZrtA/U016xudB3ffgweORxX7g==} engines: {node: '>= 12.13.0'} @@ -4016,9 +4000,6 @@ packages: css-select@4.3.0: resolution: {integrity: sha512-wPpOYtnsVontu2mODhA19JrqWxNsfdatRKd64kmpRbQgh1KtItko5sTnEpPdpSaJszTOhEMlF/RPz28qj4HqhQ==} - css-to-react-native@3.2.0: - resolution: {integrity: sha512-e8RKaLXMOFii+02mOlqwjbD00KSEKqblnpO9e++1aXS1fPQOpS1YoqdVHBqPjHNoxeF2mimzVqawm2KCbEdtHQ==} - css-what@6.2.2: resolution: {integrity: sha512-u/O3vwbptzhMs3L1fQE82ZSLHQQfto5gyZzwteVIEyeaY5Fc7R4dapF/BvRoSYFeqfBk4m0V1Vafq5Pjv25wvA==} engines: {node: '>= 6'} @@ -6131,10 +6112,6 @@ packages: resolution: {integrity: sha512-PS08Iboia9mts/2ygV3eLpY5ghnUcfLV/EXTOW1E2qYxJKGGBUtNjN76FYHnMs36RmARn41bC0AZmn+rR0OVpQ==} engines: {node: ^10 || ^12 || >=14} - postcss@8.4.49: - resolution: {integrity: sha512-OCVPnIObs4N29kxTjzLfUryOkvZEq+pf8jTF0lg8E7uETuWHA+v7j3c/xJmiqpX450191LlmZfUKkXxkTry7nA==} - engines: {node: ^10 || ^12 || >=14} - postcss@8.5.6: resolution: {integrity: sha512-3Ybi1tAuwAP9s0r1UQ2J4n5Y0G05bJkpUIO0/bI9MhwmD70S5aTWbXGBwxHrelT+XM1k6dM0pk+SwNkpTRN7Pg==} engines: {node: ^10 || ^12 || >=14} @@ -6306,12 +6283,6 @@ packages: peerDependencies: react: ^18.3.1 - react-drag-drop-files@2.4.0: - resolution: {integrity: sha512-MGPV3HVVnwXEXq3gQfLtSU3jz5j5jrabvGedokpiSEMoONrDHgYl/NpIOlfsqGQ4zBv1bzzv7qbKURZNOX32PA==} - peerDependencies: - react: ^18.0.0 - react-dom: ^18.0.0 - react-hook-form@7.66.0: resolution: {integrity: sha512-xXBqsWGKrY46ZqaHDo+ZUYiMUgi8suYu5kdrS20EG8KiL7VRQitEbNjm+UcrDYrNi1YLyfpmAeGjCZYXLT9YBw==} engines: {node: '>=18.0.0'} @@ -6678,9 +6649,6 @@ packages: engines: {node: '>= 0.10'} hasBin: true - shallowequal@1.1.0: - resolution: {integrity: sha512-y0m1JoUZSlPAjXVtPPW70aZWfIL/dSP7AFkRnniLCrK/8MDKog3TySTBmckD+RObVxH0v4Tox67+F14PdED2oQ==} - sharp@0.34.5: resolution: {integrity: sha512-Ou9I5Ft9WNcCbXrU9cMgPBcCK8LiwLqcbywW3t4oDV37n1pzpuNLsYiAV8eODnjbtQlSDwZ2cUEeQz4E54Hltg==} engines: {node: ^18.17.0 || ^20.3.0 || >=21.0.0} @@ -6894,13 +6862,6 @@ packages: style-to-object@1.0.14: resolution: {integrity: sha512-LIN7rULI0jBscWQYaSswptyderlarFkjQ+t79nzty8tcIAceVomEVlLzH5VP4Cmsv6MtKhs7qaAiwlcp+Mgaxw==} - styled-components@6.2.0: - resolution: {integrity: sha512-ryFCkETE++8jlrBmC+BoGPUN96ld1/Yp0s7t5bcXDobrs4XoXroY1tN+JbFi09hV6a5h3MzbcVi8/BGDP0eCgQ==} - engines: {node: '>= 16'} - peerDependencies: - react: '>= 16.8.0' - react-dom: '>= 16.8.0' - styled-jsx@5.1.6: resolution: {integrity: sha512-qSVyDTeMotdvQYoHWLNGwRFJHC+i+ZvdBRYosOFgC+Wg1vx4frN2/RG/NA7SYqqvKNLf39P2LSRA2pu6n0XYZA==} engines: {node: '>= 12.0.0'} @@ -6927,9 +6888,6 @@ packages: babel-plugin-macros: optional: true - stylis@4.3.6: - resolution: {integrity: sha512-yQ3rwFWRfwNUY7H5vpU0wfdkNSnvnJinhF9830Swlaxl03zsOjCfmX0ugac+3LtK0lYSgwL/KXc8oYL3mG4YFQ==} - sucrase@3.35.1: resolution: {integrity: sha512-DhuTmvZWux4H1UOnWMB3sk0sbaCVOoQZjv8u1rDoTV0HTdGem9hkAZtl4JZy8P2z4Bg0nT+YMeOFyVr4zcG5Tw==} engines: {node: '>=16 || 14 >=14.17'} @@ -7096,9 +7054,6 @@ packages: tslib@1.14.1: resolution: {integrity: sha512-Xni35NKzjgMrwevysHTCArtLDpPvye8zV/0E4EyYn43P7/7qvQwPh9BGkHewbMulVntbigmcT7rdX3BNo9wRJg==} - tslib@2.6.2: - resolution: {integrity: sha512-AEYxH93jGFPn/a2iVAwW87VuUIkR1FVUKB77NwMF7nBTDkDrrT/Hpt/IrCJ0QXhW27jTBDcf5ZY7w6RiqTMw2Q==} - tslib@2.8.1: resolution: {integrity: sha512-oJFu94HQb+KVduSUQL7wnpmqnfmLsOA/nAh6b6EH0wCEoK0/mPeXU6c3wKDV83MkOuHPRHtSXKKU99IBazS/2w==} @@ -8335,10 +8290,10 @@ snapshots: '@emotion/is-prop-valid@1.2.2': dependencies: '@emotion/memoize': 0.8.1 + optional: true - '@emotion/memoize@0.8.1': {} - - '@emotion/unitless@0.8.1': {} + '@emotion/memoize@0.8.1': + optional: true '@epic-web/invariant@1.0.0': {} @@ -10734,8 +10689,6 @@ snapshots: '@types/statuses@2.0.6': {} - '@types/stylis@4.2.7': {} - '@types/tedious@4.0.14': dependencies: '@types/node': 24.10.0 @@ -11432,8 +11385,6 @@ snapshots: camelcase-css@2.0.1: {} - camelize@1.0.1: {} - caniuse-lite@1.0.30001762: {} case-sensitive-paths-webpack-plugin@2.4.0: {} @@ -11645,8 +11596,6 @@ snapshots: randombytes: 2.1.0 randomfill: 1.0.4 - css-color-keywords@1.0.0: {} - css-loader@6.11.0(webpack@5.104.1(esbuild@0.25.12)): dependencies: icss-utils: 5.1.0(postcss@8.5.6) @@ -11668,12 +11617,6 @@ snapshots: domutils: 2.8.0 nth-check: 2.1.1 - css-to-react-native@3.2.0: - dependencies: - camelize: 1.0.1 - css-color-keywords: 1.0.0 - postcss-value-parser: 4.2.0 - css-what@6.2.2: {} css.escape@1.5.1: {} @@ -12127,8 +12070,8 @@ snapshots: '@typescript-eslint/parser': 8.52.0(eslint@8.57.1)(typescript@5.9.3) eslint: 8.57.1 eslint-import-resolver-node: 0.3.9 - eslint-import-resolver-typescript: 3.10.1(eslint-plugin-import@2.32.0(@typescript-eslint/parser@8.52.0(eslint@8.57.1)(typescript@5.9.3))(eslint@8.57.1))(eslint@8.57.1) - eslint-plugin-import: 2.32.0(@typescript-eslint/parser@8.52.0(eslint@8.57.1)(typescript@5.9.3))(eslint-import-resolver-typescript@3.10.1(eslint-plugin-import@2.32.0(@typescript-eslint/parser@8.52.0(eslint@8.57.1)(typescript@5.9.3))(eslint@8.57.1))(eslint@8.57.1))(eslint@8.57.1) + eslint-import-resolver-typescript: 3.10.1(eslint-plugin-import@2.32.0)(eslint@8.57.1) + eslint-plugin-import: 2.32.0(@typescript-eslint/parser@8.52.0(eslint@8.57.1)(typescript@5.9.3))(eslint-import-resolver-typescript@3.10.1)(eslint@8.57.1) eslint-plugin-jsx-a11y: 6.10.2(eslint@8.57.1) eslint-plugin-react: 7.37.5(eslint@8.57.1) eslint-plugin-react-hooks: 5.2.0(eslint@8.57.1) @@ -12147,7 +12090,7 @@ snapshots: transitivePeerDependencies: - supports-color - eslint-import-resolver-typescript@3.10.1(eslint-plugin-import@2.32.0(@typescript-eslint/parser@8.52.0(eslint@8.57.1)(typescript@5.9.3))(eslint@8.57.1))(eslint@8.57.1): + eslint-import-resolver-typescript@3.10.1(eslint-plugin-import@2.32.0)(eslint@8.57.1): dependencies: '@nolyfill/is-core-module': 1.0.39 debug: 4.4.3 @@ -12158,22 +12101,22 @@ snapshots: tinyglobby: 0.2.15 unrs-resolver: 1.11.1 optionalDependencies: - eslint-plugin-import: 2.32.0(@typescript-eslint/parser@8.52.0(eslint@8.57.1)(typescript@5.9.3))(eslint-import-resolver-typescript@3.10.1(eslint-plugin-import@2.32.0(@typescript-eslint/parser@8.52.0(eslint@8.57.1)(typescript@5.9.3))(eslint@8.57.1))(eslint@8.57.1))(eslint@8.57.1) + eslint-plugin-import: 2.32.0(@typescript-eslint/parser@8.52.0(eslint@8.57.1)(typescript@5.9.3))(eslint-import-resolver-typescript@3.10.1)(eslint@8.57.1) transitivePeerDependencies: - supports-color - eslint-module-utils@2.12.1(@typescript-eslint/parser@8.52.0(eslint@8.57.1)(typescript@5.9.3))(eslint-import-resolver-node@0.3.9)(eslint-import-resolver-typescript@3.10.1(eslint-plugin-import@2.32.0(@typescript-eslint/parser@8.52.0(eslint@8.57.1)(typescript@5.9.3))(eslint@8.57.1))(eslint@8.57.1))(eslint@8.57.1): + eslint-module-utils@2.12.1(@typescript-eslint/parser@8.52.0(eslint@8.57.1)(typescript@5.9.3))(eslint-import-resolver-node@0.3.9)(eslint-import-resolver-typescript@3.10.1)(eslint@8.57.1): dependencies: debug: 3.2.7 optionalDependencies: '@typescript-eslint/parser': 8.52.0(eslint@8.57.1)(typescript@5.9.3) eslint: 8.57.1 eslint-import-resolver-node: 0.3.9 - eslint-import-resolver-typescript: 3.10.1(eslint-plugin-import@2.32.0(@typescript-eslint/parser@8.52.0(eslint@8.57.1)(typescript@5.9.3))(eslint@8.57.1))(eslint@8.57.1) + eslint-import-resolver-typescript: 3.10.1(eslint-plugin-import@2.32.0)(eslint@8.57.1) transitivePeerDependencies: - supports-color - eslint-plugin-import@2.32.0(@typescript-eslint/parser@8.52.0(eslint@8.57.1)(typescript@5.9.3))(eslint-import-resolver-typescript@3.10.1(eslint-plugin-import@2.32.0(@typescript-eslint/parser@8.52.0(eslint@8.57.1)(typescript@5.9.3))(eslint@8.57.1))(eslint@8.57.1))(eslint@8.57.1): + eslint-plugin-import@2.32.0(@typescript-eslint/parser@8.52.0(eslint@8.57.1)(typescript@5.9.3))(eslint-import-resolver-typescript@3.10.1)(eslint@8.57.1): dependencies: '@rtsao/scc': 1.1.0 array-includes: 3.1.9 @@ -12184,7 +12127,7 @@ snapshots: doctrine: 2.1.0 eslint: 8.57.1 eslint-import-resolver-node: 0.3.9 - eslint-module-utils: 2.12.1(@typescript-eslint/parser@8.52.0(eslint@8.57.1)(typescript@5.9.3))(eslint-import-resolver-node@0.3.9)(eslint-import-resolver-typescript@3.10.1(eslint-plugin-import@2.32.0(@typescript-eslint/parser@8.52.0(eslint@8.57.1)(typescript@5.9.3))(eslint@8.57.1))(eslint@8.57.1))(eslint@8.57.1) + eslint-module-utils: 2.12.1(@typescript-eslint/parser@8.52.0(eslint@8.57.1)(typescript@5.9.3))(eslint-import-resolver-node@0.3.9)(eslint-import-resolver-typescript@3.10.1)(eslint@8.57.1) hasown: 2.0.2 is-core-module: 2.16.1 is-glob: 4.0.3 @@ -14259,12 +14202,6 @@ snapshots: picocolors: 1.1.1 source-map-js: 1.2.1 - postcss@8.4.49: - dependencies: - nanoid: 3.3.11 - picocolors: 1.1.1 - source-map-js: 1.2.1 - postcss@8.5.6: dependencies: nanoid: 3.3.11 @@ -14386,13 +14323,6 @@ snapshots: react: 18.3.1 scheduler: 0.23.2 - react-drag-drop-files@2.4.0(react-dom@18.3.1(react@18.3.1))(react@18.3.1): - dependencies: - prop-types: 15.8.1 - react: 18.3.1 - react-dom: 18.3.1(react@18.3.1) - styled-components: 6.2.0(react-dom@18.3.1(react@18.3.1))(react@18.3.1) - react-hook-form@7.66.0(react@18.3.1): dependencies: react: 18.3.1 @@ -14886,8 +14816,6 @@ snapshots: safe-buffer: 5.2.1 to-buffer: 1.2.2 - shallowequal@1.1.0: {} - sharp@0.34.5: dependencies: '@img/colour': 1.0.0 @@ -15178,20 +15106,6 @@ snapshots: dependencies: inline-style-parser: 0.2.7 - styled-components@6.2.0(react-dom@18.3.1(react@18.3.1))(react@18.3.1): - dependencies: - '@emotion/is-prop-valid': 1.2.2 - '@emotion/unitless': 0.8.1 - '@types/stylis': 4.2.7 - css-to-react-native: 3.2.0 - csstype: 3.2.3 - postcss: 8.4.49 - react: 18.3.1 - react-dom: 18.3.1(react@18.3.1) - shallowequal: 1.1.0 - stylis: 4.3.6 - tslib: 2.6.2 - styled-jsx@5.1.6(@babel/core@7.28.5)(react@18.3.1): dependencies: client-only: 0.0.1 @@ -15206,8 +15120,6 @@ snapshots: optionalDependencies: '@babel/core': 7.28.5 - stylis@4.3.6: {} - sucrase@3.35.1: dependencies: '@jridgewell/gen-mapping': 0.3.13 @@ -15390,8 +15302,6 @@ snapshots: tslib@1.14.1: {} - tslib@2.6.2: {} - tslib@2.8.1: {} tty-browserify@0.0.1: {} diff --git a/autogpt_platform/frontend/public/integrations/webshare_proxy.png b/autogpt_platform/frontend/public/integrations/webshare_proxy.png new file mode 100644 index 0000000000..2b07ef8415 Binary files /dev/null and b/autogpt_platform/frontend/public/integrations/webshare_proxy.png differ diff --git a/autogpt_platform/frontend/public/integrations/wordpress.png b/autogpt_platform/frontend/public/integrations/wordpress.png new file mode 100644 index 0000000000..b8ba8bd3ff Binary files /dev/null and b/autogpt_platform/frontend/public/integrations/wordpress.png differ diff --git a/autogpt_platform/frontend/src/app/(platform)/build/components/BuilderActions/components/RunInputDialog/RunInputDialog.tsx b/autogpt_platform/frontend/src/app/(platform)/build/components/BuilderActions/components/RunInputDialog/RunInputDialog.tsx index 431feeaade..bd08aa8ee0 100644 --- a/autogpt_platform/frontend/src/app/(platform)/build/components/BuilderActions/components/RunInputDialog/RunInputDialog.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/build/components/BuilderActions/components/RunInputDialog/RunInputDialog.tsx @@ -66,6 +66,7 @@ export const RunInputDialog = ({ formContext={{ showHandles: false, size: "large", + showOptionalToggle: false, }} /> diff --git a/autogpt_platform/frontend/src/app/(platform)/build/components/BuilderActions/components/RunInputDialog/useRunInputDialog.ts b/autogpt_platform/frontend/src/app/(platform)/build/components/BuilderActions/components/RunInputDialog/useRunInputDialog.ts index a71ad0bd07..ddd77bae48 100644 --- a/autogpt_platform/frontend/src/app/(platform)/build/components/BuilderActions/components/RunInputDialog/useRunInputDialog.ts +++ b/autogpt_platform/frontend/src/app/(platform)/build/components/BuilderActions/components/RunInputDialog/useRunInputDialog.ts @@ -66,7 +66,7 @@ export const useRunInputDialog = ({ if (isCredentialFieldSchema(fieldSchema)) { dynamicUiSchema[fieldName] = { ...dynamicUiSchema[fieldName], - "ui:field": "credentials", + "ui:field": "custom/credential_field", }; } }); @@ -76,12 +76,18 @@ export const useRunInputDialog = ({ }, [credentialsSchema]); const handleManualRun = async () => { + // Filter out incomplete credentials (those without a valid id) + // RJSF auto-populates const values (provider, type) but not id field + const validCredentials = Object.fromEntries( + Object.entries(credentialValues).filter(([_, cred]) => cred && cred.id), + ); + await executeGraph({ graphId: flowID ?? "", graphVersion: flowVersion || null, data: { inputs: inputValues, - credentials_inputs: credentialValues, + credentials_inputs: validCredentials, source: "builder", }, }); diff --git a/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/Flow/Flow.tsx b/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/Flow/Flow.tsx index faaebb6b35..29fd984b1d 100644 --- a/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/Flow/Flow.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/Flow/Flow.tsx @@ -97,6 +97,9 @@ export const Flow = () => { onConnect={onConnect} onEdgesChange={onEdgesChange} onNodeDragStop={onNodeDragStop} + onNodeContextMenu={(event) => { + event.preventDefault(); + }} maxZoom={2} minZoom={0.1} onDragOver={onDragOver} diff --git a/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/CustomNode.tsx b/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/CustomNode.tsx index 3523079b71..99a5b9f0e5 100644 --- a/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/CustomNode.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/CustomNode.tsx @@ -1,24 +1,25 @@ -import React from "react"; -import { Node as XYNode, NodeProps } from "@xyflow/react"; -import { RJSFSchema } from "@rjsf/utils"; -import { BlockUIType } from "../../../types"; -import { StickyNoteBlock } from "./components/StickyNoteBlock"; -import { BlockInfoCategoriesItem } from "@/app/api/__generated__/models/blockInfoCategoriesItem"; -import { BlockCost } from "@/app/api/__generated__/models/blockCost"; import { AgentExecutionStatus } from "@/app/api/__generated__/models/agentExecutionStatus"; +import { BlockCost } from "@/app/api/__generated__/models/blockCost"; +import { BlockInfoCategoriesItem } from "@/app/api/__generated__/models/blockInfoCategoriesItem"; import { NodeExecutionResult } from "@/app/api/__generated__/models/nodeExecutionResult"; -import { NodeContainer } from "./components/NodeContainer"; -import { NodeHeader } from "./components/NodeHeader"; -import { FormCreator } from "../FormCreator"; -import { preprocessInputSchema } from "@/components/renderers/InputRenderer/utils/input-schema-pre-processor"; -import { OutputHandler } from "../OutputHandler"; -import { NodeAdvancedToggle } from "./components/NodeAdvancedToggle"; -import { NodeDataRenderer } from "./components/NodeOutput/NodeOutput"; -import { NodeExecutionBadge } from "./components/NodeExecutionBadge"; -import { cn } from "@/lib/utils"; -import { WebhookDisclaimer } from "./components/WebhookDisclaimer"; -import { AyrshareConnectButton } from "./components/AyrshareConnectButton"; import { NodeModelMetadata } from "@/app/api/__generated__/models/nodeModelMetadata"; +import { preprocessInputSchema } from "@/components/renderers/InputRenderer/utils/input-schema-pre-processor"; +import { cn } from "@/lib/utils"; +import { RJSFSchema } from "@rjsf/utils"; +import { NodeProps, Node as XYNode } from "@xyflow/react"; +import React from "react"; +import { BlockUIType } from "../../../types"; +import { FormCreator } from "../FormCreator"; +import { OutputHandler } from "../OutputHandler"; +import { AyrshareConnectButton } from "./components/AyrshareConnectButton"; +import { NodeAdvancedToggle } from "./components/NodeAdvancedToggle"; +import { NodeContainer } from "./components/NodeContainer"; +import { NodeExecutionBadge } from "./components/NodeExecutionBadge"; +import { NodeHeader } from "./components/NodeHeader"; +import { NodeDataRenderer } from "./components/NodeOutput/NodeOutput"; +import { NodeRightClickMenu } from "./components/NodeRightClickMenu"; +import { StickyNoteBlock } from "./components/StickyNoteBlock"; +import { WebhookDisclaimer } from "./components/WebhookDisclaimer"; export type CustomNodeData = { hardcodedValues: { @@ -88,7 +89,7 @@ export const CustomNode: React.FC> = React.memo( // Currently all blockTypes design are similar - that's why i am using the same component for all of them // If in future - if we need some drastic change in some blockTypes design - we can create separate components for them - return ( + const node = (
@@ -117,6 +118,15 @@ export const CustomNode: React.FC> = React.memo( ); + + return ( + + {node} + + ); }, ); diff --git a/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/components/NodeContextMenu.tsx b/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/components/NodeContextMenu.tsx index 6e482122f6..1a0e23fead 100644 --- a/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/components/NodeContextMenu.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/components/NodeContextMenu.tsx @@ -1,26 +1,31 @@ -import { Separator } from "@/components/__legacy__/ui/separator"; +import { useCopyPasteStore } from "@/app/(platform)/build/stores/copyPasteStore"; +import { useNodeStore } from "@/app/(platform)/build/stores/nodeStore"; import { DropdownMenu, - DropdownMenuContent, - DropdownMenuItem, DropdownMenuTrigger, } from "@/components/molecules/DropdownMenu/DropdownMenu"; -import { DotsThreeOutlineVerticalIcon } from "@phosphor-icons/react"; -import { Copy, Trash2, ExternalLink } from "lucide-react"; -import { useNodeStore } from "@/app/(platform)/build/stores/nodeStore"; -import { useCopyPasteStore } from "@/app/(platform)/build/stores/copyPasteStore"; +import { + SecondaryDropdownMenuContent, + SecondaryDropdownMenuItem, + SecondaryDropdownMenuSeparator, +} from "@/components/molecules/SecondaryMenu/SecondaryMenu"; +import { + ArrowSquareOutIcon, + CopyIcon, + DotsThreeOutlineVerticalIcon, + TrashIcon, +} from "@phosphor-icons/react"; import { useReactFlow } from "@xyflow/react"; -export const NodeContextMenu = ({ - nodeId, - subGraphID, -}: { +type Props = { nodeId: string; subGraphID?: string; -}) => { +}; + +export const NodeContextMenu = ({ nodeId, subGraphID }: Props) => { const { deleteElements } = useReactFlow(); - const handleCopy = () => { + function handleCopy() { useNodeStore.setState((state) => ({ nodes: state.nodes.map((node) => ({ ...node, @@ -30,47 +35,47 @@ export const NodeContextMenu = ({ useCopyPasteStore.getState().copySelectedNodes(); useCopyPasteStore.getState().pasteNodes(); - }; + } - const handleDelete = () => { + function handleDelete() { deleteElements({ nodes: [{ id: nodeId }] }); - }; + } return ( - - - - Copy Node - + + + + Copy + + {subGraphID && ( - window.open(`/build?flowID=${subGraphID}`)} - className="hover:rounded-xlarge" - > - - Open Agent - + <> + window.open(`/build?flowID=${subGraphID}`)} + > + + Open agent + + + )} - - - - - Delete - - + + + Delete + + ); }; diff --git a/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/components/NodeHeader.tsx b/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/components/NodeHeader.tsx index 4dadccef2b..e13aa37a31 100644 --- a/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/components/NodeHeader.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/components/NodeHeader.tsx @@ -1,25 +1,24 @@ -import { Text } from "@/components/atoms/Text/Text"; -import { beautifyString, cn } from "@/lib/utils"; -import { NodeCost } from "./NodeCost"; -import { NodeBadges } from "./NodeBadges"; -import { NodeContextMenu } from "./NodeContextMenu"; -import { CustomNodeData } from "../CustomNode"; import { useNodeStore } from "@/app/(platform)/build/stores/nodeStore"; -import { useState } from "react"; +import { Text } from "@/components/atoms/Text/Text"; import { Tooltip, TooltipContent, TooltipProvider, TooltipTrigger, } from "@/components/atoms/Tooltip/BaseTooltip"; +import { beautifyString, cn } from "@/lib/utils"; +import { useState } from "react"; +import { CustomNodeData } from "../CustomNode"; +import { NodeBadges } from "./NodeBadges"; +import { NodeContextMenu } from "./NodeContextMenu"; +import { NodeCost } from "./NodeCost"; -export const NodeHeader = ({ - data, - nodeId, -}: { +type Props = { data: CustomNodeData; nodeId: string; -}) => { +}; + +export const NodeHeader = ({ data, nodeId }: Props) => { const updateNodeData = useNodeStore((state) => state.updateNodeData); const title = (data.metadata?.customized_name as string) || data.title; const [isEditingTitle, setIsEditingTitle] = useState(false); @@ -69,7 +68,10 @@ export const NodeHeader = ({
- + {beautifyString(title).replace("Block", "").trim()}
diff --git a/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/components/NodeOutput/components/NodeDataViewer/NodeDataViewer.tsx b/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/components/NodeOutput/components/NodeDataViewer/NodeDataViewer.tsx index c505282e7b..31b89315d6 100644 --- a/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/components/NodeOutput/components/NodeDataViewer/NodeDataViewer.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/components/NodeOutput/components/NodeDataViewer/NodeDataViewer.tsx @@ -151,7 +151,7 @@ export const NodeDataViewer: FC = ({
- {outputItems.length > 0 && ( + {outputItems.length > 1 && ( ({ value: item.value, diff --git a/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/components/NodeRightClickMenu.tsx b/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/components/NodeRightClickMenu.tsx new file mode 100644 index 0000000000..a56e42544f --- /dev/null +++ b/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/components/NodeRightClickMenu.tsx @@ -0,0 +1,104 @@ +import { useCopyPasteStore } from "@/app/(platform)/build/stores/copyPasteStore"; +import { useNodeStore } from "@/app/(platform)/build/stores/nodeStore"; +import { + SecondaryMenuContent, + SecondaryMenuItem, + SecondaryMenuSeparator, +} from "@/components/molecules/SecondaryMenu/SecondaryMenu"; +import { ArrowSquareOutIcon, CopyIcon, TrashIcon } from "@phosphor-icons/react"; +import * as ContextMenu from "@radix-ui/react-context-menu"; +import { useReactFlow } from "@xyflow/react"; +import { useEffect, useRef } from "react"; +import { CustomNode } from "../CustomNode"; + +type Props = { + nodeId: string; + subGraphID?: string; + children: React.ReactNode; +}; + +const DOUBLE_CLICK_TIMEOUT = 300; + +export function NodeRightClickMenu({ nodeId, subGraphID, children }: Props) { + const { deleteElements } = useReactFlow(); + const lastRightClickTime = useRef(0); + const containerRef = useRef(null); + + function copyNode() { + useNodeStore.setState((state) => ({ + nodes: state.nodes.map((node) => ({ + ...node, + selected: node.id === nodeId, + })), + })); + + useCopyPasteStore.getState().copySelectedNodes(); + useCopyPasteStore.getState().pasteNodes(); + } + + function deleteNode() { + deleteElements({ nodes: [{ id: nodeId }] }); + } + + useEffect(() => { + const container = containerRef.current; + if (!container) return; + + function handleContextMenu(e: MouseEvent) { + const now = Date.now(); + const timeSinceLastClick = now - lastRightClickTime.current; + + if (timeSinceLastClick < DOUBLE_CLICK_TIMEOUT) { + e.stopImmediatePropagation(); + lastRightClickTime.current = 0; + return; + } + + lastRightClickTime.current = now; + } + + container.addEventListener("contextmenu", handleContextMenu, true); + + return () => { + container.removeEventListener("contextmenu", handleContextMenu, true); + }; + }, []); + + return ( + + +
{children}
+
+ + + + Copy + + + + {subGraphID && ( + <> + window.open(`/build?flowID=${subGraphID}`)} + > + + Open agent + + + + )} + + + + Delete + + +
+ ); +} diff --git a/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/helpers.ts b/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/helpers.ts index 39384485f5..46032a67ea 100644 --- a/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/helpers.ts +++ b/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/helpers.ts @@ -89,6 +89,18 @@ export function extractOptions( // get display type and color for schema types [need for type display next to field name] export const getTypeDisplayInfo = (schema: any) => { + if ( + schema?.type === "array" && + "format" in schema && + schema.format === "table" + ) { + return { + displayType: "table", + colorClass: "!text-indigo-500", + hexColor: "#6366f1", + }; + } + if (schema?.type === "string" && schema?.format) { const formatMap: Record< string, diff --git a/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/uiSchema.ts b/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/uiSchema.ts index ad1fab7c95..065e697828 100644 --- a/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/uiSchema.ts +++ b/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/uiSchema.ts @@ -1,6 +1,6 @@ export const uiSchema = { credentials: { - "ui:field": "credentials", + "ui:field": "custom/credential_field", provider: { "ui:widget": "hidden" }, type: { "ui:widget": "hidden" }, id: { "ui:autofocus": true }, diff --git a/autogpt_platform/frontend/src/app/(platform)/build/components/NewControlPanel/NewBlockMenu/BlockMenuFilters/BlockMenuFilters.tsx b/autogpt_platform/frontend/src/app/(platform)/build/components/NewControlPanel/NewBlockMenu/BlockMenuFilters/BlockMenuFilters.tsx new file mode 100644 index 0000000000..ebcea9eee6 --- /dev/null +++ b/autogpt_platform/frontend/src/app/(platform)/build/components/NewControlPanel/NewBlockMenu/BlockMenuFilters/BlockMenuFilters.tsx @@ -0,0 +1,57 @@ +import { useBlockMenuStore } from "@/app/(platform)/build/stores/blockMenuStore"; +import { FilterChip } from "../FilterChip"; +import { categories } from "./constants"; +import { FilterSheet } from "../FilterSheet/FilterSheet"; +import { GetV2BuilderSearchFilterAnyOfItem } from "@/app/api/__generated__/models/getV2BuilderSearchFilterAnyOfItem"; + +export const BlockMenuFilters = () => { + const { + filters, + addFilter, + removeFilter, + categoryCounts, + creators, + addCreator, + removeCreator, + } = useBlockMenuStore(); + + const handleFilterClick = (filter: GetV2BuilderSearchFilterAnyOfItem) => { + if (filters.includes(filter)) { + removeFilter(filter); + } else { + addFilter(filter); + } + }; + + const handleCreatorClick = (creator: string) => { + if (creators.includes(creator)) { + removeCreator(creator); + } else { + addCreator(creator); + } + }; + + return ( +
+ + {creators.length > 0 && + creators.map((creator) => ( + handleCreatorClick(creator)} + /> + ))} + {categories.map((category) => ( + handleFilterClick(category.key)} + number={categoryCounts[category.key] ?? 0} + /> + ))} +
+ ); +}; diff --git a/autogpt_platform/frontend/src/app/(platform)/build/components/NewControlPanel/NewBlockMenu/BlockMenuFilters/constants.ts b/autogpt_platform/frontend/src/app/(platform)/build/components/NewControlPanel/NewBlockMenu/BlockMenuFilters/constants.ts new file mode 100644 index 0000000000..b438aae91b --- /dev/null +++ b/autogpt_platform/frontend/src/app/(platform)/build/components/NewControlPanel/NewBlockMenu/BlockMenuFilters/constants.ts @@ -0,0 +1,15 @@ +import { GetV2BuilderSearchFilterAnyOfItem } from "@/app/api/__generated__/models/getV2BuilderSearchFilterAnyOfItem"; +import { CategoryKey } from "./types"; + +export const categories: Array<{ key: CategoryKey; name: string }> = [ + { key: GetV2BuilderSearchFilterAnyOfItem.blocks, name: "Blocks" }, + { + key: GetV2BuilderSearchFilterAnyOfItem.integrations, + name: "Integrations", + }, + { + key: GetV2BuilderSearchFilterAnyOfItem.marketplace_agents, + name: "Marketplace agents", + }, + { key: GetV2BuilderSearchFilterAnyOfItem.my_agents, name: "My agents" }, +]; diff --git a/autogpt_platform/frontend/src/app/(platform)/build/components/NewControlPanel/NewBlockMenu/BlockMenuFilters/types.ts b/autogpt_platform/frontend/src/app/(platform)/build/components/NewControlPanel/NewBlockMenu/BlockMenuFilters/types.ts new file mode 100644 index 0000000000..8fec9ef64d --- /dev/null +++ b/autogpt_platform/frontend/src/app/(platform)/build/components/NewControlPanel/NewBlockMenu/BlockMenuFilters/types.ts @@ -0,0 +1,26 @@ +import { GetV2BuilderSearchFilterAnyOfItem } from "@/app/api/__generated__/models/getV2BuilderSearchFilterAnyOfItem"; + +export type DefaultStateType = + | "suggestion" + | "all_blocks" + | "input_blocks" + | "action_blocks" + | "output_blocks" + | "integrations" + | "marketplace_agents" + | "my_agents"; + +export type CategoryKey = GetV2BuilderSearchFilterAnyOfItem; + +export interface Filters { + categories: { + blocks: boolean; + integrations: boolean; + marketplace_agents: boolean; + my_agents: boolean; + providers: boolean; + }; + createdBy: string[]; +} + +export type CategoryCounts = Record; diff --git a/autogpt_platform/frontend/src/app/(platform)/build/components/NewControlPanel/NewBlockMenu/BlockMenuSearch/BlockMenuSearch.tsx b/autogpt_platform/frontend/src/app/(platform)/build/components/NewControlPanel/NewBlockMenu/BlockMenuSearch/BlockMenuSearch.tsx index de339431e8..26723eebcc 100644 --- a/autogpt_platform/frontend/src/app/(platform)/build/components/NewControlPanel/NewBlockMenu/BlockMenuSearch/BlockMenuSearch.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/build/components/NewControlPanel/NewBlockMenu/BlockMenuSearch/BlockMenuSearch.tsx @@ -1,111 +1,14 @@ import { Text } from "@/components/atoms/Text/Text"; -import { useBlockMenuSearch } from "./useBlockMenuSearch"; -import { InfiniteScroll } from "@/components/contextual/InfiniteScroll/InfiniteScroll"; -import { LoadingSpinner } from "@/components/__legacy__/ui/loading"; -import { SearchResponseItemsItem } from "@/app/api/__generated__/models/searchResponseItemsItem"; -import { MarketplaceAgentBlock } from "../MarketplaceAgentBlock"; -import { Block } from "../Block"; -import { UGCAgentBlock } from "../UGCAgentBlock"; -import { getSearchItemType } from "./helper"; -import { useBlockMenuStore } from "../../../../stores/blockMenuStore"; import { blockMenuContainerStyle } from "../style"; -import { cn } from "@/lib/utils"; -import { NoSearchResult } from "../NoSearchResult"; +import { BlockMenuFilters } from "../BlockMenuFilters/BlockMenuFilters"; +import { BlockMenuSearchContent } from "../BlockMenuSearchContent/BlockMenuSearchContent"; export const BlockMenuSearch = () => { - const { - searchResults, - isFetchingNextPage, - fetchNextPage, - hasNextPage, - searchLoading, - handleAddLibraryAgent, - handleAddMarketplaceAgent, - addingLibraryAgentId, - addingMarketplaceAgentSlug, - } = useBlockMenuSearch(); - const { searchQuery } = useBlockMenuStore(); - - if (searchLoading) { - return ( -
- -
- ); - } - - if (searchResults.length === 0) { - return ; - } - return (
+ Search results - } - className="space-y-2.5" - > - {searchResults.map((item: SearchResponseItemsItem, index: number) => { - const { type, data } = getSearchItemType(item); - // backend give support to these 3 types only [right now] - we need to give support to integration and ai agent types in follow up PRs - switch (type) { - case "store_agent": - return ( - - handleAddMarketplaceAgent({ - creator_name: data.creator, - slug: data.slug, - }) - } - /> - ); - case "block": - return ( - - ); - - case "library_agent": - return ( - handleAddLibraryAgent(data)} - /> - ); - - default: - return null; - } - })} - +
); }; diff --git a/autogpt_platform/frontend/src/app/(platform)/build/components/NewControlPanel/NewBlockMenu/BlockMenuSearchContent/BlockMenuSearchContent.tsx b/autogpt_platform/frontend/src/app/(platform)/build/components/NewControlPanel/NewBlockMenu/BlockMenuSearchContent/BlockMenuSearchContent.tsx new file mode 100644 index 0000000000..7229c44ed7 --- /dev/null +++ b/autogpt_platform/frontend/src/app/(platform)/build/components/NewControlPanel/NewBlockMenu/BlockMenuSearchContent/BlockMenuSearchContent.tsx @@ -0,0 +1,108 @@ +import { SearchResponseItemsItem } from "@/app/api/__generated__/models/searchResponseItemsItem"; +import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner"; +import { InfiniteScroll } from "@/components/contextual/InfiniteScroll/InfiniteScroll"; +import { getSearchItemType } from "./helper"; +import { MarketplaceAgentBlock } from "../MarketplaceAgentBlock"; +import { Block } from "../Block"; +import { UGCAgentBlock } from "../UGCAgentBlock"; +import { useBlockMenuSearchContent } from "./useBlockMenuSearchContent"; +import { useBlockMenuStore } from "@/app/(platform)/build/stores/blockMenuStore"; +import { cn } from "@/lib/utils"; +import { blockMenuContainerStyle } from "../style"; +import { NoSearchResult } from "../NoSearchResult"; + +export const BlockMenuSearchContent = () => { + const { + searchResults, + isFetchingNextPage, + fetchNextPage, + hasNextPage, + searchLoading, + handleAddLibraryAgent, + handleAddMarketplaceAgent, + addingLibraryAgentId, + addingMarketplaceAgentSlug, + } = useBlockMenuSearchContent(); + + const { searchQuery } = useBlockMenuStore(); + + if (searchLoading) { + return ( +
+ +
+ ); + } + + if (searchResults.length === 0) { + return ; + } + + return ( + } + className="space-y-2.5" + > + {searchResults.map((item: SearchResponseItemsItem, index: number) => { + const { type, data } = getSearchItemType(item); + // backend give support to these 3 types only [right now] - we need to give support to integration and ai agent types in follow up PRs + switch (type) { + case "store_agent": + return ( + + handleAddMarketplaceAgent({ + creator_name: data.creator, + slug: data.slug, + }) + } + /> + ); + case "block": + return ( + + ); + + case "library_agent": + return ( + handleAddLibraryAgent(data)} + /> + ); + + default: + return null; + } + })} + + ); +}; diff --git a/autogpt_platform/frontend/src/app/(platform)/build/components/NewControlPanel/NewBlockMenu/BlockMenuSearch/helper.ts b/autogpt_platform/frontend/src/app/(platform)/build/components/NewControlPanel/NewBlockMenu/BlockMenuSearchContent/helper.ts similarity index 100% rename from autogpt_platform/frontend/src/app/(platform)/build/components/NewControlPanel/NewBlockMenu/BlockMenuSearch/helper.ts rename to autogpt_platform/frontend/src/app/(platform)/build/components/NewControlPanel/NewBlockMenu/BlockMenuSearchContent/helper.ts diff --git a/autogpt_platform/frontend/src/app/(platform)/build/components/NewControlPanel/NewBlockMenu/BlockMenuSearch/useBlockMenuSearch.ts b/autogpt_platform/frontend/src/app/(platform)/build/components/NewControlPanel/NewBlockMenu/BlockMenuSearchContent/useBlockMenuSearchContent.tsx similarity index 83% rename from autogpt_platform/frontend/src/app/(platform)/build/components/NewControlPanel/NewBlockMenu/BlockMenuSearch/useBlockMenuSearch.ts rename to autogpt_platform/frontend/src/app/(platform)/build/components/NewControlPanel/NewBlockMenu/BlockMenuSearchContent/useBlockMenuSearchContent.tsx index beff80a984..9da9cb4cbc 100644 --- a/autogpt_platform/frontend/src/app/(platform)/build/components/NewControlPanel/NewBlockMenu/BlockMenuSearch/useBlockMenuSearch.ts +++ b/autogpt_platform/frontend/src/app/(platform)/build/components/NewControlPanel/NewBlockMenu/BlockMenuSearchContent/useBlockMenuSearchContent.tsx @@ -23,9 +23,19 @@ import { LibraryAgent } from "@/app/api/__generated__/models/libraryAgent"; import { getQueryClient } from "@/lib/react-query/queryClient"; import { useToast } from "@/components/molecules/Toast/use-toast"; import * as Sentry from "@sentry/nextjs"; +import { GetV2BuilderSearchFilterAnyOfItem } from "@/app/api/__generated__/models/getV2BuilderSearchFilterAnyOfItem"; + +export const useBlockMenuSearchContent = () => { + const { + searchQuery, + searchId, + setSearchId, + filters, + setCreatorsList, + creators, + setCategoryCounts, + } = useBlockMenuStore(); -export const useBlockMenuSearch = () => { - const { searchQuery, searchId, setSearchId } = useBlockMenuStore(); const { toast } = useToast(); const { addAgentToBuilder, addLibraryAgentToBuilder } = useAddAgentToBuilder(); @@ -57,6 +67,8 @@ export const useBlockMenuSearch = () => { page_size: 8, search_query: searchQuery, search_id: searchId, + filter: filters.length > 0 ? filters : undefined, + by_creator: creators.length > 0 ? creators : undefined, }, { query: { getNextPageParam: getPaginationNextPageNumber }, @@ -98,6 +110,26 @@ export const useBlockMenuSearch = () => { } }, [searchQueryData, searchId, setSearchId]); + // from all the results, we need to get all the unique creators + useEffect(() => { + if (!searchQueryData?.pages?.length) { + return; + } + const latestData = okData(searchQueryData.pages.at(-1)); + setCategoryCounts( + (latestData?.total_items as Record< + GetV2BuilderSearchFilterAnyOfItem, + number + >) || { + blocks: 0, + integrations: 0, + marketplace_agents: 0, + my_agents: 0, + }, + ); + setCreatorsList(latestData?.items || []); + }, [searchQueryData]); + useEffect(() => { if (searchId && !searchQuery) { resetSearchSession(); diff --git a/autogpt_platform/frontend/src/app/(platform)/build/components/NewControlPanel/NewBlockMenu/FilterChip.tsx b/autogpt_platform/frontend/src/app/(platform)/build/components/NewControlPanel/NewBlockMenu/FilterChip.tsx index 69931958b3..23197ab612 100644 --- a/autogpt_platform/frontend/src/app/(platform)/build/components/NewControlPanel/NewBlockMenu/FilterChip.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/build/components/NewControlPanel/NewBlockMenu/FilterChip.tsx @@ -1,7 +1,9 @@ import { Button } from "@/components/__legacy__/ui/button"; import { cn } from "@/lib/utils"; -import { X } from "lucide-react"; -import React, { ButtonHTMLAttributes } from "react"; +import { XIcon } from "@phosphor-icons/react"; +import { AnimatePresence, motion } from "framer-motion"; + +import React, { ButtonHTMLAttributes, useState } from "react"; interface Props extends ButtonHTMLAttributes { selected?: boolean; @@ -16,39 +18,51 @@ export const FilterChip: React.FC = ({ className, ...rest }) => { + const [isHovered, setIsHovered] = useState(false); return ( - + > + {name} + + {selected && !isHovered && ( + + + + )} + {number !== undefined && isHovered && ( + + {number > 100 ? "100+" : number} + + )} + + ); }; diff --git a/autogpt_platform/frontend/src/app/(platform)/build/components/NewControlPanel/NewBlockMenu/FilterSheet/FilterSheet.tsx b/autogpt_platform/frontend/src/app/(platform)/build/components/NewControlPanel/NewBlockMenu/FilterSheet/FilterSheet.tsx new file mode 100644 index 0000000000..dc7c428245 --- /dev/null +++ b/autogpt_platform/frontend/src/app/(platform)/build/components/NewControlPanel/NewBlockMenu/FilterSheet/FilterSheet.tsx @@ -0,0 +1,156 @@ +import { FilterChip } from "../FilterChip"; +import { cn } from "@/lib/utils"; +import { CategoryKey } from "../BlockMenuFilters/types"; +import { AnimatePresence, motion } from "framer-motion"; +import { XIcon } from "@phosphor-icons/react"; +import { Button } from "@/components/atoms/Button/Button"; +import { Text } from "@/components/atoms/Text/Text"; +import { Separator } from "@/components/__legacy__/ui/separator"; +import { Checkbox } from "@/components/__legacy__/ui/checkbox"; +import { useFilterSheet } from "./useFilterSheet"; +import { INITIAL_CREATORS_TO_SHOW } from "./constant"; + +export function FilterSheet({ + categories, +}: { + categories: Array<{ key: CategoryKey; name: string }>; +}) { + const { + isOpen, + localCategories, + localCreators, + displayedCreatorsCount, + handleLocalCategoryChange, + handleToggleShowMoreCreators, + handleLocalCreatorChange, + handleClearFilters, + handleCloseButton, + handleApplyFilters, + hasLocalActiveFilters, + visibleCreators, + creators, + handleOpenFilters, + hasActiveFilters, + } = useFilterSheet(); + + return ( +
+ + + + {isOpen && ( + + {/* Top section */} +
+ Filters + +
+ + + + {/* Category section */} +
+ Categories +
+ {categories.map((category) => ( +
+ + handleLocalCategoryChange(category.key) + } + className="border border-[#D4D4D4] shadow-none data-[state=checked]:border-none data-[state=checked]:bg-violet-700 data-[state=checked]:text-white" + /> + +
+ ))} +
+
+ + {/* Created by section */} +
+

+ Created by +

+
+ {visibleCreators.map((creator, i) => ( +
+ handleLocalCreatorChange(creator)} + className="border border-[#D4D4D4] shadow-none data-[state=checked]:border-none data-[state=checked]:bg-violet-700 data-[state=checked]:text-white" + /> + +
+ ))} +
+ {creators.length > INITIAL_CREATORS_TO_SHOW && ( + + )} +
+ + {/* Footer section */} +
+ + + +
+
+ )} +
+
+ ); +} diff --git a/autogpt_platform/frontend/src/app/(platform)/build/components/NewControlPanel/NewBlockMenu/FilterSheet/constant.ts b/autogpt_platform/frontend/src/app/(platform)/build/components/NewControlPanel/NewBlockMenu/FilterSheet/constant.ts new file mode 100644 index 0000000000..8e05dc1037 --- /dev/null +++ b/autogpt_platform/frontend/src/app/(platform)/build/components/NewControlPanel/NewBlockMenu/FilterSheet/constant.ts @@ -0,0 +1 @@ +export const INITIAL_CREATORS_TO_SHOW = 5; diff --git a/autogpt_platform/frontend/src/app/(platform)/build/components/NewControlPanel/NewBlockMenu/FilterSheet/useFilterSheet.ts b/autogpt_platform/frontend/src/app/(platform)/build/components/NewControlPanel/NewBlockMenu/FilterSheet/useFilterSheet.ts new file mode 100644 index 0000000000..200671f4e7 --- /dev/null +++ b/autogpt_platform/frontend/src/app/(platform)/build/components/NewControlPanel/NewBlockMenu/FilterSheet/useFilterSheet.ts @@ -0,0 +1,100 @@ +import { useBlockMenuStore } from "@/app/(platform)/build/stores/blockMenuStore"; +import { useState } from "react"; +import { INITIAL_CREATORS_TO_SHOW } from "./constant"; +import { GetV2BuilderSearchFilterAnyOfItem } from "@/app/api/__generated__/models/getV2BuilderSearchFilterAnyOfItem"; + +export const useFilterSheet = () => { + const { filters, creators_list, creators, setFilters, setCreators } = + useBlockMenuStore(); + + const [isOpen, setIsOpen] = useState(false); + const [localCategories, setLocalCategories] = + useState(filters); + const [localCreators, setLocalCreators] = useState(creators); + const [displayedCreatorsCount, setDisplayedCreatorsCount] = useState( + INITIAL_CREATORS_TO_SHOW, + ); + + const handleLocalCategoryChange = ( + category: GetV2BuilderSearchFilterAnyOfItem, + ) => { + setLocalCategories((prev) => { + if (prev.includes(category)) { + return prev.filter((c) => c !== category); + } + return [...prev, category]; + }); + }; + + const hasActiveFilters = () => { + return filters.length > 0 || creators.length > 0; + }; + + const handleToggleShowMoreCreators = () => { + if (displayedCreatorsCount < creators.length) { + setDisplayedCreatorsCount(creators.length); + } else { + setDisplayedCreatorsCount(INITIAL_CREATORS_TO_SHOW); + } + }; + + const handleLocalCreatorChange = (creator: string) => { + setLocalCreators((prev) => { + if (prev.includes(creator)) { + return prev.filter((c) => c !== creator); + } + return [...prev, creator]; + }); + }; + + const handleClearFilters = () => { + setLocalCategories([]); + setLocalCreators([]); + setDisplayedCreatorsCount(INITIAL_CREATORS_TO_SHOW); + }; + + const handleCloseButton = () => { + setIsOpen(false); + setLocalCategories(filters); + setLocalCreators(creators); + setDisplayedCreatorsCount(INITIAL_CREATORS_TO_SHOW); + }; + + const handleApplyFilters = () => { + setFilters(localCategories); + setCreators(localCreators); + setIsOpen(false); + }; + + const handleOpenFilters = () => { + setIsOpen(true); + setLocalCategories(filters); + setLocalCreators(creators); + }; + + const hasLocalActiveFilters = () => { + return localCategories.length > 0 || localCreators.length > 0; + }; + + const visibleCreators = creators_list.slice(0, displayedCreatorsCount); + + return { + creators, + isOpen, + setIsOpen, + localCategories, + localCreators, + displayedCreatorsCount, + setDisplayedCreatorsCount, + handleLocalCategoryChange, + handleToggleShowMoreCreators, + handleLocalCreatorChange, + handleClearFilters, + handleCloseButton, + handleOpenFilters, + handleApplyFilters, + hasLocalActiveFilters, + visibleCreators, + hasActiveFilters, + }; +}; diff --git a/autogpt_platform/frontend/src/app/(platform)/build/stores/blockMenuStore.ts b/autogpt_platform/frontend/src/app/(platform)/build/stores/blockMenuStore.ts index ea50a03979..31b9eda338 100644 --- a/autogpt_platform/frontend/src/app/(platform)/build/stores/blockMenuStore.ts +++ b/autogpt_platform/frontend/src/app/(platform)/build/stores/blockMenuStore.ts @@ -1,12 +1,30 @@ import { create } from "zustand"; import { DefaultStateType } from "../components/NewControlPanel/NewBlockMenu/types"; +import { SearchResponseItemsItem } from "@/app/api/__generated__/models/searchResponseItemsItem"; +import { getSearchItemType } from "../components/NewControlPanel/NewBlockMenu/BlockMenuSearchContent/helper"; +import { StoreAgent } from "@/app/api/__generated__/models/storeAgent"; +import { GetV2BuilderSearchFilterAnyOfItem } from "@/app/api/__generated__/models/getV2BuilderSearchFilterAnyOfItem"; type BlockMenuStore = { searchQuery: string; searchId: string | undefined; defaultState: DefaultStateType; integration: string | undefined; + filters: GetV2BuilderSearchFilterAnyOfItem[]; + creators: string[]; + creators_list: string[]; + categoryCounts: Record; + setCategoryCounts: ( + counts: Record, + ) => void; + setCreatorsList: (searchData: SearchResponseItemsItem[]) => void; + addCreator: (creator: string) => void; + setCreators: (creators: string[]) => void; + removeCreator: (creator: string) => void; + addFilter: (filter: GetV2BuilderSearchFilterAnyOfItem) => void; + setFilters: (filters: GetV2BuilderSearchFilterAnyOfItem[]) => void; + removeFilter: (filter: GetV2BuilderSearchFilterAnyOfItem) => void; setSearchQuery: (query: string) => void; setSearchId: (id: string | undefined) => void; setDefaultState: (state: DefaultStateType) => void; @@ -19,11 +37,44 @@ export const useBlockMenuStore = create((set) => ({ searchId: undefined, defaultState: DefaultStateType.SUGGESTION, integration: undefined, + filters: [], + creators: [], // creator filters that are applied to the search results + creators_list: [], // all creators that are available to filter by + categoryCounts: { + blocks: 0, + integrations: 0, + marketplace_agents: 0, + my_agents: 0, + }, + setCategoryCounts: (counts) => set({ categoryCounts: counts }), + setCreatorsList: (searchData) => { + const marketplaceAgents = searchData.filter((item) => { + return getSearchItemType(item).type === "store_agent"; + }) as StoreAgent[]; + + const newCreators = marketplaceAgents.map((agent) => agent.creator); + + set((state) => ({ + creators_list: Array.from( + new Set([...state.creators_list, ...newCreators]), + ), + })); + }, + setCreators: (creators) => set({ creators }), + setFilters: (filters) => set({ filters }), setSearchQuery: (query) => set({ searchQuery: query }), setSearchId: (id) => set({ searchId: id }), setDefaultState: (state) => set({ defaultState: state }), setIntegration: (integration) => set({ integration }), + addFilter: (filter) => + set((state) => ({ filters: [...state.filters, filter] })), + removeFilter: (filter) => + set((state) => ({ filters: state.filters.filter((f) => f !== filter) })), + addCreator: (creator) => + set((state) => ({ creators: [...state.creators, creator] })), + removeCreator: (creator) => + set((state) => ({ creators: state.creators.filter((c) => c !== creator) })), reset: () => set({ searchQuery: "", diff --git a/autogpt_platform/frontend/src/app/(platform)/build/stores/nodeStore.ts b/autogpt_platform/frontend/src/app/(platform)/build/stores/nodeStore.ts index 96478c5b6f..c151f90faa 100644 --- a/autogpt_platform/frontend/src/app/(platform)/build/stores/nodeStore.ts +++ b/autogpt_platform/frontend/src/app/(platform)/build/stores/nodeStore.ts @@ -68,6 +68,9 @@ type NodeStore = { clearAllNodeErrors: () => void; // Add this syncHardcodedValuesWithHandleIds: (nodeId: string) => void; + + // Credentials optional helpers + setCredentialsOptional: (nodeId: string, optional: boolean) => void; }; export const useNodeStore = create((set, get) => ({ @@ -226,6 +229,9 @@ export const useNodeStore = create((set, get) => ({ ...(node.data.metadata?.customized_name !== undefined && { customized_name: node.data.metadata.customized_name, }), + ...(node.data.metadata?.credentials_optional !== undefined && { + credentials_optional: node.data.metadata.credentials_optional, + }), }, }; }, @@ -342,4 +348,30 @@ export const useNodeStore = create((set, get) => ({ })); } }, + + setCredentialsOptional: (nodeId: string, optional: boolean) => { + set((state) => ({ + nodes: state.nodes.map((n) => + n.id === nodeId + ? { + ...n, + data: { + ...n.data, + metadata: { + ...n.data.metadata, + credentials_optional: optional, + }, + }, + } + : n, + ), + })); + + const newState = { + nodes: get().nodes, + edges: useEdgeStore.getState().edges, + }; + + useHistoryStore.getState().pushState(newState); + }, })); diff --git a/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/CredentialsInputs/CredentialsInputs.tsx b/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/CredentialsInputs/CredentialsInputs.tsx index 60d61fab57..a0f9376aa2 100644 --- a/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/CredentialsInputs/CredentialsInputs.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/CredentialsInputs/CredentialsInputs.tsx @@ -34,7 +34,9 @@ type Props = { onSelectCredentials: (newValue?: CredentialsMetaInput) => void; onLoaded?: (loaded: boolean) => void; readOnly?: boolean; + isOptional?: boolean; showTitle?: boolean; + variant?: "default" | "node"; }; export function CredentialsInput({ @@ -45,7 +47,9 @@ export function CredentialsInput({ siblingInputs, onLoaded, readOnly = false, + isOptional = false, showTitle = true, + variant = "default", }: Props) { const hookData = useCredentialsInput({ schema, @@ -54,6 +58,7 @@ export function CredentialsInput({ siblingInputs, onLoaded, readOnly, + isOptional, }); if (!isLoaded(hookData)) { @@ -94,7 +99,14 @@ export function CredentialsInput({
{showTitle && (
- {displayName} credentials + + {displayName} credentials + {isOptional && ( + + (optional) + + )} + {schema.description && ( )} @@ -103,14 +115,17 @@ export function CredentialsInput({ {hasCredentialsToShow ? ( <> - {credentialsToShow.length > 1 && !readOnly ? ( + {(credentialsToShow.length > 1 || isOptional) && !readOnly ? ( onSelectCredential(undefined)} readOnly={readOnly} + allowNone={isOptional} + variant={variant} /> ) : (
diff --git a/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/CredentialsInputs/components/CredentialRow/CredentialRow.tsx b/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/CredentialsInputs/components/CredentialRow/CredentialRow.tsx index 21ec1200e4..2d0358aacb 100644 --- a/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/CredentialsInputs/components/CredentialRow/CredentialRow.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/CredentialsInputs/components/CredentialRow/CredentialRow.tsx @@ -30,6 +30,8 @@ type CredentialRowProps = { readOnly?: boolean; showCaret?: boolean; asSelectTrigger?: boolean; + /** When "node", applies compact styling for node context */ + variant?: "default" | "node"; }; export function CredentialRow({ @@ -41,14 +43,22 @@ export function CredentialRow({ readOnly = false, showCaret = false, asSelectTrigger = false, + variant = "default", }: CredentialRowProps) { const ProviderIcon = providerIcons[provider] || fallbackIcon; + const isNodeVariant = variant === "node"; return (
-
+
{getCredentialDisplayName(credential, displayName)} - - {"*".repeat(MASKED_KEY_LENGTH)} - + {!(asSelectTrigger && isNodeVariant) && ( + + {"*".repeat(MASKED_KEY_LENGTH)} + + )}
{showCaret && !asSelectTrigger && ( diff --git a/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/CredentialsInputs/components/CredentialsSelect/CredentialsSelect.tsx b/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/CredentialsInputs/components/CredentialsSelect/CredentialsSelect.tsx index 7adfa5772b..6e1ec2afb1 100644 --- a/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/CredentialsInputs/components/CredentialsSelect/CredentialsSelect.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/CredentialsInputs/components/CredentialsSelect/CredentialsSelect.tsx @@ -7,6 +7,7 @@ import { } from "@/components/__legacy__/ui/select"; import { Text } from "@/components/atoms/Text/Text"; import { CredentialsMetaInput } from "@/lib/autogpt-server-api/types"; +import { cn } from "@/lib/utils"; import { useEffect } from "react"; import { getCredentialDisplayName } from "../../helpers"; import { CredentialRow } from "../CredentialRow/CredentialRow"; @@ -23,7 +24,11 @@ interface Props { displayName: string; selectedCredentials?: CredentialsMetaInput; onSelectCredential: (credentialId: string) => void; + onClearCredential?: () => void; readOnly?: boolean; + allowNone?: boolean; + /** When "node", applies compact styling for node context */ + variant?: "default" | "node"; } export function CredentialsSelect({ @@ -32,22 +37,38 @@ export function CredentialsSelect({ displayName, selectedCredentials, onSelectCredential, + onClearCredential, readOnly = false, + allowNone = true, + variant = "default", }: Props) { - // Auto-select first credential if none is selected + // Auto-select first credential if none is selected (only if allowNone is false) useEffect(() => { - if (!selectedCredentials && credentials.length > 0) { + if (!allowNone && !selectedCredentials && credentials.length > 0) { onSelectCredential(credentials[0].id); } - }, [selectedCredentials, credentials, onSelectCredential]); + }, [allowNone, selectedCredentials, credentials, onSelectCredential]); + + const handleValueChange = (value: string) => { + if (value === "__none__") { + onClearCredential?.(); + } else { + onSelectCredential(value); + } + }; return (
setIsFocused(true)} - onBlur={() => !inputRef.current?.value && setIsFocused(false)} + label="Search agents" + id="library-search-bar" + hideLabel onChange={handleSearchInput} - className="flex-1 border-none font-sans text-[16px] font-normal leading-7 shadow-none focus:shadow-none focus:ring-0" + className="min-w-[18rem] pl-12 lg:min-w-[30rem]" type="text" data-testid="library-textbox" placeholder="Search agents" /> - - {isFocused && inputRef.current?.value && ( - - )}
); } diff --git a/autogpt_platform/frontend/src/app/(platform)/library/components/LibrarySearchBar/useLibrarySearchbar.tsx b/autogpt_platform/frontend/src/app/(platform)/library/components/LibrarySearchBar/useLibrarySearchbar.tsx index f6428c6c4e..74b8e9874c 100644 --- a/autogpt_platform/frontend/src/app/(platform)/library/components/LibrarySearchBar/useLibrarySearchbar.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/library/components/LibrarySearchBar/useLibrarySearchbar.tsx @@ -1,36 +1,30 @@ -import { useRef, useState } from "react"; -import { useLibraryPageContext } from "../state-provider"; import { debounce } from "lodash"; +import { useCallback, useEffect } from "react"; -export const useLibrarySearchbar = () => { - const inputRef = useRef(null); - const [isFocused, setIsFocused] = useState(false); - const { setSearchTerm } = useLibraryPageContext(); +interface Props { + setSearchTerm: (value: string) => void; +} - const debouncedSearch = debounce((value: string) => { - setSearchTerm(value); - }, 300); +export function useLibrarySearchbar({ setSearchTerm }: Props) { + const debouncedSearch = useCallback( + debounce((value: string) => { + setSearchTerm(value); + }, 300), + [setSearchTerm], + ); - const handleSearchInput = (e: React.ChangeEvent) => { + useEffect(() => { + return () => { + debouncedSearch.cancel(); + }; + }, [debouncedSearch]); + + function handleSearchInput(e: React.ChangeEvent) { const searchTerm = e.target.value; debouncedSearch(searchTerm); - }; - - const handleClear = (e: React.MouseEvent) => { - if (inputRef.current) { - inputRef.current.value = ""; - inputRef.current.blur(); - setSearchTerm(""); - e.preventDefault(); - } - setIsFocused(false); - }; + } return { - handleClear, handleSearchInput, - isFocused, - inputRef, - setIsFocused, }; -}; +} diff --git a/autogpt_platform/frontend/src/app/(platform)/library/components/LibrarySortMenu/LibrarySortMenu.tsx b/autogpt_platform/frontend/src/app/(platform)/library/components/LibrarySortMenu/LibrarySortMenu.tsx index ac4ed060f2..de37af5fad 100644 --- a/autogpt_platform/frontend/src/app/(platform)/library/components/LibrarySortMenu/LibrarySortMenu.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/library/components/LibrarySortMenu/LibrarySortMenu.tsx @@ -1,5 +1,5 @@ "use client"; -import { ArrowDownNarrowWideIcon } from "lucide-react"; +import { LibraryAgentSort } from "@/app/api/__generated__/models/libraryAgentSort"; import { Select, SelectContent, @@ -8,11 +8,15 @@ import { SelectTrigger, SelectValue, } from "@/components/__legacy__/ui/select"; -import { LibraryAgentSort } from "@/app/api/__generated__/models/libraryAgentSort"; +import { ArrowDownNarrowWideIcon } from "lucide-react"; import { useLibrarySortMenu } from "./useLibrarySortMenu"; -export default function LibrarySortMenu(): React.ReactNode { - const { handleSortChange } = useLibrarySortMenu(); +interface Props { + setLibrarySort: (value: LibraryAgentSort) => void; +} + +export function LibrarySortMenu({ setLibrarySort }: Props) { + const { handleSortChange } = useLibrarySortMenu({ setLibrarySort }); return (
sort by diff --git a/autogpt_platform/frontend/src/app/(platform)/library/components/LibrarySortMenu/useLibrarySortMenu.ts b/autogpt_platform/frontend/src/app/(platform)/library/components/LibrarySortMenu/useLibrarySortMenu.ts index d2575c8936..e6d6f2d127 100644 --- a/autogpt_platform/frontend/src/app/(platform)/library/components/LibrarySortMenu/useLibrarySortMenu.ts +++ b/autogpt_platform/frontend/src/app/(platform)/library/components/LibrarySortMenu/useLibrarySortMenu.ts @@ -1,11 +1,11 @@ import { LibraryAgentSort } from "@/app/api/__generated__/models/libraryAgentSort"; -import { useLibraryPageContext } from "../state-provider"; -export const useLibrarySortMenu = () => { - const { setLibrarySort } = useLibraryPageContext(); +interface Props { + setLibrarySort: (value: LibraryAgentSort) => void; +} +export function useLibrarySortMenu({ setLibrarySort }: Props) { const handleSortChange = (value: LibraryAgentSort) => { - // Simply updating the sort state - React Query will handle the rest setLibrarySort(value); }; @@ -24,4 +24,4 @@ export const useLibrarySortMenu = () => { handleSortChange, getSortLabel, }; -}; +} diff --git a/autogpt_platform/frontend/src/app/(platform)/library/components/LibraryUploadAgentDialog/LibraryUploadAgentDialog.tsx b/autogpt_platform/frontend/src/app/(platform)/library/components/LibraryUploadAgentDialog/LibraryUploadAgentDialog.tsx index d92bbe86fe..1a6999721e 100644 --- a/autogpt_platform/frontend/src/app/(platform)/library/components/LibraryUploadAgentDialog/LibraryUploadAgentDialog.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/library/components/LibraryUploadAgentDialog/LibraryUploadAgentDialog.tsx @@ -1,192 +1,134 @@ "use client"; -import { Upload, X } from "lucide-react"; -import { Button } from "@/components/__legacy__/Button"; -import { - Dialog, - DialogContent, - DialogDescription, - DialogHeader, - DialogTitle, - DialogTrigger, -} from "@/components/__legacy__/ui/dialog"; -import { z } from "zod"; -import { FileUploader } from "react-drag-drop-files"; +import { Button } from "@/components/atoms/Button/Button"; +import { FileInput } from "@/components/atoms/FileInput/FileInput"; +import { Input } from "@/components/atoms/Input/Input"; +import { Dialog } from "@/components/molecules/Dialog/Dialog"; import { Form, FormControl, FormField, FormItem, - FormLabel, FormMessage, -} from "@/components/__legacy__/ui/form"; -import { Input } from "@/components/__legacy__/ui/input"; -import { Textarea } from "@/components/__legacy__/ui/textarea"; +} from "@/components/molecules/Form/Form"; +import { UploadSimpleIcon } from "@phosphor-icons/react"; +import { z } from "zod"; import { useLibraryUploadAgentDialog } from "./useLibraryUploadAgentDialog"; -const fileTypes = ["JSON"]; - -const fileSchema = z.custom((val) => val instanceof File, { - message: "Must be a File object", -}); - export const uploadAgentFormSchema = z.object({ - agentFile: fileSchema, + agentFile: z.string().min(1, "Agent file is required"), agentName: z.string().min(1, "Agent name is required"), agentDescription: z.string(), }); -export default function LibraryUploadAgentDialog(): React.ReactNode { - const { - onSubmit, - isUploading, - isOpen, - setIsOpen, - isDroped, - handleChange, - form, - setisDroped, - agentObject, - clearAgentFile, - } = useLibraryUploadAgentDialog(); +export default function LibraryUploadAgentDialog() { + const { onSubmit, isUploading, isOpen, setIsOpen, form, agentObject } = + useLibraryUploadAgentDialog(); + return ( - - + { + setIsOpen(false); + }} + > + - - - - Upload Agent - - Upload your agent by providing a name, description, and JSON file. - - + + +
+ ( + + + + + + + )} + /> - - - ( - - Agent name - - - - - - )} - /> + ( + + + + + + + )} + /> - ( - - Description - -