Compare commits

..

4 Commits

Author SHA1 Message Date
Nicholas Tindle
2110e252ff Merge branch 'dev' into seer/xml-parsing-error-handling 2025-12-26 10:05:30 -06:00
claude[bot]
5b31b74b18 fix(backend): use gravitasml 0.1.4 from PyPI directly
Co-authored-by: Nicholas Tindle <ntindle@users.noreply.github.com>
2025-12-26 16:02:40 +00:00
Cursor Agent
267ea4d446 fix(backend): update gravitasml parser dependency
Pin the backend to gravitasml tag 0.1.4 so the XML parser fix
from Significant-Gravitas/gravitasml#15 ships with this PR.
2025-12-18 17:59:29 +00:00
seer-by-sentry[bot]
42904c6a41 feat(blocks): Improve XML parsing error handling with specific AttributeError message 2025-10-18 07:42:41 +00:00
377 changed files with 8628 additions and 18790 deletions

View File

@@ -1,37 +0,0 @@
{
"worktreeCopyPatterns": [
".env*",
".vscode/**",
".auth/**",
".claude/**",
"autogpt_platform/.env*",
"autogpt_platform/backend/.env*",
"autogpt_platform/frontend/.env*",
"autogpt_platform/frontend/.auth/**",
"autogpt_platform/db/docker/.env*"
],
"worktreeCopyIgnores": [
"**/node_modules/**",
"**/dist/**",
"**/.git/**",
"**/Thumbs.db",
"**/.DS_Store",
"**/.next/**",
"**/__pycache__/**",
"**/.ruff_cache/**",
"**/.pytest_cache/**",
"**/*.pyc",
"**/playwright-report/**",
"**/logs/**",
"**/site/**"
],
"worktreePathTemplate": "$BASE_PATH.worktree",
"postCreateCmd": [
"cd autogpt_platform/autogpt_libs && poetry install",
"cd autogpt_platform/backend && poetry install && poetry run prisma generate",
"cd autogpt_platform/frontend && pnpm install",
"cd docs && pip install -r requirements.txt"
],
"terminalCommand": "code .",
"deleteBranchWithWorktree": false
}

View File

@@ -16,7 +16,6 @@
!autogpt_platform/backend/poetry.lock
!autogpt_platform/backend/README.md
!autogpt_platform/backend/.env
!autogpt_platform/backend/gen_prisma_types_stub.py
# Platform - Market
!autogpt_platform/market/market/

View File

@@ -74,7 +74,7 @@ jobs:
- name: Generate Prisma Client
working-directory: autogpt_platform/backend
run: poetry run prisma generate && poetry run gen-prisma-stub
run: poetry run prisma generate
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
- name: Set up Node.js

View File

@@ -90,7 +90,7 @@ jobs:
- name: Generate Prisma Client
working-directory: autogpt_platform/backend
run: poetry run prisma generate && poetry run gen-prisma-stub
run: poetry run prisma generate
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
- name: Set up Node.js

View File

@@ -72,7 +72,7 @@ jobs:
- name: Generate Prisma Client
working-directory: autogpt_platform/backend
run: poetry run prisma generate && poetry run gen-prisma-stub
run: poetry run prisma generate
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
- name: Set up Node.js
@@ -108,16 +108,6 @@ jobs:
# run: pnpm playwright install --with-deps chromium
# Docker setup for development environment
- name: Free up disk space
run: |
# Remove large unused tools to free disk space for Docker builds
sudo rm -rf /usr/share/dotnet
sudo rm -rf /usr/local/lib/android
sudo rm -rf /opt/ghc
sudo rm -rf /opt/hostedtoolcache/CodeQL
sudo docker system prune -af
df -h
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3

View File

@@ -134,7 +134,7 @@ jobs:
run: poetry install
- name: Generate Prisma Client
run: poetry run prisma generate && poetry run gen-prisma-stub
run: poetry run prisma generate
- id: supabase
name: Start Supabase

View File

@@ -12,7 +12,6 @@ reset-db:
rm -rf db/docker/volumes/db/data
cd backend && poetry run prisma migrate deploy
cd backend && poetry run prisma generate
cd backend && poetry run gen-prisma-stub
# View logs for core services
logs-core:
@@ -34,7 +33,6 @@ init-env:
migrate:
cd backend && poetry run prisma migrate deploy
cd backend && poetry run prisma generate
cd backend && poetry run gen-prisma-stub
run-backend:
cd backend && poetry run app

View File

@@ -48,8 +48,7 @@ RUN poetry install --no-ansi --no-root
# Generate Prisma client
COPY autogpt_platform/backend/schema.prisma ./
COPY autogpt_platform/backend/backend/data/partial_types.py ./backend/data/partial_types.py
COPY autogpt_platform/backend/gen_prisma_types_stub.py ./
RUN poetry run prisma generate && poetry run gen-prisma-stub
RUN poetry run prisma generate
FROM debian:13-slim AS server_dependencies

View File

@@ -489,7 +489,7 @@ async def update_agent_version_in_library(
agent_graph_version: int,
) -> library_model.LibraryAgent:
"""
Updates the agent version in the library for any agent owned by the user.
Updates the agent version in the library if useGraphIsActiveVersion is True.
Args:
user_id: Owner of the LibraryAgent.
@@ -498,31 +498,20 @@ async def update_agent_version_in_library(
Raises:
DatabaseError: If there's an error with the update.
NotFoundError: If no library agent is found for this user and agent.
"""
logger.debug(
f"Updating agent version in library for user #{user_id}, "
f"agent #{agent_graph_id} v{agent_graph_version}"
)
async with transaction() as tx:
library_agent = await prisma.models.LibraryAgent.prisma(tx).find_first_or_raise(
try:
library_agent = await prisma.models.LibraryAgent.prisma().find_first_or_raise(
where={
"userId": user_id,
"agentGraphId": agent_graph_id,
"useGraphIsActiveVersion": True,
},
)
# Delete any conflicting LibraryAgent for the target version
await prisma.models.LibraryAgent.prisma(tx).delete_many(
where={
"userId": user_id,
"agentGraphId": agent_graph_id,
"agentGraphVersion": agent_graph_version,
"id": {"not": library_agent.id},
}
)
lib = await prisma.models.LibraryAgent.prisma(tx).update(
lib = await prisma.models.LibraryAgent.prisma().update(
where={"id": library_agent.id},
data={
"AgentGraph": {
@@ -536,13 +525,13 @@ async def update_agent_version_in_library(
},
include={"AgentGraph": True},
)
if lib is None:
raise NotFoundError(f"Library agent {library_agent.id} not found")
if lib is None:
raise NotFoundError(
f"Failed to update library agent for {agent_graph_id} v{agent_graph_version}"
)
return library_model.LibraryAgent.from_db(lib)
return library_model.LibraryAgent.from_db(lib)
except prisma.errors.PrismaError as e:
logger.error(f"Database error updating agent version in library: {e}")
raise DatabaseError("Failed to update agent version in library") from e
async def update_library_agent(
@@ -836,7 +825,6 @@ async def add_store_agent_to_library(
}
},
"isCreatedByUser": False,
"useGraphIsActiveVersion": False,
"settings": SafeJson(
_initialize_graph_settings(graph_model).model_dump()
),

View File

@@ -48,7 +48,6 @@ class LibraryAgent(pydantic.BaseModel):
id: str
graph_id: str
graph_version: int
owner_user_id: str # ID of user who owns/created this agent graph
image_url: str | None
@@ -164,7 +163,6 @@ class LibraryAgent(pydantic.BaseModel):
id=agent.id,
graph_id=agent.agentGraphId,
graph_version=agent.agentGraphVersion,
owner_user_id=agent.userId,
image_url=agent.imageUrl,
creator_name=creator_name,
creator_image_url=creator_image_url,

View File

@@ -42,7 +42,6 @@ async def test_get_library_agents_success(
id="test-agent-1",
graph_id="test-agent-1",
graph_version=1,
owner_user_id=test_user_id,
name="Test Agent 1",
description="Test Description 1",
image_url=None,
@@ -65,7 +64,6 @@ async def test_get_library_agents_success(
id="test-agent-2",
graph_id="test-agent-2",
graph_version=1,
owner_user_id=test_user_id,
name="Test Agent 2",
description="Test Description 2",
image_url=None,
@@ -140,7 +138,6 @@ async def test_get_favorite_library_agents_success(
id="test-agent-1",
graph_id="test-agent-1",
graph_version=1,
owner_user_id=test_user_id,
name="Favorite Agent 1",
description="Test Favorite Description 1",
image_url=None,
@@ -208,7 +205,6 @@ def test_add_agent_to_library_success(
id="test-library-agent-id",
graph_id="test-agent-1",
graph_version=1,
owner_user_id=test_user_id,
name="Test Agent 1",
description="Test Description 1",
image_url=None,

View File

@@ -614,7 +614,6 @@ async def get_store_submissions(
submission_models = []
for sub in submissions:
submission_model = store_model.StoreSubmission(
listing_id=sub.listing_id,
agent_id=sub.agent_id,
agent_version=sub.agent_version,
name=sub.name,
@@ -668,48 +667,35 @@ async def delete_store_submission(
submission_id: str,
) -> bool:
"""
Delete a store submission version as the submitting user.
Delete a store listing submission as the submitting user.
Args:
user_id: ID of the authenticated user
submission_id: StoreListingVersion ID to delete
submission_id: ID of the submission to be deleted
Returns:
bool: True if successfully deleted
bool: True if the submission was successfully deleted, False otherwise
"""
logger.debug(f"Deleting store submission {submission_id} for user {user_id}")
try:
# Find the submission version with ownership check
version = await prisma.models.StoreListingVersion.prisma().find_first(
where={"id": submission_id}, include={"StoreListing": True}
# Verify the submission belongs to this user
submission = await prisma.models.StoreListing.prisma().find_first(
where={"agentGraphId": submission_id, "owningUserId": user_id}
)
if (
not version
or not version.StoreListing
or version.StoreListing.owningUserId != user_id
):
raise store_exceptions.SubmissionNotFoundError("Submission not found")
# Prevent deletion of approved submissions
if version.submissionStatus == prisma.enums.SubmissionStatus.APPROVED:
raise store_exceptions.InvalidOperationError(
"Cannot delete approved submissions"
if not submission:
logger.warning(f"Submission not found for user {user_id}: {submission_id}")
raise store_exceptions.SubmissionNotFoundError(
f"Submission not found for this user. User ID: {user_id}, Submission ID: {submission_id}"
)
# Delete the version
await prisma.models.StoreListingVersion.prisma().delete(
where={"id": version.id}
)
# Delete the submission
await prisma.models.StoreListing.prisma().delete(where={"id": submission.id})
# Clean up empty listing if this was the last version
remaining = await prisma.models.StoreListingVersion.prisma().count(
where={"storeListingId": version.storeListingId}
logger.debug(
f"Successfully deleted submission {submission_id} for user {user_id}"
)
if remaining == 0:
await prisma.models.StoreListing.prisma().delete(
where={"id": version.storeListingId}
)
return True
except Exception as e:
@@ -773,15 +759,9 @@ async def create_store_submission(
logger.warning(
f"Agent not found for user {user_id}: {agent_id} v{agent_version}"
)
# Provide more user-friendly error message when agent_id is empty
if not agent_id or agent_id.strip() == "":
raise store_exceptions.AgentNotFoundError(
"No agent selected. Please select an agent before submitting to the store."
)
else:
raise store_exceptions.AgentNotFoundError(
f"Agent not found for this user. User ID: {user_id}, Agent ID: {agent_id}, Version: {agent_version}"
)
raise store_exceptions.AgentNotFoundError(
f"Agent not found for this user. User ID: {user_id}, Agent ID: {agent_id}, Version: {agent_version}"
)
# Check if listing already exists for this agent
existing_listing = await prisma.models.StoreListing.prisma().find_first(
@@ -853,7 +833,6 @@ async def create_store_submission(
logger.debug(f"Created store listing for agent {agent_id}")
# Return submission details
return store_model.StoreSubmission(
listing_id=listing.id,
agent_id=agent_id,
agent_version=agent_version,
name=name,
@@ -965,56 +944,81 @@ async def edit_store_submission(
# Currently we are not allowing user to update the agent associated with a submission
# If we allow it in future, then we need a check here to verify the agent belongs to this user.
# Only allow editing of PENDING submissions
if current_version.submissionStatus != prisma.enums.SubmissionStatus.PENDING:
# Check if we can edit this submission
if current_version.submissionStatus == prisma.enums.SubmissionStatus.REJECTED:
raise store_exceptions.InvalidOperationError(
f"Cannot edit a {current_version.submissionStatus.value.lower()} submission. Only pending submissions can be edited."
"Cannot edit a rejected submission"
)
# For APPROVED submissions, we need to create a new version
if current_version.submissionStatus == prisma.enums.SubmissionStatus.APPROVED:
# Create a new version for the existing listing
return await create_store_version(
user_id=user_id,
agent_id=current_version.agentGraphId,
agent_version=current_version.agentGraphVersion,
store_listing_id=current_version.storeListingId,
name=name,
video_url=video_url,
agent_output_demo_url=agent_output_demo_url,
image_urls=image_urls,
description=description,
sub_heading=sub_heading,
categories=categories,
changes_summary=changes_summary,
recommended_schedule_cron=recommended_schedule_cron,
instructions=instructions,
)
# For PENDING submissions, we can update the existing version
# Update the existing version
updated_version = await prisma.models.StoreListingVersion.prisma().update(
where={"id": store_listing_version_id},
data=prisma.types.StoreListingVersionUpdateInput(
elif current_version.submissionStatus == prisma.enums.SubmissionStatus.PENDING:
# Update the existing version
updated_version = await prisma.models.StoreListingVersion.prisma().update(
where={"id": store_listing_version_id},
data=prisma.types.StoreListingVersionUpdateInput(
name=name,
videoUrl=video_url,
agentOutputDemoUrl=agent_output_demo_url,
imageUrls=image_urls,
description=description,
categories=categories,
subHeading=sub_heading,
changesSummary=changes_summary,
recommendedScheduleCron=recommended_schedule_cron,
instructions=instructions,
),
)
logger.debug(
f"Updated existing version {store_listing_version_id} for agent {current_version.agentGraphId}"
)
if not updated_version:
raise DatabaseError("Failed to update store listing version")
return store_model.StoreSubmission(
agent_id=current_version.agentGraphId,
agent_version=current_version.agentGraphVersion,
name=name,
videoUrl=video_url,
agentOutputDemoUrl=agent_output_demo_url,
imageUrls=image_urls,
sub_heading=sub_heading,
slug=current_version.StoreListing.slug,
description=description,
categories=categories,
subHeading=sub_heading,
changesSummary=changes_summary,
recommendedScheduleCron=recommended_schedule_cron,
instructions=instructions,
),
)
image_urls=image_urls,
date_submitted=updated_version.submittedAt or updated_version.createdAt,
status=updated_version.submissionStatus,
runs=0,
rating=0.0,
store_listing_version_id=updated_version.id,
changes_summary=changes_summary,
video_url=video_url,
categories=categories,
version=updated_version.version,
)
logger.debug(
f"Updated existing version {store_listing_version_id} for agent {current_version.agentGraphId}"
)
if not updated_version:
raise DatabaseError("Failed to update store listing version")
return store_model.StoreSubmission(
listing_id=current_version.StoreListing.id,
agent_id=current_version.agentGraphId,
agent_version=current_version.agentGraphVersion,
name=name,
sub_heading=sub_heading,
slug=current_version.StoreListing.slug,
description=description,
instructions=instructions,
image_urls=image_urls,
date_submitted=updated_version.submittedAt or updated_version.createdAt,
status=updated_version.submissionStatus,
runs=0,
rating=0.0,
store_listing_version_id=updated_version.id,
changes_summary=changes_summary,
video_url=video_url,
categories=categories,
version=updated_version.version,
)
else:
raise store_exceptions.InvalidOperationError(
f"Cannot edit submission with status: {current_version.submissionStatus}"
)
except (
store_exceptions.SubmissionNotFoundError,
@@ -1093,78 +1097,38 @@ async def create_store_version(
f"Agent not found for this user. User ID: {user_id}, Agent ID: {agent_id}, Version: {agent_version}"
)
# Check if there's already a PENDING submission for this agent (any version)
existing_pending_submission = (
await prisma.models.StoreListingVersion.prisma().find_first(
where=prisma.types.StoreListingVersionWhereInput(
storeListingId=store_listing_id,
agentGraphId=agent_id,
submissionStatus=prisma.enums.SubmissionStatus.PENDING,
isDeleted=False,
)
# Get the latest version number
latest_version = listing.Versions[0] if listing.Versions else None
next_version = (latest_version.version + 1) if latest_version else 1
# Create a new version for the existing listing
new_version = await prisma.models.StoreListingVersion.prisma().create(
data=prisma.types.StoreListingVersionCreateInput(
version=next_version,
agentGraphId=agent_id,
agentGraphVersion=agent_version,
name=name,
videoUrl=video_url,
agentOutputDemoUrl=agent_output_demo_url,
imageUrls=image_urls,
description=description,
instructions=instructions,
categories=categories,
subHeading=sub_heading,
submissionStatus=prisma.enums.SubmissionStatus.PENDING,
submittedAt=datetime.now(),
changesSummary=changes_summary,
recommendedScheduleCron=recommended_schedule_cron,
storeListingId=store_listing_id,
)
)
# Handle existing pending submission and create new one atomically
async with transaction() as tx:
# Get the latest version number first
latest_listing = await prisma.models.StoreListing.prisma(tx).find_first(
where=prisma.types.StoreListingWhereInput(
id=store_listing_id, owningUserId=user_id
),
include={"Versions": {"order_by": {"version": "desc"}, "take": 1}},
)
if not latest_listing:
raise store_exceptions.ListingNotFoundError(
f"Store listing not found. User ID: {user_id}, Listing ID: {store_listing_id}"
)
latest_version = (
latest_listing.Versions[0] if latest_listing.Versions else None
)
next_version = (latest_version.version + 1) if latest_version else 1
# If there's an existing pending submission, delete it atomically before creating new one
if existing_pending_submission:
logger.info(
f"Found existing PENDING submission for agent {agent_id} (was v{existing_pending_submission.agentGraphVersion}, now v{agent_version}), replacing existing submission instead of creating duplicate"
)
await prisma.models.StoreListingVersion.prisma(tx).delete(
where={"id": existing_pending_submission.id}
)
logger.debug(
f"Deleted existing pending submission {existing_pending_submission.id}"
)
# Create a new version for the existing listing
new_version = await prisma.models.StoreListingVersion.prisma(tx).create(
data=prisma.types.StoreListingVersionCreateInput(
version=next_version,
agentGraphId=agent_id,
agentGraphVersion=agent_version,
name=name,
videoUrl=video_url,
agentOutputDemoUrl=agent_output_demo_url,
imageUrls=image_urls,
description=description,
instructions=instructions,
categories=categories,
subHeading=sub_heading,
submissionStatus=prisma.enums.SubmissionStatus.PENDING,
submittedAt=datetime.now(),
changesSummary=changes_summary,
recommendedScheduleCron=recommended_schedule_cron,
storeListingId=store_listing_id,
)
)
logger.debug(
f"Created new version for listing {store_listing_id} of agent {agent_id}"
)
# Return submission details
return store_model.StoreSubmission(
listing_id=listing.id,
agent_id=agent_id,
agent_version=agent_version,
name=name,
@@ -1744,12 +1708,15 @@ async def review_store_submission(
# Convert to Pydantic model for consistency
return store_model.StoreSubmission(
listing_id=(submission.StoreListing.id if submission.StoreListing else ""),
agent_id=submission.agentGraphId,
agent_version=submission.agentGraphVersion,
name=submission.name,
sub_heading=submission.subHeading,
slug=(submission.StoreListing.slug if submission.StoreListing else ""),
slug=(
submission.StoreListing.slug
if hasattr(submission, "storeListing") and submission.StoreListing
else ""
),
description=submission.description,
instructions=submission.instructions,
image_urls=submission.imageUrls or [],
@@ -1851,7 +1818,9 @@ async def get_admin_listings_with_versions(
where = prisma.types.StoreListingWhereInput(**where_dict)
include = prisma.types.StoreListingInclude(
Versions=prisma.types.FindManyStoreListingVersionArgsFromStoreListing(
order_by={"version": "desc"}
order_by=prisma.types._StoreListingVersion_version_OrderByInput(
version="desc"
)
),
OwningUser=True,
)
@@ -1876,7 +1845,6 @@ async def get_admin_listings_with_versions(
# If we have versions, turn them into StoreSubmission models
for version in listing.Versions or []:
version_model = store_model.StoreSubmission(
listing_id=listing.id,
agent_id=version.agentGraphId,
agent_version=version.agentGraphVersion,
name=version.name,

View File

@@ -110,7 +110,6 @@ class Profile(pydantic.BaseModel):
class StoreSubmission(pydantic.BaseModel):
listing_id: str
agent_id: str
agent_version: int
name: str
@@ -165,12 +164,8 @@ class StoreListingsWithVersionsResponse(pydantic.BaseModel):
class StoreSubmissionRequest(pydantic.BaseModel):
agent_id: str = pydantic.Field(
..., min_length=1, description="Agent ID cannot be empty"
)
agent_version: int = pydantic.Field(
..., gt=0, description="Agent version must be greater than 0"
)
agent_id: str
agent_version: int
slug: str
name: str
sub_heading: str

View File

@@ -138,7 +138,6 @@ def test_creator_details():
def test_store_submission():
submission = store_model.StoreSubmission(
listing_id="listing123",
agent_id="agent123",
agent_version=1,
sub_heading="Test subheading",
@@ -160,7 +159,6 @@ def test_store_submissions_response():
response = store_model.StoreSubmissionsResponse(
submissions=[
store_model.StoreSubmission(
listing_id="listing123",
agent_id="agent123",
agent_version=1,
sub_heading="Test subheading",

View File

@@ -521,7 +521,6 @@ def test_get_submissions_success(
mocked_value = store_model.StoreSubmissionsResponse(
submissions=[
store_model.StoreSubmission(
listing_id="test-listing-id",
name="Test Agent",
description="Test agent description",
image_urls=["test.jpg"],

View File

@@ -39,7 +39,7 @@ import backend.data.user
import backend.integrations.webhooks.utils
import backend.util.service
import backend.util.settings
from backend.blocks.llm import DEFAULT_LLM_MODEL
from backend.blocks.llm import LlmModel
from backend.data.model import Credentials
from backend.integrations.providers import ProviderName
from backend.monitoring.instrumentation import instrument_fastapi
@@ -113,7 +113,7 @@ async def lifespan_context(app: fastapi.FastAPI):
await backend.data.user.migrate_and_encrypt_user_integrations()
await backend.data.graph.fix_llm_provider_credentials()
await backend.data.graph.migrate_llm_models(DEFAULT_LLM_MODEL)
await backend.data.graph.migrate_llm_models(LlmModel.GPT4O)
await backend.integrations.webhooks.utils.migrate_legacy_triggered_graphs()
with launch_darkly_context():

View File

@@ -1,7 +1,6 @@
from typing import Any
from backend.blocks.llm import (
DEFAULT_LLM_MODEL,
TEST_CREDENTIALS,
TEST_CREDENTIALS_INPUT,
AIBlockBase,
@@ -50,7 +49,7 @@ class AIConditionBlock(AIBlockBase):
)
model: LlmModel = SchemaField(
title="LLM Model",
default=DEFAULT_LLM_MODEL,
default=LlmModel.GPT4O,
description="The language model to use for evaluating the condition.",
advanced=False,
)
@@ -82,7 +81,7 @@ class AIConditionBlock(AIBlockBase):
"condition": "the input is an email address",
"yes_value": "Valid email",
"no_value": "Not an email",
"model": DEFAULT_LLM_MODEL,
"model": LlmModel.GPT4O,
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,

View File

@@ -6,9 +6,6 @@ import hashlib
import hmac
import logging
from enum import Enum
from typing import cast
from prisma.types import Serializable
from backend.sdk import (
BaseWebhooksManager,
@@ -87,9 +84,7 @@ class AirtableWebhookManager(BaseWebhooksManager):
# update webhook config
await update_webhook(
webhook.id,
config=cast(
dict[str, Serializable], {"base_id": base_id, "cursor": response.cursor}
),
config={"base_id": base_id, "cursor": response.cursor},
)
event_type = "notification"

View File

@@ -1,602 +0,0 @@
import shlex
from typing import Literal
from e2b import AsyncSandbox as BaseAsyncSandbox
from pydantic import BaseModel, SecretStr
from backend.data.block import (
Block,
BlockCategory,
BlockOutput,
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.model import (
APIKeyCredentials,
CredentialsField,
CredentialsMetaInput,
SchemaField,
)
from backend.integrations.providers import ProviderName
# Test credentials for E2B
TEST_E2B_CREDENTIALS = APIKeyCredentials(
id="01234567-89ab-cdef-0123-456789abcdef",
provider="e2b",
api_key=SecretStr("mock-e2b-api-key"),
title="Mock E2B API key",
expires_at=None,
)
TEST_E2B_CREDENTIALS_INPUT = {
"provider": TEST_E2B_CREDENTIALS.provider,
"id": TEST_E2B_CREDENTIALS.id,
"type": TEST_E2B_CREDENTIALS.type,
"title": TEST_E2B_CREDENTIALS.title,
}
# Test credentials for Anthropic
TEST_ANTHROPIC_CREDENTIALS = APIKeyCredentials(
id="2e568a2b-b2ea-475a-8564-9a676bf31c56",
provider="anthropic",
api_key=SecretStr("mock-anthropic-api-key"),
title="Mock Anthropic API key",
expires_at=None,
)
TEST_ANTHROPIC_CREDENTIALS_INPUT = {
"provider": TEST_ANTHROPIC_CREDENTIALS.provider,
"id": TEST_ANTHROPIC_CREDENTIALS.id,
"type": TEST_ANTHROPIC_CREDENTIALS.type,
"title": TEST_ANTHROPIC_CREDENTIALS.title,
}
class ClaudeCodeBlock(Block):
"""
Execute tasks using Claude Code (Anthropic's AI coding assistant) in an E2B sandbox.
Claude Code can create files, install tools, run commands, and perform complex
coding tasks autonomously within a secure sandbox environment.
"""
# Use base template - we'll install Claude Code ourselves for latest version
DEFAULT_TEMPLATE = "base"
class Input(BlockSchemaInput):
e2b_credentials: CredentialsMetaInput[
Literal[ProviderName.E2B], Literal["api_key"]
] = CredentialsField(
description=(
"API key for the E2B platform to create the sandbox. "
"Get one at https://e2b.dev/docs"
),
)
anthropic_credentials: CredentialsMetaInput[
Literal[ProviderName.ANTHROPIC], Literal["api_key"]
] = CredentialsField(
description=(
"API key for Anthropic to power Claude Code. "
"Get one at https://console.anthropic.com/"
),
)
prompt: str = SchemaField(
description=(
"The task or instruction for Claude Code to execute. "
"Claude Code can create files, install packages, run commands, "
"and perform complex coding tasks."
),
placeholder="Create a hello world index.html file",
default="",
advanced=False,
)
timeout: int = SchemaField(
description=(
"Sandbox timeout in seconds. Claude Code tasks can take "
"a while, so set this appropriately for your task complexity. "
"Note: This only applies when creating a new sandbox. "
"When reconnecting to an existing sandbox via sandbox_id, "
"the original timeout is retained."
),
default=300, # 5 minutes default
advanced=True,
)
setup_commands: list[str] = SchemaField(
description=(
"Optional shell commands to run before executing Claude Code. "
"Useful for installing dependencies or setting up the environment."
),
default_factory=list,
advanced=True,
)
working_directory: str = SchemaField(
description="Working directory for Claude Code to operate in.",
default="/home/user",
advanced=True,
)
# Session/continuation support
session_id: str = SchemaField(
description=(
"Session ID to resume a previous conversation. "
"Leave empty for a new conversation. "
"Use the session_id from a previous run to continue that conversation."
),
default="",
advanced=True,
)
sandbox_id: str = SchemaField(
description=(
"Sandbox ID to reconnect to an existing sandbox. "
"Required when resuming a session (along with session_id). "
"Use the sandbox_id from a previous run where dispose_sandbox was False."
),
default="",
advanced=True,
)
conversation_history: str = SchemaField(
description=(
"Previous conversation history to continue from. "
"Use this to restore context on a fresh sandbox if the previous one timed out. "
"Pass the conversation_history output from a previous run."
),
default="",
advanced=True,
)
dispose_sandbox: bool = SchemaField(
description=(
"Whether to dispose of the sandbox immediately after execution. "
"Set to False if you want to continue the conversation later "
"(you'll need both sandbox_id and session_id from the output)."
),
default=True,
advanced=True,
)
class FileOutput(BaseModel):
"""A file extracted from the sandbox."""
path: str
relative_path: str # Path relative to working directory (for GitHub, etc.)
name: str
content: str
class Output(BlockSchemaOutput):
response: str = SchemaField(
description="The output/response from Claude Code execution"
)
files: list["ClaudeCodeBlock.FileOutput"] = SchemaField(
description=(
"List of files created/modified by Claude Code. "
"Each file has 'path', 'name', and 'content' fields."
)
)
conversation_history: str = SchemaField(
description=(
"Full conversation history including this turn. "
"Pass this to conversation_history input to continue on a fresh sandbox "
"if the previous sandbox timed out."
)
)
session_id: str = SchemaField(
description=(
"Session ID for this conversation. "
"Pass this back along with sandbox_id to continue the conversation."
)
)
sandbox_id: str = SchemaField(
description=(
"ID of the sandbox instance. "
"Pass this back along with session_id to continue the conversation "
"(only available if dispose_sandbox was False)."
)
)
def __init__(self):
super().__init__(
id="4e34f4a5-9b89-4326-ba77-2dd6750b7194",
description=(
"Execute tasks using Claude Code in an E2B sandbox. "
"Claude Code can create files, install tools, run commands, "
"and perform complex coding tasks autonomously."
),
categories={BlockCategory.DEVELOPER_TOOLS, BlockCategory.AI},
input_schema=ClaudeCodeBlock.Input,
output_schema=ClaudeCodeBlock.Output,
test_credentials={
"e2b_credentials": TEST_E2B_CREDENTIALS,
"anthropic_credentials": TEST_ANTHROPIC_CREDENTIALS,
},
test_input={
"e2b_credentials": TEST_E2B_CREDENTIALS_INPUT,
"anthropic_credentials": TEST_ANTHROPIC_CREDENTIALS_INPUT,
"prompt": "Create a hello world HTML file",
"timeout": 300,
"setup_commands": [],
"working_directory": "/home/user",
"session_id": "",
"sandbox_id": "",
"conversation_history": "",
"dispose_sandbox": True,
},
test_output=[
("response", "Created index.html with hello world content"),
(
"files",
[
{
"path": "/home/user/index.html",
"relative_path": "index.html",
"name": "index.html",
"content": "<html>Hello World</html>",
}
],
),
(
"conversation_history",
"User: Create a hello world HTML file\n"
"Claude: Created index.html with hello world content",
),
("session_id", str),
],
test_mock={
"execute_claude_code": lambda *args, **kwargs: (
"Created index.html with hello world content", # response
[
ClaudeCodeBlock.FileOutput(
path="/home/user/index.html",
relative_path="index.html",
name="index.html",
content="<html>Hello World</html>",
)
], # files
"User: Create a hello world HTML file\n"
"Claude: Created index.html with hello world content", # conversation_history
"test-session-id", # session_id
"sandbox_id", # sandbox_id
),
},
)
async def execute_claude_code(
self,
e2b_api_key: str,
anthropic_api_key: str,
prompt: str,
timeout: int,
setup_commands: list[str],
working_directory: str,
session_id: str,
existing_sandbox_id: str,
conversation_history: str,
dispose_sandbox: bool,
) -> tuple[str, list["ClaudeCodeBlock.FileOutput"], str, str, str]:
"""
Execute Claude Code in an E2B sandbox.
Returns:
Tuple of (response, files, conversation_history, session_id, sandbox_id)
"""
import json
import uuid
# Validate that sandbox_id is provided when resuming a session
if session_id and not existing_sandbox_id:
raise ValueError(
"sandbox_id is required when resuming a session with session_id. "
"The session state is stored in the original sandbox. "
"If the sandbox has timed out, use conversation_history instead "
"to restore context on a fresh sandbox."
)
sandbox = None
try:
# Either reconnect to existing sandbox or create a new one
if existing_sandbox_id:
# Reconnect to existing sandbox for conversation continuation
sandbox = await BaseAsyncSandbox.connect(
sandbox_id=existing_sandbox_id,
api_key=e2b_api_key,
)
else:
# Create new sandbox
sandbox = await BaseAsyncSandbox.create(
template=self.DEFAULT_TEMPLATE,
api_key=e2b_api_key,
timeout=timeout,
envs={"ANTHROPIC_API_KEY": anthropic_api_key},
)
# Install Claude Code from npm (ensures we get the latest version)
install_result = await sandbox.commands.run(
"npm install -g @anthropic-ai/claude-code@latest",
timeout=120, # 2 min timeout for install
)
if install_result.exit_code != 0:
raise Exception(
f"Failed to install Claude Code: {install_result.stderr}"
)
# Run any user-provided setup commands
for cmd in setup_commands:
setup_result = await sandbox.commands.run(cmd)
if setup_result.exit_code != 0:
raise Exception(
f"Setup command failed: {cmd}\n"
f"Exit code: {setup_result.exit_code}\n"
f"Stdout: {setup_result.stdout}\n"
f"Stderr: {setup_result.stderr}"
)
# Generate or use provided session ID
current_session_id = session_id if session_id else str(uuid.uuid4())
# Build base Claude flags
base_flags = "-p --dangerously-skip-permissions --output-format json"
# Add conversation history context if provided (for fresh sandbox continuation)
history_flag = ""
if conversation_history and not session_id:
# Inject previous conversation as context via system prompt
# Use consistent escaping via _escape_prompt helper
escaped_history = self._escape_prompt(
f"Previous conversation context: {conversation_history}"
)
history_flag = f" --append-system-prompt {escaped_history}"
# Build Claude command based on whether we're resuming or starting new
# Use shlex.quote for working_directory and session IDs to prevent injection
safe_working_dir = shlex.quote(working_directory)
if session_id:
# Resuming existing session (sandbox still alive)
safe_session_id = shlex.quote(session_id)
claude_command = (
f"cd {safe_working_dir} && "
f"echo {self._escape_prompt(prompt)} | "
f"claude --resume {safe_session_id} {base_flags}"
)
else:
# New session with specific ID
safe_current_session_id = shlex.quote(current_session_id)
claude_command = (
f"cd {safe_working_dir} && "
f"echo {self._escape_prompt(prompt)} | "
f"claude --session-id {safe_current_session_id} {base_flags}{history_flag}"
)
result = await sandbox.commands.run(
claude_command,
timeout=0, # No command timeout - let sandbox timeout handle it
)
# Check for command failure
if result.exit_code != 0:
error_msg = result.stderr or result.stdout or "Unknown error"
raise Exception(
f"Claude Code command failed with exit code {result.exit_code}:\n"
f"{error_msg}"
)
raw_output = result.stdout or ""
sandbox_id = sandbox.sandbox_id
# Parse JSON output to extract response and build conversation history
response = ""
new_conversation_history = conversation_history or ""
try:
# The JSON output contains the result
output_data = json.loads(raw_output)
response = output_data.get("result", raw_output)
# Build conversation history entry
turn_entry = f"User: {prompt}\nClaude: {response}"
if new_conversation_history:
new_conversation_history = (
f"{new_conversation_history}\n\n{turn_entry}"
)
else:
new_conversation_history = turn_entry
except json.JSONDecodeError:
# If not valid JSON, use raw output
response = raw_output
turn_entry = f"User: {prompt}\nClaude: {response}"
if new_conversation_history:
new_conversation_history = (
f"{new_conversation_history}\n\n{turn_entry}"
)
else:
new_conversation_history = turn_entry
# Extract files from the working directory
files = await self._extract_files(sandbox, working_directory)
return (
response,
files,
new_conversation_history,
current_session_id,
sandbox_id,
)
finally:
if dispose_sandbox and sandbox:
await sandbox.kill()
async def _extract_files(
self,
sandbox: BaseAsyncSandbox,
working_directory: str,
) -> list["ClaudeCodeBlock.FileOutput"]:
"""
Extract files from the sandbox working directory.
Returns:
List of FileOutput objects with path, name, and content
"""
files: list[ClaudeCodeBlock.FileOutput] = []
# Text file extensions we can safely read
text_extensions = {
".txt",
".md",
".html",
".htm",
".css",
".js",
".ts",
".jsx",
".tsx",
".json",
".xml",
".yaml",
".yml",
".toml",
".ini",
".cfg",
".conf",
".py",
".rb",
".php",
".java",
".c",
".cpp",
".h",
".hpp",
".cs",
".go",
".rs",
".swift",
".kt",
".scala",
".sh",
".bash",
".zsh",
".sql",
".graphql",
".env",
".gitignore",
".dockerfile",
"Dockerfile",
".vue",
".svelte",
".astro",
".mdx",
".rst",
".tex",
".csv",
".log",
}
try:
# List files recursively using find command
# Exclude node_modules and .git directories, but allow hidden files
# like .env and .gitignore (they're filtered by text_extensions later)
safe_working_dir = shlex.quote(working_directory)
find_result = await sandbox.commands.run(
f"find {safe_working_dir} -type f "
f"-not -path '*/node_modules/*' "
f"-not -path '*/.git/*' "
f"2>/dev/null | head -100"
)
if find_result.stdout:
for file_path in find_result.stdout.strip().split("\n"):
if not file_path:
continue
# Check if it's a text file we can read
is_text = any(
file_path.endswith(ext) for ext in text_extensions
) or file_path.endswith("Dockerfile")
if is_text:
try:
content = await sandbox.files.read(file_path)
# Handle bytes or string
if isinstance(content, bytes):
content = content.decode("utf-8", errors="replace")
# Extract filename from path
file_name = file_path.split("/")[-1]
# Calculate relative path by stripping working directory
relative_path = file_path
if file_path.startswith(working_directory):
relative_path = file_path[len(working_directory) :]
# Remove leading slash if present
if relative_path.startswith("/"):
relative_path = relative_path[1:]
files.append(
ClaudeCodeBlock.FileOutput(
path=file_path,
relative_path=relative_path,
name=file_name,
content=content,
)
)
except Exception:
# Skip files that can't be read
pass
except Exception:
# If file extraction fails, return empty results
pass
return files
def _escape_prompt(self, prompt: str) -> str:
"""Escape the prompt for safe shell execution."""
# Use single quotes and escape any single quotes in the prompt
escaped = prompt.replace("'", "'\"'\"'")
return f"'{escaped}'"
async def run(
self,
input_data: Input,
*,
e2b_credentials: APIKeyCredentials,
anthropic_credentials: APIKeyCredentials,
**kwargs,
) -> BlockOutput:
try:
(
response,
files,
conversation_history,
session_id,
sandbox_id,
) = await self.execute_claude_code(
e2b_api_key=e2b_credentials.api_key.get_secret_value(),
anthropic_api_key=anthropic_credentials.api_key.get_secret_value(),
prompt=input_data.prompt,
timeout=input_data.timeout,
setup_commands=input_data.setup_commands,
working_directory=input_data.working_directory,
session_id=input_data.session_id,
existing_sandbox_id=input_data.sandbox_id,
conversation_history=input_data.conversation_history,
dispose_sandbox=input_data.dispose_sandbox,
)
yield "response", response
# Always yield files (empty list if none) to match Output schema
yield "files", [f.model_dump() for f in files]
# Always yield conversation_history so user can restore context on fresh sandbox
yield "conversation_history", conversation_history
# Always yield session_id so user can continue conversation
yield "session_id", session_id
if not input_data.dispose_sandbox and sandbox_id:
yield "sandbox_id", sandbox_id
except Exception as e:
yield "error", str(e)

View File

@@ -182,10 +182,13 @@ class DataForSeoRelatedKeywordsBlock(Block):
if results and len(results) > 0:
# results is a list, get the first element
first_result = results[0] if isinstance(results, list) else results
# Handle missing key, null value, or valid list value
if isinstance(first_result, dict):
items = first_result.get("items") or []
else:
items = (
first_result.get("items", [])
if isinstance(first_result, dict)
else []
)
# Ensure items is never None
if items is None:
items = []
for item in items:
# Extract keyword_data from the item

File diff suppressed because it is too large Load Diff

View File

@@ -1,184 +0,0 @@
"""
Shared helpers for Human-In-The-Loop (HITL) review functionality.
Used by both the dedicated HumanInTheLoopBlock and blocks that require human review.
"""
import logging
from typing import Any, Optional
from prisma.enums import ReviewStatus
from pydantic import BaseModel
from backend.data.execution import ExecutionContext, ExecutionStatus
from backend.data.human_review import ReviewResult
from backend.executor.manager import async_update_node_execution_status
from backend.util.clients import get_database_manager_async_client
logger = logging.getLogger(__name__)
class ReviewDecision(BaseModel):
"""Result of a review decision."""
should_proceed: bool
message: str
review_result: ReviewResult
class HITLReviewHelper:
"""Helper class for Human-In-The-Loop review operations."""
@staticmethod
async def get_or_create_human_review(**kwargs) -> Optional[ReviewResult]:
"""Create or retrieve a human review from the database."""
return await get_database_manager_async_client().get_or_create_human_review(
**kwargs
)
@staticmethod
async def update_node_execution_status(**kwargs) -> None:
"""Update the execution status of a node."""
await async_update_node_execution_status(
db_client=get_database_manager_async_client(), **kwargs
)
@staticmethod
async def update_review_processed_status(
node_exec_id: str, processed: bool
) -> None:
"""Update the processed status of a review."""
return await get_database_manager_async_client().update_review_processed_status(
node_exec_id, processed
)
@staticmethod
async def _handle_review_request(
input_data: Any,
user_id: str,
node_exec_id: str,
graph_exec_id: str,
graph_id: str,
graph_version: int,
execution_context: ExecutionContext,
block_name: str = "Block",
editable: bool = False,
) -> Optional[ReviewResult]:
"""
Handle a review request for a block that requires human review.
Args:
input_data: The input data to be reviewed
user_id: ID of the user requesting the review
node_exec_id: ID of the node execution
graph_exec_id: ID of the graph execution
graph_id: ID of the graph
graph_version: Version of the graph
execution_context: Current execution context
block_name: Name of the block requesting review
editable: Whether the reviewer can edit the data
Returns:
ReviewResult if review is complete, None if waiting for human input
Raises:
Exception: If review creation or status update fails
"""
# Skip review if safe mode is disabled - return auto-approved result
if not execution_context.safe_mode:
logger.info(
f"Block {block_name} skipping review for node {node_exec_id} - safe mode disabled"
)
return ReviewResult(
data=input_data,
status=ReviewStatus.APPROVED,
message="Auto-approved (safe mode disabled)",
processed=True,
node_exec_id=node_exec_id,
)
result = await HITLReviewHelper.get_or_create_human_review(
user_id=user_id,
node_exec_id=node_exec_id,
graph_exec_id=graph_exec_id,
graph_id=graph_id,
graph_version=graph_version,
input_data=input_data,
message=f"Review required for {block_name} execution",
editable=editable,
)
if result is None:
logger.info(
f"Block {block_name} pausing execution for node {node_exec_id} - awaiting human review"
)
await HITLReviewHelper.update_node_execution_status(
exec_id=node_exec_id,
status=ExecutionStatus.REVIEW,
)
return None # Signal that execution should pause
# Mark review as processed if not already done
if not result.processed:
await HITLReviewHelper.update_review_processed_status(
node_exec_id=node_exec_id, processed=True
)
return result
@staticmethod
async def handle_review_decision(
input_data: Any,
user_id: str,
node_exec_id: str,
graph_exec_id: str,
graph_id: str,
graph_version: int,
execution_context: ExecutionContext,
block_name: str = "Block",
editable: bool = False,
) -> Optional[ReviewDecision]:
"""
Handle a review request and return the decision in a single call.
Args:
input_data: The input data to be reviewed
user_id: ID of the user requesting the review
node_exec_id: ID of the node execution
graph_exec_id: ID of the graph execution
graph_id: ID of the graph
graph_version: Version of the graph
execution_context: Current execution context
block_name: Name of the block requesting review
editable: Whether the reviewer can edit the data
Returns:
ReviewDecision if review is complete (approved/rejected),
None if execution should pause (awaiting review)
"""
review_result = await HITLReviewHelper._handle_review_request(
input_data=input_data,
user_id=user_id,
node_exec_id=node_exec_id,
graph_exec_id=graph_exec_id,
graph_id=graph_id,
graph_version=graph_version,
execution_context=execution_context,
block_name=block_name,
editable=editable,
)
if review_result is None:
# Still awaiting review - return None to pause execution
return None
# Review is complete, determine outcome
should_proceed = review_result.status == ReviewStatus.APPROVED
message = review_result.message or (
"Execution approved by reviewer"
if should_proceed
else "Execution rejected by reviewer"
)
return ReviewDecision(
should_proceed=should_proceed, message=message, review_result=review_result
)

View File

@@ -3,7 +3,6 @@ from typing import Any
from prisma.enums import ReviewStatus
from backend.blocks.helpers.review import HITLReviewHelper
from backend.data.block import (
Block,
BlockCategory,
@@ -12,9 +11,11 @@ from backend.data.block import (
BlockSchemaOutput,
BlockType,
)
from backend.data.execution import ExecutionContext
from backend.data.execution import ExecutionContext, ExecutionStatus
from backend.data.human_review import ReviewResult
from backend.data.model import SchemaField
from backend.executor.manager import async_update_node_execution_status
from backend.util.clients import get_database_manager_async_client
logger = logging.getLogger(__name__)
@@ -71,26 +72,32 @@ class HumanInTheLoopBlock(Block):
("approved_data", {"name": "John Doe", "age": 30}),
],
test_mock={
"handle_review_decision": lambda **kwargs: type(
"ReviewDecision",
(),
{
"should_proceed": True,
"message": "Test approval message",
"review_result": ReviewResult(
data={"name": "John Doe", "age": 30},
status=ReviewStatus.APPROVED,
message="",
processed=False,
node_exec_id="test-node-exec-id",
),
},
)(),
"get_or_create_human_review": lambda *_args, **_kwargs: ReviewResult(
data={"name": "John Doe", "age": 30},
status=ReviewStatus.APPROVED,
message="",
processed=False,
node_exec_id="test-node-exec-id",
),
"update_node_execution_status": lambda *_args, **_kwargs: None,
"update_review_processed_status": lambda *_args, **_kwargs: None,
},
)
async def handle_review_decision(self, **kwargs):
return await HITLReviewHelper.handle_review_decision(**kwargs)
async def get_or_create_human_review(self, **kwargs):
return await get_database_manager_async_client().get_or_create_human_review(
**kwargs
)
async def update_node_execution_status(self, **kwargs):
return await async_update_node_execution_status(
db_client=get_database_manager_async_client(), **kwargs
)
async def update_review_processed_status(self, node_exec_id: str, processed: bool):
return await get_database_manager_async_client().update_review_processed_status(
node_exec_id, processed
)
async def run(
self,
@@ -102,7 +109,7 @@ class HumanInTheLoopBlock(Block):
graph_id: str,
graph_version: int,
execution_context: ExecutionContext,
**_kwargs,
**kwargs,
) -> BlockOutput:
if not execution_context.safe_mode:
logger.info(
@@ -112,28 +119,48 @@ class HumanInTheLoopBlock(Block):
yield "review_message", "Auto-approved (safe mode disabled)"
return
decision = await self.handle_review_decision(
input_data=input_data.data,
user_id=user_id,
node_exec_id=node_exec_id,
graph_exec_id=graph_exec_id,
graph_id=graph_id,
graph_version=graph_version,
execution_context=execution_context,
block_name=self.name,
editable=input_data.editable,
)
try:
result = await self.get_or_create_human_review(
user_id=user_id,
node_exec_id=node_exec_id,
graph_exec_id=graph_exec_id,
graph_id=graph_id,
graph_version=graph_version,
input_data=input_data.data,
message=input_data.name,
editable=input_data.editable,
)
except Exception as e:
logger.error(f"Error in HITL block for node {node_exec_id}: {str(e)}")
raise
if decision is None:
return
if result is None:
logger.info(
f"HITL block pausing execution for node {node_exec_id} - awaiting human review"
)
try:
await self.update_node_execution_status(
exec_id=node_exec_id,
status=ExecutionStatus.REVIEW,
)
return
except Exception as e:
logger.error(
f"Failed to update node status for HITL block {node_exec_id}: {str(e)}"
)
raise
status = decision.review_result.status
if status == ReviewStatus.APPROVED:
yield "approved_data", decision.review_result.data
elif status == ReviewStatus.REJECTED:
yield "rejected_data", decision.review_result.data
else:
raise RuntimeError(f"Unexpected review status: {status}")
if not result.processed:
await self.update_review_processed_status(
node_exec_id=node_exec_id, processed=True
)
if decision.message:
yield "review_message", decision.message
if result.status == ReviewStatus.APPROVED:
yield "approved_data", result.data
if result.message:
yield "review_message", result.message
elif result.status == ReviewStatus.REJECTED:
yield "rejected_data", result.data
if result.message:
yield "review_message", result.message

View File

@@ -92,9 +92,8 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
O1 = "o1"
O1_MINI = "o1-mini"
# GPT-5 models
GPT5_2 = "gpt-5.2-2025-12-11"
GPT5_1 = "gpt-5.1-2025-11-13"
GPT5 = "gpt-5-2025-08-07"
GPT5_1 = "gpt-5.1-2025-11-13"
GPT5_MINI = "gpt-5-mini-2025-08-07"
GPT5_NANO = "gpt-5-nano-2025-08-07"
GPT5_CHAT = "gpt-5-chat-latest"
@@ -195,9 +194,8 @@ MODEL_METADATA = {
LlmModel.O1: ModelMetadata("openai", 200000, 100000), # o1-2024-12-17
LlmModel.O1_MINI: ModelMetadata("openai", 128000, 65536), # o1-mini-2024-09-12
# GPT-5 models
LlmModel.GPT5_2: ModelMetadata("openai", 400000, 128000),
LlmModel.GPT5_1: ModelMetadata("openai", 400000, 128000),
LlmModel.GPT5: ModelMetadata("openai", 400000, 128000),
LlmModel.GPT5_1: ModelMetadata("openai", 400000, 128000),
LlmModel.GPT5_MINI: ModelMetadata("openai", 400000, 128000),
LlmModel.GPT5_NANO: ModelMetadata("openai", 400000, 128000),
LlmModel.GPT5_CHAT: ModelMetadata("openai", 400000, 16384),
@@ -305,8 +303,6 @@ MODEL_METADATA = {
LlmModel.V0_1_0_MD: ModelMetadata("v0", 128000, 64000),
}
DEFAULT_LLM_MODEL = LlmModel.GPT5_2
for model in LlmModel:
if model not in MODEL_METADATA:
raise ValueError(f"Missing MODEL_METADATA metadata for model: {model}")
@@ -794,7 +790,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
)
model: LlmModel = SchemaField(
title="LLM Model",
default=DEFAULT_LLM_MODEL,
default=LlmModel.GPT4O,
description="The language model to use for answering the prompt.",
advanced=False,
)
@@ -859,7 +855,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
input_schema=AIStructuredResponseGeneratorBlock.Input,
output_schema=AIStructuredResponseGeneratorBlock.Output,
test_input={
"model": DEFAULT_LLM_MODEL,
"model": LlmModel.GPT4O,
"credentials": TEST_CREDENTIALS_INPUT,
"expected_format": {
"key1": "value1",
@@ -1225,7 +1221,7 @@ class AITextGeneratorBlock(AIBlockBase):
)
model: LlmModel = SchemaField(
title="LLM Model",
default=DEFAULT_LLM_MODEL,
default=LlmModel.GPT4O,
description="The language model to use for answering the prompt.",
advanced=False,
)
@@ -1321,7 +1317,7 @@ class AITextSummarizerBlock(AIBlockBase):
)
model: LlmModel = SchemaField(
title="LLM Model",
default=DEFAULT_LLM_MODEL,
default=LlmModel.GPT4O,
description="The language model to use for summarizing the text.",
)
focus: str = SchemaField(
@@ -1538,7 +1534,7 @@ class AIConversationBlock(AIBlockBase):
)
model: LlmModel = SchemaField(
title="LLM Model",
default=DEFAULT_LLM_MODEL,
default=LlmModel.GPT4O,
description="The language model to use for the conversation.",
)
credentials: AICredentials = AICredentialsField()
@@ -1576,7 +1572,7 @@ class AIConversationBlock(AIBlockBase):
},
{"role": "user", "content": "Where was it played?"},
],
"model": DEFAULT_LLM_MODEL,
"model": LlmModel.GPT4O,
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
@@ -1639,7 +1635,7 @@ class AIListGeneratorBlock(AIBlockBase):
)
model: LlmModel = SchemaField(
title="LLM Model",
default=DEFAULT_LLM_MODEL,
default=LlmModel.GPT4O,
description="The language model to use for generating the list.",
advanced=True,
)
@@ -1696,7 +1692,7 @@ class AIListGeneratorBlock(AIBlockBase):
"drawing explorers to uncover its mysteries. Each planet showcases the limitless possibilities of "
"fictional worlds."
),
"model": DEFAULT_LLM_MODEL,
"model": LlmModel.GPT4O,
"credentials": TEST_CREDENTIALS_INPUT,
"max_retries": 3,
"force_json_output": False,

File diff suppressed because it is too large Load Diff

View File

@@ -18,7 +18,6 @@ from backend.data.model import (
SchemaField,
)
from backend.integrations.providers import ProviderName
from backend.util.request import DEFAULT_USER_AGENT
class GetWikipediaSummaryBlock(Block, GetRequest):
@@ -40,27 +39,17 @@ class GetWikipediaSummaryBlock(Block, GetRequest):
output_schema=GetWikipediaSummaryBlock.Output,
test_input={"topic": "Artificial Intelligence"},
test_output=("summary", "summary content"),
test_mock={
"get_request": lambda url, headers, json: {"extract": "summary content"}
},
test_mock={"get_request": lambda url, json: {"extract": "summary content"}},
)
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
topic = input_data.topic
# URL-encode the topic to handle spaces and special characters
encoded_topic = quote(topic, safe="")
url = f"https://en.wikipedia.org/api/rest_v1/page/summary/{encoded_topic}"
# Set headers per Wikimedia robot policy (https://w.wiki/4wJS)
# - User-Agent: Required, must identify the bot
# - Accept-Encoding: gzip recommended to reduce bandwidth
headers = {
"User-Agent": DEFAULT_USER_AGENT,
"Accept-Encoding": "gzip, deflate",
}
url = f"https://en.wikipedia.org/api/rest_v1/page/summary/{topic}"
# Note: User-Agent is now automatically set by the request library
# to comply with Wikimedia's robot policy (https://w.wiki/4wJS)
try:
response = await self.get_request(url, headers=headers, json=True)
response = await self.get_request(url, json=True)
if "extract" not in response:
raise ValueError(f"Unable to parse Wikipedia response: {response}")
yield "summary", response["extract"]

View File

@@ -226,7 +226,7 @@ class SmartDecisionMakerBlock(Block):
)
model: llm.LlmModel = SchemaField(
title="LLM Model",
default=llm.DEFAULT_LLM_MODEL,
default=llm.LlmModel.GPT4O,
description="The language model to use for answering the prompt.",
advanced=False,
)
@@ -391,12 +391,8 @@ class SmartDecisionMakerBlock(Block):
"""
block = sink_node.block
# Use custom name from node metadata if set, otherwise fall back to block.name
custom_name = sink_node.metadata.get("customized_name")
tool_name = custom_name if custom_name else block.name
tool_function: dict[str, Any] = {
"name": SmartDecisionMakerBlock.cleanup(tool_name),
"name": SmartDecisionMakerBlock.cleanup(block.name),
"description": block.description,
}
sink_block_input_schema = block.input_schema
@@ -493,24 +489,14 @@ class SmartDecisionMakerBlock(Block):
f"Sink graph metadata not found: {graph_id} {graph_version}"
)
# Use custom name from node metadata if set, otherwise fall back to graph name
custom_name = sink_node.metadata.get("customized_name")
tool_name = custom_name if custom_name else sink_graph_meta.name
tool_function: dict[str, Any] = {
"name": SmartDecisionMakerBlock.cleanup(tool_name),
"name": SmartDecisionMakerBlock.cleanup(sink_graph_meta.name),
"description": sink_graph_meta.description,
}
properties = {}
field_mapping = {}
for link in links:
field_name = link.sink_name
clean_field_name = SmartDecisionMakerBlock.cleanup(field_name)
field_mapping[clean_field_name] = field_name
sink_block_input_schema = sink_node.input_default["input_schema"]
sink_block_properties = sink_block_input_schema.get("properties", {}).get(
link.sink_name, {}
@@ -520,7 +506,7 @@ class SmartDecisionMakerBlock(Block):
if "description" in sink_block_properties
else f"The {link.sink_name} of the tool"
)
properties[clean_field_name] = {
properties[link.sink_name] = {
"type": "string",
"description": description,
"default": json.dumps(sink_block_properties.get("default", None)),
@@ -533,7 +519,7 @@ class SmartDecisionMakerBlock(Block):
"strict": True,
}
tool_function["_field_mapping"] = field_mapping
# Store node info for later use in output processing
tool_function["_sink_node_id"] = sink_node.id
return {"type": "function", "function": tool_function}
@@ -989,28 +975,10 @@ class SmartDecisionMakerBlock(Block):
graph_version: int,
execution_context: ExecutionContext,
execution_processor: "ExecutionProcessor",
nodes_to_skip: set[str] | None = None,
**kwargs,
) -> BlockOutput:
tool_functions = await self._create_tool_node_signatures(node_id)
original_tool_count = len(tool_functions)
# Filter out tools for nodes that should be skipped (e.g., missing optional credentials)
if nodes_to_skip:
tool_functions = [
tf
for tf in tool_functions
if tf.get("function", {}).get("_sink_node_id") not in nodes_to_skip
]
# Only raise error if we had tools but they were all filtered out
if original_tool_count > 0 and not tool_functions:
raise ValueError(
"No available tools to execute - all downstream nodes are unavailable "
"(possibly due to missing optional credentials)"
)
yield "tool_functions", json.dumps(tool_functions)
conversation_history = input_data.conversation_history or []
@@ -1161,9 +1129,8 @@ class SmartDecisionMakerBlock(Block):
original_field_name = field_mapping.get(clean_arg_name, clean_arg_name)
arg_value = tool_args.get(clean_arg_name)
# Use original_field_name directly (not sanitized) to match link sink_name
# The field_mapping already translates from LLM's cleaned names to original names
emit_key = f"tools_^_{sink_node_id}_~_{original_field_name}"
sanitized_arg_name = self.cleanup(original_field_name)
emit_key = f"tools_^_{sink_node_id}_~_{sanitized_arg_name}"
logger.debug(
"[SmartDecisionMakerBlock|geid:%s|neid:%s] emit %s",

View File

@@ -196,15 +196,6 @@ class TestXMLParserBlockSecurity:
async for _ in block.run(XMLParserBlock.Input(input_xml=large_xml)):
pass
async def test_rejects_text_outside_root(self):
"""Ensure parser surfaces readable errors for invalid root text."""
block = XMLParserBlock()
invalid_xml = "<root><child>value</child></root> trailing"
with pytest.raises(ValueError, match="text outside the root element"):
async for _ in block.run(XMLParserBlock.Input(input_xml=invalid_xml)):
pass
class TestStoreMediaFileSecurity:
"""Test file storage security limits."""

View File

@@ -28,7 +28,7 @@ class TestLLMStatsTracking:
response = await llm.llm_call(
credentials=llm.TEST_CREDENTIALS,
llm_model=llm.DEFAULT_LLM_MODEL,
llm_model=llm.LlmModel.GPT4O,
prompt=[{"role": "user", "content": "Hello"}],
max_tokens=100,
)
@@ -65,7 +65,7 @@ class TestLLMStatsTracking:
input_data = llm.AIStructuredResponseGeneratorBlock.Input(
prompt="Test prompt",
expected_format={"key1": "desc1", "key2": "desc2"},
model=llm.DEFAULT_LLM_MODEL,
model=llm.LlmModel.GPT4O,
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore # type: ignore
)
@@ -109,7 +109,7 @@ class TestLLMStatsTracking:
# Run the block
input_data = llm.AITextGeneratorBlock.Input(
prompt="Generate text",
model=llm.DEFAULT_LLM_MODEL,
model=llm.LlmModel.GPT4O,
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
)
@@ -170,7 +170,7 @@ class TestLLMStatsTracking:
input_data = llm.AIStructuredResponseGeneratorBlock.Input(
prompt="Test prompt",
expected_format={"key1": "desc1", "key2": "desc2"},
model=llm.DEFAULT_LLM_MODEL,
model=llm.LlmModel.GPT4O,
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
retry=2,
)
@@ -228,7 +228,7 @@ class TestLLMStatsTracking:
input_data = llm.AITextSummarizerBlock.Input(
text=long_text,
model=llm.DEFAULT_LLM_MODEL,
model=llm.LlmModel.GPT4O,
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
max_tokens=100, # Small chunks
chunk_overlap=10,
@@ -299,7 +299,7 @@ class TestLLMStatsTracking:
# Test with very short text (should only need 1 chunk + 1 final summary)
input_data = llm.AITextSummarizerBlock.Input(
text="This is a short text.",
model=llm.DEFAULT_LLM_MODEL,
model=llm.LlmModel.GPT4O,
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
max_tokens=1000, # Large enough to avoid chunking
)
@@ -346,7 +346,7 @@ class TestLLMStatsTracking:
{"role": "assistant", "content": "Hi there!"},
{"role": "user", "content": "How are you?"},
],
model=llm.DEFAULT_LLM_MODEL,
model=llm.LlmModel.GPT4O,
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
)
@@ -387,7 +387,7 @@ class TestLLMStatsTracking:
# Run the block
input_data = llm.AIListGeneratorBlock.Input(
focus="test items",
model=llm.DEFAULT_LLM_MODEL,
model=llm.LlmModel.GPT4O,
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
max_retries=3,
)
@@ -469,7 +469,7 @@ class TestLLMStatsTracking:
input_data = llm.AIStructuredResponseGeneratorBlock.Input(
prompt="Test",
expected_format={"result": "desc"},
model=llm.DEFAULT_LLM_MODEL,
model=llm.LlmModel.GPT4O,
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
)
@@ -513,7 +513,7 @@ class TestAITextSummarizerValidation:
# Create input data
input_data = llm.AITextSummarizerBlock.Input(
text="Some text to summarize",
model=llm.DEFAULT_LLM_MODEL,
model=llm.LlmModel.GPT4O,
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
style=llm.SummaryStyle.BULLET_POINTS,
)
@@ -558,7 +558,7 @@ class TestAITextSummarizerValidation:
# Create input data
input_data = llm.AITextSummarizerBlock.Input(
text="Some text to summarize",
model=llm.DEFAULT_LLM_MODEL,
model=llm.LlmModel.GPT4O,
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
style=llm.SummaryStyle.BULLET_POINTS,
max_tokens=1000,
@@ -593,7 +593,7 @@ class TestAITextSummarizerValidation:
# Create input data
input_data = llm.AITextSummarizerBlock.Input(
text="Some text to summarize",
model=llm.DEFAULT_LLM_MODEL,
model=llm.LlmModel.GPT4O,
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
)
@@ -623,7 +623,7 @@ class TestAITextSummarizerValidation:
# Create input data
input_data = llm.AITextSummarizerBlock.Input(
text="Some text to summarize",
model=llm.DEFAULT_LLM_MODEL,
model=llm.LlmModel.GPT4O,
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
max_tokens=1000,
)
@@ -654,7 +654,7 @@ class TestAITextSummarizerValidation:
# Create input data
input_data = llm.AITextSummarizerBlock.Input(
text="Some text to summarize",
model=llm.DEFAULT_LLM_MODEL,
model=llm.LlmModel.GPT4O,
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
)

View File

@@ -233,7 +233,7 @@ async def test_smart_decision_maker_tracks_llm_stats():
# Create test input
input_data = SmartDecisionMakerBlock.Input(
prompt="Should I continue with this task?",
model=llm_module.DEFAULT_LLM_MODEL,
model=llm_module.LlmModel.GPT4O,
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
agent_mode_max_iterations=0,
)
@@ -335,7 +335,7 @@ async def test_smart_decision_maker_parameter_validation():
input_data = SmartDecisionMakerBlock.Input(
prompt="Search for keywords",
model=llm_module.DEFAULT_LLM_MODEL,
model=llm_module.LlmModel.GPT4O,
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
retry=2, # Set retry to 2 for testing
agent_mode_max_iterations=0,
@@ -402,7 +402,7 @@ async def test_smart_decision_maker_parameter_validation():
input_data = SmartDecisionMakerBlock.Input(
prompt="Search for keywords",
model=llm_module.DEFAULT_LLM_MODEL,
model=llm_module.LlmModel.GPT4O,
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
agent_mode_max_iterations=0,
)
@@ -462,7 +462,7 @@ async def test_smart_decision_maker_parameter_validation():
input_data = SmartDecisionMakerBlock.Input(
prompt="Search for keywords",
model=llm_module.DEFAULT_LLM_MODEL,
model=llm_module.LlmModel.GPT4O,
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
agent_mode_max_iterations=0,
)
@@ -526,7 +526,7 @@ async def test_smart_decision_maker_parameter_validation():
input_data = SmartDecisionMakerBlock.Input(
prompt="Search for keywords",
model=llm_module.DEFAULT_LLM_MODEL,
model=llm_module.LlmModel.GPT4O,
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
agent_mode_max_iterations=0,
)
@@ -648,7 +648,7 @@ async def test_smart_decision_maker_raw_response_conversion():
input_data = SmartDecisionMakerBlock.Input(
prompt="Test prompt",
model=llm_module.DEFAULT_LLM_MODEL,
model=llm_module.LlmModel.GPT4O,
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
retry=2,
agent_mode_max_iterations=0,
@@ -722,7 +722,7 @@ async def test_smart_decision_maker_raw_response_conversion():
):
input_data = SmartDecisionMakerBlock.Input(
prompt="Simple prompt",
model=llm_module.DEFAULT_LLM_MODEL,
model=llm_module.LlmModel.GPT4O,
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
agent_mode_max_iterations=0,
)
@@ -778,7 +778,7 @@ async def test_smart_decision_maker_raw_response_conversion():
):
input_data = SmartDecisionMakerBlock.Input(
prompt="Another test",
model=llm_module.DEFAULT_LLM_MODEL,
model=llm_module.LlmModel.GPT4O,
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
agent_mode_max_iterations=0,
)
@@ -931,7 +931,7 @@ async def test_smart_decision_maker_agent_mode():
# Test agent mode with max_iterations = 3
input_data = SmartDecisionMakerBlock.Input(
prompt="Complete this task using tools",
model=llm_module.DEFAULT_LLM_MODEL,
model=llm_module.LlmModel.GPT4O,
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
agent_mode_max_iterations=3, # Enable agent mode with 3 max iterations
)
@@ -1020,7 +1020,7 @@ async def test_smart_decision_maker_traditional_mode_default():
# Test default behavior (traditional mode)
input_data = SmartDecisionMakerBlock.Input(
prompt="Test prompt",
model=llm_module.DEFAULT_LLM_MODEL,
model=llm_module.LlmModel.GPT4O,
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
agent_mode_max_iterations=0, # Traditional mode
)
@@ -1057,153 +1057,3 @@ async def test_smart_decision_maker_traditional_mode_default():
) # Should yield individual tool parameters
assert "tools_^_test-sink-node-id_~_max_keyword_difficulty" in outputs
assert "conversations" in outputs
@pytest.mark.asyncio
async def test_smart_decision_maker_uses_customized_name_for_blocks():
"""Test that SmartDecisionMakerBlock uses customized_name from node metadata for tool names."""
from unittest.mock import MagicMock
from backend.blocks.basic import StoreValueBlock
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
from backend.data.graph import Link, Node
# Create a mock node with customized_name in metadata
mock_node = MagicMock(spec=Node)
mock_node.id = "test-node-id"
mock_node.block_id = StoreValueBlock().id
mock_node.metadata = {"customized_name": "My Custom Tool Name"}
mock_node.block = StoreValueBlock()
# Create a mock link
mock_link = MagicMock(spec=Link)
mock_link.sink_name = "input"
# Call the function directly
result = await SmartDecisionMakerBlock._create_block_function_signature(
mock_node, [mock_link]
)
# Verify the tool name uses the customized name (cleaned up)
assert result["type"] == "function"
assert result["function"]["name"] == "my_custom_tool_name" # Cleaned version
assert result["function"]["_sink_node_id"] == "test-node-id"
@pytest.mark.asyncio
async def test_smart_decision_maker_falls_back_to_block_name():
"""Test that SmartDecisionMakerBlock falls back to block.name when no customized_name."""
from unittest.mock import MagicMock
from backend.blocks.basic import StoreValueBlock
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
from backend.data.graph import Link, Node
# Create a mock node without customized_name
mock_node = MagicMock(spec=Node)
mock_node.id = "test-node-id"
mock_node.block_id = StoreValueBlock().id
mock_node.metadata = {} # No customized_name
mock_node.block = StoreValueBlock()
# Create a mock link
mock_link = MagicMock(spec=Link)
mock_link.sink_name = "input"
# Call the function directly
result = await SmartDecisionMakerBlock._create_block_function_signature(
mock_node, [mock_link]
)
# Verify the tool name uses the block's default name
assert result["type"] == "function"
assert result["function"]["name"] == "storevalueblock" # Default block name cleaned
assert result["function"]["_sink_node_id"] == "test-node-id"
@pytest.mark.asyncio
async def test_smart_decision_maker_uses_customized_name_for_agents():
"""Test that SmartDecisionMakerBlock uses customized_name from metadata for agent nodes."""
from unittest.mock import AsyncMock, MagicMock, patch
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
from backend.data.graph import Link, Node
# Create a mock node with customized_name in metadata
mock_node = MagicMock(spec=Node)
mock_node.id = "test-agent-node-id"
mock_node.metadata = {"customized_name": "My Custom Agent"}
mock_node.input_default = {
"graph_id": "test-graph-id",
"graph_version": 1,
"input_schema": {"properties": {"test_input": {"description": "Test input"}}},
}
# Create a mock link
mock_link = MagicMock(spec=Link)
mock_link.sink_name = "test_input"
# Mock the database client
mock_graph_meta = MagicMock()
mock_graph_meta.name = "Original Agent Name"
mock_graph_meta.description = "Agent description"
mock_db_client = AsyncMock()
mock_db_client.get_graph_metadata.return_value = mock_graph_meta
with patch(
"backend.blocks.smart_decision_maker.get_database_manager_async_client",
return_value=mock_db_client,
):
result = await SmartDecisionMakerBlock._create_agent_function_signature(
mock_node, [mock_link]
)
# Verify the tool name uses the customized name (cleaned up)
assert result["type"] == "function"
assert result["function"]["name"] == "my_custom_agent" # Cleaned version
assert result["function"]["_sink_node_id"] == "test-agent-node-id"
@pytest.mark.asyncio
async def test_smart_decision_maker_agent_falls_back_to_graph_name():
"""Test that agent node falls back to graph name when no customized_name."""
from unittest.mock import AsyncMock, MagicMock, patch
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
from backend.data.graph import Link, Node
# Create a mock node without customized_name
mock_node = MagicMock(spec=Node)
mock_node.id = "test-agent-node-id"
mock_node.metadata = {} # No customized_name
mock_node.input_default = {
"graph_id": "test-graph-id",
"graph_version": 1,
"input_schema": {"properties": {"test_input": {"description": "Test input"}}},
}
# Create a mock link
mock_link = MagicMock(spec=Link)
mock_link.sink_name = "test_input"
# Mock the database client
mock_graph_meta = MagicMock()
mock_graph_meta.name = "Original Agent Name"
mock_graph_meta.description = "Agent description"
mock_db_client = AsyncMock()
mock_db_client.get_graph_metadata.return_value = mock_graph_meta
with patch(
"backend.blocks.smart_decision_maker.get_database_manager_async_client",
return_value=mock_db_client,
):
result = await SmartDecisionMakerBlock._create_agent_function_signature(
mock_node, [mock_link]
)
# Verify the tool name uses the graph's default name
assert result["type"] == "function"
assert result["function"]["name"] == "original_agent_name" # Graph name cleaned
assert result["function"]["_sink_node_id"] == "test-agent-node-id"

View File

@@ -15,7 +15,6 @@ async def test_smart_decision_maker_handles_dynamic_dict_fields():
mock_node.block = CreateDictionaryBlock()
mock_node.block_id = CreateDictionaryBlock().id
mock_node.input_default = {}
mock_node.metadata = {}
# Create mock links with dynamic dictionary fields
mock_links = [
@@ -78,7 +77,6 @@ async def test_smart_decision_maker_handles_dynamic_list_fields():
mock_node.block = AddToListBlock()
mock_node.block_id = AddToListBlock().id
mock_node.input_default = {}
mock_node.metadata = {}
# Create mock links with dynamic list fields
mock_links = [

View File

@@ -44,7 +44,6 @@ async def test_create_block_function_signature_with_dict_fields():
mock_node.block = CreateDictionaryBlock()
mock_node.block_id = CreateDictionaryBlock().id
mock_node.input_default = {}
mock_node.metadata = {}
# Create mock links with dynamic dictionary fields (source sanitized, sink original)
mock_links = [
@@ -107,7 +106,6 @@ async def test_create_block_function_signature_with_list_fields():
mock_node.block = AddToListBlock()
mock_node.block_id = AddToListBlock().id
mock_node.input_default = {}
mock_node.metadata = {}
# Create mock links with dynamic list fields
mock_links = [
@@ -161,7 +159,6 @@ async def test_create_block_function_signature_with_object_fields():
mock_node.block = MatchTextPatternBlock()
mock_node.block_id = MatchTextPatternBlock().id
mock_node.input_default = {}
mock_node.metadata = {}
# Create mock links with dynamic object fields
mock_links = [
@@ -211,13 +208,11 @@ async def test_create_tool_node_signatures():
mock_dict_node.block = CreateDictionaryBlock()
mock_dict_node.block_id = CreateDictionaryBlock().id
mock_dict_node.input_default = {}
mock_dict_node.metadata = {}
mock_list_node = Mock()
mock_list_node.block = AddToListBlock()
mock_list_node.block_id = AddToListBlock().id
mock_list_node.input_default = {}
mock_list_node.metadata = {}
# Mock links with dynamic fields
dict_link1 = Mock(
@@ -378,7 +373,7 @@ async def test_output_yielding_with_dynamic_fields():
input_data = block.input_schema(
prompt="Create a user dictionary",
credentials=llm.TEST_CREDENTIALS_INPUT,
model=llm.DEFAULT_LLM_MODEL,
model=llm.LlmModel.GPT4O,
agent_mode_max_iterations=0, # Use traditional mode to test output yielding
)
@@ -428,7 +423,6 @@ async def test_mixed_regular_and_dynamic_fields():
mock_node.block.name = "TestBlock"
mock_node.block.description = "A test block"
mock_node.block.input_schema = Mock()
mock_node.metadata = {}
# Mock the get_field_schema to return a proper schema for regular fields
def get_field_schema(field_name):
@@ -600,7 +594,7 @@ async def test_validation_errors_dont_pollute_conversation():
input_data = block.input_schema(
prompt="Test prompt",
credentials=llm.TEST_CREDENTIALS_INPUT,
model=llm.DEFAULT_LLM_MODEL,
model=llm.LlmModel.GPT4O,
retry=3, # Allow retries
agent_mode_max_iterations=1,
)

View File

@@ -1,3 +1,3 @@
from .blog import WordPressCreatePostBlock, WordPressGetAllPostsBlock
from .blog import WordPressCreatePostBlock
__all__ = ["WordPressCreatePostBlock", "WordPressGetAllPostsBlock"]
__all__ = ["WordPressCreatePostBlock"]

View File

@@ -161,7 +161,7 @@ async def oauth_exchange_code_for_tokens(
grant_type="authorization_code",
).model_dump(exclude_none=True)
response = await Requests(raise_for_status=False).post(
response = await Requests().post(
f"{WORDPRESS_BASE_URL}oauth2/token",
headers=headers,
data=data,
@@ -205,7 +205,7 @@ async def oauth_refresh_tokens(
grant_type="refresh_token",
).model_dump(exclude_none=True)
response = await Requests(raise_for_status=False).post(
response = await Requests().post(
f"{WORDPRESS_BASE_URL}oauth2/token",
headers=headers,
data=data,
@@ -252,7 +252,7 @@ async def validate_token(
"token": token,
}
response = await Requests(raise_for_status=False).get(
response = await Requests().get(
f"{WORDPRESS_BASE_URL}oauth2/token-info",
params=params,
)
@@ -296,7 +296,7 @@ async def make_api_request(
url = f"{WORDPRESS_BASE_URL.rstrip('/')}{endpoint}"
request_method = getattr(Requests(raise_for_status=False), method.lower())
request_method = getattr(Requests(), method.lower())
response = await request_method(
url,
headers=headers,
@@ -476,7 +476,6 @@ async def create_post(
data["tags"] = ",".join(str(t) for t in data["tags"])
# Make the API request
site = normalize_site(site)
endpoint = f"/rest/v1.1/sites/{site}/posts/new"
headers = {
@@ -484,7 +483,7 @@ async def create_post(
"Content-Type": "application/x-www-form-urlencoded",
}
response = await Requests(raise_for_status=False).post(
response = await Requests().post(
f"{WORDPRESS_BASE_URL.rstrip('/')}{endpoint}",
headers=headers,
data=data,
@@ -500,132 +499,3 @@ async def create_post(
)
error_message = error_data.get("message", response.text)
raise ValueError(f"Failed to create post: {response.status} - {error_message}")
class Post(BaseModel):
"""Response model for individual posts in a posts list response.
This is a simplified version compared to PostResponse, as the list endpoint
returns less detailed information than the create/get single post endpoints.
"""
ID: int
site_ID: int
author: PostAuthor
date: datetime
modified: datetime
title: str
URL: str
short_URL: str
content: str | None = None
excerpt: str | None = None
slug: str
guid: str
status: str
sticky: bool
password: str | None = ""
parent: Union[Dict[str, Any], bool, None] = None
type: str
discussion: Dict[str, Union[str, bool, int]] | None = None
likes_enabled: bool | None = None
sharing_enabled: bool | None = None
like_count: int | None = None
i_like: bool | None = None
is_reblogged: bool | None = None
is_following: bool | None = None
global_ID: str | None = None
featured_image: str | None = None
post_thumbnail: Dict[str, Any] | None = None
format: str | None = None
geo: Union[Dict[str, Any], bool, None] = None
menu_order: int | None = None
page_template: str | None = None
publicize_URLs: List[str] | None = None
terms: Dict[str, Dict[str, Any]] | None = None
tags: Dict[str, Dict[str, Any]] | None = None
categories: Dict[str, Dict[str, Any]] | None = None
attachments: Dict[str, Dict[str, Any]] | None = None
attachment_count: int | None = None
metadata: List[Dict[str, Any]] | None = None
meta: Dict[str, Any] | None = None
capabilities: Dict[str, bool] | None = None
revisions: List[int] | None = None
other_URLs: Dict[str, Any] | None = None
class PostsResponse(BaseModel):
"""Response model for WordPress posts list."""
found: int
posts: List[Post]
meta: Dict[str, Any]
def normalize_site(site: str) -> str:
"""
Normalize a site identifier by stripping protocol and trailing slashes.
Args:
site: Site URL, domain, or ID (e.g., "https://myblog.wordpress.com/", "myblog.wordpress.com", "123456789")
Returns:
Normalized site identifier (domain or ID only)
"""
site = site.strip()
if site.startswith("https://"):
site = site[8:]
elif site.startswith("http://"):
site = site[7:]
return site.rstrip("/")
async def get_posts(
credentials: Credentials,
site: str,
status: PostStatus | None = None,
number: int = 100,
offset: int = 0,
) -> PostsResponse:
"""
Get posts from a WordPress site.
Args:
credentials: OAuth credentials
site: Site ID or domain (e.g., "myblog.wordpress.com" or "123456789")
status: Filter by post status using PostStatus enum, or None for all
number: Number of posts to retrieve (max 100)
offset: Number of posts to skip (for pagination)
Returns:
PostsResponse with the list of posts
"""
site = normalize_site(site)
endpoint = f"/rest/v1.1/sites/{site}/posts"
headers = {
"Authorization": credentials.auth_header(),
}
params: Dict[str, Any] = {
"number": max(1, min(number, 100)), # 1100 posts per request
"offset": offset,
}
if status:
params["status"] = status.value
response = await Requests(raise_for_status=False).get(
f"{WORDPRESS_BASE_URL.rstrip('/')}{endpoint}",
headers=headers,
params=params,
)
if response.ok:
return PostsResponse.model_validate(response.json())
error_data = (
response.json()
if response.headers.get("content-type", "").startswith("application/json")
else {}
)
error_message = error_data.get("message", response.text)
raise ValueError(f"Failed to get posts: {response.status} - {error_message}")

View File

@@ -9,15 +9,7 @@ from backend.sdk import (
SchemaField,
)
from ._api import (
CreatePostRequest,
Post,
PostResponse,
PostsResponse,
PostStatus,
create_post,
get_posts,
)
from ._api import CreatePostRequest, PostResponse, PostStatus, create_post
from ._config import wordpress
@@ -57,15 +49,8 @@ class WordPressCreatePostBlock(Block):
media_urls: list[str] = SchemaField(
description="URLs of images to sideload and attach to the post", default=[]
)
publish_as_draft: bool = SchemaField(
description="If True, publishes the post as a draft. If False, publishes it publicly.",
default=False,
)
class Output(BlockSchemaOutput):
site: str = SchemaField(
description="The site ID or domain (pass-through for chaining with other blocks)"
)
post_id: int = SchemaField(description="The ID of the created post")
post_url: str = SchemaField(description="The full URL of the created post")
short_url: str = SchemaField(description="The shortened wp.me URL")
@@ -93,9 +78,7 @@ class WordPressCreatePostBlock(Block):
tags=input_data.tags,
featured_image=input_data.featured_image,
media_urls=input_data.media_urls,
status=(
PostStatus.DRAFT if input_data.publish_as_draft else PostStatus.PUBLISH
),
status=PostStatus.PUBLISH,
)
post_response: PostResponse = await create_post(
@@ -104,69 +87,7 @@ class WordPressCreatePostBlock(Block):
post_data=post_request,
)
yield "site", input_data.site
yield "post_id", post_response.ID
yield "post_url", post_response.URL
yield "short_url", post_response.short_URL
yield "post_data", post_response.model_dump()
class WordPressGetAllPostsBlock(Block):
"""
Fetches all posts from a WordPress.com site or Jetpack-enabled site.
Supports filtering by status and pagination.
"""
class Input(BlockSchemaInput):
credentials: CredentialsMetaInput = wordpress.credentials_field()
site: str = SchemaField(
description="Site ID or domain (e.g., 'myblog.wordpress.com' or '123456789')"
)
status: PostStatus | None = SchemaField(
description="Filter by post status, or None for all",
default=None,
)
number: int = SchemaField(
description="Number of posts to retrieve (max 100 per request)", default=20
)
offset: int = SchemaField(
description="Number of posts to skip (for pagination)", default=0
)
class Output(BlockSchemaOutput):
site: str = SchemaField(
description="The site ID or domain (pass-through for chaining with other blocks)"
)
found: int = SchemaField(description="Total number of posts found")
posts: list[Post] = SchemaField(
description="List of post objects with their details"
)
post: Post = SchemaField(
description="Individual post object (yielded for each post)"
)
def __init__(self):
super().__init__(
id="97728fa7-7f6f-4789-ba0c-f2c114119536",
description="Fetch all posts from WordPress.com or Jetpack sites",
categories={BlockCategory.SOCIAL},
input_schema=self.Input,
output_schema=self.Output,
)
async def run(
self, input_data: Input, *, credentials: Credentials, **kwargs
) -> BlockOutput:
posts_response: PostsResponse = await get_posts(
credentials=credentials,
site=input_data.site,
status=input_data.status,
number=input_data.number,
offset=input_data.offset,
)
yield "site", input_data.site
yield "found", posts_response.found
yield "posts", posts_response.posts
for post in posts_response.posts:
yield "post", post

View File

@@ -1,5 +1,5 @@
from gravitasml.parser import Parser
from gravitasml.token import Token, tokenize
from gravitasml.token import tokenize
from backend.data.block import Block, BlockOutput, BlockSchemaInput, BlockSchemaOutput
from backend.data.model import SchemaField
@@ -25,38 +25,6 @@ class XMLParserBlock(Block):
],
)
@staticmethod
def _validate_tokens(tokens: list[Token]) -> None:
"""Ensure the XML has a single root element and no stray text."""
if not tokens:
raise ValueError("XML input is empty.")
depth = 0
root_seen = False
for token in tokens:
if token.type == "TAG_OPEN":
if depth == 0 and root_seen:
raise ValueError("XML must have a single root element.")
depth += 1
if depth == 1:
root_seen = True
elif token.type == "TAG_CLOSE":
depth -= 1
if depth < 0:
raise SyntaxError("Unexpected closing tag in XML input.")
elif token.type in {"TEXT", "ESCAPE"}:
if depth == 0 and token.value:
raise ValueError(
"XML contains text outside the root element; "
"wrap content in a single root tag."
)
if depth != 0:
raise SyntaxError("Unclosed tag detected in XML input.")
if not root_seen:
raise ValueError("XML must include a root element.")
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
# Security fix: Add size limits to prevent XML bomb attacks
MAX_XML_SIZE = 10 * 1024 * 1024 # 10MB limit for XML input
@@ -67,12 +35,16 @@ class XMLParserBlock(Block):
)
try:
tokens = list(tokenize(input_data.input_xml))
self._validate_tokens(tokens)
tokens = tokenize(input_data.input_xml)
parser = Parser(tokens)
parsed_result = parser.parse()
yield "parsed_xml", parsed_result
except AttributeError as attr_e:
raise ValueError(
f"XML parsing error: The gravitasml library encountered an incompatible structure. "
f"This may occur with certain XML formats that create List objects with text content. "
f"Details: {attr_e}"
) from attr_e
except ValueError as val_e:
raise ValueError(f"Validation error for dict:{val_e}") from val_e
except SyntaxError as syn_e:

View File

@@ -111,8 +111,6 @@ class TranscribeYoutubeVideoBlock(Block):
return parsed_url.path.split("/")[2]
if parsed_url.path[:3] == "/v/":
return parsed_url.path.split("/")[2]
if parsed_url.path.startswith("/shorts/"):
return parsed_url.path.split("/")[2]
raise ValueError(f"Invalid YouTube URL: {url}")
def get_transcript(

View File

@@ -50,8 +50,6 @@ from .model import (
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from backend.data.execution import ExecutionContext
from .graph import Link
app_config = Config()
@@ -474,7 +472,6 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
self.block_type = block_type
self.webhook_config = webhook_config
self.execution_stats: NodeExecutionStats = NodeExecutionStats()
self.requires_human_review: bool = False
if self.webhook_config:
if isinstance(self.webhook_config, BlockWebhookConfig):
@@ -617,77 +614,7 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
block_id=self.id,
) from ex
async def is_block_exec_need_review(
self,
input_data: BlockInput,
*,
user_id: str,
node_exec_id: str,
graph_exec_id: str,
graph_id: str,
graph_version: int,
execution_context: "ExecutionContext",
**kwargs,
) -> tuple[bool, BlockInput]:
"""
Check if this block execution needs human review and handle the review process.
Returns:
Tuple of (should_pause, input_data_to_use)
- should_pause: True if execution should be paused for review
- input_data_to_use: The input data to use (may be modified by reviewer)
"""
# Skip review if not required or safe mode is disabled
if not self.requires_human_review or not execution_context.safe_mode:
return False, input_data
from backend.blocks.helpers.review import HITLReviewHelper
# Handle the review request and get decision
decision = await HITLReviewHelper.handle_review_decision(
input_data=input_data,
user_id=user_id,
node_exec_id=node_exec_id,
graph_exec_id=graph_exec_id,
graph_id=graph_id,
graph_version=graph_version,
execution_context=execution_context,
block_name=self.name,
editable=True,
)
if decision is None:
# We're awaiting review - pause execution
return True, input_data
if not decision.should_proceed:
# Review was rejected, raise an error to stop execution
raise BlockExecutionError(
message=f"Block execution rejected by reviewer: {decision.message}",
block_name=self.name,
block_id=self.id,
)
# Review was approved - use the potentially modified data
# ReviewResult.data must be a dict for block inputs
reviewed_data = decision.review_result.data
if not isinstance(reviewed_data, dict):
raise BlockExecutionError(
message=f"Review data must be a dict for block input, got {type(reviewed_data).__name__}",
block_name=self.name,
block_id=self.id,
)
return False, reviewed_data
async def _execute(self, input_data: BlockInput, **kwargs) -> BlockOutput:
# Check for review requirement and get potentially modified input data
should_pause, input_data = await self.is_block_exec_need_review(
input_data, **kwargs
)
if should_pause:
return
# Validate the input data (original or reviewer-modified) once
if error := self.input_schema.validate_data(input_data):
raise BlockInputError(
message=f"Unable to execute block with invalid input data: {error}",
@@ -695,7 +622,6 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
block_id=self.id,
)
# Use the validated input data
async for output_name, output_data in self.run(
self.input_schema(**{k: v for k, v in input_data.items() if v is not None}),
**kwargs,

View File

@@ -59,13 +59,12 @@ from backend.integrations.credentials_store import (
MODEL_COST: dict[LlmModel, int] = {
LlmModel.O3: 4,
LlmModel.O3_MINI: 2,
LlmModel.O1: 16,
LlmModel.O3_MINI: 2, # $1.10 / $4.40
LlmModel.O1: 16, # $15 / $60
LlmModel.O1_MINI: 4,
# GPT-5 models
LlmModel.GPT5_2: 6,
LlmModel.GPT5_1: 5,
LlmModel.GPT5: 2,
LlmModel.GPT5_1: 5,
LlmModel.GPT5_MINI: 1,
LlmModel.GPT5_NANO: 1,
LlmModel.GPT5_CHAT: 5,
@@ -88,7 +87,7 @@ MODEL_COST: dict[LlmModel, int] = {
LlmModel.AIML_API_LLAMA3_3_70B: 1,
LlmModel.AIML_API_META_LLAMA_3_1_70B: 1,
LlmModel.AIML_API_LLAMA_3_2_3B: 1,
LlmModel.LLAMA3_3_70B: 1,
LlmModel.LLAMA3_3_70B: 1, # $0.59 / $0.79
LlmModel.LLAMA3_1_8B: 1,
LlmModel.OLLAMA_LLAMA3_3: 1,
LlmModel.OLLAMA_LLAMA3_2: 1,

View File

@@ -341,19 +341,6 @@ class UserCreditBase(ABC):
if result:
# UserBalance is already updated by the CTE
# Clear insufficient funds notification flags when credits are added
# so user can receive alerts again if they run out in the future.
if transaction.amount > 0 and transaction.type in [
CreditTransactionType.GRANT,
CreditTransactionType.TOP_UP,
]:
from backend.executor.manager import (
clear_insufficient_funds_notifications,
)
await clear_insufficient_funds_notifications(user_id)
return result[0]["balance"]
async def _add_transaction(
@@ -543,22 +530,6 @@ class UserCreditBase(ABC):
if result:
new_balance, tx_key = result[0]["balance"], result[0]["transactionKey"]
# UserBalance is already updated by the CTE
# Clear insufficient funds notification flags when credits are added
# so user can receive alerts again if they run out in the future.
if (
amount > 0
and is_active
and transaction_type
in [CreditTransactionType.GRANT, CreditTransactionType.TOP_UP]
):
# Lazy import to avoid circular dependency with executor.manager
from backend.executor.manager import (
clear_insufficient_funds_notifications,
)
await clear_insufficient_funds_notifications(user_id)
return new_balance, tx_key
# If no result, either user doesn't exist or insufficient balance

View File

@@ -383,7 +383,6 @@ class GraphExecutionWithNodes(GraphExecution):
self,
execution_context: ExecutionContext,
compiled_nodes_input_masks: Optional[NodesInputMasks] = None,
nodes_to_skip: Optional[set[str]] = None,
):
return GraphExecutionEntry(
user_id=self.user_id,
@@ -391,7 +390,6 @@ class GraphExecutionWithNodes(GraphExecution):
graph_version=self.graph_version or 0,
graph_exec_id=self.id,
nodes_input_masks=compiled_nodes_input_masks,
nodes_to_skip=nodes_to_skip or set(),
execution_context=execution_context,
)
@@ -1147,8 +1145,6 @@ class GraphExecutionEntry(BaseModel):
graph_id: str
graph_version: int
nodes_input_masks: Optional[NodesInputMasks] = None
nodes_to_skip: set[str] = Field(default_factory=set)
"""Node IDs that should be skipped due to optional credentials not being configured."""
execution_context: ExecutionContext = Field(default_factory=ExecutionContext)

View File

@@ -94,15 +94,6 @@ class Node(BaseDbModel):
input_links: list[Link] = []
output_links: list[Link] = []
@property
def credentials_optional(self) -> bool:
"""
Whether credentials are optional for this node.
When True and credentials are not configured, the node will be skipped
during execution rather than causing a validation error.
"""
return self.metadata.get("credentials_optional", False)
@property
def block(self) -> AnyBlockSchema | "_UnknownBlockBase":
"""Get the block for this node. Returns UnknownBlock if block is deleted/missing."""
@@ -244,10 +235,7 @@ class BaseGraph(BaseDbModel):
return any(
node.block_id
for node in self.nodes
if (
node.block.block_type == BlockType.HUMAN_IN_THE_LOOP
or node.block.requires_human_review
)
if node.block.block_type == BlockType.HUMAN_IN_THE_LOOP
)
@property
@@ -338,35 +326,7 @@ class Graph(BaseGraph):
@computed_field
@property
def credentials_input_schema(self) -> dict[str, Any]:
schema = self._credentials_input_schema.jsonschema()
# Determine which credential fields are required based on credentials_optional metadata
graph_credentials_inputs = self.aggregate_credentials_inputs()
required_fields = []
# Build a map of node_id -> node for quick lookup
all_nodes = {node.id: node for node in self.nodes}
for sub_graph in self.sub_graphs:
for node in sub_graph.nodes:
all_nodes[node.id] = node
for field_key, (
_field_info,
node_field_pairs,
) in graph_credentials_inputs.items():
# A field is required if ANY node using it has credentials_optional=False
is_required = False
for node_id, _field_name in node_field_pairs:
node = all_nodes.get(node_id)
if node and not node.credentials_optional:
is_required = True
break
if is_required:
required_fields.append(field_key)
schema["required"] = required_fields
return schema
return self._credentials_input_schema.jsonschema()
@property
def _credentials_input_schema(self) -> type[BlockSchema]:

View File

@@ -396,58 +396,3 @@ async def test_access_store_listing_graph(server: SpinTestServer):
created_graph.id, created_graph.version, "3e53486c-cf57-477e-ba2a-cb02dc828e1b"
)
assert got_graph is not None
# ============================================================================
# Tests for Optional Credentials Feature
# ============================================================================
def test_node_credentials_optional_default():
"""Test that credentials_optional defaults to False when not set in metadata."""
node = Node(
id="test_node",
block_id=StoreValueBlock().id,
input_default={},
metadata={},
)
assert node.credentials_optional is False
def test_node_credentials_optional_true():
"""Test that credentials_optional returns True when explicitly set."""
node = Node(
id="test_node",
block_id=StoreValueBlock().id,
input_default={},
metadata={"credentials_optional": True},
)
assert node.credentials_optional is True
def test_node_credentials_optional_false():
"""Test that credentials_optional returns False when explicitly set to False."""
node = Node(
id="test_node",
block_id=StoreValueBlock().id,
input_default={},
metadata={"credentials_optional": False},
)
assert node.credentials_optional is False
def test_node_credentials_optional_with_other_metadata():
"""Test that credentials_optional works correctly with other metadata present."""
node = Node(
id="test_node",
block_id=StoreValueBlock().id,
input_default={},
metadata={
"position": {"x": 100, "y": 200},
"customized_name": "My Custom Node",
"credentials_optional": True,
},
)
assert node.credentials_optional is True
assert node.metadata["position"] == {"x": 100, "y": 200}
assert node.metadata["customized_name"] == "My Custom Node"

View File

@@ -114,40 +114,6 @@ utilization_gauge = Gauge(
"Ratio of active graph runs to max graph workers",
)
# Redis key prefix for tracking insufficient funds Discord notifications.
# We only send one notification per user per agent until they top up credits.
INSUFFICIENT_FUNDS_NOTIFIED_PREFIX = "insufficient_funds_discord_notified"
# TTL for the notification flag (30 days) - acts as a fallback cleanup
INSUFFICIENT_FUNDS_NOTIFIED_TTL_SECONDS = 30 * 24 * 60 * 60
async def clear_insufficient_funds_notifications(user_id: str) -> int:
"""
Clear all insufficient funds notification flags for a user.
This should be called when a user tops up their credits, allowing
Discord notifications to be sent again if they run out of funds.
Args:
user_id: The user ID to clear notifications for.
Returns:
The number of keys that were deleted.
"""
try:
redis_client = await redis.get_redis_async()
pattern = f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:*"
keys = [key async for key in redis_client.scan_iter(match=pattern)]
if keys:
return await redis_client.delete(*keys)
return 0
except Exception as e:
logger.warning(
f"Failed to clear insufficient funds notification flags for user "
f"{user_id}: {e}"
)
return 0
# Thread-local storage for ExecutionProcessor instances
_tls = threading.local()
@@ -178,7 +144,6 @@ async def execute_node(
execution_processor: "ExecutionProcessor",
execution_stats: NodeExecutionStats | None = None,
nodes_input_masks: Optional[NodesInputMasks] = None,
nodes_to_skip: Optional[set[str]] = None,
) -> BlockOutput:
"""
Execute a node in the graph. This will trigger a block execution on a node,
@@ -246,7 +211,6 @@ async def execute_node(
"user_id": user_id,
"execution_context": execution_context,
"execution_processor": execution_processor,
"nodes_to_skip": nodes_to_skip or set(),
}
# Last-minute fetch credentials + acquire a system-wide read-write lock to prevent
@@ -544,7 +508,6 @@ class ExecutionProcessor:
node_exec_progress: NodeExecutionProgress,
nodes_input_masks: Optional[NodesInputMasks],
graph_stats_pair: tuple[GraphExecutionStats, threading.Lock],
nodes_to_skip: Optional[set[str]] = None,
) -> NodeExecutionStats:
log_metadata = LogMetadata(
logger=_logger,
@@ -567,7 +530,6 @@ class ExecutionProcessor:
db_client=db_client,
log_metadata=log_metadata,
nodes_input_masks=nodes_input_masks,
nodes_to_skip=nodes_to_skip,
)
if isinstance(status, BaseException):
raise status
@@ -613,7 +575,6 @@ class ExecutionProcessor:
db_client: "DatabaseManagerAsyncClient",
log_metadata: LogMetadata,
nodes_input_masks: Optional[NodesInputMasks] = None,
nodes_to_skip: Optional[set[str]] = None,
) -> ExecutionStatus:
status = ExecutionStatus.RUNNING
@@ -650,7 +611,6 @@ class ExecutionProcessor:
execution_processor=self,
execution_stats=stats,
nodes_input_masks=nodes_input_masks,
nodes_to_skip=nodes_to_skip,
):
await persist_output(output_name, output_data)
@@ -962,21 +922,6 @@ class ExecutionProcessor:
queued_node_exec = execution_queue.get()
# Check if this node should be skipped due to optional credentials
if queued_node_exec.node_id in graph_exec.nodes_to_skip:
log_metadata.info(
f"Skipping node execution {queued_node_exec.node_exec_id} "
f"for node {queued_node_exec.node_id} - optional credentials not configured"
)
# Mark the node as completed without executing
# No outputs will be produced, so downstream nodes won't trigger
update_node_execution_status(
db_client=db_client,
exec_id=queued_node_exec.node_exec_id,
status=ExecutionStatus.COMPLETED,
)
continue
log_metadata.debug(
f"Dispatching node execution {queued_node_exec.node_exec_id} "
f"for node {queued_node_exec.node_id}",
@@ -1037,7 +982,6 @@ class ExecutionProcessor:
execution_stats,
execution_stats_lock,
),
nodes_to_skip=graph_exec.nodes_to_skip,
),
self.node_execution_loop,
)
@@ -1317,40 +1261,12 @@ class ExecutionProcessor:
graph_id: str,
e: InsufficientBalanceError,
):
# Check if we've already sent a notification for this user+agent combo.
# We only send one notification per user per agent until they top up credits.
redis_key = f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:{graph_id}"
try:
redis_client = redis.get_redis()
# SET NX returns True only if the key was newly set (didn't exist)
is_new_notification = redis_client.set(
redis_key,
"1",
nx=True,
ex=INSUFFICIENT_FUNDS_NOTIFIED_TTL_SECONDS,
)
if not is_new_notification:
# Already notified for this user+agent, skip all notifications
logger.debug(
f"Skipping duplicate insufficient funds notification for "
f"user={user_id}, graph={graph_id}"
)
return
except Exception as redis_error:
# If Redis fails, log and continue to send the notification
# (better to occasionally duplicate than to never notify)
logger.warning(
f"Failed to check/set insufficient funds notification flag in Redis: "
f"{redis_error}"
)
shortfall = abs(e.amount) - e.balance
metadata = db_client.get_graph_metadata(graph_id)
base_url = (
settings.config.frontend_base_url or settings.config.platform_base_url
)
# Queue user email notification
queue_notification(
NotificationEventModel(
user_id=user_id,
@@ -1364,7 +1280,6 @@ class ExecutionProcessor:
)
)
# Send Discord system alert
try:
user_email = db_client.get_user_email_by_id(user_id)

View File

@@ -1,560 +0,0 @@
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from prisma.enums import NotificationType
from backend.data.notifications import ZeroBalanceData
from backend.executor.manager import (
INSUFFICIENT_FUNDS_NOTIFIED_PREFIX,
ExecutionProcessor,
clear_insufficient_funds_notifications,
)
from backend.util.exceptions import InsufficientBalanceError
from backend.util.test import SpinTestServer
async def async_iter(items):
"""Helper to create an async iterator from a list."""
for item in items:
yield item
@pytest.mark.asyncio(loop_scope="session")
async def test_handle_insufficient_funds_sends_discord_alert_first_time(
server: SpinTestServer,
):
"""Test that the first insufficient funds notification sends a Discord alert."""
execution_processor = ExecutionProcessor()
user_id = "test-user-123"
graph_id = "test-graph-456"
error = InsufficientBalanceError(
message="Insufficient balance",
user_id=user_id,
balance=72, # $0.72
amount=-714, # Attempting to spend $7.14
)
with patch(
"backend.executor.manager.queue_notification"
) as mock_queue_notif, patch(
"backend.executor.manager.get_notification_manager_client"
) as mock_get_client, patch(
"backend.executor.manager.settings"
) as mock_settings, patch(
"backend.executor.manager.redis"
) as mock_redis_module:
# Setup mocks
mock_client = MagicMock()
mock_get_client.return_value = mock_client
mock_settings.config.frontend_base_url = "https://test.com"
# Mock Redis to simulate first-time notification (set returns True)
mock_redis_client = MagicMock()
mock_redis_module.get_redis.return_value = mock_redis_client
mock_redis_client.set.return_value = True # Key was newly set
# Create mock database client
mock_db_client = MagicMock()
mock_graph_metadata = MagicMock()
mock_graph_metadata.name = "Test Agent"
mock_db_client.get_graph_metadata.return_value = mock_graph_metadata
mock_db_client.get_user_email_by_id.return_value = "test@example.com"
# Test the insufficient funds handler
execution_processor._handle_insufficient_funds_notif(
db_client=mock_db_client,
user_id=user_id,
graph_id=graph_id,
e=error,
)
# Verify notification was queued
mock_queue_notif.assert_called_once()
notification_call = mock_queue_notif.call_args[0][0]
assert notification_call.type == NotificationType.ZERO_BALANCE
assert notification_call.user_id == user_id
assert isinstance(notification_call.data, ZeroBalanceData)
assert notification_call.data.current_balance == 72
# Verify Redis was checked with correct key pattern
expected_key = f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:{graph_id}"
mock_redis_client.set.assert_called_once()
call_args = mock_redis_client.set.call_args
assert call_args[0][0] == expected_key
assert call_args[1]["nx"] is True
# Verify Discord alert was sent
mock_client.discord_system_alert.assert_called_once()
discord_message = mock_client.discord_system_alert.call_args[0][0]
assert "Insufficient Funds Alert" in discord_message
assert "test@example.com" in discord_message
assert "Test Agent" in discord_message
@pytest.mark.asyncio(loop_scope="session")
async def test_handle_insufficient_funds_skips_duplicate_notifications(
server: SpinTestServer,
):
"""Test that duplicate insufficient funds notifications skip both email and Discord."""
execution_processor = ExecutionProcessor()
user_id = "test-user-123"
graph_id = "test-graph-456"
error = InsufficientBalanceError(
message="Insufficient balance",
user_id=user_id,
balance=72,
amount=-714,
)
with patch(
"backend.executor.manager.queue_notification"
) as mock_queue_notif, patch(
"backend.executor.manager.get_notification_manager_client"
) as mock_get_client, patch(
"backend.executor.manager.settings"
) as mock_settings, patch(
"backend.executor.manager.redis"
) as mock_redis_module:
# Setup mocks
mock_client = MagicMock()
mock_get_client.return_value = mock_client
mock_settings.config.frontend_base_url = "https://test.com"
# Mock Redis to simulate duplicate notification (set returns False/None)
mock_redis_client = MagicMock()
mock_redis_module.get_redis.return_value = mock_redis_client
mock_redis_client.set.return_value = None # Key already existed
# Create mock database client
mock_db_client = MagicMock()
mock_db_client.get_graph_metadata.return_value = MagicMock(name="Test Agent")
# Test the insufficient funds handler
execution_processor._handle_insufficient_funds_notif(
db_client=mock_db_client,
user_id=user_id,
graph_id=graph_id,
e=error,
)
# Verify email notification was NOT queued (deduplication worked)
mock_queue_notif.assert_not_called()
# Verify Discord alert was NOT sent (deduplication worked)
mock_client.discord_system_alert.assert_not_called()
@pytest.mark.asyncio(loop_scope="session")
async def test_handle_insufficient_funds_different_agents_get_separate_alerts(
server: SpinTestServer,
):
"""Test that different agents for the same user get separate Discord alerts."""
execution_processor = ExecutionProcessor()
user_id = "test-user-123"
graph_id_1 = "test-graph-111"
graph_id_2 = "test-graph-222"
error = InsufficientBalanceError(
message="Insufficient balance",
user_id=user_id,
balance=72,
amount=-714,
)
with patch("backend.executor.manager.queue_notification"), patch(
"backend.executor.manager.get_notification_manager_client"
) as mock_get_client, patch(
"backend.executor.manager.settings"
) as mock_settings, patch(
"backend.executor.manager.redis"
) as mock_redis_module:
mock_client = MagicMock()
mock_get_client.return_value = mock_client
mock_settings.config.frontend_base_url = "https://test.com"
mock_redis_client = MagicMock()
mock_redis_module.get_redis.return_value = mock_redis_client
# Both calls return True (first time for each agent)
mock_redis_client.set.return_value = True
mock_db_client = MagicMock()
mock_graph_metadata = MagicMock()
mock_graph_metadata.name = "Test Agent"
mock_db_client.get_graph_metadata.return_value = mock_graph_metadata
mock_db_client.get_user_email_by_id.return_value = "test@example.com"
# First agent notification
execution_processor._handle_insufficient_funds_notif(
db_client=mock_db_client,
user_id=user_id,
graph_id=graph_id_1,
e=error,
)
# Second agent notification
execution_processor._handle_insufficient_funds_notif(
db_client=mock_db_client,
user_id=user_id,
graph_id=graph_id_2,
e=error,
)
# Verify Discord alerts were sent for both agents
assert mock_client.discord_system_alert.call_count == 2
# Verify Redis was called with different keys
assert mock_redis_client.set.call_count == 2
calls = mock_redis_client.set.call_args_list
assert (
calls[0][0][0]
== f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:{graph_id_1}"
)
assert (
calls[1][0][0]
== f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:{graph_id_2}"
)
@pytest.mark.asyncio(loop_scope="session")
async def test_clear_insufficient_funds_notifications(server: SpinTestServer):
"""Test that clearing notifications removes all keys for a user."""
user_id = "test-user-123"
with patch("backend.executor.manager.redis") as mock_redis_module:
mock_redis_client = MagicMock()
# get_redis_async is an async function, so we need AsyncMock for it
mock_redis_module.get_redis_async = AsyncMock(return_value=mock_redis_client)
# Mock scan_iter to return some keys as an async iterator
mock_keys = [
f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:graph-1",
f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:graph-2",
f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:graph-3",
]
mock_redis_client.scan_iter.return_value = async_iter(mock_keys)
# delete is awaited, so use AsyncMock
mock_redis_client.delete = AsyncMock(return_value=3)
# Clear notifications
result = await clear_insufficient_funds_notifications(user_id)
# Verify correct pattern was used
expected_pattern = f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:*"
mock_redis_client.scan_iter.assert_called_once_with(match=expected_pattern)
# Verify delete was called with all keys
mock_redis_client.delete.assert_called_once_with(*mock_keys)
# Verify return value
assert result == 3
@pytest.mark.asyncio(loop_scope="session")
async def test_clear_insufficient_funds_notifications_no_keys(server: SpinTestServer):
"""Test clearing notifications when there are no keys to clear."""
user_id = "test-user-no-notifications"
with patch("backend.executor.manager.redis") as mock_redis_module:
mock_redis_client = MagicMock()
# get_redis_async is an async function, so we need AsyncMock for it
mock_redis_module.get_redis_async = AsyncMock(return_value=mock_redis_client)
# Mock scan_iter to return no keys as an async iterator
mock_redis_client.scan_iter.return_value = async_iter([])
# Clear notifications
result = await clear_insufficient_funds_notifications(user_id)
# Verify delete was not called
mock_redis_client.delete.assert_not_called()
# Verify return value
assert result == 0
@pytest.mark.asyncio(loop_scope="session")
async def test_clear_insufficient_funds_notifications_handles_redis_error(
server: SpinTestServer,
):
"""Test that clearing notifications handles Redis errors gracefully."""
user_id = "test-user-redis-error"
with patch("backend.executor.manager.redis") as mock_redis_module:
# Mock get_redis_async to raise an error
mock_redis_module.get_redis_async = AsyncMock(
side_effect=Exception("Redis connection failed")
)
# Clear notifications should not raise, just return 0
result = await clear_insufficient_funds_notifications(user_id)
# Verify it returned 0 (graceful failure)
assert result == 0
@pytest.mark.asyncio(loop_scope="session")
async def test_handle_insufficient_funds_continues_on_redis_error(
server: SpinTestServer,
):
"""Test that both email and Discord notifications are still sent when Redis fails."""
execution_processor = ExecutionProcessor()
user_id = "test-user-123"
graph_id = "test-graph-456"
error = InsufficientBalanceError(
message="Insufficient balance",
user_id=user_id,
balance=72,
amount=-714,
)
with patch(
"backend.executor.manager.queue_notification"
) as mock_queue_notif, patch(
"backend.executor.manager.get_notification_manager_client"
) as mock_get_client, patch(
"backend.executor.manager.settings"
) as mock_settings, patch(
"backend.executor.manager.redis"
) as mock_redis_module:
mock_client = MagicMock()
mock_get_client.return_value = mock_client
mock_settings.config.frontend_base_url = "https://test.com"
# Mock Redis to raise an error
mock_redis_client = MagicMock()
mock_redis_module.get_redis.return_value = mock_redis_client
mock_redis_client.set.side_effect = Exception("Redis connection error")
mock_db_client = MagicMock()
mock_graph_metadata = MagicMock()
mock_graph_metadata.name = "Test Agent"
mock_db_client.get_graph_metadata.return_value = mock_graph_metadata
mock_db_client.get_user_email_by_id.return_value = "test@example.com"
# Test the insufficient funds handler
execution_processor._handle_insufficient_funds_notif(
db_client=mock_db_client,
user_id=user_id,
graph_id=graph_id,
e=error,
)
# Verify email notification was still queued despite Redis error
mock_queue_notif.assert_called_once()
# Verify Discord alert was still sent despite Redis error
mock_client.discord_system_alert.assert_called_once()
@pytest.mark.asyncio(loop_scope="session")
async def test_add_transaction_clears_notifications_on_grant(server: SpinTestServer):
"""Test that _add_transaction clears notification flags when adding GRANT credits."""
from prisma.enums import CreditTransactionType
from backend.data.credit import UserCredit
user_id = "test-user-grant-clear"
with patch("backend.data.credit.query_raw_with_schema") as mock_query, patch(
"backend.executor.manager.redis"
) as mock_redis_module:
# Mock the query to return a successful transaction
mock_query.return_value = [{"balance": 1000, "transactionKey": "test-tx-key"}]
# Mock async Redis for notification clearing
mock_redis_client = MagicMock()
mock_redis_module.get_redis_async = AsyncMock(return_value=mock_redis_client)
mock_redis_client.scan_iter.return_value = async_iter(
[f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:graph-1"]
)
mock_redis_client.delete = AsyncMock(return_value=1)
# Create a concrete instance
credit_model = UserCredit()
# Call _add_transaction with GRANT type (should clear notifications)
await credit_model._add_transaction(
user_id=user_id,
amount=500, # Positive amount
transaction_type=CreditTransactionType.GRANT,
is_active=True, # Active transaction
)
# Verify notification clearing was called
mock_redis_module.get_redis_async.assert_called_once()
mock_redis_client.scan_iter.assert_called_once_with(
match=f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:*"
)
@pytest.mark.asyncio(loop_scope="session")
async def test_add_transaction_clears_notifications_on_top_up(server: SpinTestServer):
"""Test that _add_transaction clears notification flags when adding TOP_UP credits."""
from prisma.enums import CreditTransactionType
from backend.data.credit import UserCredit
user_id = "test-user-topup-clear"
with patch("backend.data.credit.query_raw_with_schema") as mock_query, patch(
"backend.executor.manager.redis"
) as mock_redis_module:
# Mock the query to return a successful transaction
mock_query.return_value = [{"balance": 2000, "transactionKey": "test-tx-key-2"}]
# Mock async Redis for notification clearing
mock_redis_client = MagicMock()
mock_redis_module.get_redis_async = AsyncMock(return_value=mock_redis_client)
mock_redis_client.scan_iter.return_value = async_iter([])
mock_redis_client.delete = AsyncMock(return_value=0)
credit_model = UserCredit()
# Call _add_transaction with TOP_UP type (should clear notifications)
await credit_model._add_transaction(
user_id=user_id,
amount=1000, # Positive amount
transaction_type=CreditTransactionType.TOP_UP,
is_active=True,
)
# Verify notification clearing was attempted
mock_redis_module.get_redis_async.assert_called_once()
@pytest.mark.asyncio(loop_scope="session")
async def test_add_transaction_skips_clearing_for_inactive_transaction(
server: SpinTestServer,
):
"""Test that _add_transaction does NOT clear notifications for inactive transactions."""
from prisma.enums import CreditTransactionType
from backend.data.credit import UserCredit
user_id = "test-user-inactive"
with patch("backend.data.credit.query_raw_with_schema") as mock_query, patch(
"backend.executor.manager.redis"
) as mock_redis_module:
# Mock the query to return a successful transaction
mock_query.return_value = [{"balance": 500, "transactionKey": "test-tx-key-3"}]
# Mock async Redis
mock_redis_client = MagicMock()
mock_redis_module.get_redis_async = AsyncMock(return_value=mock_redis_client)
credit_model = UserCredit()
# Call _add_transaction with is_active=False (should NOT clear notifications)
await credit_model._add_transaction(
user_id=user_id,
amount=500,
transaction_type=CreditTransactionType.TOP_UP,
is_active=False, # Inactive - pending Stripe payment
)
# Verify notification clearing was NOT called
mock_redis_module.get_redis_async.assert_not_called()
@pytest.mark.asyncio(loop_scope="session")
async def test_add_transaction_skips_clearing_for_usage_transaction(
server: SpinTestServer,
):
"""Test that _add_transaction does NOT clear notifications for USAGE transactions."""
from prisma.enums import CreditTransactionType
from backend.data.credit import UserCredit
user_id = "test-user-usage"
with patch("backend.data.credit.query_raw_with_schema") as mock_query, patch(
"backend.executor.manager.redis"
) as mock_redis_module:
# Mock the query to return a successful transaction
mock_query.return_value = [{"balance": 400, "transactionKey": "test-tx-key-4"}]
# Mock async Redis
mock_redis_client = MagicMock()
mock_redis_module.get_redis_async = AsyncMock(return_value=mock_redis_client)
credit_model = UserCredit()
# Call _add_transaction with USAGE type (spending, should NOT clear)
await credit_model._add_transaction(
user_id=user_id,
amount=-100, # Negative - spending credits
transaction_type=CreditTransactionType.USAGE,
is_active=True,
)
# Verify notification clearing was NOT called
mock_redis_module.get_redis_async.assert_not_called()
@pytest.mark.asyncio(loop_scope="session")
async def test_enable_transaction_clears_notifications(server: SpinTestServer):
"""Test that _enable_transaction clears notification flags when enabling a TOP_UP."""
from prisma.enums import CreditTransactionType
from backend.data.credit import UserCredit
user_id = "test-user-enable"
with patch("backend.data.credit.CreditTransaction") as mock_credit_tx, patch(
"backend.data.credit.query_raw_with_schema"
) as mock_query, patch("backend.executor.manager.redis") as mock_redis_module:
# Mock finding the pending transaction
mock_transaction = MagicMock()
mock_transaction.amount = 1000
mock_transaction.type = CreditTransactionType.TOP_UP
mock_credit_tx.prisma.return_value.find_first = AsyncMock(
return_value=mock_transaction
)
# Mock the query to return updated balance
mock_query.return_value = [{"balance": 1500}]
# Mock async Redis for notification clearing
mock_redis_client = MagicMock()
mock_redis_module.get_redis_async = AsyncMock(return_value=mock_redis_client)
mock_redis_client.scan_iter.return_value = async_iter(
[f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:graph-1"]
)
mock_redis_client.delete = AsyncMock(return_value=1)
credit_model = UserCredit()
# Call _enable_transaction (simulates Stripe checkout completion)
from backend.util.json import SafeJson
await credit_model._enable_transaction(
transaction_key="cs_test_123",
user_id=user_id,
metadata=SafeJson({"payment": "completed"}),
)
# Verify notification clearing was called
mock_redis_module.get_redis_async.assert_called_once()
mock_redis_client.scan_iter.assert_called_once_with(
match=f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:*"
)

View File

@@ -239,19 +239,14 @@ async def _validate_node_input_credentials(
graph: GraphModel,
user_id: str,
nodes_input_masks: Optional[NodesInputMasks] = None,
) -> tuple[dict[str, dict[str, str]], set[str]]:
) -> dict[str, dict[str, str]]:
"""
Checks all credentials for all nodes of the graph and returns structured errors
and a set of nodes that should be skipped due to optional missing credentials.
Checks all credentials for all nodes of the graph and returns structured errors.
Returns:
tuple[
dict[node_id, dict[field_name, error_message]]: Credential validation errors per node,
set[node_id]: Nodes that should be skipped (optional credentials not configured)
]
dict[node_id, dict[field_name, error_message]]: Credential validation errors per node
"""
credential_errors: dict[str, dict[str, str]] = defaultdict(dict)
nodes_to_skip: set[str] = set()
for node in graph.nodes:
block = node.block
@@ -261,46 +256,27 @@ async def _validate_node_input_credentials(
if not credentials_fields:
continue
# Track if any credential field is missing for this node
has_missing_credentials = False
for field_name, credentials_meta_type in credentials_fields.items():
try:
# Check nodes_input_masks first, then input_default
field_value = None
if (
nodes_input_masks
and (node_input_mask := nodes_input_masks.get(node.id))
and field_name in node_input_mask
):
field_value = node_input_mask[field_name]
credentials_meta = credentials_meta_type.model_validate(
node_input_mask[field_name]
)
elif field_name in node.input_default:
# For optional credentials, don't use input_default - treat as missing
# This prevents stale credential IDs from failing validation
if node.credentials_optional:
field_value = None
else:
field_value = node.input_default[field_name]
# Check if credentials are missing (None, empty, or not present)
if field_value is None or (
isinstance(field_value, dict) and not field_value.get("id")
):
has_missing_credentials = True
# If node has credentials_optional flag, mark for skipping instead of error
if node.credentials_optional:
continue # Don't add error, will be marked for skip after loop
else:
credential_errors[node.id][
field_name
] = "These credentials are required"
continue
credentials_meta = credentials_meta_type.model_validate(field_value)
credentials_meta = credentials_meta_type.model_validate(
node.input_default[field_name]
)
else:
# Missing credentials
credential_errors[node.id][
field_name
] = "These credentials are required"
continue
except ValidationError as e:
# Validation error means credentials were provided but invalid
# This should always be an error, even if optional
credential_errors[node.id][field_name] = f"Invalid credentials: {e}"
continue
@@ -311,7 +287,6 @@ async def _validate_node_input_credentials(
)
except Exception as e:
# Handle any errors fetching credentials
# If credentials were explicitly configured but unavailable, it's an error
credential_errors[node.id][
field_name
] = f"Credentials not available: {e}"
@@ -338,19 +313,7 @@ async def _validate_node_input_credentials(
] = "Invalid credentials: type/provider mismatch"
continue
# If node has optional credentials and any are missing, mark for skipping
# But only if there are no other errors for this node
if (
has_missing_credentials
and node.credentials_optional
and node.id not in credential_errors
):
nodes_to_skip.add(node.id)
logger.info(
f"Node #{node.id} will be skipped: optional credentials not configured"
)
return credential_errors, nodes_to_skip
return credential_errors
def make_node_credentials_input_map(
@@ -392,25 +355,21 @@ async def validate_graph_with_credentials(
graph: GraphModel,
user_id: str,
nodes_input_masks: Optional[NodesInputMasks] = None,
) -> tuple[Mapping[str, Mapping[str, str]], set[str]]:
) -> Mapping[str, Mapping[str, str]]:
"""
Validate graph including credentials and return structured errors per node,
along with a set of nodes that should be skipped due to optional missing credentials.
Validate graph including credentials and return structured errors per node.
Returns:
tuple[
dict[node_id, dict[field_name, error_message]]: Validation errors per node,
set[node_id]: Nodes that should be skipped (optional credentials not configured)
]
dict[node_id, dict[field_name, error_message]]: Validation errors per node
"""
# Get input validation errors
node_input_errors = GraphModel.validate_graph_get_errors(
graph, for_run=True, nodes_input_masks=nodes_input_masks
)
# Get credential input/availability/validation errors and nodes to skip
node_credential_input_errors, nodes_to_skip = (
await _validate_node_input_credentials(graph, user_id, nodes_input_masks)
# Get credential input/availability/validation errors
node_credential_input_errors = await _validate_node_input_credentials(
graph, user_id, nodes_input_masks
)
# Merge credential errors with structural errors
@@ -419,7 +378,7 @@ async def validate_graph_with_credentials(
node_input_errors[node_id] = {}
node_input_errors[node_id].update(field_errors)
return node_input_errors, nodes_to_skip
return node_input_errors
async def _construct_starting_node_execution_input(
@@ -427,7 +386,7 @@ async def _construct_starting_node_execution_input(
user_id: str,
graph_inputs: BlockInput,
nodes_input_masks: Optional[NodesInputMasks] = None,
) -> tuple[list[tuple[str, BlockInput]], set[str]]:
) -> list[tuple[str, BlockInput]]:
"""
Validates and prepares the input data for executing a graph.
This function checks the graph for starting nodes, validates the input data
@@ -441,14 +400,11 @@ async def _construct_starting_node_execution_input(
node_credentials_map: `dict[node_id, dict[input_name, CredentialsMetaInput]]`
Returns:
tuple[
list[tuple[str, BlockInput]]: A list of tuples, each containing the node ID
and the corresponding input data for that node.
set[str]: Node IDs that should be skipped (optional credentials not configured)
]
list[tuple[str, BlockInput]]: A list of tuples, each containing the node ID and
the corresponding input data for that node.
"""
# Use new validation function that includes credentials
validation_errors, nodes_to_skip = await validate_graph_with_credentials(
validation_errors = await validate_graph_with_credentials(
graph, user_id, nodes_input_masks
)
n_error_nodes = len(validation_errors)
@@ -489,7 +445,7 @@ async def _construct_starting_node_execution_input(
"No starting nodes found for the graph, make sure an AgentInput or blocks with no inbound links are present as starting nodes."
)
return nodes_input, nodes_to_skip
return nodes_input
async def validate_and_construct_node_execution_input(
@@ -500,7 +456,7 @@ async def validate_and_construct_node_execution_input(
graph_credentials_inputs: Optional[Mapping[str, CredentialsMetaInput]] = None,
nodes_input_masks: Optional[NodesInputMasks] = None,
is_sub_graph: bool = False,
) -> tuple[GraphModel, list[tuple[str, BlockInput]], NodesInputMasks, set[str]]:
) -> tuple[GraphModel, list[tuple[str, BlockInput]], NodesInputMasks]:
"""
Public wrapper that handles graph fetching, credential mapping, and validation+construction.
This centralizes the logic used by both scheduler validation and actual execution.
@@ -517,7 +473,6 @@ async def validate_and_construct_node_execution_input(
GraphModel: Full graph object for the given `graph_id`.
list[tuple[node_id, BlockInput]]: Starting node IDs with corresponding inputs.
dict[str, BlockInput]: Node input masks including all passed-in credentials.
set[str]: Node IDs that should be skipped (optional credentials not configured).
Raises:
NotFoundError: If the graph is not found.
@@ -559,16 +514,14 @@ async def validate_and_construct_node_execution_input(
nodes_input_masks or {},
)
starting_nodes_input, nodes_to_skip = (
await _construct_starting_node_execution_input(
graph=graph,
user_id=user_id,
graph_inputs=graph_inputs,
nodes_input_masks=nodes_input_masks,
)
starting_nodes_input = await _construct_starting_node_execution_input(
graph=graph,
user_id=user_id,
graph_inputs=graph_inputs,
nodes_input_masks=nodes_input_masks,
)
return graph, starting_nodes_input, nodes_input_masks, nodes_to_skip
return graph, starting_nodes_input, nodes_input_masks
def _merge_nodes_input_masks(
@@ -826,9 +779,6 @@ async def add_graph_execution(
# Use existing execution's compiled input masks
compiled_nodes_input_masks = graph_exec.nodes_input_masks or {}
# For resumed executions, nodes_to_skip was already determined at creation time
# TODO: Consider storing nodes_to_skip in DB if we need to preserve it across resumes
nodes_to_skip: set[str] = set()
logger.info(f"Resuming graph execution #{graph_exec.id} for graph #{graph_id}")
else:
@@ -837,7 +787,7 @@ async def add_graph_execution(
)
# Create new execution
graph, starting_nodes_input, compiled_nodes_input_masks, nodes_to_skip = (
graph, starting_nodes_input, compiled_nodes_input_masks = (
await validate_and_construct_node_execution_input(
graph_id=graph_id,
user_id=user_id,
@@ -886,7 +836,6 @@ async def add_graph_execution(
try:
graph_exec_entry = graph_exec.to_graph_execution_entry(
compiled_nodes_input_masks=compiled_nodes_input_masks,
nodes_to_skip=nodes_to_skip,
execution_context=execution_context,
)
logger.info(f"Publishing execution {graph_exec.id} to execution queue")

View File

@@ -367,13 +367,10 @@ async def test_add_graph_execution_is_repeatable(mocker: MockerFixture):
)
# Setup mock returns
# The function returns (graph, starting_nodes_input, compiled_nodes_input_masks, nodes_to_skip)
nodes_to_skip: set[str] = set()
mock_validate.return_value = (
mock_graph,
starting_nodes_input,
compiled_nodes_input_masks,
nodes_to_skip,
)
mock_prisma.is_connected.return_value = True
mock_edb.create_graph_execution = mocker.AsyncMock(return_value=mock_graph_exec)
@@ -459,212 +456,3 @@ async def test_add_graph_execution_is_repeatable(mocker: MockerFixture):
# Both executions should succeed (though they create different objects)
assert result1 == mock_graph_exec
assert result2 == mock_graph_exec_2
# ============================================================================
# Tests for Optional Credentials Feature
# ============================================================================
@pytest.mark.asyncio
async def test_validate_node_input_credentials_returns_nodes_to_skip(
mocker: MockerFixture,
):
"""
Test that _validate_node_input_credentials returns nodes_to_skip set
for nodes with credentials_optional=True and missing credentials.
"""
from backend.executor.utils import _validate_node_input_credentials
# Create a mock node with credentials_optional=True
mock_node = mocker.MagicMock()
mock_node.id = "node-with-optional-creds"
mock_node.credentials_optional = True
mock_node.input_default = {} # No credentials configured
# Create a mock block with credentials field
mock_block = mocker.MagicMock()
mock_credentials_field_type = mocker.MagicMock()
mock_block.input_schema.get_credentials_fields.return_value = {
"credentials": mock_credentials_field_type
}
mock_node.block = mock_block
# Create mock graph
mock_graph = mocker.MagicMock()
mock_graph.nodes = [mock_node]
# Call the function
errors, nodes_to_skip = await _validate_node_input_credentials(
graph=mock_graph,
user_id="test-user-id",
nodes_input_masks=None,
)
# Node should be in nodes_to_skip, not in errors
assert mock_node.id in nodes_to_skip
assert mock_node.id not in errors
@pytest.mark.asyncio
async def test_validate_node_input_credentials_required_missing_creds_error(
mocker: MockerFixture,
):
"""
Test that _validate_node_input_credentials returns errors
for nodes with credentials_optional=False and missing credentials.
"""
from backend.executor.utils import _validate_node_input_credentials
# Create a mock node with credentials_optional=False (required)
mock_node = mocker.MagicMock()
mock_node.id = "node-with-required-creds"
mock_node.credentials_optional = False
mock_node.input_default = {} # No credentials configured
# Create a mock block with credentials field
mock_block = mocker.MagicMock()
mock_credentials_field_type = mocker.MagicMock()
mock_block.input_schema.get_credentials_fields.return_value = {
"credentials": mock_credentials_field_type
}
mock_node.block = mock_block
# Create mock graph
mock_graph = mocker.MagicMock()
mock_graph.nodes = [mock_node]
# Call the function
errors, nodes_to_skip = await _validate_node_input_credentials(
graph=mock_graph,
user_id="test-user-id",
nodes_input_masks=None,
)
# Node should be in errors, not in nodes_to_skip
assert mock_node.id in errors
assert "credentials" in errors[mock_node.id]
assert "required" in errors[mock_node.id]["credentials"].lower()
assert mock_node.id not in nodes_to_skip
@pytest.mark.asyncio
async def test_validate_graph_with_credentials_returns_nodes_to_skip(
mocker: MockerFixture,
):
"""
Test that validate_graph_with_credentials returns nodes_to_skip set
from _validate_node_input_credentials.
"""
from backend.executor.utils import validate_graph_with_credentials
# Mock _validate_node_input_credentials to return specific values
mock_validate = mocker.patch(
"backend.executor.utils._validate_node_input_credentials"
)
expected_errors = {"node1": {"field": "error"}}
expected_nodes_to_skip = {"node2", "node3"}
mock_validate.return_value = (expected_errors, expected_nodes_to_skip)
# Mock GraphModel with validate_graph_get_errors method
mock_graph = mocker.MagicMock()
mock_graph.validate_graph_get_errors.return_value = {}
# Call the function
errors, nodes_to_skip = await validate_graph_with_credentials(
graph=mock_graph,
user_id="test-user-id",
nodes_input_masks=None,
)
# Verify nodes_to_skip is passed through
assert nodes_to_skip == expected_nodes_to_skip
assert "node1" in errors
@pytest.mark.asyncio
async def test_add_graph_execution_with_nodes_to_skip(mocker: MockerFixture):
"""
Test that add_graph_execution properly passes nodes_to_skip
to the graph execution entry.
"""
from backend.data.execution import GraphExecutionWithNodes
from backend.executor.utils import add_graph_execution
# Mock data
graph_id = "test-graph-id"
user_id = "test-user-id"
inputs = {"test_input": "test_value"}
graph_version = 1
# Mock the graph object
mock_graph = mocker.MagicMock()
mock_graph.version = graph_version
# Starting nodes and masks
starting_nodes_input = [("node1", {"input1": "value1"})]
compiled_nodes_input_masks = {}
nodes_to_skip = {"skipped-node-1", "skipped-node-2"}
# Mock the graph execution object
mock_graph_exec = mocker.MagicMock(spec=GraphExecutionWithNodes)
mock_graph_exec.id = "execution-id-123"
mock_graph_exec.node_executions = []
# Track what's passed to to_graph_execution_entry
captured_kwargs = {}
def capture_to_entry(**kwargs):
captured_kwargs.update(kwargs)
return mocker.MagicMock()
mock_graph_exec.to_graph_execution_entry.side_effect = capture_to_entry
# Setup mocks
mock_validate = mocker.patch(
"backend.executor.utils.validate_and_construct_node_execution_input"
)
mock_edb = mocker.patch("backend.executor.utils.execution_db")
mock_prisma = mocker.patch("backend.executor.utils.prisma")
mock_udb = mocker.patch("backend.executor.utils.user_db")
mock_gdb = mocker.patch("backend.executor.utils.graph_db")
mock_get_queue = mocker.patch("backend.executor.utils.get_async_execution_queue")
mock_get_event_bus = mocker.patch(
"backend.executor.utils.get_async_execution_event_bus"
)
# Setup returns - include nodes_to_skip in the tuple
mock_validate.return_value = (
mock_graph,
starting_nodes_input,
compiled_nodes_input_masks,
nodes_to_skip, # This should be passed through
)
mock_prisma.is_connected.return_value = True
mock_edb.create_graph_execution = mocker.AsyncMock(return_value=mock_graph_exec)
mock_edb.update_graph_execution_stats = mocker.AsyncMock(
return_value=mock_graph_exec
)
mock_edb.update_node_execution_status_batch = mocker.AsyncMock()
mock_user = mocker.MagicMock()
mock_user.timezone = "UTC"
mock_settings = mocker.MagicMock()
mock_settings.human_in_the_loop_safe_mode = True
mock_udb.get_user_by_id = mocker.AsyncMock(return_value=mock_user)
mock_gdb.get_graph_settings = mocker.AsyncMock(return_value=mock_settings)
mock_get_queue.return_value = mocker.AsyncMock()
mock_get_event_bus.return_value = mocker.MagicMock(publish=mocker.AsyncMock())
# Call the function
await add_graph_execution(
graph_id=graph_id,
user_id=user_id,
inputs=inputs,
graph_version=graph_version,
)
# Verify nodes_to_skip was passed to to_graph_execution_entry
assert "nodes_to_skip" in captured_kwargs
assert captured_kwargs["nodes_to_skip"] == nodes_to_skip

View File

@@ -8,7 +8,6 @@ from .discord import DiscordOAuthHandler
from .github import GitHubOAuthHandler
from .google import GoogleOAuthHandler
from .notion import NotionOAuthHandler
from .reddit import RedditOAuthHandler
from .twitter import TwitterOAuthHandler
if TYPE_CHECKING:
@@ -21,7 +20,6 @@ _ORIGINAL_HANDLERS = [
GitHubOAuthHandler,
GoogleOAuthHandler,
NotionOAuthHandler,
RedditOAuthHandler,
TwitterOAuthHandler,
TodoistOAuthHandler,
]

View File

@@ -1,208 +0,0 @@
import time
import urllib.parse
from typing import ClassVar, Optional
from pydantic import SecretStr
from backend.data.model import OAuth2Credentials
from backend.integrations.oauth.base import BaseOAuthHandler
from backend.integrations.providers import ProviderName
from backend.util.request import Requests
from backend.util.settings import Settings
settings = Settings()
class RedditOAuthHandler(BaseOAuthHandler):
"""
Reddit OAuth 2.0 handler.
Based on the documentation at:
- https://github.com/reddit-archive/reddit/wiki/OAuth2
Notes:
- Reddit requires `duration=permanent` to get refresh tokens
- Access tokens expire after 1 hour (3600 seconds)
- Reddit requires HTTP Basic Auth for token requests
- Reddit requires a unique User-Agent header
"""
PROVIDER_NAME = ProviderName.REDDIT
DEFAULT_SCOPES: ClassVar[list[str]] = [
"identity", # Get username, verify auth
"read", # Access posts and comments
"submit", # Submit new posts and comments
"edit", # Edit own posts and comments
"history", # Access user's post history
"privatemessages", # Access inbox and send private messages
"flair", # Access and set flair on posts/subreddits
]
AUTHORIZE_URL = "https://www.reddit.com/api/v1/authorize"
TOKEN_URL = "https://www.reddit.com/api/v1/access_token"
USERNAME_URL = "https://oauth.reddit.com/api/v1/me"
REVOKE_URL = "https://www.reddit.com/api/v1/revoke_token"
def __init__(self, client_id: str, client_secret: str, redirect_uri: str):
self.client_id = client_id
self.client_secret = client_secret
self.redirect_uri = redirect_uri
def get_login_url(
self, scopes: list[str], state: str, code_challenge: Optional[str]
) -> str:
"""Generate Reddit OAuth 2.0 authorization URL"""
scopes = self.handle_default_scopes(scopes)
params = {
"response_type": "code",
"client_id": self.client_id,
"redirect_uri": self.redirect_uri,
"scope": " ".join(scopes),
"state": state,
"duration": "permanent", # Required for refresh tokens
}
return f"{self.AUTHORIZE_URL}?{urllib.parse.urlencode(params)}"
async def exchange_code_for_tokens(
self, code: str, scopes: list[str], code_verifier: Optional[str]
) -> OAuth2Credentials:
"""Exchange authorization code for access tokens"""
scopes = self.handle_default_scopes(scopes)
headers = {
"Content-Type": "application/x-www-form-urlencoded",
"User-Agent": settings.config.reddit_user_agent,
}
data = {
"grant_type": "authorization_code",
"code": code,
"redirect_uri": self.redirect_uri,
}
# Reddit requires HTTP Basic Auth for token requests
auth = (self.client_id, self.client_secret)
response = await Requests().post(
self.TOKEN_URL, headers=headers, data=data, auth=auth
)
if not response.ok:
error_text = response.text()
raise ValueError(
f"Reddit token exchange failed: {response.status} - {error_text}"
)
tokens = response.json()
if "error" in tokens:
raise ValueError(f"Reddit OAuth error: {tokens.get('error')}")
username = await self._get_username(tokens["access_token"])
return OAuth2Credentials(
provider=self.PROVIDER_NAME,
title=None,
username=username,
access_token=tokens["access_token"],
refresh_token=tokens.get("refresh_token"),
access_token_expires_at=int(time.time()) + tokens.get("expires_in", 3600),
refresh_token_expires_at=None, # Reddit refresh tokens don't expire
scopes=scopes,
)
async def _get_username(self, access_token: str) -> str:
"""Get the username from the access token"""
headers = {
"Authorization": f"Bearer {access_token}",
"User-Agent": settings.config.reddit_user_agent,
}
response = await Requests().get(self.USERNAME_URL, headers=headers)
if not response.ok:
raise ValueError(f"Failed to get Reddit username: {response.status}")
data = response.json()
return data.get("name", "unknown")
async def _refresh_tokens(
self, credentials: OAuth2Credentials
) -> OAuth2Credentials:
"""Refresh access tokens using refresh token"""
if not credentials.refresh_token:
raise ValueError("No refresh token available")
headers = {
"Content-Type": "application/x-www-form-urlencoded",
"User-Agent": settings.config.reddit_user_agent,
}
data = {
"grant_type": "refresh_token",
"refresh_token": credentials.refresh_token.get_secret_value(),
}
auth = (self.client_id, self.client_secret)
response = await Requests().post(
self.TOKEN_URL, headers=headers, data=data, auth=auth
)
if not response.ok:
error_text = response.text()
raise ValueError(
f"Reddit token refresh failed: {response.status} - {error_text}"
)
tokens = response.json()
if "error" in tokens:
raise ValueError(f"Reddit OAuth error: {tokens.get('error')}")
username = await self._get_username(tokens["access_token"])
# Reddit may or may not return a new refresh token
new_refresh_token = tokens.get("refresh_token")
if new_refresh_token:
refresh_token: SecretStr | None = SecretStr(new_refresh_token)
elif credentials.refresh_token:
# Keep the existing refresh token
refresh_token = credentials.refresh_token
else:
refresh_token = None
return OAuth2Credentials(
id=credentials.id,
provider=self.PROVIDER_NAME,
title=credentials.title,
username=username,
access_token=tokens["access_token"],
refresh_token=refresh_token,
access_token_expires_at=int(time.time()) + tokens.get("expires_in", 3600),
refresh_token_expires_at=None,
scopes=credentials.scopes,
)
async def revoke_tokens(self, credentials: OAuth2Credentials) -> bool:
"""Revoke the access token"""
headers = {
"Content-Type": "application/x-www-form-urlencoded",
"User-Agent": settings.config.reddit_user_agent,
}
data = {
"token": credentials.access_token.get_secret_value(),
"token_type_hint": "access_token",
}
auth = (self.client_id, self.client_secret)
response = await Requests().post(
self.REVOKE_URL, headers=headers, data=data, auth=auth
)
# Reddit returns 204 No Content on successful revocation
return response.ok

View File

@@ -264,7 +264,7 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
)
reddit_user_agent: str = Field(
default="web:AutoGPT:v0.6.0 (by /u/autogpt)",
default="AutoGPT:1.0 (by /u/autogpt)",
description="The user agent for the Reddit API",
)

View File

@@ -1,227 +0,0 @@
#!/usr/bin/env python3
"""
Generate a lightweight stub for prisma/types.py that collapses all exported
symbols to Any. This prevents Pyright from spending time/budget on Prisma's
query DSL types while keeping runtime behavior unchanged.
Usage:
poetry run gen-prisma-stub
This script automatically finds the prisma package location and generates
the types.pyi stub file in the same directory as types.py.
"""
from __future__ import annotations
import ast
import importlib.util
import sys
from pathlib import Path
from typing import Iterable, Set
def _iter_assigned_names(target: ast.expr) -> Iterable[str]:
"""Extract names from assignment targets (handles tuple unpacking)."""
if isinstance(target, ast.Name):
yield target.id
elif isinstance(target, (ast.Tuple, ast.List)):
for elt in target.elts:
yield from _iter_assigned_names(elt)
def _is_private(name: str) -> bool:
"""Check if a name is private (starts with _ but not __)."""
return name.startswith("_") and not name.startswith("__")
def _is_safe_type_alias(node: ast.Assign) -> bool:
"""Check if an assignment is a safe type alias that shouldn't be stubbed.
Safe types are:
- Literal types (don't cause type budget issues)
- Simple type references (SortMode, SortOrder, etc.)
- TypeVar definitions
"""
if not node.value:
return False
# Check if it's a Subscript (like Literal[...], Union[...], TypeVar[...])
if isinstance(node.value, ast.Subscript):
# Get the base type name
if isinstance(node.value.value, ast.Name):
base_name = node.value.value.id
# Literal types are safe
if base_name == "Literal":
return True
# TypeVar is safe
if base_name == "TypeVar":
return True
elif isinstance(node.value.value, ast.Attribute):
# Handle typing_extensions.Literal etc.
if node.value.value.attr == "Literal":
return True
# Check if it's a simple Name reference (like SortMode = _types.SortMode)
if isinstance(node.value, ast.Attribute):
return True
# Check if it's a Call (like TypeVar(...))
if isinstance(node.value, ast.Call):
if isinstance(node.value.func, ast.Name):
if node.value.func.id == "TypeVar":
return True
return False
def collect_top_level_symbols(
tree: ast.Module, source_lines: list[str]
) -> tuple[Set[str], Set[str], list[str], Set[str]]:
"""Collect all top-level symbols from an AST module.
Returns:
Tuple of (class_names, function_names, safe_variable_sources, unsafe_variable_names)
safe_variable_sources contains the actual source code lines for safe variables
"""
classes: Set[str] = set()
functions: Set[str] = set()
safe_variable_sources: list[str] = []
unsafe_variables: Set[str] = set()
for node in tree.body:
if isinstance(node, ast.ClassDef):
if not _is_private(node.name):
classes.add(node.name)
elif isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
if not _is_private(node.name):
functions.add(node.name)
elif isinstance(node, ast.Assign):
is_safe = _is_safe_type_alias(node)
names = []
for t in node.targets:
for n in _iter_assigned_names(t):
if not _is_private(n):
names.append(n)
if names:
if is_safe:
# Extract the source code for this assignment
start_line = node.lineno - 1 # 0-indexed
end_line = node.end_lineno if node.end_lineno else node.lineno
source = "\n".join(source_lines[start_line:end_line])
safe_variable_sources.append(source)
else:
unsafe_variables.update(names)
elif isinstance(node, ast.AnnAssign) and node.target:
# Annotated assignments are always stubbed
for n in _iter_assigned_names(node.target):
if not _is_private(n):
unsafe_variables.add(n)
return classes, functions, safe_variable_sources, unsafe_variables
def find_prisma_types_path() -> Path:
"""Find the prisma types.py file in the installed package."""
spec = importlib.util.find_spec("prisma")
if spec is None or spec.origin is None:
raise RuntimeError("Could not find prisma package. Is it installed?")
prisma_dir = Path(spec.origin).parent
types_path = prisma_dir / "types.py"
if not types_path.exists():
raise RuntimeError(f"prisma/types.py not found at {types_path}")
return types_path
def generate_stub(src_path: Path, stub_path: Path) -> int:
"""Generate the .pyi stub file from the source types.py."""
code = src_path.read_text(encoding="utf-8", errors="ignore")
source_lines = code.splitlines()
tree = ast.parse(code, filename=str(src_path))
classes, functions, safe_variable_sources, unsafe_variables = (
collect_top_level_symbols(tree, source_lines)
)
header = """\
# -*- coding: utf-8 -*-
# Auto-generated stub file - DO NOT EDIT
# Generated by gen_prisma_types_stub.py
#
# This stub intentionally collapses complex Prisma query DSL types to Any.
# Prisma's generated types can explode Pyright's type inference budgets
# on large schemas. We collapse them to Any so the rest of the codebase
# can remain strongly typed while keeping runtime behavior unchanged.
#
# Safe types (Literal, TypeVar, simple references) are preserved from the
# original types.py to maintain proper type checking where possible.
from __future__ import annotations
from typing import Any
from typing_extensions import Literal
# Re-export commonly used typing constructs that may be imported from this module
from typing import TYPE_CHECKING, TypeVar, Generic, Union, Optional, List, Dict
# Base type alias for stubbed Prisma types - allows any dict structure
_PrismaDict = dict[str, Any]
"""
lines = [header]
# Include safe variable definitions (Literal types, TypeVars, etc.)
lines.append("# Safe type definitions preserved from original types.py")
for source in safe_variable_sources:
lines.append(source)
lines.append("")
# Stub all classes and unsafe variables uniformly as dict[str, Any] aliases
# This allows:
# 1. Use in type annotations: x: SomeType
# 2. Constructor calls: SomeType(...)
# 3. Dict literal assignments: x: SomeType = {...}
lines.append(
"# Stubbed types (collapsed to dict[str, Any] to prevent type budget exhaustion)"
)
all_stubbed = sorted(classes | unsafe_variables)
for name in all_stubbed:
lines.append(f"{name} = _PrismaDict")
lines.append("")
# Stub functions
for name in sorted(functions):
lines.append(f"def {name}(*args: Any, **kwargs: Any) -> Any: ...")
lines.append("")
stub_path.write_text("\n".join(lines), encoding="utf-8")
return (
len(classes)
+ len(functions)
+ len(safe_variable_sources)
+ len(unsafe_variables)
)
def main() -> None:
"""Main entry point."""
try:
types_path = find_prisma_types_path()
stub_path = types_path.with_suffix(".pyi")
print(f"Found prisma types.py at: {types_path}")
print(f"Generating stub at: {stub_path}")
num_symbols = generate_stub(types_path, stub_path)
print(f"Generated {stub_path.name} with {num_symbols} Any-typed symbols")
except Exception as e:
print(f"Error: {e}", file=sys.stderr)
sys.exit(1)
if __name__ == "__main__":
main()

View File

@@ -25,9 +25,6 @@ def run(*command: str) -> None:
def lint():
# Generate Prisma types stub before running pyright to prevent type budget exhaustion
run("gen-prisma-stub")
lint_step_args: list[list[str]] = [
["ruff", "check", *TARGET_DIRS, "--exit-zero"],
["ruff", "format", "--diff", "--check", LIBS_DIR],
@@ -52,6 +49,4 @@ def format():
run("ruff", "format", LIBS_DIR)
run("isort", "--profile", "black", BACKEND_DIR)
run("black", BACKEND_DIR)
# Generate Prisma types stub before running pyright to prevent type budget exhaustion
run("gen-prisma-stub")
run("pyright", *TARGET_DIRS)

View File

@@ -1906,22 +1906,6 @@ httpx = {version = ">=0.26,<0.29", extras = ["http2"]}
pydantic = ">=1.10,<3"
pyjwt = ">=2.10.1,<3.0.0"
[[package]]
name = "gravitas-md2gdocs"
version = "0.1.0"
description = "Convert Markdown to Google Docs API requests"
optional = false
python-versions = ">=3.10"
groups = ["main"]
files = [
{file = "gravitas_md2gdocs-0.1.0-py3-none-any.whl", hash = "sha256:0cb0627779fdd65c1604818af4142eea1b25d055060183363de1bae4d9e46508"},
{file = "gravitas_md2gdocs-0.1.0.tar.gz", hash = "sha256:bb3122fe9fa35c528f3f00b785d3f1398d350082d5d03f60f56c895bdcc68033"},
]
[package.extras]
dev = ["google-auth-oauthlib (>=1.0.0)", "pytest (>=7.0.0)", "pytest-cov (>=4.0.0)", "python-dotenv (>=1.0.0)", "ruff (>=0.1.0)"]
google = ["google-api-python-client (>=2.0.0)", "google-auth (>=2.0.0)"]
[[package]]
name = "gravitasml"
version = "0.1.4"
@@ -7295,4 +7279,4 @@ cffi = ["cffi (>=1.11)"]
[metadata]
lock-version = "2.1"
python-versions = ">=3.10,<3.14"
content-hash = "a93ba0cea3b465cb6ec3e3f258b383b09f84ea352ccfdbfa112902cde5653fc6"
content-hash = "13b191b2a1989d3321ff713c66ff6f5f4f3b82d15df4d407e0e5dbf87d7522c4"

View File

@@ -82,7 +82,6 @@ firecrawl-py = "^4.3.6"
exa-py = "^1.14.20"
croniter = "^6.0.0"
stagehand = "^0.5.1"
gravitas-md2gdocs = "^0.1.0"
[tool.poetry.group.dev.dependencies]
aiohappyeyeballs = "^2.6.1"
@@ -117,7 +116,6 @@ lint = "linter:lint"
test = "run_tests:test"
load-store-agents = "test.load_store_agents:run"
export-api-schema = "backend.cli.generate_openapi_json:main"
gen-prisma-stub = "gen_prisma_types_stub:main"
oauth-tool = "backend.cli.oauth_tool:cli"
[tool.isort]
@@ -135,9 +133,6 @@ ignore_patterns = []
[tool.pytest.ini_options]
asyncio_mode = "auto"
asyncio_default_fixture_loop_scope = "session"
# Disable syrupy plugin to avoid conflict with pytest-snapshot
# Both provide --snapshot-update argument causing ArgumentError
addopts = "-p no:syrupy"
filterwarnings = [
"ignore:'audioop' is deprecated:DeprecationWarning:discord.player",
"ignore:invalid escape sequence:DeprecationWarning:tweepy.api",

View File

@@ -2,7 +2,6 @@
"created_at": "2025-09-04T13:37:00",
"credentials_input_schema": {
"properties": {},
"required": [],
"title": "TestGraphCredentialsInputSchema",
"type": "object"
},

View File

@@ -2,7 +2,6 @@
{
"credentials_input_schema": {
"properties": {},
"required": [],
"title": "TestGraphCredentialsInputSchema",
"type": "object"
},

View File

@@ -4,7 +4,6 @@
"id": "test-agent-1",
"graph_id": "test-agent-1",
"graph_version": 1,
"owner_user_id": "3e53486c-cf57-477e-ba2a-cb02dc828e1a",
"image_url": null,
"creator_name": "Test Creator",
"creator_image_url": "",
@@ -42,7 +41,6 @@
"id": "test-agent-2",
"graph_id": "test-agent-2",
"graph_version": 1,
"owner_user_id": "3e53486c-cf57-477e-ba2a-cb02dc828e1a",
"image_url": null,
"creator_name": "Test Creator",
"creator_image_url": "",

View File

@@ -1,7 +1,6 @@
{
"submissions": [
{
"listing_id": "test-listing-id",
"agent_id": "test-agent-id",
"agent_version": 1,
"name": "Test Agent",

View File

@@ -1,113 +0,0 @@
from unittest.mock import Mock
from backend.blocks.google.docs import GoogleDocsFormatTextBlock
def _make_mock_docs_service() -> Mock:
service = Mock()
# Ensure chained call exists: service.documents().batchUpdate(...).execute()
service.documents.return_value.batchUpdate.return_value.execute.return_value = {}
return service
def test_format_text_parses_shorthand_hex_color():
block = GoogleDocsFormatTextBlock()
service = _make_mock_docs_service()
result = block._format_text(
service,
document_id="doc_1",
start_index=1,
end_index=2,
bold=False,
italic=False,
underline=False,
font_size=0,
foreground_color="#FFF",
)
assert result["success"] is True
# Verify request body contains correct rgbColor for white.
_, kwargs = service.documents.return_value.batchUpdate.call_args
requests = kwargs["body"]["requests"]
rgb = requests[0]["updateTextStyle"]["textStyle"]["foregroundColor"]["color"][
"rgbColor"
]
assert rgb == {"red": 1.0, "green": 1.0, "blue": 1.0}
def test_format_text_parses_full_hex_color():
block = GoogleDocsFormatTextBlock()
service = _make_mock_docs_service()
result = block._format_text(
service,
document_id="doc_1",
start_index=1,
end_index=2,
bold=False,
italic=False,
underline=False,
font_size=0,
foreground_color="#FF0000",
)
assert result["success"] is True
_, kwargs = service.documents.return_value.batchUpdate.call_args
requests = kwargs["body"]["requests"]
rgb = requests[0]["updateTextStyle"]["textStyle"]["foregroundColor"]["color"][
"rgbColor"
]
assert rgb == {"red": 1.0, "green": 0.0, "blue": 0.0}
def test_format_text_ignores_invalid_color_when_other_fields_present():
block = GoogleDocsFormatTextBlock()
service = _make_mock_docs_service()
result = block._format_text(
service,
document_id="doc_1",
start_index=1,
end_index=2,
bold=True,
italic=False,
underline=False,
font_size=0,
foreground_color="#GGG",
)
assert result["success"] is True
assert "warning" in result
# Should still apply bold, but should NOT include foregroundColor in textStyle.
_, kwargs = service.documents.return_value.batchUpdate.call_args
requests = kwargs["body"]["requests"]
text_style = requests[0]["updateTextStyle"]["textStyle"]
fields = requests[0]["updateTextStyle"]["fields"]
assert text_style == {"bold": True}
assert fields == "bold"
def test_format_text_invalid_color_only_does_not_call_api():
block = GoogleDocsFormatTextBlock()
service = _make_mock_docs_service()
result = block._format_text(
service,
document_id="doc_1",
start_index=1,
end_index=2,
bold=False,
italic=False,
underline=False,
font_size=0,
foreground_color="#F",
)
assert result["success"] is False
assert "Invalid foreground_color" in result["message"]
service.documents.return_value.batchUpdate.assert_not_called()

View File

@@ -37,18 +37,6 @@ class TestTranscribeYoutubeVideoBlock:
video_id = self.youtube_block.extract_video_id(url)
assert video_id == "dQw4w9WgXcQ"
def test_extract_video_id_shorts_url(self):
"""Test extracting video ID from YouTube Shorts URL."""
url = "https://www.youtube.com/shorts/dtUqwMu3e-g"
video_id = self.youtube_block.extract_video_id(url)
assert video_id == "dtUqwMu3e-g"
def test_extract_video_id_shorts_url_with_params(self):
"""Test extracting video ID from YouTube Shorts URL with query parameters."""
url = "https://www.youtube.com/shorts/dtUqwMu3e-g?feature=share"
video_id = self.youtube_block.extract_video_id(url)
assert video_id == "dtUqwMu3e-g"
@patch("backend.blocks.youtube.YouTubeTranscriptApi")
def test_get_transcript_english_available(self, mock_api_class):
"""Test getting transcript when English is available."""

View File

@@ -1,146 +0,0 @@
/**
* Cloudflare Workers Script for docs.agpt.co → agpt.co/docs migration
*
* Deploy this script to handle all redirects with a single JavaScript file.
* No rule limits, easy to maintain, handles all edge cases.
*/
// URL mapping for special cases that don't follow patterns
const SPECIAL_MAPPINGS = {
// Root page
'/': '/docs/platform',
// Special cases that don't follow standard patterns
'/platform/d_id/': '/docs/integrations/block-integrations/d-id',
'/platform/blocks/blocks/': '/docs/integrations',
'/platform/blocks/decoder_block/': '/docs/integrations/block-integrations/text-decoder',
'/platform/blocks/http': '/docs/integrations/block-integrations/send-web-request',
'/platform/blocks/llm/': '/docs/integrations/block-integrations/ai-and-llm',
'/platform/blocks/time_blocks': '/docs/integrations/block-integrations/time-and-date',
'/platform/blocks/text_to_speech_block': '/docs/integrations/block-integrations/text-to-speech',
'/platform/blocks/ai_shortform_video_block': '/docs/integrations/block-integrations/ai-shortform-video',
'/platform/blocks/replicate_flux_advanced': '/docs/integrations/block-integrations/replicate-flux-advanced',
'/platform/blocks/flux_kontext': '/docs/integrations/block-integrations/flux-kontext',
'/platform/blocks/ai_condition/': '/docs/integrations/block-integrations/ai-condition',
'/platform/blocks/email_block': '/docs/integrations/block-integrations/email',
'/platform/blocks/google_maps': '/docs/integrations/block-integrations/google-maps',
'/platform/blocks/google/gmail': '/docs/integrations/block-integrations/gmail',
'/platform/blocks/github/issues/': '/docs/integrations/block-integrations/github-issues',
'/platform/blocks/github/repo/': '/docs/integrations/block-integrations/github-repo',
'/platform/blocks/github/pull_requests': '/docs/integrations/block-integrations/github-pull-requests',
'/platform/blocks/twitter/twitter': '/docs/integrations/block-integrations/twitter',
'/classic/setup/': '/docs/classic/setup/setting-up-autogpt-classic',
'/code-of-conduct/': '/docs/classic/help-us-improve-autogpt/code-of-conduct',
'/contributing/': '/docs/classic/contributing',
'/contribute/': '/docs/contribute',
'/forge/components/introduction/': '/docs/classic/forge/introduction'
};
/**
* Transform path by replacing underscores with hyphens and removing trailing slashes
*/
function transformPath(path) {
return path.replace(/_/g, '-').replace(/\/$/, '');
}
/**
* Handle docs.agpt.co redirects
*/
function handleDocsRedirect(url) {
const pathname = url.pathname;
// Check special mappings first
if (SPECIAL_MAPPINGS[pathname]) {
return `https://agpt.co${SPECIAL_MAPPINGS[pathname]}`;
}
// Pattern-based redirects
// Platform blocks: /platform/blocks/* → /docs/integrations/block-integrations/*
if (pathname.startsWith('/platform/blocks/')) {
const blockName = pathname.substring('/platform/blocks/'.length);
const transformedName = transformPath(blockName);
return `https://agpt.co/docs/integrations/block-integrations/${transformedName}`;
}
// Platform contributing: /platform/contributing/* → /docs/platform/contributing/*
if (pathname.startsWith('/platform/contributing/')) {
const subPath = pathname.substring('/platform/contributing/'.length);
return `https://agpt.co/docs/platform/contributing/${subPath}`;
}
// Platform general: /platform/* → /docs/platform/* (with underscore→hyphen)
if (pathname.startsWith('/platform/')) {
const subPath = pathname.substring('/platform/'.length);
const transformedPath = transformPath(subPath);
return `https://agpt.co/docs/platform/${transformedPath}`;
}
// Forge components: /forge/components/* → /docs/classic/forge/introduction/*
if (pathname.startsWith('/forge/components/')) {
const subPath = pathname.substring('/forge/components/'.length);
return `https://agpt.co/docs/classic/forge/introduction/${subPath}`;
}
// Forge general: /forge/* → /docs/classic/forge/*
if (pathname.startsWith('/forge/')) {
const subPath = pathname.substring('/forge/'.length);
return `https://agpt.co/docs/classic/forge/${subPath}`;
}
// Classic: /classic/* → /docs/classic/*
if (pathname.startsWith('/classic/')) {
const subPath = pathname.substring('/classic/'.length);
return `https://agpt.co/docs/classic/${subPath}`;
}
// Default fallback
return 'https://agpt.co/docs/';
}
/**
* Main Worker function
*/
export default {
async fetch(request, env, ctx) {
const url = new URL(request.url);
// Only handle docs.agpt.co requests
if (url.hostname === 'docs.agpt.co') {
const redirectUrl = handleDocsRedirect(url);
return new Response(null, {
status: 301,
headers: {
'Location': redirectUrl,
'Cache-Control': 'max-age=300' // Cache redirects for 5 minutes
}
});
}
// For non-docs requests, pass through or return 404
return new Response('Not Found', { status: 404 });
}
};
// Test function for local development
export function testRedirects() {
const testCases = [
'https://docs.agpt.co/',
'https://docs.agpt.co/platform/getting-started/',
'https://docs.agpt.co/platform/advanced_setup/',
'https://docs.agpt.co/platform/blocks/basic/',
'https://docs.agpt.co/platform/blocks/ai_condition/',
'https://docs.agpt.co/classic/setup/',
'https://docs.agpt.co/forge/components/agents/',
'https://docs.agpt.co/contributing/',
'https://docs.agpt.co/unknown-page'
];
console.log('Testing redirects:');
testCases.forEach(testUrl => {
const url = new URL(testUrl);
const result = handleDocsRedirect(url);
console.log(`${testUrl}${result}`);
});
}

View File

@@ -37,7 +37,7 @@ services:
context: ../
dockerfile: autogpt_platform/backend/Dockerfile
target: migrate
command: ["sh", "-c", "poetry run prisma generate && poetry run gen-prisma-stub && poetry run prisma migrate deploy"]
command: ["sh", "-c", "poetry run prisma generate && poetry run prisma migrate deploy"]
develop:
watch:
- path: ./

View File

@@ -46,15 +46,14 @@
"@radix-ui/react-scroll-area": "1.2.10",
"@radix-ui/react-select": "2.2.6",
"@radix-ui/react-separator": "1.1.7",
"@radix-ui/react-slider": "1.3.6",
"@radix-ui/react-slot": "1.2.3",
"@radix-ui/react-switch": "1.2.6",
"@radix-ui/react-tabs": "1.1.13",
"@radix-ui/react-toast": "1.2.15",
"@radix-ui/react-tooltip": "1.2.8",
"@rjsf/core": "6.1.2",
"@rjsf/utils": "6.1.2",
"@rjsf/validator-ajv8": "6.1.2",
"@rjsf/core": "5.24.13",
"@rjsf/utils": "5.24.13",
"@rjsf/validator-ajv8": "5.24.13",
"@sentry/nextjs": "10.27.0",
"@supabase/ssr": "0.7.0",
"@supabase/supabase-js": "2.78.0",
@@ -70,7 +69,6 @@
"cmdk": "1.1.1",
"cookie": "1.0.2",
"date-fns": "4.1.0",
"dexie": "4.2.1",
"dotenv": "17.2.3",
"elliptic": "6.6.1",
"embla-carousel-react": "8.6.0",
@@ -92,6 +90,7 @@
"react-currency-input-field": "4.0.3",
"react-day-picker": "9.11.1",
"react-dom": "18.3.1",
"react-drag-drop-files": "2.4.0",
"react-hook-form": "7.66.0",
"react-icons": "5.5.0",
"react-markdown": "9.0.3",

File diff suppressed because it is too large Load Diff

Binary file not shown.

Before

Width:  |  Height:  |  Size: 2.6 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 16 KiB

View File

@@ -1,4 +1,4 @@
import { OAuthPopupResultMessage } from "./types";
import { OAuthPopupResultMessage } from "@/components/renderers/input-renderer/fields/CredentialField/models/OAuthCredentialModal/useOAuthCredentialModal";
import { NextResponse } from "next/server";
// This route is intended to be used as the callback for integration OAuth flows,

View File

@@ -1,11 +0,0 @@
export type OAuthPopupResultMessage = { message_type: "oauth_popup_result" } & (
| {
success: true;
code: string;
state: string;
}
| {
success: false;
message: string;
}
);

View File

@@ -16,7 +16,6 @@ import {
SheetTitle,
SheetTrigger,
} from "@/components/__legacy__/ui/sheet";
import { Button } from "@/components/atoms/Button/Button";
import {
Tooltip,
TooltipContent,
@@ -26,6 +25,7 @@ import {
import { BookOpenIcon } from "@phosphor-icons/react";
import { useMemo } from "react";
import { useShallow } from "zustand/react/shallow";
import { BuilderActionButton } from "../BuilderActionButton";
export const AgentOutputs = ({ flowID }: { flowID: string | null }) => {
const hasOutputs = useGraphStore(useShallow((state) => state.hasOutputs));
@@ -76,13 +76,9 @@ export const AgentOutputs = ({ flowID }: { flowID: string | null }) => {
<Tooltip>
<TooltipTrigger asChild>
<SheetTrigger asChild>
<Button
variant="outline"
size="icon"
disabled={!flowID || !hasOutputs()}
>
<BookOpenIcon className="size-4" />
</Button>
<BuilderActionButton disabled={!flowID || !hasOutputs()}>
<BookOpenIcon className="size-6" />
</BuilderActionButton>
</SheetTrigger>
</TooltipTrigger>
<TooltipContent>

View File

@@ -0,0 +1,37 @@
import { Button } from "@/components/atoms/Button/Button";
import { ButtonProps } from "@/components/atoms/Button/helpers";
import { cn } from "@/lib/utils";
import { CircleNotchIcon } from "@phosphor-icons/react";
export const BuilderActionButton = ({
children,
className,
isLoading,
...props
}: ButtonProps & { isLoading?: boolean }) => {
return (
<Button
variant="icon"
size={"small"}
className={cn(
"relative h-12 w-12 min-w-0 text-lg",
"bg-gradient-to-br from-zinc-50 to-zinc-200",
"border border-zinc-200",
"shadow-[inset_0_3px_0_0_rgba(255,255,255,0.5),0_2px_4px_0_rgba(0,0,0,0.2)]",
"dark:shadow-[inset_0_1px_0_0_rgba(255,255,255,0.1),0_2px_4px_0_rgba(0,0,0,0.4)]",
"hover:shadow-[inset_0_1px_0_0_rgba(255,255,255,0.5),0_1px_2px_0_rgba(0,0,0,0.2)]",
"active:shadow-[inset_0_2px_4px_0_rgba(0,0,0,0.2)]",
"transition-all duration-150",
"disabled:cursor-not-allowed disabled:opacity-50",
className,
)}
{...props}
>
{!isLoading ? (
children
) : (
<CircleNotchIcon className="size-6 animate-spin" />
)}
</Button>
);
};

View File

@@ -1,12 +1,12 @@
import { Button } from "@/components/atoms/Button/Button";
import { ShareIcon } from "@phosphor-icons/react";
import { BuilderActionButton } from "../BuilderActionButton";
import {
Tooltip,
TooltipContent,
TooltipTrigger,
} from "@/components/atoms/Tooltip/BaseTooltip";
import { PublishAgentModal } from "@/components/contextual/PublishAgentModal/PublishAgentModal";
import { ShareIcon } from "@phosphor-icons/react";
import { usePublishToMarketplace } from "./usePublishToMarketplace";
import { PublishAgentModal } from "@/components/contextual/PublishAgentModal/PublishAgentModal";
export const PublishToMarketplace = ({ flowID }: { flowID: string | null }) => {
const { handlePublishToMarketplace, publishState, handleStateChange } =
@@ -16,14 +16,12 @@ export const PublishToMarketplace = ({ flowID }: { flowID: string | null }) => {
<>
<Tooltip>
<TooltipTrigger asChild>
<Button
variant="outline"
size="icon"
<BuilderActionButton
onClick={handlePublishToMarketplace}
disabled={!flowID}
>
<ShareIcon className="size-4" />
</Button>
<ShareIcon className="size-6 drop-shadow-sm" />
</BuilderActionButton>
</TooltipTrigger>
<TooltipContent>Publish to Marketplace</TooltipContent>
</Tooltip>
@@ -32,7 +30,6 @@ export const PublishToMarketplace = ({ flowID }: { flowID: string | null }) => {
targetState={publishState}
onStateChange={handleStateChange}
preSelectedAgentId={flowID || undefined}
showTrigger={false}
/>
</>
);

View File

@@ -1,14 +1,15 @@
import { useRunGraph } from "./useRunGraph";
import { useGraphStore } from "@/app/(platform)/build/stores/graphStore";
import { Button } from "@/components/atoms/Button/Button";
import { useShallow } from "zustand/react/shallow";
import { PlayIcon, StopIcon } from "@phosphor-icons/react";
import { cn } from "@/lib/utils";
import { RunInputDialog } from "../RunInputDialog/RunInputDialog";
import {
Tooltip,
TooltipContent,
TooltipTrigger,
} from "@/components/atoms/Tooltip/BaseTooltip";
import { PlayIcon, StopIcon } from "@phosphor-icons/react";
import { useShallow } from "zustand/react/shallow";
import { RunInputDialog } from "../RunInputDialog/RunInputDialog";
import { useRunGraph } from "./useRunGraph";
import { BuilderActionButton } from "../BuilderActionButton";
export const RunGraph = ({ flowID }: { flowID: string | null }) => {
const {
@@ -28,19 +29,21 @@ export const RunGraph = ({ flowID }: { flowID: string | null }) => {
<>
<Tooltip>
<TooltipTrigger asChild>
<Button
size="icon"
variant={isGraphRunning ? "destructive" : "primary"}
<BuilderActionButton
className={cn(
isGraphRunning &&
"border-red-500 bg-gradient-to-br from-red-400 to-red-500 shadow-[inset_0_2px_0_0_rgba(255,255,255,0.5),0_2px_4px_0_rgba(0,0,0,0.2)]",
)}
onClick={isGraphRunning ? handleStopGraph : handleRunGraph}
disabled={!flowID || isExecutingGraph || isTerminatingGraph}
loading={isExecutingGraph || isTerminatingGraph || isSaving}
isLoading={isExecutingGraph || isTerminatingGraph || isSaving}
>
{!isGraphRunning ? (
<PlayIcon className="size-4" />
<PlayIcon className="size-6 drop-shadow-sm" />
) : (
<StopIcon className="size-4" />
<StopIcon className="size-6 drop-shadow-sm" />
)}
</Button>
</BuilderActionButton>
</TooltipTrigger>
<TooltipContent>
{isGraphRunning ? "Stop agent" : "Run agent"}

View File

@@ -5,7 +5,7 @@ import { useGraphStore } from "@/app/(platform)/build/stores/graphStore";
import { Button } from "@/components/atoms/Button/Button";
import { ClockIcon, PlayIcon } from "@phosphor-icons/react";
import { Text } from "@/components/atoms/Text/Text";
import { FormRenderer } from "@/components/renderers/InputRenderer/FormRenderer";
import { FormRenderer } from "@/components/renderers/input-renderer/FormRenderer";
import { useRunInputDialog } from "./useRunInputDialog";
import { CronSchedulerDialog } from "../CronSchedulerDialog/CronSchedulerDialog";
@@ -66,7 +66,6 @@ export const RunInputDialog = ({
formContext={{
showHandles: false,
size: "large",
showOptionalToggle: false,
}}
/>
</div>

View File

@@ -8,7 +8,7 @@ import {
import { parseAsInteger, parseAsString, useQueryStates } from "nuqs";
import { useMemo, useState } from "react";
import { uiSchema } from "../../../FlowEditor/nodes/uiSchema";
import { isCredentialFieldSchema } from "@/components/renderers/InputRenderer/custom/CredentialField/helpers";
import { isCredentialFieldSchema } from "@/components/renderers/input-renderer/fields/CredentialField/helpers";
export const useRunInputDialog = ({
setIsOpen,
@@ -66,7 +66,7 @@ export const useRunInputDialog = ({
if (isCredentialFieldSchema(fieldSchema)) {
dynamicUiSchema[fieldName] = {
...dynamicUiSchema[fieldName],
"ui:field": "custom/credential_field",
"ui:field": "credentials",
};
}
});
@@ -76,18 +76,12 @@ export const useRunInputDialog = ({
}, [credentialsSchema]);
const handleManualRun = async () => {
// Filter out incomplete credentials (those without a valid id)
// RJSF auto-populates const values (provider, type) but not id field
const validCredentials = Object.fromEntries(
Object.entries(credentialValues).filter(([_, cred]) => cred && cred.id),
);
await executeGraph({
graphId: flowID ?? "",
graphVersion: flowVersion || null,
data: {
inputs: inputValues,
credentials_inputs: validCredentials,
credentials_inputs: credentialValues,
source: "builder",
},
});

View File

@@ -1,14 +1,14 @@
import { Button } from "@/components/atoms/Button/Button";
import { ClockIcon } from "@phosphor-icons/react";
import { RunInputDialog } from "../RunInputDialog/RunInputDialog";
import { useScheduleGraph } from "./useScheduleGraph";
import {
Tooltip,
TooltipContent,
TooltipProvider,
TooltipTrigger,
} from "@/components/atoms/Tooltip/BaseTooltip";
import { ClockIcon } from "@phosphor-icons/react";
import { CronSchedulerDialog } from "../CronSchedulerDialog/CronSchedulerDialog";
import { RunInputDialog } from "../RunInputDialog/RunInputDialog";
import { useScheduleGraph } from "./useScheduleGraph";
import { BuilderActionButton } from "../BuilderActionButton";
export const ScheduleGraph = ({ flowID }: { flowID: string | null }) => {
const {
@@ -23,14 +23,12 @@ export const ScheduleGraph = ({ flowID }: { flowID: string | null }) => {
<TooltipProvider>
<Tooltip>
<TooltipTrigger asChild>
<Button
variant="outline"
size="icon"
<BuilderActionButton
onClick={handleScheduleGraph}
disabled={!flowID}
>
<ClockIcon className="size-4" />
</Button>
<ClockIcon className="size-6" />
</BuilderActionButton>
</TooltipTrigger>
<TooltipContent>
<p>Schedule Graph</p>

View File

@@ -1,7 +1,7 @@
import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag";
import { usePathname, useRouter, useSearchParams } from "next/navigation";
import { useEffect, useMemo } from "react";
import { BuilderView } from "./components/BuilderViewTabs/BuilderViewTabs";
import { BuilderView } from "./BuilderViewTabs";
export function useBuilderView() {
const isNewFlowEditorEnabled = useGetFlag(Flag.NEW_FLOW_EDITOR);

View File

@@ -1,160 +0,0 @@
"use client";
import { Button } from "@/components/atoms/Button/Button";
import { ClockCounterClockwiseIcon, XIcon } from "@phosphor-icons/react";
import { cn } from "@/lib/utils";
import { formatTimeAgo } from "@/lib/utils/time";
import {
Tooltip,
TooltipContent,
TooltipTrigger,
} from "@/components/atoms/Tooltip/BaseTooltip";
import { useDraftRecoveryPopup } from "./useDraftRecoveryPopup";
import { Text } from "@/components/atoms/Text/Text";
import { AnimatePresence, motion } from "framer-motion";
import { DraftDiff } from "@/lib/dexie/draft-utils";
interface DraftRecoveryPopupProps {
isInitialLoadComplete: boolean;
}
function formatDiffSummary(diff: DraftDiff | null): string {
if (!diff) return "";
const parts: string[] = [];
// Node changes
const nodeChanges: string[] = [];
if (diff.nodes.added > 0) nodeChanges.push(`+${diff.nodes.added}`);
if (diff.nodes.removed > 0) nodeChanges.push(`-${diff.nodes.removed}`);
if (diff.nodes.modified > 0) nodeChanges.push(`~${diff.nodes.modified}`);
if (nodeChanges.length > 0) {
parts.push(
`${nodeChanges.join("/")} block${diff.nodes.added + diff.nodes.removed + diff.nodes.modified !== 1 ? "s" : ""}`,
);
}
// Edge changes
const edgeChanges: string[] = [];
if (diff.edges.added > 0) edgeChanges.push(`+${diff.edges.added}`);
if (diff.edges.removed > 0) edgeChanges.push(`-${diff.edges.removed}`);
if (diff.edges.modified > 0) edgeChanges.push(`~${diff.edges.modified}`);
if (edgeChanges.length > 0) {
parts.push(
`${edgeChanges.join("/")} connection${diff.edges.added + diff.edges.removed + diff.edges.modified !== 1 ? "s" : ""}`,
);
}
return parts.join(", ");
}
export function DraftRecoveryPopup({
isInitialLoadComplete,
}: DraftRecoveryPopupProps) {
const {
isOpen,
popupRef,
nodeCount,
edgeCount,
diff,
savedAt,
onLoad,
onDiscard,
} = useDraftRecoveryPopup(isInitialLoadComplete);
const diffSummary = formatDiffSummary(diff);
return (
<AnimatePresence>
{isOpen && (
<motion.div
ref={popupRef}
className={cn("absolute left-1/2 top-4 z-50")}
initial={{
opacity: 0,
x: "-50%",
y: "-150%",
scale: 0.5,
filter: "blur(20px)",
}}
animate={{
opacity: 1,
x: "-50%",
y: "0%",
scale: 1,
filter: "blur(0px)",
}}
exit={{
opacity: 0,
y: "-150%",
scale: 0.5,
filter: "blur(20px)",
transition: { duration: 0.4, type: "spring", bounce: 0.2 },
}}
transition={{ duration: 0.2, type: "spring", bounce: 0.2 }}
>
<div
className={cn(
"flex items-center gap-3 rounded-xlarge border border-amber-200 bg-amber-50 px-4 py-3 shadow-lg",
)}
>
<div className="flex items-center gap-2 text-amber-700 dark:text-amber-300">
<ClockCounterClockwiseIcon className="h-5 w-5" weight="fill" />
</div>
<div className="flex flex-col">
<Text
variant="small-medium"
className="text-amber-900 dark:text-amber-100"
>
Unsaved changes found
</Text>
<Text
variant="small"
className="text-amber-700 dark:text-amber-400"
>
{diffSummary ||
`${nodeCount} block${nodeCount !== 1 ? "s" : ""}, ${edgeCount} connection${edgeCount !== 1 ? "s" : ""}`}{" "}
{formatTimeAgo(new Date(savedAt).toISOString())}
</Text>
</div>
<div className="ml-2 flex items-center gap-2">
<Tooltip delayDuration={10}>
<TooltipTrigger asChild>
<Button
variant="primary"
size="small"
onClick={onLoad}
className="aspect-square min-w-0 p-1.5"
>
<ClockCounterClockwiseIcon size={20} weight="fill" />
<span className="sr-only">Restore changes</span>
</Button>
</TooltipTrigger>
<TooltipContent>Restore changes</TooltipContent>
</Tooltip>
<Tooltip delayDuration={10}>
<TooltipTrigger asChild>
<Button
variant="destructive"
size="icon"
onClick={onDiscard}
aria-label="Discard changes"
className="aspect-square min-w-0 p-1.5"
>
<XIcon size={20} />
<span className="sr-only">Discard changes</span>
</Button>
</TooltipTrigger>
<TooltipContent>Discard changes</TooltipContent>
</Tooltip>
</div>
</div>
</motion.div>
)}
</AnimatePresence>
);
}

View File

@@ -1,63 +0,0 @@
import { useEffect, useRef } from "react";
import { useDraftManager } from "../FlowEditor/Flow/useDraftManager";
export const useDraftRecoveryPopup = (isInitialLoadComplete: boolean) => {
const popupRef = useRef<HTMLDivElement>(null);
const {
isRecoveryOpen: isOpen,
savedAt,
nodeCount,
edgeCount,
diff,
loadDraft: onLoad,
discardDraft: onDiscard,
} = useDraftManager(isInitialLoadComplete);
useEffect(() => {
if (!isOpen) return;
const handleClickOutside = (event: MouseEvent) => {
if (
popupRef.current &&
!popupRef.current.contains(event.target as Node)
) {
onDiscard();
}
};
const timeoutId = setTimeout(() => {
document.addEventListener("mousedown", handleClickOutside);
}, 100);
return () => {
clearTimeout(timeoutId);
document.removeEventListener("mousedown", handleClickOutside);
};
}, [isOpen, onDiscard]);
useEffect(() => {
if (!isOpen) return;
const handleKeyDown = (event: KeyboardEvent) => {
if (event.key === "Escape") {
onDiscard();
}
};
document.addEventListener("keydown", handleKeyDown);
return () => {
document.removeEventListener("keydown", handleKeyDown);
};
}, [isOpen, onDiscard]);
return {
popupRef,
isOpen,
nodeCount,
edgeCount,
diff,
savedAt,
onLoad,
onDiscard,
};
};

View File

@@ -1,27 +1,26 @@
import { useGetV1GetSpecificGraph } from "@/app/api/__generated__/endpoints/graphs/graphs";
import { okData } from "@/app/api/helpers";
import { FloatingReviewsPanel } from "@/components/organisms/FloatingReviewsPanel/FloatingReviewsPanel";
import { Background, ReactFlow } from "@xyflow/react";
import { parseAsString, useQueryStates } from "nuqs";
import { useCallback, useMemo } from "react";
import { useShallow } from "zustand/react/shallow";
import { useGraphStore } from "../../../stores/graphStore";
import { useNodeStore } from "../../../stores/nodeStore";
import { BuilderActions } from "../../BuilderActions/BuilderActions";
import { DraftRecoveryPopup } from "../../DraftRecoveryDialog/DraftRecoveryPopup";
import { FloatingSafeModeToggle } from "../../FloatingSafeModeToogle";
import { ReactFlow, Background } from "@xyflow/react";
import NewControlPanel from "../../NewControlPanel/NewControlPanel";
import CustomEdge from "../edges/CustomEdge";
import { useCustomEdge } from "../edges/useCustomEdge";
import { useFlow } from "./useFlow";
import { useShallow } from "zustand/react/shallow";
import { useNodeStore } from "../../../stores/nodeStore";
import { useMemo, useEffect, useCallback } from "react";
import { CustomNode } from "../nodes/CustomNode/CustomNode";
import { CustomControls } from "./components/CustomControl";
import { useCustomEdge } from "../edges/useCustomEdge";
import { useFlowRealtime } from "./useFlowRealtime";
import { GraphLoadingBox } from "./components/GraphLoadingBox";
import { BuilderActions } from "../../BuilderActions/BuilderActions";
import { RunningBackground } from "./components/RunningBackground";
import { useGraphStore } from "../../../stores/graphStore";
import { useCopyPaste } from "./useCopyPaste";
import { FloatingReviewsPanel } from "@/components/organisms/FloatingReviewsPanel/FloatingReviewsPanel";
import { parseAsString, useQueryStates } from "nuqs";
import { CustomControls } from "./components/CustomControl";
import { useGetV1GetSpecificGraph } from "@/app/api/__generated__/endpoints/graphs/graphs";
import { okData } from "@/app/api/helpers";
import { TriggerAgentBanner } from "./components/TriggerAgentBanner";
import { resolveCollisions } from "./helpers/resolve-collision";
import { useCopyPaste } from "./useCopyPaste";
import { useFlow } from "./useFlow";
import { useFlowRealtime } from "./useFlowRealtime";
import { FloatingSafeModeToggle } from "../../FloatingSafeModeToogle";
export const Flow = () => {
const [{ flowID, flowExecutionID }] = useQueryStates({
@@ -42,18 +41,14 @@ export const Flow = () => {
const nodes = useNodeStore(useShallow((state) => state.nodes));
const setNodes = useNodeStore(useShallow((state) => state.setNodes));
const onNodesChange = useNodeStore(
useShallow((state) => state.onNodesChange),
);
const hasWebhookNodes = useNodeStore(
useShallow((state) => state.hasWebhookNodes()),
);
const nodeTypes = useMemo(() => ({ custom: CustomNode }), []);
const edgeTypes = useMemo(() => ({ custom: CustomEdge }), []);
const onNodeDragStop = useCallback(() => {
setNodes(
resolveCollisions(nodes, {
@@ -65,26 +60,29 @@ export const Flow = () => {
}, [setNodes, nodes]);
const { edges, onConnect, onEdgesChange } = useCustomEdge();
// for loading purpose
const {
onDragOver,
onDrop,
isFlowContentLoading,
isInitialLoadComplete,
isLocked,
setIsLocked,
} = useFlow();
// We use this hook to load the graph and convert them into custom nodes and edges.
const { onDragOver, onDrop, isFlowContentLoading, isLocked, setIsLocked } =
useFlow();
// This hook is used for websocket realtime updates.
useFlowRealtime();
// Copy/paste functionality
useCopyPaste();
const handleCopyPaste = useCopyPaste();
useEffect(() => {
const handleKeyDown = (event: KeyboardEvent) => {
handleCopyPaste(event);
};
window.addEventListener("keydown", handleKeyDown);
return () => {
window.removeEventListener("keydown", handleKeyDown);
};
}, [handleCopyPaste]);
const isGraphRunning = useGraphStore(
useShallow((state) => state.isGraphRunning),
);
return (
<div className="flex h-full w-full dark:bg-slate-900">
<div className="relative flex-1">
@@ -97,9 +95,6 @@ export const Flow = () => {
onConnect={onConnect}
onEdgesChange={onEdgesChange}
onNodeDragStop={onNodeDragStop}
onNodeContextMenu={(event) => {
event.preventDefault();
}}
maxZoom={2}
minZoom={0.1}
onDragOver={onDragOver}
@@ -107,7 +102,6 @@ export const Flow = () => {
nodesDraggable={!isLocked}
nodesConnectable={!isLocked}
elementsSelectable={!isLocked}
deleteKeyCode={["Backspace", "Delete"]}
>
<Background />
<CustomControls setIsLocked={setIsLocked} isLocked={isLocked} />
@@ -121,7 +115,6 @@ export const Flow = () => {
className="right-2 top-32 p-2"
/>
)}
<DraftRecoveryPopup isInitialLoadComplete={isInitialLoadComplete} />
</ReactFlow>
</div>
{/* TODO: Need to update it in future - also do not send executionId as prop - rather use useQueryState inside the component */}

View File

@@ -48,6 +48,8 @@ export const resolveCollisions: CollisionAlgorithm = (
const width = (node.width ?? node.measured?.width ?? 0) + margin * 2;
const height = (node.height ?? node.measured?.height ?? 0) + margin * 2;
console.log("width", width);
console.log("height", height);
const x = node.position.x - margin;
const y = node.position.y - margin;

View File

@@ -1,4 +1,4 @@
import { useCallback, useEffect } from "react";
import { useCallback } from "react";
import { useReactFlow } from "@xyflow/react";
import { v4 as uuidv4 } from "uuid";
import { useNodeStore } from "../../../stores/nodeStore";
@@ -151,16 +151,5 @@ export function useCopyPaste() {
[getViewport, toast],
);
useEffect(() => {
const handleKeyDown = (event: KeyboardEvent) => {
handleCopyPaste(event);
};
window.addEventListener("keydown", handleKeyDown);
return () => {
window.removeEventListener("keydown", handleKeyDown);
};
}, [handleCopyPaste]);
return handleCopyPaste;
}

View File

@@ -1,319 +0,0 @@
import { useState, useCallback, useEffect, useRef } from "react";
import { parseAsString, parseAsInteger, useQueryStates } from "nuqs";
import {
draftService,
getTempFlowId,
getOrCreateTempFlowId,
DraftData,
} from "@/services/builder-draft/draft-service";
import { BuilderDraft } from "@/lib/dexie/db";
import {
cleanNodes,
cleanEdges,
calculateDraftDiff,
DraftDiff,
} from "@/lib/dexie/draft-utils";
import { useNodeStore } from "../../../stores/nodeStore";
import { useEdgeStore } from "../../../stores/edgeStore";
import { useGraphStore } from "../../../stores/graphStore";
import { useHistoryStore } from "../../../stores/historyStore";
import isEqual from "lodash/isEqual";
const AUTO_SAVE_INTERVAL_MS = 15000; // 15 seconds
interface DraftRecoveryState {
isOpen: boolean;
draft: BuilderDraft | null;
diff: DraftDiff | null;
}
/**
* Consolidated hook for draft persistence and recovery
* - Auto-saves builder state every 15 seconds
* - Saves on beforeunload event
* - Checks for and manages unsaved drafts on load
*/
export function useDraftManager(isInitialLoadComplete: boolean) {
const [state, setState] = useState<DraftRecoveryState>({
isOpen: false,
draft: null,
diff: null,
});
const [{ flowID, flowVersion }] = useQueryStates({
flowID: parseAsString,
flowVersion: parseAsInteger,
});
const lastSavedStateRef = useRef<DraftData | null>(null);
const saveTimeoutRef = useRef<NodeJS.Timeout | null>(null);
const isDirtyRef = useRef(false);
const hasCheckedForDraft = useRef(false);
const getEffectiveFlowId = useCallback((): string => {
return flowID || getOrCreateTempFlowId();
}, [flowID]);
const getCurrentState = useCallback((): DraftData => {
const nodes = useNodeStore.getState().nodes;
const edges = useEdgeStore.getState().edges;
const nodeCounter = useNodeStore.getState().nodeCounter;
const graphStore = useGraphStore.getState();
return {
nodes,
edges,
graphSchemas: {
input: graphStore.inputSchema,
credentials: graphStore.credentialsInputSchema,
output: graphStore.outputSchema,
},
nodeCounter,
flowVersion: flowVersion ?? undefined,
};
}, [flowVersion]);
const cleanStateForComparison = useCallback((stateData: DraftData) => {
return {
nodes: cleanNodes(stateData.nodes),
edges: cleanEdges(stateData.edges),
};
}, []);
const hasChanges = useCallback((): boolean => {
const currentState = getCurrentState();
if (!lastSavedStateRef.current) {
return currentState.nodes.length > 0;
}
const currentClean = cleanStateForComparison(currentState);
const lastClean = cleanStateForComparison(lastSavedStateRef.current);
return !isEqual(currentClean, lastClean);
}, [getCurrentState, cleanStateForComparison]);
const saveDraft = useCallback(async () => {
const effectiveFlowId = getEffectiveFlowId();
const currentState = getCurrentState();
if (currentState.nodes.length === 0 && currentState.edges.length === 0) {
return;
}
if (!hasChanges()) {
return;
}
try {
await draftService.saveDraft(effectiveFlowId, currentState);
lastSavedStateRef.current = currentState;
isDirtyRef.current = false;
} catch (error) {
console.error("[DraftPersistence] Failed to save draft:", error);
}
}, [getEffectiveFlowId, getCurrentState, hasChanges]);
const scheduleSave = useCallback(() => {
isDirtyRef.current = true;
if (saveTimeoutRef.current) {
clearTimeout(saveTimeoutRef.current);
}
saveTimeoutRef.current = setTimeout(() => {
saveDraft();
}, AUTO_SAVE_INTERVAL_MS);
}, [saveDraft]);
useEffect(() => {
const unsubscribeNodes = useNodeStore.subscribe((storeState, prevState) => {
if (storeState.nodes !== prevState.nodes) {
scheduleSave();
}
});
const unsubscribeEdges = useEdgeStore.subscribe((storeState, prevState) => {
if (storeState.edges !== prevState.edges) {
scheduleSave();
}
});
return () => {
unsubscribeNodes();
unsubscribeEdges();
};
}, [scheduleSave]);
useEffect(() => {
const handleBeforeUnload = () => {
if (isDirtyRef.current) {
const effectiveFlowId = getEffectiveFlowId();
const currentState = getCurrentState();
if (
currentState.nodes.length === 0 &&
currentState.edges.length === 0
) {
return;
}
draftService.saveDraft(effectiveFlowId, currentState).catch(() => {
// Ignore errors on unload
});
}
};
window.addEventListener("beforeunload", handleBeforeUnload);
return () => {
window.removeEventListener("beforeunload", handleBeforeUnload);
};
}, [getEffectiveFlowId, getCurrentState]);
useEffect(() => {
return () => {
if (saveTimeoutRef.current) {
clearTimeout(saveTimeoutRef.current);
}
if (isDirtyRef.current) {
saveDraft();
}
};
}, [saveDraft]);
useEffect(() => {
draftService.cleanupExpired().catch((error) => {
console.error(
"[DraftPersistence] Failed to cleanup expired drafts:",
error,
);
});
}, []);
const checkForDraft = useCallback(async () => {
const effectiveFlowId = flowID || getTempFlowId();
if (!effectiveFlowId) {
return;
}
try {
const draft = await draftService.loadDraft(effectiveFlowId);
if (!draft) {
return;
}
const currentNodes = useNodeStore.getState().nodes;
const currentEdges = useEdgeStore.getState().edges;
const isDifferent = draftService.isDraftDifferent(
draft,
currentNodes,
currentEdges,
);
if (isDifferent && (draft.nodes.length > 0 || draft.edges.length > 0)) {
const diff = calculateDraftDiff(
draft.nodes,
draft.edges,
currentNodes,
currentEdges,
);
setState({
isOpen: true,
draft,
diff,
});
} else {
await draftService.deleteDraft(effectiveFlowId);
}
} catch (error) {
console.error("[DraftRecovery] Failed to check for draft:", error);
}
}, [flowID]);
useEffect(() => {
if (isInitialLoadComplete && !hasCheckedForDraft.current) {
hasCheckedForDraft.current = true;
checkForDraft();
}
}, [isInitialLoadComplete, checkForDraft]);
useEffect(() => {
hasCheckedForDraft.current = false;
setState({
isOpen: false,
draft: null,
diff: null,
});
}, [flowID]);
const loadDraft = useCallback(async () => {
if (!state.draft) return;
const { draft } = state;
try {
useNodeStore.getState().setNodes(draft.nodes);
useEdgeStore.getState().setEdges(draft.edges);
draft.nodes.forEach((node) => {
useNodeStore.getState().syncHardcodedValuesWithHandleIds(node.id);
});
if (draft.nodeCounter !== undefined) {
useNodeStore.setState({ nodeCounter: draft.nodeCounter });
}
if (draft.graphSchemas) {
useGraphStore
.getState()
.setGraphSchemas(
draft.graphSchemas.input as Record<string, unknown> | null,
draft.graphSchemas.credentials as Record<string, unknown> | null,
draft.graphSchemas.output as Record<string, unknown> | null,
);
}
setTimeout(() => {
useHistoryStore.getState().initializeHistory();
}, 100);
await draftService.deleteDraft(draft.id);
setState({
isOpen: false,
draft: null,
diff: null,
});
} catch (error) {
console.error("[DraftRecovery] Failed to load draft:", error);
}
}, [state.draft]);
const discardDraft = useCallback(async () => {
if (!state.draft) {
setState({ isOpen: false, draft: null, diff: null });
return;
}
try {
await draftService.deleteDraft(state.draft.id);
} catch (error) {
console.error("[DraftRecovery] Failed to discard draft:", error);
}
setState({ isOpen: false, draft: null, diff: null });
}, [state.draft]);
return {
// Recovery popup props
isRecoveryOpen: state.isOpen,
savedAt: state.draft?.savedAt ?? 0,
nodeCount: state.draft?.nodes.length ?? 0,
edgeCount: state.draft?.edges.length ?? 0,
diff: state.diff,
loadDraft,
discardDraft,
};
}

View File

@@ -21,7 +21,6 @@ import { AgentExecutionStatus } from "@/app/api/__generated__/models/agentExecut
export const useFlow = () => {
const [isLocked, setIsLocked] = useState(false);
const [hasAutoFramed, setHasAutoFramed] = useState(false);
const [isInitialLoadComplete, setIsInitialLoadComplete] = useState(false);
const addNodes = useNodeStore(useShallow((state) => state.addNodes));
const addLinks = useEdgeStore(useShallow((state) => state.addLinks));
const updateNodeStatus = useNodeStore(
@@ -121,14 +120,6 @@ export const useFlow = () => {
if (customNodes.length > 0) {
useNodeStore.getState().setNodes([]);
addNodes(customNodes);
// Sync hardcoded values with handle IDs.
// If a keyvalue field has a key without a value, the backend omits it from hardcoded values.
// But if a handleId exists for that key, it causes inconsistency.
// This ensures hardcoded values stay in sync with handle IDs.
customNodes.forEach((node) => {
useNodeStore.getState().syncHardcodedValuesWithHandleIds(node.id);
});
}
}, [customNodes, addNodes]);
@@ -183,23 +174,11 @@ export const useFlow = () => {
if (customNodes.length > 0 && graph?.links) {
const timer = setTimeout(() => {
useHistoryStore.getState().initializeHistory();
// Mark initial load as complete after history is initialized
setIsInitialLoadComplete(true);
}, 100);
return () => clearTimeout(timer);
}
}, [customNodes, graph?.links]);
// Also mark as complete for new flows (no flowID) after a short delay
useEffect(() => {
if (!flowID && !isGraphLoading && !isBlocksLoading) {
const timer = setTimeout(() => {
setIsInitialLoadComplete(true);
}, 200);
return () => clearTimeout(timer);
}
}, [flowID, isGraphLoading, isBlocksLoading]);
useEffect(() => {
return () => {
useNodeStore.getState().setNodes([]);
@@ -238,7 +217,6 @@ export const useFlow = () => {
useEffect(() => {
setHasAutoFramed(false);
setIsInitialLoadComplete(false);
}, [flowID, flowVersion]);
// Drag and drop block from block menu
@@ -275,7 +253,6 @@ export const useFlow = () => {
return {
isFlowContentLoading: isGraphLoading || isBlocksLoading,
isInitialLoadComplete,
onDragOver,
onDrop,
isLocked,

View File

@@ -1,17 +1,12 @@
import {
Connection as RFConnection,
EdgeChange,
applyEdgeChanges,
} from "@xyflow/react";
import { Connection as RFConnection, EdgeChange } from "@xyflow/react";
import { useEdgeStore } from "@/app/(platform)/build/stores/edgeStore";
import { useCallback } from "react";
import { useNodeStore } from "../../../stores/nodeStore";
import { CustomEdge } from "./CustomEdge";
export const useCustomEdge = () => {
const edges = useEdgeStore((s) => s.edges);
const addEdge = useEdgeStore((s) => s.addEdge);
const setEdges = useEdgeStore((s) => s.setEdges);
const removeEdge = useEdgeStore((s) => s.removeEdge);
const onConnect = useCallback(
(conn: RFConnection) => {
@@ -50,10 +45,14 @@ export const useCustomEdge = () => {
);
const onEdgesChange = useCallback(
(changes: EdgeChange<CustomEdge>[]) => {
setEdges(applyEdgeChanges(changes, edges));
(changes: EdgeChange[]) => {
changes.forEach((change) => {
if (change.type === "remove") {
removeEdge(change.id);
}
});
},
[edges, setEdges],
[removeEdge],
);
return { edges, onConnect, onEdgesChange };

View File

@@ -1,32 +1,26 @@
import { CircleIcon } from "@phosphor-icons/react";
import { Handle, Position } from "@xyflow/react";
import { useEdgeStore } from "../../../stores/edgeStore";
import { cleanUpHandleId } from "@/components/renderers/InputRenderer/helpers";
import { cn } from "@/lib/utils";
const InputNodeHandle = ({
const NodeHandle = ({
handleId,
nodeId,
isConnected,
side,
}: {
handleId: string;
nodeId: string;
isConnected: boolean;
side: "left" | "right";
}) => {
const cleanedHandleId = cleanUpHandleId(handleId);
const isInputConnected = useEdgeStore((state) =>
state.isInputConnected(nodeId ?? "", cleanedHandleId),
);
return (
<Handle
type={"target"}
position={Position.Left}
id={cleanedHandleId}
className={"-ml-6 mr-2"}
type={side === "left" ? "target" : "source"}
position={side === "left" ? Position.Left : Position.Right}
id={handleId}
className={side === "left" ? "-ml-4 mr-2" : "-mr-2 ml-2"}
>
<div className="pointer-events-none">
<CircleIcon
size={16}
weight={isInputConnected ? "fill" : "duotone"}
weight={isConnected ? "fill" : "duotone"}
className={"text-gray-400 opacity-100"}
/>
</div>
@@ -34,35 +28,4 @@ const InputNodeHandle = ({
);
};
const OutputNodeHandle = ({
field_name,
nodeId,
hexColor,
}: {
field_name: string;
nodeId: string;
hexColor: string;
}) => {
const isOutputConnected = useEdgeStore((state) =>
state.isOutputConnected(nodeId, field_name),
);
return (
<Handle
type={"source"}
position={Position.Right}
id={field_name}
className={"-mr-2 ml-2"}
>
<div className="pointer-events-none">
<CircleIcon
size={16}
weight={"duotone"}
color={isOutputConnected ? hexColor : "gray"}
className={cn("text-gray-400 opacity-100")}
/>
</div>
</Handle>
);
};
export { InputNodeHandle, OutputNodeHandle };
export default NodeHandle;

View File

@@ -1,4 +1,31 @@
// Here we are handling single level of nesting, if need more in future then i will update it
/**
* Handle ID Types for different input structures
*
* Examples:
* SIMPLE: "message"
* NESTED: "config.api_key"
* ARRAY: "items_$_0", "items_$_1"
* KEY_VALUE: "headers_#_Authorization", "params_#_limit"
*
* Note: All handle IDs are sanitized to remove spaces and special characters.
* Spaces become underscores, and special characters are removed.
* Example: "user name" becomes "user_name", "email@domain.com" becomes "emaildomaincom"
*/
export enum HandleIdType {
SIMPLE = "SIMPLE",
NESTED = "NESTED",
ARRAY = "ARRAY",
KEY_VALUE = "KEY_VALUE",
}
const fromRjsfId = (id: string): string => {
if (!id) return "";
const parts = id.split("_");
const filtered = parts.filter(
(p) => p !== "root" && p !== "properties" && p.length > 0,
);
return filtered.join("_") || "";
};
const sanitizeForHandleId = (str: string): string => {
if (!str) return "";
@@ -11,53 +38,51 @@ const sanitizeForHandleId = (str: string): string => {
.replace(/^_|_$/g, ""); // Remove leading/trailing underscores
};
const cleanTitleId = (id: string): string => {
if (!id) return "";
if (id.endsWith("_title")) {
id = id.slice(0, -6);
}
const parts = id.split("_");
const filtered = parts.filter(
(p) => p !== "root" && p !== "properties" && p.length > 0,
);
const filtered_id = filtered.join("_") || "";
return filtered_id;
};
export const generateHandleIdFromTitleId = (
export const generateHandleId = (
fieldKey: string,
{
isObjectProperty,
isAdditionalProperty,
isArrayItem,
}: {
isArrayItem?: boolean;
isObjectProperty?: boolean;
isAdditionalProperty?: boolean;
} = {
isArrayItem: false,
isObjectProperty: false,
isAdditionalProperty: false,
},
nestedValues: string[] = [],
type: HandleIdType = HandleIdType.SIMPLE,
): string => {
if (!fieldKey) return "";
const filteredKey = cleanTitleId(fieldKey);
if (isAdditionalProperty || isArrayItem) {
return filteredKey;
}
const cleanedKey = sanitizeForHandleId(filteredKey);
fieldKey = fromRjsfId(fieldKey);
fieldKey = sanitizeForHandleId(fieldKey);
if (isObjectProperty) {
// "config_api_key" -> "config.api_key"
const parts = cleanedKey.split("_");
if (parts.length >= 2) {
const baseName = parts[0];
const propertyName = parts.slice(1).join("_");
return `${baseName}.${propertyName}`;
}
if (type === HandleIdType.SIMPLE || nestedValues.length === 0) {
return fieldKey;
}
return cleanedKey;
const sanitizedNestedValues = nestedValues.map((value) =>
sanitizeForHandleId(value),
);
switch (type) {
case HandleIdType.NESTED:
return [fieldKey, ...sanitizedNestedValues].join(".");
case HandleIdType.ARRAY:
return [fieldKey, ...sanitizedNestedValues].join("_$_");
case HandleIdType.KEY_VALUE:
return [fieldKey, ...sanitizedNestedValues].join("_#_");
default:
return fieldKey;
}
};
export const parseKeyValueHandleId = (
handleId: string,
type: HandleIdType,
): string => {
if (type === HandleIdType.KEY_VALUE) {
return handleId.split("_#_")[1];
} else if (type === HandleIdType.ARRAY) {
return handleId.split("_$_")[1];
} else if (type === HandleIdType.NESTED) {
return handleId.split(".")[1];
} else if (type === HandleIdType.SIMPLE) {
return handleId.split("_")[1];
}
return "";
};

View File

@@ -1,25 +1,24 @@
import { AgentExecutionStatus } from "@/app/api/__generated__/models/agentExecutionStatus";
import { BlockCost } from "@/app/api/__generated__/models/blockCost";
import { BlockInfoCategoriesItem } from "@/app/api/__generated__/models/blockInfoCategoriesItem";
import { NodeExecutionResult } from "@/app/api/__generated__/models/nodeExecutionResult";
import { NodeModelMetadata } from "@/app/api/__generated__/models/nodeModelMetadata";
import { preprocessInputSchema } from "@/components/renderers/InputRenderer/utils/input-schema-pre-processor";
import { cn } from "@/lib/utils";
import { RJSFSchema } from "@rjsf/utils";
import { NodeProps, Node as XYNode } from "@xyflow/react";
import React from "react";
import { Node as XYNode, NodeProps } from "@xyflow/react";
import { RJSFSchema } from "@rjsf/utils";
import { BlockUIType } from "../../../types";
import { FormCreator } from "../FormCreator";
import { OutputHandler } from "../OutputHandler";
import { AyrshareConnectButton } from "./components/AyrshareConnectButton";
import { NodeAdvancedToggle } from "./components/NodeAdvancedToggle";
import { NodeContainer } from "./components/NodeContainer";
import { NodeExecutionBadge } from "./components/NodeExecutionBadge";
import { NodeHeader } from "./components/NodeHeader";
import { NodeDataRenderer } from "./components/NodeOutput/NodeOutput";
import { NodeRightClickMenu } from "./components/NodeRightClickMenu";
import { StickyNoteBlock } from "./components/StickyNoteBlock";
import { BlockInfoCategoriesItem } from "@/app/api/__generated__/models/blockInfoCategoriesItem";
import { BlockCost } from "@/app/api/__generated__/models/blockCost";
import { AgentExecutionStatus } from "@/app/api/__generated__/models/agentExecutionStatus";
import { NodeExecutionResult } from "@/app/api/__generated__/models/nodeExecutionResult";
import { NodeContainer } from "./components/NodeContainer";
import { NodeHeader } from "./components/NodeHeader";
import { FormCreator } from "../FormCreator";
import { preprocessInputSchema } from "@/components/renderers/input-renderer/utils/input-schema-pre-processor";
import { OutputHandler } from "../OutputHandler";
import { NodeAdvancedToggle } from "./components/NodeAdvancedToggle";
import { NodeDataRenderer } from "./components/NodeOutput/NodeOutput";
import { NodeExecutionBadge } from "./components/NodeExecutionBadge";
import { cn } from "@/lib/utils";
import { WebhookDisclaimer } from "./components/WebhookDisclaimer";
import { AyrshareConnectButton } from "./components/AyrshareConnectButton";
import { NodeModelMetadata } from "@/app/api/__generated__/models/nodeModelMetadata";
export type CustomNodeData = {
hardcodedValues: {
@@ -89,7 +88,7 @@ export const CustomNode: React.FC<NodeProps<CustomNode>> = React.memo(
// Currently all blockTypes design are similar - that's why i am using the same component for all of them
// If in future - if we need some drastic change in some blockTypes design - we can create separate components for them
const node = (
return (
<NodeContainer selected={selected} nodeId={nodeId} hasErrors={hasErrors}>
<div className="rounded-xlarge bg-white">
<NodeHeader data={data} nodeId={nodeId} />
@@ -100,7 +99,7 @@ export const CustomNode: React.FC<NodeProps<CustomNode>> = React.memo(
nodeId={nodeId}
uiType={data.uiType}
className={cn(
"bg-white px-4",
"bg-white pr-6",
isWebhook && "pointer-events-none opacity-50",
)}
showHandles={showHandles}
@@ -118,15 +117,6 @@ export const CustomNode: React.FC<NodeProps<CustomNode>> = React.memo(
<NodeExecutionBadge nodeId={nodeId} />
</NodeContainer>
);
return (
<NodeRightClickMenu
nodeId={nodeId}
subGraphID={data.hardcodedValues?.graph_id}
>
{node}
</NodeRightClickMenu>
);
},
);

View File

@@ -8,7 +8,7 @@ export const NodeAdvancedToggle = ({ nodeId }: { nodeId: string }) => {
);
const setShowAdvanced = useNodeStore((state) => state.setShowAdvanced);
return (
<div className="flex items-center justify-between gap-2 rounded-b-xlarge border-t border-zinc-200 bg-white px-5 py-3.5">
<div className="flex items-center justify-between gap-2 rounded-b-xlarge border-t border-slate-200/50 bg-white px-5 py-3.5">
<Text variant="body" className="font-medium text-slate-700">
Advanced
</Text>

View File

@@ -22,7 +22,7 @@ export const NodeContainer = ({
return (
<div
className={cn(
"z-12 w-[350px] rounded-xlarge ring-1 ring-slate-200/60",
"z-12 max-w-[370px] rounded-xlarge ring-1 ring-slate-200/60",
selected && "shadow-lg ring-2 ring-slate-200",
status && nodeStyleBasedOnStatus[status],
hasErrors ? nodeStyleBasedOnStatus[AgentExecutionStatus.FAILED] : "",

View File

@@ -1,31 +1,26 @@
import { useCopyPasteStore } from "@/app/(platform)/build/stores/copyPasteStore";
import { useNodeStore } from "@/app/(platform)/build/stores/nodeStore";
import { Separator } from "@/components/__legacy__/ui/separator";
import {
DropdownMenu,
DropdownMenuContent,
DropdownMenuItem,
DropdownMenuTrigger,
} from "@/components/molecules/DropdownMenu/DropdownMenu";
import {
SecondaryDropdownMenuContent,
SecondaryDropdownMenuItem,
SecondaryDropdownMenuSeparator,
} from "@/components/molecules/SecondaryMenu/SecondaryMenu";
import {
ArrowSquareOutIcon,
CopyIcon,
DotsThreeOutlineVerticalIcon,
TrashIcon,
} from "@phosphor-icons/react";
import { DotsThreeOutlineVerticalIcon } from "@phosphor-icons/react";
import { Copy, Trash2, ExternalLink } from "lucide-react";
import { useNodeStore } from "@/app/(platform)/build/stores/nodeStore";
import { useCopyPasteStore } from "@/app/(platform)/build/stores/copyPasteStore";
import { useReactFlow } from "@xyflow/react";
type Props = {
export const NodeContextMenu = ({
nodeId,
subGraphID,
}: {
nodeId: string;
subGraphID?: string;
};
export const NodeContextMenu = ({ nodeId, subGraphID }: Props) => {
}) => {
const { deleteElements } = useReactFlow();
function handleCopy() {
const handleCopy = () => {
useNodeStore.setState((state) => ({
nodes: state.nodes.map((node) => ({
...node,
@@ -35,47 +30,47 @@ export const NodeContextMenu = ({ nodeId, subGraphID }: Props) => {
useCopyPasteStore.getState().copySelectedNodes();
useCopyPasteStore.getState().pasteNodes();
}
};
function handleDelete() {
const handleDelete = () => {
deleteElements({ nodes: [{ id: nodeId }] });
}
};
return (
<DropdownMenu>
<DropdownMenuTrigger className="py-2">
<DotsThreeOutlineVerticalIcon size={16} weight="fill" />
</DropdownMenuTrigger>
<SecondaryDropdownMenuContent side="right" align="start">
<SecondaryDropdownMenuItem onClick={handleCopy}>
<CopyIcon size={20} className="mr-2 dark:text-gray-100" />
<span className="dark:text-gray-100">Copy</span>
</SecondaryDropdownMenuItem>
<SecondaryDropdownMenuSeparator />
<DropdownMenuContent
side="right"
align="start"
className="rounded-xlarge"
>
<DropdownMenuItem onClick={handleCopy} className="hover:rounded-xlarge">
<Copy className="mr-2 h-4 w-4" />
Copy Node
</DropdownMenuItem>
{subGraphID && (
<>
<SecondaryDropdownMenuItem
onClick={() => window.open(`/build?flowID=${subGraphID}`)}
>
<ArrowSquareOutIcon
size={20}
className="mr-2 dark:text-gray-100"
/>
<span className="dark:text-gray-100">Open agent</span>
</SecondaryDropdownMenuItem>
<SecondaryDropdownMenuSeparator />
</>
<DropdownMenuItem
onClick={() => window.open(`/build?flowID=${subGraphID}`)}
className="hover:rounded-xlarge"
>
<ExternalLink className="mr-2 h-4 w-4" />
Open Agent
</DropdownMenuItem>
)}
<SecondaryDropdownMenuItem variant="destructive" onClick={handleDelete}>
<TrashIcon
size={20}
className="mr-2 text-red-500 dark:text-red-400"
/>
<span className="dark:text-red-400">Delete</span>
</SecondaryDropdownMenuItem>
</SecondaryDropdownMenuContent>
<Separator className="my-2" />
<DropdownMenuItem
onClick={handleDelete}
className="text-red-600 hover:rounded-xlarge"
>
<Trash2 className="mr-2 h-4 w-4" />
Delete
</DropdownMenuItem>
</DropdownMenuContent>
</DropdownMenu>
);
};

View File

@@ -1,30 +1,29 @@
import { useNodeStore } from "@/app/(platform)/build/stores/nodeStore";
import { Text } from "@/components/atoms/Text/Text";
import { beautifyString, cn } from "@/lib/utils";
import { NodeCost } from "./NodeCost";
import { NodeBadges } from "./NodeBadges";
import { NodeContextMenu } from "./NodeContextMenu";
import { CustomNodeData } from "../CustomNode";
import { useNodeStore } from "@/app/(platform)/build/stores/nodeStore";
import { useState } from "react";
import {
Tooltip,
TooltipContent,
TooltipProvider,
TooltipTrigger,
} from "@/components/atoms/Tooltip/BaseTooltip";
import { beautifyString, cn } from "@/lib/utils";
import { useState } from "react";
import { CustomNodeData } from "../CustomNode";
import { NodeBadges } from "./NodeBadges";
import { NodeContextMenu } from "./NodeContextMenu";
import { NodeCost } from "./NodeCost";
type Props = {
export const NodeHeader = ({
data,
nodeId,
}: {
data: CustomNodeData;
nodeId: string;
};
export const NodeHeader = ({ data, nodeId }: Props) => {
}) => {
const updateNodeData = useNodeStore((state) => state.updateNodeData);
const title = (data.metadata?.customized_name as string) || data.title;
const [isEditingTitle, setIsEditingTitle] = useState(false);
const [editedTitle, setEditedTitle] = useState(
beautifyString(title).replace("Block", "").trim(),
);
const [editedTitle, setEditedTitle] = useState(title);
const handleTitleEdit = () => {
updateNodeData(nodeId, {
@@ -42,7 +41,7 @@ export const NodeHeader = ({ data, nodeId }: Props) => {
};
return (
<div className="flex h-auto flex-col gap-1 rounded-xlarge border-b border-zinc-200 bg-gradient-to-r from-slate-50/80 to-white/90 px-4 py-4 pt-3">
<div className="flex h-auto flex-col gap-1 rounded-xlarge border-b border-slate-200/50 bg-gradient-to-r from-slate-50/80 to-white/90 px-4 py-4 pt-3">
{/* Title row with context menu */}
<div className="flex items-start justify-between gap-2">
<div className="flex min-w-0 flex-1 items-center gap-2">
@@ -68,16 +67,13 @@ export const NodeHeader = ({ data, nodeId }: Props) => {
<Tooltip>
<TooltipTrigger asChild>
<div>
<Text
variant="large-semibold"
className="line-clamp-1 hover:cursor-text"
>
{beautifyString(title).replace("Block", "").trim()}
<Text variant="large-semibold" className="line-clamp-1">
{beautifyString(title)}
</Text>
</div>
</TooltipTrigger>
<TooltipContent>
<p>{beautifyString(title).replace("Block", "").trim()}</p>
<p>{beautifyString(title)}</p>
</TooltipContent>
</Tooltip>
</TooltipProvider>

View File

@@ -23,7 +23,7 @@ export const NodeDataRenderer = ({ nodeId }: { nodeId: string }) => {
}
return (
<div className="flex flex-col gap-3 rounded-b-xl border-t border-zinc-200 px-4 py-4">
<div className="flex flex-col gap-3 rounded-b-xl border-t border-slate-200/50 px-4 py-4">
<div className="flex items-center justify-between">
<Text variant="body-medium" className="!font-semibold text-slate-700">
Node Output

View File

@@ -151,7 +151,7 @@ export const NodeDataViewer: FC<NodeDataViewerProps> = ({
</div>
<div className="flex justify-end pt-4">
{outputItems.length > 1 && (
{outputItems.length > 0 && (
<OutputActions
items={outputItems.map((item) => ({
value: item.value,

View File

@@ -1,104 +0,0 @@
import { useCopyPasteStore } from "@/app/(platform)/build/stores/copyPasteStore";
import { useNodeStore } from "@/app/(platform)/build/stores/nodeStore";
import {
SecondaryMenuContent,
SecondaryMenuItem,
SecondaryMenuSeparator,
} from "@/components/molecules/SecondaryMenu/SecondaryMenu";
import { ArrowSquareOutIcon, CopyIcon, TrashIcon } from "@phosphor-icons/react";
import * as ContextMenu from "@radix-ui/react-context-menu";
import { useReactFlow } from "@xyflow/react";
import { useEffect, useRef } from "react";
import { CustomNode } from "../CustomNode";
type Props = {
nodeId: string;
subGraphID?: string;
children: React.ReactNode;
};
const DOUBLE_CLICK_TIMEOUT = 300;
export function NodeRightClickMenu({ nodeId, subGraphID, children }: Props) {
const { deleteElements } = useReactFlow<CustomNode>();
const lastRightClickTime = useRef<number>(0);
const containerRef = useRef<HTMLDivElement>(null);
function copyNode() {
useNodeStore.setState((state) => ({
nodes: state.nodes.map((node) => ({
...node,
selected: node.id === nodeId,
})),
}));
useCopyPasteStore.getState().copySelectedNodes();
useCopyPasteStore.getState().pasteNodes();
}
function deleteNode() {
deleteElements({ nodes: [{ id: nodeId }] });
}
useEffect(() => {
const container = containerRef.current;
if (!container) return;
function handleContextMenu(e: MouseEvent) {
const now = Date.now();
const timeSinceLastClick = now - lastRightClickTime.current;
if (timeSinceLastClick < DOUBLE_CLICK_TIMEOUT) {
e.stopImmediatePropagation();
lastRightClickTime.current = 0;
return;
}
lastRightClickTime.current = now;
}
container.addEventListener("contextmenu", handleContextMenu, true);
return () => {
container.removeEventListener("contextmenu", handleContextMenu, true);
};
}, []);
return (
<ContextMenu.Root>
<ContextMenu.Trigger asChild>
<div ref={containerRef}>{children}</div>
</ContextMenu.Trigger>
<SecondaryMenuContent>
<SecondaryMenuItem onSelect={copyNode}>
<CopyIcon size={20} className="mr-2 dark:text-gray-100" />
<span className="dark:text-gray-100">Copy</span>
</SecondaryMenuItem>
<SecondaryMenuSeparator />
{subGraphID && (
<>
<SecondaryMenuItem
onClick={() => window.open(`/build?flowID=${subGraphID}`)}
>
<ArrowSquareOutIcon
size={20}
className="mr-2 dark:text-gray-100"
/>
<span className="dark:text-gray-100">Open agent</span>
</SecondaryMenuItem>
<SecondaryMenuSeparator />
</>
)}
<SecondaryMenuItem variant="destructive" onSelect={deleteNode}>
<TrashIcon
size={20}
className="mr-2 text-red-500 dark:text-red-400"
/>
<span className="dark:text-red-400">Delete</span>
</SecondaryMenuItem>
</SecondaryMenuContent>
</ContextMenu.Root>
);
}

View File

@@ -1,6 +1,6 @@
import { useMemo } from "react";
import { FormCreator } from "../../FormCreator";
import { preprocessInputSchema } from "@/components/renderers/InputRenderer/utils/input-schema-pre-processor";
import { preprocessInputSchema } from "@/components/renderers/input-renderer/utils/input-schema-pre-processor";
import { CustomNodeData } from "../CustomNode";
import { Text } from "@/components/atoms/Text/Text";
import { cn } from "@/lib/utils";

View File

@@ -3,7 +3,7 @@ import React from "react";
import { uiSchema } from "./uiSchema";
import { useNodeStore } from "../../../stores/nodeStore";
import { BlockUIType } from "../../types";
import { FormRenderer } from "@/components/renderers/InputRenderer/FormRenderer";
import { FormRenderer } from "@/components/renderers/input-renderer/FormRenderer";
export const FormCreator = React.memo(
({

View File

@@ -4,7 +4,7 @@ import { CaretDownIcon, InfoIcon } from "@phosphor-icons/react";
import { RJSFSchema } from "@rjsf/utils";
import { useState } from "react";
import { OutputNodeHandle } from "../handlers/NodeHandle";
import NodeHandle from "../handlers/NodeHandle";
import {
Tooltip,
TooltipContent,
@@ -13,6 +13,7 @@ import {
} from "@/components/atoms/Tooltip/BaseTooltip";
import { useEdgeStore } from "@/app/(platform)/build/stores/edgeStore";
import { getTypeDisplayInfo } from "./helpers";
import { generateHandleId } from "../handlers/helpers";
import { BlockUIType } from "../../types";
export const OutputHandler = ({
@@ -28,73 +29,8 @@ export const OutputHandler = ({
const properties = outputSchema?.properties || {};
const [isOutputVisible, setIsOutputVisible] = useState(true);
const showHandles = uiType !== BlockUIType.OUTPUT;
const renderOutputHandles = (
schema: RJSFSchema,
keyPrefix: string = "",
titlePrefix: string = "",
): React.ReactNode[] => {
return Object.entries(schema).map(
([key, fieldSchema]: [string, RJSFSchema]) => {
const fullKey = keyPrefix ? `${keyPrefix}_#_${key}` : key;
const fieldTitle = titlePrefix + (fieldSchema?.title || key);
const isConnected = isOutputConnected(nodeId, fullKey);
const shouldShow = isConnected || isOutputVisible;
const { displayType, colorClass, hexColor } =
getTypeDisplayInfo(fieldSchema);
return shouldShow ? (
<div key={fullKey} className="flex flex-col items-end gap-2">
<div className="relative flex items-center gap-2">
{fieldSchema?.description && (
<TooltipProvider>
<Tooltip>
<TooltipTrigger asChild>
<span
style={{ marginLeft: 6, cursor: "pointer" }}
aria-label="info"
tabIndex={0}
>
<InfoIcon />
</span>
</TooltipTrigger>
<TooltipContent>{fieldSchema?.description}</TooltipContent>
</Tooltip>
</TooltipProvider>
)}
<Text variant="body" className="text-slate-700">
{fieldTitle}
</Text>
<Text variant="small" as="span" className={colorClass}>
({displayType})
</Text>
{showHandles && (
<OutputNodeHandle
field_name={fullKey}
nodeId={nodeId}
hexColor={hexColor}
/>
)}
</div>
{/* Recursively render nested properties */}
{fieldSchema?.properties &&
renderOutputHandles(
fieldSchema.properties,
fullKey,
`${fieldTitle}.`,
)}
</div>
) : null;
},
);
};
return (
<div className="flex flex-col items-end justify-between gap-2 rounded-b-xlarge border-t border-zinc-200 bg-white py-3.5">
<div className="flex flex-col items-end justify-between gap-2 rounded-b-xlarge border-t border-slate-200/50 bg-white py-3.5">
<Button
variant="ghost"
className="mr-4 h-fit min-w-0 p-0 hover:border-transparent hover:bg-transparent"
@@ -113,9 +49,50 @@ export const OutputHandler = ({
</Text>
</Button>
<div className="flex flex-col items-end gap-2">
{renderOutputHandles(properties)}
</div>
{
<div className="flex flex-col items-end gap-2">
{Object.entries(properties).map(([key, property]: [string, any]) => {
const isConnected = isOutputConnected(nodeId, key);
const shouldShow = isConnected || isOutputVisible;
const { displayType, colorClass } = getTypeDisplayInfo(property);
return shouldShow ? (
<div key={key} className="relative flex items-center gap-2">
{property?.description && (
<TooltipProvider>
<Tooltip>
<TooltipTrigger asChild>
<span
style={{ marginLeft: 6, cursor: "pointer" }}
aria-label="info"
tabIndex={0}
>
<InfoIcon />
</span>
</TooltipTrigger>
<TooltipContent>{property?.description}</TooltipContent>
</Tooltip>
</TooltipProvider>
)}
<Text variant="body" className="text-slate-700">
{property?.title || key}{" "}
</Text>
<Text variant="small" as="span" className={colorClass}>
({displayType})
</Text>
<NodeHandle
handleId={
uiType === BlockUIType.AGENT ? key : generateHandleId(key)
}
isConnected={isConnected}
side="right"
/>
</div>
) : null;
})}
</div>
}
</div>
);
};

View File

@@ -89,53 +89,17 @@ export function extractOptions(
// get display type and color for schema types [need for type display next to field name]
export const getTypeDisplayInfo = (schema: any) => {
if (
schema?.type === "array" &&
"format" in schema &&
schema.format === "table"
) {
return {
displayType: "table",
colorClass: "!text-indigo-500",
hexColor: "#6366f1",
};
}
if (schema?.type === "string" && schema?.format) {
const formatMap: Record<
string,
{ displayType: string; colorClass: string; hexColor: string }
{ displayType: string; colorClass: string }
> = {
file: {
displayType: "file",
colorClass: "!text-green-500",
hexColor: "#22c55e",
},
date: {
displayType: "date",
colorClass: "!text-blue-500",
hexColor: "#3b82f6",
},
time: {
displayType: "time",
colorClass: "!text-blue-500",
hexColor: "#3b82f6",
},
"date-time": {
displayType: "datetime",
colorClass: "!text-blue-500",
hexColor: "#3b82f6",
},
"long-text": {
displayType: "text",
colorClass: "!text-green-500",
hexColor: "#22c55e",
},
"short-text": {
displayType: "text",
colorClass: "!text-green-500",
hexColor: "#22c55e",
},
file: { displayType: "file", colorClass: "!text-green-500" },
date: { displayType: "date", colorClass: "!text-blue-500" },
time: { displayType: "time", colorClass: "!text-blue-500" },
"date-time": { displayType: "datetime", colorClass: "!text-blue-500" },
"long-text": { displayType: "text", colorClass: "!text-green-500" },
"short-text": { displayType: "text", colorClass: "!text-green-500" },
};
const formatInfo = formatMap[schema.format];
@@ -167,23 +131,10 @@ export const getTypeDisplayInfo = (schema: any) => {
any: "!text-gray-500",
};
const hexColorMap: Record<string, string> = {
string: "#22c55e",
number: "#3b82f6",
integer: "#3b82f6",
boolean: "#eab308",
object: "#a855f7",
array: "#6366f1",
null: "#6b7280",
any: "#6b7280",
};
const colorClass = colorMap[schema?.type] || "!text-gray-500";
const hexColor = hexColorMap[schema?.type] || "#6b7280";
return {
displayType,
colorClass,
hexColor,
};
};

View File

@@ -1,6 +1,6 @@
export const uiSchema = {
credentials: {
"ui:field": "custom/credential_field",
"ui:field": "credentials",
provider: { "ui:widget": "hidden" },
type: { "ui:widget": "hidden" },
id: { "ui:autofocus": true },

View File

@@ -24,7 +24,7 @@ export const ControlPanelButton: React.FC<Props> = ({
role={as === "div" ? "button" : undefined}
disabled={as === "button" ? disabled : undefined}
className={cn(
"flex w-auto items-center justify-center whitespace-normal bg-white px-4 py-4 text-zinc-800 shadow-none hover:cursor-pointer hover:bg-zinc-100 hover:text-zinc-950 focus:ring-0",
"flex h-[4.25rem] w-[4.25rem] items-center justify-center whitespace-normal bg-white p-[1.38rem] text-zinc-800 shadow-none hover:cursor-pointer hover:bg-zinc-100 hover:text-zinc-950 focus:ring-0",
selected &&
"bg-violet-50 text-violet-700 hover:cursor-default hover:bg-violet-50 hover:text-violet-700 active:bg-violet-50 active:text-violet-700",
disabled && "cursor-not-allowed opacity-50 hover:cursor-not-allowed",

Some files were not shown because too many files have changed in this diff Show More