mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-13 00:58:16 -05:00
Compare commits
1 Commits
dev
...
fix/librar
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
bfdd23d8f8 |
@@ -1,37 +0,0 @@
|
||||
{
|
||||
"worktreeCopyPatterns": [
|
||||
".env*",
|
||||
".vscode/**",
|
||||
".auth/**",
|
||||
".claude/**",
|
||||
"autogpt_platform/.env*",
|
||||
"autogpt_platform/backend/.env*",
|
||||
"autogpt_platform/frontend/.env*",
|
||||
"autogpt_platform/frontend/.auth/**",
|
||||
"autogpt_platform/db/docker/.env*"
|
||||
],
|
||||
"worktreeCopyIgnores": [
|
||||
"**/node_modules/**",
|
||||
"**/dist/**",
|
||||
"**/.git/**",
|
||||
"**/Thumbs.db",
|
||||
"**/.DS_Store",
|
||||
"**/.next/**",
|
||||
"**/__pycache__/**",
|
||||
"**/.ruff_cache/**",
|
||||
"**/.pytest_cache/**",
|
||||
"**/*.pyc",
|
||||
"**/playwright-report/**",
|
||||
"**/logs/**",
|
||||
"**/site/**"
|
||||
],
|
||||
"worktreePathTemplate": "$BASE_PATH.worktree",
|
||||
"postCreateCmd": [
|
||||
"cd autogpt_platform/autogpt_libs && poetry install",
|
||||
"cd autogpt_platform/backend && poetry install && poetry run prisma generate",
|
||||
"cd autogpt_platform/frontend && pnpm install",
|
||||
"cd docs && pip install -r requirements.txt"
|
||||
],
|
||||
"terminalCommand": "code .",
|
||||
"deleteBranchWithWorktree": false
|
||||
}
|
||||
@@ -16,7 +16,6 @@
|
||||
!autogpt_platform/backend/poetry.lock
|
||||
!autogpt_platform/backend/README.md
|
||||
!autogpt_platform/backend/.env
|
||||
!autogpt_platform/backend/gen_prisma_types_stub.py
|
||||
|
||||
# Platform - Market
|
||||
!autogpt_platform/market/market/
|
||||
|
||||
2
.github/workflows/claude-dependabot.yml
vendored
2
.github/workflows/claude-dependabot.yml
vendored
@@ -74,7 +74,7 @@ jobs:
|
||||
|
||||
- name: Generate Prisma Client
|
||||
working-directory: autogpt_platform/backend
|
||||
run: poetry run prisma generate && poetry run gen-prisma-stub
|
||||
run: poetry run prisma generate
|
||||
|
||||
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
|
||||
- name: Set up Node.js
|
||||
|
||||
2
.github/workflows/claude.yml
vendored
2
.github/workflows/claude.yml
vendored
@@ -90,7 +90,7 @@ jobs:
|
||||
|
||||
- name: Generate Prisma Client
|
||||
working-directory: autogpt_platform/backend
|
||||
run: poetry run prisma generate && poetry run gen-prisma-stub
|
||||
run: poetry run prisma generate
|
||||
|
||||
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
|
||||
- name: Set up Node.js
|
||||
|
||||
12
.github/workflows/copilot-setup-steps.yml
vendored
12
.github/workflows/copilot-setup-steps.yml
vendored
@@ -72,7 +72,7 @@ jobs:
|
||||
|
||||
- name: Generate Prisma Client
|
||||
working-directory: autogpt_platform/backend
|
||||
run: poetry run prisma generate && poetry run gen-prisma-stub
|
||||
run: poetry run prisma generate
|
||||
|
||||
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
|
||||
- name: Set up Node.js
|
||||
@@ -108,16 +108,6 @@ jobs:
|
||||
# run: pnpm playwright install --with-deps chromium
|
||||
|
||||
# Docker setup for development environment
|
||||
- name: Free up disk space
|
||||
run: |
|
||||
# Remove large unused tools to free disk space for Docker builds
|
||||
sudo rm -rf /usr/share/dotnet
|
||||
sudo rm -rf /usr/local/lib/android
|
||||
sudo rm -rf /opt/ghc
|
||||
sudo rm -rf /opt/hostedtoolcache/CodeQL
|
||||
sudo docker system prune -af
|
||||
df -h
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
|
||||
2
.github/workflows/platform-backend-ci.yml
vendored
2
.github/workflows/platform-backend-ci.yml
vendored
@@ -134,7 +134,7 @@ jobs:
|
||||
run: poetry install
|
||||
|
||||
- name: Generate Prisma Client
|
||||
run: poetry run prisma generate && poetry run gen-prisma-stub
|
||||
run: poetry run prisma generate
|
||||
|
||||
- id: supabase
|
||||
name: Start Supabase
|
||||
|
||||
@@ -12,7 +12,6 @@ reset-db:
|
||||
rm -rf db/docker/volumes/db/data
|
||||
cd backend && poetry run prisma migrate deploy
|
||||
cd backend && poetry run prisma generate
|
||||
cd backend && poetry run gen-prisma-stub
|
||||
|
||||
# View logs for core services
|
||||
logs-core:
|
||||
@@ -34,7 +33,6 @@ init-env:
|
||||
migrate:
|
||||
cd backend && poetry run prisma migrate deploy
|
||||
cd backend && poetry run prisma generate
|
||||
cd backend && poetry run gen-prisma-stub
|
||||
|
||||
run-backend:
|
||||
cd backend && poetry run app
|
||||
|
||||
@@ -48,8 +48,7 @@ RUN poetry install --no-ansi --no-root
|
||||
# Generate Prisma client
|
||||
COPY autogpt_platform/backend/schema.prisma ./
|
||||
COPY autogpt_platform/backend/backend/data/partial_types.py ./backend/data/partial_types.py
|
||||
COPY autogpt_platform/backend/gen_prisma_types_stub.py ./
|
||||
RUN poetry run prisma generate && poetry run gen-prisma-stub
|
||||
RUN poetry run prisma generate
|
||||
|
||||
FROM debian:13-slim AS server_dependencies
|
||||
|
||||
|
||||
@@ -119,12 +119,58 @@ async def list_library_agents(
|
||||
f"Retrieved {len(library_agents)} library agents for user #{user_id}"
|
||||
)
|
||||
|
||||
# Batch fetch StoreListings for all agents' graphs
|
||||
graph_ids = {agent.agentGraphId for agent in library_agents if agent.AgentGraph}
|
||||
store_listings_map: dict[str, prisma.models.StoreListing] = {}
|
||||
profiles_map: dict[str, prisma.models.Profile] = {}
|
||||
|
||||
if graph_ids:
|
||||
store_listings = await prisma.models.StoreListing.prisma().find_many(
|
||||
where={
|
||||
"agentGraphId": {"in": list(graph_ids)},
|
||||
"isDeleted": False,
|
||||
"hasApprovedVersion": True,
|
||||
},
|
||||
include={"ActiveVersion": True},
|
||||
)
|
||||
|
||||
# Build map of graph_id -> StoreListing
|
||||
for listing in store_listings:
|
||||
if listing.agentGraphId:
|
||||
store_listings_map[listing.agentGraphId] = listing
|
||||
|
||||
# Fetch profiles for store listing owners
|
||||
owning_user_ids = {
|
||||
listing.owningUserId
|
||||
for listing in store_listings_map.values()
|
||||
if listing.owningUserId
|
||||
}
|
||||
if owning_user_ids:
|
||||
profiles = await prisma.models.Profile.prisma().find_many(
|
||||
where={"userId": {"in": list(owning_user_ids)}}
|
||||
)
|
||||
profiles_map = {profile.userId: profile for profile in profiles}
|
||||
|
||||
# Only pass valid agents to the response
|
||||
valid_library_agents: list[library_model.LibraryAgent] = []
|
||||
|
||||
for agent in library_agents:
|
||||
try:
|
||||
library_agent = library_model.LibraryAgent.from_db(agent)
|
||||
# Get store listing and profile for this agent's graph
|
||||
store_listing = None
|
||||
profile = None
|
||||
if agent.AgentGraph and agent.agentGraphId in store_listings_map:
|
||||
store_listing = store_listings_map[agent.agentGraphId]
|
||||
if (
|
||||
store_listing
|
||||
and store_listing.owningUserId
|
||||
and store_listing.owningUserId in profiles_map
|
||||
):
|
||||
profile = profiles_map[store_listing.owningUserId]
|
||||
|
||||
library_agent = library_model.LibraryAgent.from_db(
|
||||
agent, store_listing=store_listing, profile=profile
|
||||
)
|
||||
valid_library_agents.append(library_agent)
|
||||
except Exception as e:
|
||||
# Skip this agent if there was an error
|
||||
@@ -205,12 +251,58 @@ async def list_favorite_library_agents(
|
||||
f"Retrieved {len(library_agents)} favorite library agents for user #{user_id}"
|
||||
)
|
||||
|
||||
# Batch fetch StoreListings for all agents' graphs
|
||||
graph_ids = {agent.agentGraphId for agent in library_agents if agent.AgentGraph}
|
||||
store_listings_map: dict[str, prisma.models.StoreListing] = {}
|
||||
profiles_map: dict[str, prisma.models.Profile] = {}
|
||||
|
||||
if graph_ids:
|
||||
store_listings = await prisma.models.StoreListing.prisma().find_many(
|
||||
where={
|
||||
"agentGraphId": {"in": list(graph_ids)},
|
||||
"isDeleted": False,
|
||||
"hasApprovedVersion": True,
|
||||
},
|
||||
include={"ActiveVersion": True},
|
||||
)
|
||||
|
||||
# Build map of graph_id -> StoreListing
|
||||
for listing in store_listings:
|
||||
if listing.agentGraphId:
|
||||
store_listings_map[listing.agentGraphId] = listing
|
||||
|
||||
# Fetch profiles for store listing owners
|
||||
owning_user_ids = {
|
||||
listing.owningUserId
|
||||
for listing in store_listings_map.values()
|
||||
if listing.owningUserId
|
||||
}
|
||||
if owning_user_ids:
|
||||
profiles = await prisma.models.Profile.prisma().find_many(
|
||||
where={"userId": {"in": list(owning_user_ids)}}
|
||||
)
|
||||
profiles_map = {profile.userId: profile for profile in profiles}
|
||||
|
||||
# Only pass valid agents to the response
|
||||
valid_library_agents: list[library_model.LibraryAgent] = []
|
||||
|
||||
for agent in library_agents:
|
||||
try:
|
||||
library_agent = library_model.LibraryAgent.from_db(agent)
|
||||
# Get store listing and profile for this agent's graph
|
||||
store_listing = None
|
||||
profile = None
|
||||
if agent.AgentGraph and agent.agentGraphId in store_listings_map:
|
||||
store_listing = store_listings_map[agent.agentGraphId]
|
||||
if (
|
||||
store_listing
|
||||
and store_listing.owningUserId
|
||||
and store_listing.owningUserId in profiles_map
|
||||
):
|
||||
profile = profiles_map[store_listing.owningUserId]
|
||||
|
||||
library_agent = library_model.LibraryAgent.from_db(
|
||||
agent, store_listing=store_listing, profile=profile
|
||||
)
|
||||
valid_library_agents.append(library_agent)
|
||||
except Exception as e:
|
||||
# Skip this agent if there was an error
|
||||
@@ -489,7 +581,7 @@ async def update_agent_version_in_library(
|
||||
agent_graph_version: int,
|
||||
) -> library_model.LibraryAgent:
|
||||
"""
|
||||
Updates the agent version in the library for any agent owned by the user.
|
||||
Updates the agent version in the library if useGraphIsActiveVersion is True.
|
||||
|
||||
Args:
|
||||
user_id: Owner of the LibraryAgent.
|
||||
@@ -498,31 +590,20 @@ async def update_agent_version_in_library(
|
||||
|
||||
Raises:
|
||||
DatabaseError: If there's an error with the update.
|
||||
NotFoundError: If no library agent is found for this user and agent.
|
||||
"""
|
||||
logger.debug(
|
||||
f"Updating agent version in library for user #{user_id}, "
|
||||
f"agent #{agent_graph_id} v{agent_graph_version}"
|
||||
)
|
||||
async with transaction() as tx:
|
||||
library_agent = await prisma.models.LibraryAgent.prisma(tx).find_first_or_raise(
|
||||
try:
|
||||
library_agent = await prisma.models.LibraryAgent.prisma().find_first_or_raise(
|
||||
where={
|
||||
"userId": user_id,
|
||||
"agentGraphId": agent_graph_id,
|
||||
"useGraphIsActiveVersion": True,
|
||||
},
|
||||
)
|
||||
|
||||
# Delete any conflicting LibraryAgent for the target version
|
||||
await prisma.models.LibraryAgent.prisma(tx).delete_many(
|
||||
where={
|
||||
"userId": user_id,
|
||||
"agentGraphId": agent_graph_id,
|
||||
"agentGraphVersion": agent_graph_version,
|
||||
"id": {"not": library_agent.id},
|
||||
}
|
||||
)
|
||||
|
||||
lib = await prisma.models.LibraryAgent.prisma(tx).update(
|
||||
lib = await prisma.models.LibraryAgent.prisma().update(
|
||||
where={"id": library_agent.id},
|
||||
data={
|
||||
"AgentGraph": {
|
||||
@@ -536,13 +617,13 @@ async def update_agent_version_in_library(
|
||||
},
|
||||
include={"AgentGraph": True},
|
||||
)
|
||||
if lib is None:
|
||||
raise NotFoundError(f"Library agent {library_agent.id} not found")
|
||||
|
||||
if lib is None:
|
||||
raise NotFoundError(
|
||||
f"Failed to update library agent for {agent_graph_id} v{agent_graph_version}"
|
||||
)
|
||||
|
||||
return library_model.LibraryAgent.from_db(lib)
|
||||
return library_model.LibraryAgent.from_db(lib)
|
||||
except prisma.errors.PrismaError as e:
|
||||
logger.error(f"Database error updating agent version in library: {e}")
|
||||
raise DatabaseError("Failed to update agent version in library") from e
|
||||
|
||||
|
||||
async def update_library_agent(
|
||||
@@ -836,7 +917,6 @@ async def add_store_agent_to_library(
|
||||
}
|
||||
},
|
||||
"isCreatedByUser": False,
|
||||
"useGraphIsActiveVersion": False,
|
||||
"settings": SafeJson(
|
||||
_initialize_graph_settings(graph_model).model_dump()
|
||||
),
|
||||
|
||||
@@ -48,7 +48,6 @@ class LibraryAgent(pydantic.BaseModel):
|
||||
id: str
|
||||
graph_id: str
|
||||
graph_version: int
|
||||
owner_user_id: str # ID of user who owns/created this agent graph
|
||||
|
||||
image_url: str | None
|
||||
|
||||
@@ -164,7 +163,6 @@ class LibraryAgent(pydantic.BaseModel):
|
||||
id=agent.id,
|
||||
graph_id=agent.agentGraphId,
|
||||
graph_version=agent.agentGraphVersion,
|
||||
owner_user_id=agent.userId,
|
||||
image_url=agent.imageUrl,
|
||||
creator_name=creator_name,
|
||||
creator_image_url=creator_image_url,
|
||||
|
||||
@@ -42,7 +42,6 @@ async def test_get_library_agents_success(
|
||||
id="test-agent-1",
|
||||
graph_id="test-agent-1",
|
||||
graph_version=1,
|
||||
owner_user_id=test_user_id,
|
||||
name="Test Agent 1",
|
||||
description="Test Description 1",
|
||||
image_url=None,
|
||||
@@ -65,7 +64,6 @@ async def test_get_library_agents_success(
|
||||
id="test-agent-2",
|
||||
graph_id="test-agent-2",
|
||||
graph_version=1,
|
||||
owner_user_id=test_user_id,
|
||||
name="Test Agent 2",
|
||||
description="Test Description 2",
|
||||
image_url=None,
|
||||
@@ -140,7 +138,6 @@ async def test_get_favorite_library_agents_success(
|
||||
id="test-agent-1",
|
||||
graph_id="test-agent-1",
|
||||
graph_version=1,
|
||||
owner_user_id=test_user_id,
|
||||
name="Favorite Agent 1",
|
||||
description="Test Favorite Description 1",
|
||||
image_url=None,
|
||||
@@ -208,7 +205,6 @@ def test_add_agent_to_library_success(
|
||||
id="test-library-agent-id",
|
||||
graph_id="test-agent-1",
|
||||
graph_version=1,
|
||||
owner_user_id=test_user_id,
|
||||
name="Test Agent 1",
|
||||
description="Test Description 1",
|
||||
image_url=None,
|
||||
|
||||
@@ -614,7 +614,6 @@ async def get_store_submissions(
|
||||
submission_models = []
|
||||
for sub in submissions:
|
||||
submission_model = store_model.StoreSubmission(
|
||||
listing_id=sub.listing_id,
|
||||
agent_id=sub.agent_id,
|
||||
agent_version=sub.agent_version,
|
||||
name=sub.name,
|
||||
@@ -668,48 +667,35 @@ async def delete_store_submission(
|
||||
submission_id: str,
|
||||
) -> bool:
|
||||
"""
|
||||
Delete a store submission version as the submitting user.
|
||||
Delete a store listing submission as the submitting user.
|
||||
|
||||
Args:
|
||||
user_id: ID of the authenticated user
|
||||
submission_id: StoreListingVersion ID to delete
|
||||
submission_id: ID of the submission to be deleted
|
||||
|
||||
Returns:
|
||||
bool: True if successfully deleted
|
||||
bool: True if the submission was successfully deleted, False otherwise
|
||||
"""
|
||||
logger.debug(f"Deleting store submission {submission_id} for user {user_id}")
|
||||
|
||||
try:
|
||||
# Find the submission version with ownership check
|
||||
version = await prisma.models.StoreListingVersion.prisma().find_first(
|
||||
where={"id": submission_id}, include={"StoreListing": True}
|
||||
# Verify the submission belongs to this user
|
||||
submission = await prisma.models.StoreListing.prisma().find_first(
|
||||
where={"agentGraphId": submission_id, "owningUserId": user_id}
|
||||
)
|
||||
|
||||
if (
|
||||
not version
|
||||
or not version.StoreListing
|
||||
or version.StoreListing.owningUserId != user_id
|
||||
):
|
||||
raise store_exceptions.SubmissionNotFoundError("Submission not found")
|
||||
|
||||
# Prevent deletion of approved submissions
|
||||
if version.submissionStatus == prisma.enums.SubmissionStatus.APPROVED:
|
||||
raise store_exceptions.InvalidOperationError(
|
||||
"Cannot delete approved submissions"
|
||||
if not submission:
|
||||
logger.warning(f"Submission not found for user {user_id}: {submission_id}")
|
||||
raise store_exceptions.SubmissionNotFoundError(
|
||||
f"Submission not found for this user. User ID: {user_id}, Submission ID: {submission_id}"
|
||||
)
|
||||
|
||||
# Delete the version
|
||||
await prisma.models.StoreListingVersion.prisma().delete(
|
||||
where={"id": version.id}
|
||||
)
|
||||
# Delete the submission
|
||||
await prisma.models.StoreListing.prisma().delete(where={"id": submission.id})
|
||||
|
||||
# Clean up empty listing if this was the last version
|
||||
remaining = await prisma.models.StoreListingVersion.prisma().count(
|
||||
where={"storeListingId": version.storeListingId}
|
||||
logger.debug(
|
||||
f"Successfully deleted submission {submission_id} for user {user_id}"
|
||||
)
|
||||
if remaining == 0:
|
||||
await prisma.models.StoreListing.prisma().delete(
|
||||
where={"id": version.storeListingId}
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
@@ -773,15 +759,9 @@ async def create_store_submission(
|
||||
logger.warning(
|
||||
f"Agent not found for user {user_id}: {agent_id} v{agent_version}"
|
||||
)
|
||||
# Provide more user-friendly error message when agent_id is empty
|
||||
if not agent_id or agent_id.strip() == "":
|
||||
raise store_exceptions.AgentNotFoundError(
|
||||
"No agent selected. Please select an agent before submitting to the store."
|
||||
)
|
||||
else:
|
||||
raise store_exceptions.AgentNotFoundError(
|
||||
f"Agent not found for this user. User ID: {user_id}, Agent ID: {agent_id}, Version: {agent_version}"
|
||||
)
|
||||
raise store_exceptions.AgentNotFoundError(
|
||||
f"Agent not found for this user. User ID: {user_id}, Agent ID: {agent_id}, Version: {agent_version}"
|
||||
)
|
||||
|
||||
# Check if listing already exists for this agent
|
||||
existing_listing = await prisma.models.StoreListing.prisma().find_first(
|
||||
@@ -853,7 +833,6 @@ async def create_store_submission(
|
||||
logger.debug(f"Created store listing for agent {agent_id}")
|
||||
# Return submission details
|
||||
return store_model.StoreSubmission(
|
||||
listing_id=listing.id,
|
||||
agent_id=agent_id,
|
||||
agent_version=agent_version,
|
||||
name=name,
|
||||
@@ -965,56 +944,81 @@ async def edit_store_submission(
|
||||
# Currently we are not allowing user to update the agent associated with a submission
|
||||
# If we allow it in future, then we need a check here to verify the agent belongs to this user.
|
||||
|
||||
# Only allow editing of PENDING submissions
|
||||
if current_version.submissionStatus != prisma.enums.SubmissionStatus.PENDING:
|
||||
# Check if we can edit this submission
|
||||
if current_version.submissionStatus == prisma.enums.SubmissionStatus.REJECTED:
|
||||
raise store_exceptions.InvalidOperationError(
|
||||
f"Cannot edit a {current_version.submissionStatus.value.lower()} submission. Only pending submissions can be edited."
|
||||
"Cannot edit a rejected submission"
|
||||
)
|
||||
|
||||
# For APPROVED submissions, we need to create a new version
|
||||
if current_version.submissionStatus == prisma.enums.SubmissionStatus.APPROVED:
|
||||
# Create a new version for the existing listing
|
||||
return await create_store_version(
|
||||
user_id=user_id,
|
||||
agent_id=current_version.agentGraphId,
|
||||
agent_version=current_version.agentGraphVersion,
|
||||
store_listing_id=current_version.storeListingId,
|
||||
name=name,
|
||||
video_url=video_url,
|
||||
agent_output_demo_url=agent_output_demo_url,
|
||||
image_urls=image_urls,
|
||||
description=description,
|
||||
sub_heading=sub_heading,
|
||||
categories=categories,
|
||||
changes_summary=changes_summary,
|
||||
recommended_schedule_cron=recommended_schedule_cron,
|
||||
instructions=instructions,
|
||||
)
|
||||
|
||||
# For PENDING submissions, we can update the existing version
|
||||
# Update the existing version
|
||||
updated_version = await prisma.models.StoreListingVersion.prisma().update(
|
||||
where={"id": store_listing_version_id},
|
||||
data=prisma.types.StoreListingVersionUpdateInput(
|
||||
elif current_version.submissionStatus == prisma.enums.SubmissionStatus.PENDING:
|
||||
# Update the existing version
|
||||
updated_version = await prisma.models.StoreListingVersion.prisma().update(
|
||||
where={"id": store_listing_version_id},
|
||||
data=prisma.types.StoreListingVersionUpdateInput(
|
||||
name=name,
|
||||
videoUrl=video_url,
|
||||
agentOutputDemoUrl=agent_output_demo_url,
|
||||
imageUrls=image_urls,
|
||||
description=description,
|
||||
categories=categories,
|
||||
subHeading=sub_heading,
|
||||
changesSummary=changes_summary,
|
||||
recommendedScheduleCron=recommended_schedule_cron,
|
||||
instructions=instructions,
|
||||
),
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Updated existing version {store_listing_version_id} for agent {current_version.agentGraphId}"
|
||||
)
|
||||
|
||||
if not updated_version:
|
||||
raise DatabaseError("Failed to update store listing version")
|
||||
return store_model.StoreSubmission(
|
||||
agent_id=current_version.agentGraphId,
|
||||
agent_version=current_version.agentGraphVersion,
|
||||
name=name,
|
||||
videoUrl=video_url,
|
||||
agentOutputDemoUrl=agent_output_demo_url,
|
||||
imageUrls=image_urls,
|
||||
sub_heading=sub_heading,
|
||||
slug=current_version.StoreListing.slug,
|
||||
description=description,
|
||||
categories=categories,
|
||||
subHeading=sub_heading,
|
||||
changesSummary=changes_summary,
|
||||
recommendedScheduleCron=recommended_schedule_cron,
|
||||
instructions=instructions,
|
||||
),
|
||||
)
|
||||
image_urls=image_urls,
|
||||
date_submitted=updated_version.submittedAt or updated_version.createdAt,
|
||||
status=updated_version.submissionStatus,
|
||||
runs=0,
|
||||
rating=0.0,
|
||||
store_listing_version_id=updated_version.id,
|
||||
changes_summary=changes_summary,
|
||||
video_url=video_url,
|
||||
categories=categories,
|
||||
version=updated_version.version,
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Updated existing version {store_listing_version_id} for agent {current_version.agentGraphId}"
|
||||
)
|
||||
|
||||
if not updated_version:
|
||||
raise DatabaseError("Failed to update store listing version")
|
||||
return store_model.StoreSubmission(
|
||||
listing_id=current_version.StoreListing.id,
|
||||
agent_id=current_version.agentGraphId,
|
||||
agent_version=current_version.agentGraphVersion,
|
||||
name=name,
|
||||
sub_heading=sub_heading,
|
||||
slug=current_version.StoreListing.slug,
|
||||
description=description,
|
||||
instructions=instructions,
|
||||
image_urls=image_urls,
|
||||
date_submitted=updated_version.submittedAt or updated_version.createdAt,
|
||||
status=updated_version.submissionStatus,
|
||||
runs=0,
|
||||
rating=0.0,
|
||||
store_listing_version_id=updated_version.id,
|
||||
changes_summary=changes_summary,
|
||||
video_url=video_url,
|
||||
categories=categories,
|
||||
version=updated_version.version,
|
||||
)
|
||||
else:
|
||||
raise store_exceptions.InvalidOperationError(
|
||||
f"Cannot edit submission with status: {current_version.submissionStatus}"
|
||||
)
|
||||
|
||||
except (
|
||||
store_exceptions.SubmissionNotFoundError,
|
||||
@@ -1093,78 +1097,38 @@ async def create_store_version(
|
||||
f"Agent not found for this user. User ID: {user_id}, Agent ID: {agent_id}, Version: {agent_version}"
|
||||
)
|
||||
|
||||
# Check if there's already a PENDING submission for this agent (any version)
|
||||
existing_pending_submission = (
|
||||
await prisma.models.StoreListingVersion.prisma().find_first(
|
||||
where=prisma.types.StoreListingVersionWhereInput(
|
||||
storeListingId=store_listing_id,
|
||||
agentGraphId=agent_id,
|
||||
submissionStatus=prisma.enums.SubmissionStatus.PENDING,
|
||||
isDeleted=False,
|
||||
)
|
||||
# Get the latest version number
|
||||
latest_version = listing.Versions[0] if listing.Versions else None
|
||||
|
||||
next_version = (latest_version.version + 1) if latest_version else 1
|
||||
|
||||
# Create a new version for the existing listing
|
||||
new_version = await prisma.models.StoreListingVersion.prisma().create(
|
||||
data=prisma.types.StoreListingVersionCreateInput(
|
||||
version=next_version,
|
||||
agentGraphId=agent_id,
|
||||
agentGraphVersion=agent_version,
|
||||
name=name,
|
||||
videoUrl=video_url,
|
||||
agentOutputDemoUrl=agent_output_demo_url,
|
||||
imageUrls=image_urls,
|
||||
description=description,
|
||||
instructions=instructions,
|
||||
categories=categories,
|
||||
subHeading=sub_heading,
|
||||
submissionStatus=prisma.enums.SubmissionStatus.PENDING,
|
||||
submittedAt=datetime.now(),
|
||||
changesSummary=changes_summary,
|
||||
recommendedScheduleCron=recommended_schedule_cron,
|
||||
storeListingId=store_listing_id,
|
||||
)
|
||||
)
|
||||
|
||||
# Handle existing pending submission and create new one atomically
|
||||
async with transaction() as tx:
|
||||
# Get the latest version number first
|
||||
latest_listing = await prisma.models.StoreListing.prisma(tx).find_first(
|
||||
where=prisma.types.StoreListingWhereInput(
|
||||
id=store_listing_id, owningUserId=user_id
|
||||
),
|
||||
include={"Versions": {"order_by": {"version": "desc"}, "take": 1}},
|
||||
)
|
||||
|
||||
if not latest_listing:
|
||||
raise store_exceptions.ListingNotFoundError(
|
||||
f"Store listing not found. User ID: {user_id}, Listing ID: {store_listing_id}"
|
||||
)
|
||||
|
||||
latest_version = (
|
||||
latest_listing.Versions[0] if latest_listing.Versions else None
|
||||
)
|
||||
next_version = (latest_version.version + 1) if latest_version else 1
|
||||
|
||||
# If there's an existing pending submission, delete it atomically before creating new one
|
||||
if existing_pending_submission:
|
||||
logger.info(
|
||||
f"Found existing PENDING submission for agent {agent_id} (was v{existing_pending_submission.agentGraphVersion}, now v{agent_version}), replacing existing submission instead of creating duplicate"
|
||||
)
|
||||
await prisma.models.StoreListingVersion.prisma(tx).delete(
|
||||
where={"id": existing_pending_submission.id}
|
||||
)
|
||||
logger.debug(
|
||||
f"Deleted existing pending submission {existing_pending_submission.id}"
|
||||
)
|
||||
|
||||
# Create a new version for the existing listing
|
||||
new_version = await prisma.models.StoreListingVersion.prisma(tx).create(
|
||||
data=prisma.types.StoreListingVersionCreateInput(
|
||||
version=next_version,
|
||||
agentGraphId=agent_id,
|
||||
agentGraphVersion=agent_version,
|
||||
name=name,
|
||||
videoUrl=video_url,
|
||||
agentOutputDemoUrl=agent_output_demo_url,
|
||||
imageUrls=image_urls,
|
||||
description=description,
|
||||
instructions=instructions,
|
||||
categories=categories,
|
||||
subHeading=sub_heading,
|
||||
submissionStatus=prisma.enums.SubmissionStatus.PENDING,
|
||||
submittedAt=datetime.now(),
|
||||
changesSummary=changes_summary,
|
||||
recommendedScheduleCron=recommended_schedule_cron,
|
||||
storeListingId=store_listing_id,
|
||||
)
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Created new version for listing {store_listing_id} of agent {agent_id}"
|
||||
)
|
||||
# Return submission details
|
||||
return store_model.StoreSubmission(
|
||||
listing_id=listing.id,
|
||||
agent_id=agent_id,
|
||||
agent_version=agent_version,
|
||||
name=name,
|
||||
@@ -1744,12 +1708,15 @@ async def review_store_submission(
|
||||
|
||||
# Convert to Pydantic model for consistency
|
||||
return store_model.StoreSubmission(
|
||||
listing_id=(submission.StoreListing.id if submission.StoreListing else ""),
|
||||
agent_id=submission.agentGraphId,
|
||||
agent_version=submission.agentGraphVersion,
|
||||
name=submission.name,
|
||||
sub_heading=submission.subHeading,
|
||||
slug=(submission.StoreListing.slug if submission.StoreListing else ""),
|
||||
slug=(
|
||||
submission.StoreListing.slug
|
||||
if hasattr(submission, "storeListing") and submission.StoreListing
|
||||
else ""
|
||||
),
|
||||
description=submission.description,
|
||||
instructions=submission.instructions,
|
||||
image_urls=submission.imageUrls or [],
|
||||
@@ -1851,7 +1818,9 @@ async def get_admin_listings_with_versions(
|
||||
where = prisma.types.StoreListingWhereInput(**where_dict)
|
||||
include = prisma.types.StoreListingInclude(
|
||||
Versions=prisma.types.FindManyStoreListingVersionArgsFromStoreListing(
|
||||
order_by={"version": "desc"}
|
||||
order_by=prisma.types._StoreListingVersion_version_OrderByInput(
|
||||
version="desc"
|
||||
)
|
||||
),
|
||||
OwningUser=True,
|
||||
)
|
||||
@@ -1876,7 +1845,6 @@ async def get_admin_listings_with_versions(
|
||||
# If we have versions, turn them into StoreSubmission models
|
||||
for version in listing.Versions or []:
|
||||
version_model = store_model.StoreSubmission(
|
||||
listing_id=listing.id,
|
||||
agent_id=version.agentGraphId,
|
||||
agent_version=version.agentGraphVersion,
|
||||
name=version.name,
|
||||
|
||||
@@ -110,7 +110,6 @@ class Profile(pydantic.BaseModel):
|
||||
|
||||
|
||||
class StoreSubmission(pydantic.BaseModel):
|
||||
listing_id: str
|
||||
agent_id: str
|
||||
agent_version: int
|
||||
name: str
|
||||
@@ -165,12 +164,8 @@ class StoreListingsWithVersionsResponse(pydantic.BaseModel):
|
||||
|
||||
|
||||
class StoreSubmissionRequest(pydantic.BaseModel):
|
||||
agent_id: str = pydantic.Field(
|
||||
..., min_length=1, description="Agent ID cannot be empty"
|
||||
)
|
||||
agent_version: int = pydantic.Field(
|
||||
..., gt=0, description="Agent version must be greater than 0"
|
||||
)
|
||||
agent_id: str
|
||||
agent_version: int
|
||||
slug: str
|
||||
name: str
|
||||
sub_heading: str
|
||||
|
||||
@@ -138,7 +138,6 @@ def test_creator_details():
|
||||
|
||||
def test_store_submission():
|
||||
submission = store_model.StoreSubmission(
|
||||
listing_id="listing123",
|
||||
agent_id="agent123",
|
||||
agent_version=1,
|
||||
sub_heading="Test subheading",
|
||||
@@ -160,7 +159,6 @@ def test_store_submissions_response():
|
||||
response = store_model.StoreSubmissionsResponse(
|
||||
submissions=[
|
||||
store_model.StoreSubmission(
|
||||
listing_id="listing123",
|
||||
agent_id="agent123",
|
||||
agent_version=1,
|
||||
sub_heading="Test subheading",
|
||||
|
||||
@@ -521,7 +521,6 @@ def test_get_submissions_success(
|
||||
mocked_value = store_model.StoreSubmissionsResponse(
|
||||
submissions=[
|
||||
store_model.StoreSubmission(
|
||||
listing_id="test-listing-id",
|
||||
name="Test Agent",
|
||||
description="Test agent description",
|
||||
image_urls=["test.jpg"],
|
||||
|
||||
@@ -6,9 +6,6 @@ import hashlib
|
||||
import hmac
|
||||
import logging
|
||||
from enum import Enum
|
||||
from typing import cast
|
||||
|
||||
from prisma.types import Serializable
|
||||
|
||||
from backend.sdk import (
|
||||
BaseWebhooksManager,
|
||||
@@ -87,9 +84,7 @@ class AirtableWebhookManager(BaseWebhooksManager):
|
||||
# update webhook config
|
||||
await update_webhook(
|
||||
webhook.id,
|
||||
config=cast(
|
||||
dict[str, Serializable], {"base_id": base_id, "cursor": response.cursor}
|
||||
),
|
||||
config={"base_id": base_id, "cursor": response.cursor},
|
||||
)
|
||||
|
||||
event_type = "notification"
|
||||
|
||||
@@ -1,184 +0,0 @@
|
||||
"""
|
||||
Shared helpers for Human-In-The-Loop (HITL) review functionality.
|
||||
Used by both the dedicated HumanInTheLoopBlock and blocks that require human review.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
|
||||
from prisma.enums import ReviewStatus
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.execution import ExecutionContext, ExecutionStatus
|
||||
from backend.data.human_review import ReviewResult
|
||||
from backend.executor.manager import async_update_node_execution_status
|
||||
from backend.util.clients import get_database_manager_async_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ReviewDecision(BaseModel):
|
||||
"""Result of a review decision."""
|
||||
|
||||
should_proceed: bool
|
||||
message: str
|
||||
review_result: ReviewResult
|
||||
|
||||
|
||||
class HITLReviewHelper:
|
||||
"""Helper class for Human-In-The-Loop review operations."""
|
||||
|
||||
@staticmethod
|
||||
async def get_or_create_human_review(**kwargs) -> Optional[ReviewResult]:
|
||||
"""Create or retrieve a human review from the database."""
|
||||
return await get_database_manager_async_client().get_or_create_human_review(
|
||||
**kwargs
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def update_node_execution_status(**kwargs) -> None:
|
||||
"""Update the execution status of a node."""
|
||||
await async_update_node_execution_status(
|
||||
db_client=get_database_manager_async_client(), **kwargs
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def update_review_processed_status(
|
||||
node_exec_id: str, processed: bool
|
||||
) -> None:
|
||||
"""Update the processed status of a review."""
|
||||
return await get_database_manager_async_client().update_review_processed_status(
|
||||
node_exec_id, processed
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def _handle_review_request(
|
||||
input_data: Any,
|
||||
user_id: str,
|
||||
node_exec_id: str,
|
||||
graph_exec_id: str,
|
||||
graph_id: str,
|
||||
graph_version: int,
|
||||
execution_context: ExecutionContext,
|
||||
block_name: str = "Block",
|
||||
editable: bool = False,
|
||||
) -> Optional[ReviewResult]:
|
||||
"""
|
||||
Handle a review request for a block that requires human review.
|
||||
|
||||
Args:
|
||||
input_data: The input data to be reviewed
|
||||
user_id: ID of the user requesting the review
|
||||
node_exec_id: ID of the node execution
|
||||
graph_exec_id: ID of the graph execution
|
||||
graph_id: ID of the graph
|
||||
graph_version: Version of the graph
|
||||
execution_context: Current execution context
|
||||
block_name: Name of the block requesting review
|
||||
editable: Whether the reviewer can edit the data
|
||||
|
||||
Returns:
|
||||
ReviewResult if review is complete, None if waiting for human input
|
||||
|
||||
Raises:
|
||||
Exception: If review creation or status update fails
|
||||
"""
|
||||
# Skip review if safe mode is disabled - return auto-approved result
|
||||
if not execution_context.safe_mode:
|
||||
logger.info(
|
||||
f"Block {block_name} skipping review for node {node_exec_id} - safe mode disabled"
|
||||
)
|
||||
return ReviewResult(
|
||||
data=input_data,
|
||||
status=ReviewStatus.APPROVED,
|
||||
message="Auto-approved (safe mode disabled)",
|
||||
processed=True,
|
||||
node_exec_id=node_exec_id,
|
||||
)
|
||||
|
||||
result = await HITLReviewHelper.get_or_create_human_review(
|
||||
user_id=user_id,
|
||||
node_exec_id=node_exec_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
graph_id=graph_id,
|
||||
graph_version=graph_version,
|
||||
input_data=input_data,
|
||||
message=f"Review required for {block_name} execution",
|
||||
editable=editable,
|
||||
)
|
||||
|
||||
if result is None:
|
||||
logger.info(
|
||||
f"Block {block_name} pausing execution for node {node_exec_id} - awaiting human review"
|
||||
)
|
||||
await HITLReviewHelper.update_node_execution_status(
|
||||
exec_id=node_exec_id,
|
||||
status=ExecutionStatus.REVIEW,
|
||||
)
|
||||
return None # Signal that execution should pause
|
||||
|
||||
# Mark review as processed if not already done
|
||||
if not result.processed:
|
||||
await HITLReviewHelper.update_review_processed_status(
|
||||
node_exec_id=node_exec_id, processed=True
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
async def handle_review_decision(
|
||||
input_data: Any,
|
||||
user_id: str,
|
||||
node_exec_id: str,
|
||||
graph_exec_id: str,
|
||||
graph_id: str,
|
||||
graph_version: int,
|
||||
execution_context: ExecutionContext,
|
||||
block_name: str = "Block",
|
||||
editable: bool = False,
|
||||
) -> Optional[ReviewDecision]:
|
||||
"""
|
||||
Handle a review request and return the decision in a single call.
|
||||
|
||||
Args:
|
||||
input_data: The input data to be reviewed
|
||||
user_id: ID of the user requesting the review
|
||||
node_exec_id: ID of the node execution
|
||||
graph_exec_id: ID of the graph execution
|
||||
graph_id: ID of the graph
|
||||
graph_version: Version of the graph
|
||||
execution_context: Current execution context
|
||||
block_name: Name of the block requesting review
|
||||
editable: Whether the reviewer can edit the data
|
||||
|
||||
Returns:
|
||||
ReviewDecision if review is complete (approved/rejected),
|
||||
None if execution should pause (awaiting review)
|
||||
"""
|
||||
review_result = await HITLReviewHelper._handle_review_request(
|
||||
input_data=input_data,
|
||||
user_id=user_id,
|
||||
node_exec_id=node_exec_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
graph_id=graph_id,
|
||||
graph_version=graph_version,
|
||||
execution_context=execution_context,
|
||||
block_name=block_name,
|
||||
editable=editable,
|
||||
)
|
||||
|
||||
if review_result is None:
|
||||
# Still awaiting review - return None to pause execution
|
||||
return None
|
||||
|
||||
# Review is complete, determine outcome
|
||||
should_proceed = review_result.status == ReviewStatus.APPROVED
|
||||
message = review_result.message or (
|
||||
"Execution approved by reviewer"
|
||||
if should_proceed
|
||||
else "Execution rejected by reviewer"
|
||||
)
|
||||
|
||||
return ReviewDecision(
|
||||
should_proceed=should_proceed, message=message, review_result=review_result
|
||||
)
|
||||
@@ -3,7 +3,6 @@ from typing import Any
|
||||
|
||||
from prisma.enums import ReviewStatus
|
||||
|
||||
from backend.blocks.helpers.review import HITLReviewHelper
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
@@ -12,9 +11,11 @@ from backend.data.block import (
|
||||
BlockSchemaOutput,
|
||||
BlockType,
|
||||
)
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.data.execution import ExecutionContext, ExecutionStatus
|
||||
from backend.data.human_review import ReviewResult
|
||||
from backend.data.model import SchemaField
|
||||
from backend.executor.manager import async_update_node_execution_status
|
||||
from backend.util.clients import get_database_manager_async_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -71,26 +72,32 @@ class HumanInTheLoopBlock(Block):
|
||||
("approved_data", {"name": "John Doe", "age": 30}),
|
||||
],
|
||||
test_mock={
|
||||
"handle_review_decision": lambda **kwargs: type(
|
||||
"ReviewDecision",
|
||||
(),
|
||||
{
|
||||
"should_proceed": True,
|
||||
"message": "Test approval message",
|
||||
"review_result": ReviewResult(
|
||||
data={"name": "John Doe", "age": 30},
|
||||
status=ReviewStatus.APPROVED,
|
||||
message="",
|
||||
processed=False,
|
||||
node_exec_id="test-node-exec-id",
|
||||
),
|
||||
},
|
||||
)(),
|
||||
"get_or_create_human_review": lambda *_args, **_kwargs: ReviewResult(
|
||||
data={"name": "John Doe", "age": 30},
|
||||
status=ReviewStatus.APPROVED,
|
||||
message="",
|
||||
processed=False,
|
||||
node_exec_id="test-node-exec-id",
|
||||
),
|
||||
"update_node_execution_status": lambda *_args, **_kwargs: None,
|
||||
"update_review_processed_status": lambda *_args, **_kwargs: None,
|
||||
},
|
||||
)
|
||||
|
||||
async def handle_review_decision(self, **kwargs):
|
||||
return await HITLReviewHelper.handle_review_decision(**kwargs)
|
||||
async def get_or_create_human_review(self, **kwargs):
|
||||
return await get_database_manager_async_client().get_or_create_human_review(
|
||||
**kwargs
|
||||
)
|
||||
|
||||
async def update_node_execution_status(self, **kwargs):
|
||||
return await async_update_node_execution_status(
|
||||
db_client=get_database_manager_async_client(), **kwargs
|
||||
)
|
||||
|
||||
async def update_review_processed_status(self, node_exec_id: str, processed: bool):
|
||||
return await get_database_manager_async_client().update_review_processed_status(
|
||||
node_exec_id, processed
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
@@ -102,7 +109,7 @@ class HumanInTheLoopBlock(Block):
|
||||
graph_id: str,
|
||||
graph_version: int,
|
||||
execution_context: ExecutionContext,
|
||||
**_kwargs,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
if not execution_context.safe_mode:
|
||||
logger.info(
|
||||
@@ -112,28 +119,48 @@ class HumanInTheLoopBlock(Block):
|
||||
yield "review_message", "Auto-approved (safe mode disabled)"
|
||||
return
|
||||
|
||||
decision = await self.handle_review_decision(
|
||||
input_data=input_data.data,
|
||||
user_id=user_id,
|
||||
node_exec_id=node_exec_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
graph_id=graph_id,
|
||||
graph_version=graph_version,
|
||||
execution_context=execution_context,
|
||||
block_name=self.name,
|
||||
editable=input_data.editable,
|
||||
)
|
||||
try:
|
||||
result = await self.get_or_create_human_review(
|
||||
user_id=user_id,
|
||||
node_exec_id=node_exec_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
graph_id=graph_id,
|
||||
graph_version=graph_version,
|
||||
input_data=input_data.data,
|
||||
message=input_data.name,
|
||||
editable=input_data.editable,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in HITL block for node {node_exec_id}: {str(e)}")
|
||||
raise
|
||||
|
||||
if decision is None:
|
||||
return
|
||||
if result is None:
|
||||
logger.info(
|
||||
f"HITL block pausing execution for node {node_exec_id} - awaiting human review"
|
||||
)
|
||||
try:
|
||||
await self.update_node_execution_status(
|
||||
exec_id=node_exec_id,
|
||||
status=ExecutionStatus.REVIEW,
|
||||
)
|
||||
return
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to update node status for HITL block {node_exec_id}: {str(e)}"
|
||||
)
|
||||
raise
|
||||
|
||||
status = decision.review_result.status
|
||||
if status == ReviewStatus.APPROVED:
|
||||
yield "approved_data", decision.review_result.data
|
||||
elif status == ReviewStatus.REJECTED:
|
||||
yield "rejected_data", decision.review_result.data
|
||||
else:
|
||||
raise RuntimeError(f"Unexpected review status: {status}")
|
||||
if not result.processed:
|
||||
await self.update_review_processed_status(
|
||||
node_exec_id=node_exec_id, processed=True
|
||||
)
|
||||
|
||||
if decision.message:
|
||||
yield "review_message", decision.message
|
||||
if result.status == ReviewStatus.APPROVED:
|
||||
yield "approved_data", result.data
|
||||
if result.message:
|
||||
yield "review_message", result.message
|
||||
|
||||
elif result.status == ReviewStatus.REJECTED:
|
||||
yield "rejected_data", result.data
|
||||
if result.message:
|
||||
yield "review_message", result.message
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -391,12 +391,8 @@ class SmartDecisionMakerBlock(Block):
|
||||
"""
|
||||
block = sink_node.block
|
||||
|
||||
# Use custom name from node metadata if set, otherwise fall back to block.name
|
||||
custom_name = sink_node.metadata.get("customized_name")
|
||||
tool_name = custom_name if custom_name else block.name
|
||||
|
||||
tool_function: dict[str, Any] = {
|
||||
"name": SmartDecisionMakerBlock.cleanup(tool_name),
|
||||
"name": SmartDecisionMakerBlock.cleanup(block.name),
|
||||
"description": block.description,
|
||||
}
|
||||
sink_block_input_schema = block.input_schema
|
||||
@@ -493,24 +489,14 @@ class SmartDecisionMakerBlock(Block):
|
||||
f"Sink graph metadata not found: {graph_id} {graph_version}"
|
||||
)
|
||||
|
||||
# Use custom name from node metadata if set, otherwise fall back to graph name
|
||||
custom_name = sink_node.metadata.get("customized_name")
|
||||
tool_name = custom_name if custom_name else sink_graph_meta.name
|
||||
|
||||
tool_function: dict[str, Any] = {
|
||||
"name": SmartDecisionMakerBlock.cleanup(tool_name),
|
||||
"name": SmartDecisionMakerBlock.cleanup(sink_graph_meta.name),
|
||||
"description": sink_graph_meta.description,
|
||||
}
|
||||
|
||||
properties = {}
|
||||
field_mapping = {}
|
||||
|
||||
for link in links:
|
||||
field_name = link.sink_name
|
||||
|
||||
clean_field_name = SmartDecisionMakerBlock.cleanup(field_name)
|
||||
field_mapping[clean_field_name] = field_name
|
||||
|
||||
sink_block_input_schema = sink_node.input_default["input_schema"]
|
||||
sink_block_properties = sink_block_input_schema.get("properties", {}).get(
|
||||
link.sink_name, {}
|
||||
@@ -520,7 +506,7 @@ class SmartDecisionMakerBlock(Block):
|
||||
if "description" in sink_block_properties
|
||||
else f"The {link.sink_name} of the tool"
|
||||
)
|
||||
properties[clean_field_name] = {
|
||||
properties[link.sink_name] = {
|
||||
"type": "string",
|
||||
"description": description,
|
||||
"default": json.dumps(sink_block_properties.get("default", None)),
|
||||
@@ -533,7 +519,7 @@ class SmartDecisionMakerBlock(Block):
|
||||
"strict": True,
|
||||
}
|
||||
|
||||
tool_function["_field_mapping"] = field_mapping
|
||||
# Store node info for later use in output processing
|
||||
tool_function["_sink_node_id"] = sink_node.id
|
||||
|
||||
return {"type": "function", "function": tool_function}
|
||||
@@ -989,28 +975,10 @@ class SmartDecisionMakerBlock(Block):
|
||||
graph_version: int,
|
||||
execution_context: ExecutionContext,
|
||||
execution_processor: "ExecutionProcessor",
|
||||
nodes_to_skip: set[str] | None = None,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
|
||||
tool_functions = await self._create_tool_node_signatures(node_id)
|
||||
original_tool_count = len(tool_functions)
|
||||
|
||||
# Filter out tools for nodes that should be skipped (e.g., missing optional credentials)
|
||||
if nodes_to_skip:
|
||||
tool_functions = [
|
||||
tf
|
||||
for tf in tool_functions
|
||||
if tf.get("function", {}).get("_sink_node_id") not in nodes_to_skip
|
||||
]
|
||||
|
||||
# Only raise error if we had tools but they were all filtered out
|
||||
if original_tool_count > 0 and not tool_functions:
|
||||
raise ValueError(
|
||||
"No available tools to execute - all downstream nodes are unavailable "
|
||||
"(possibly due to missing optional credentials)"
|
||||
)
|
||||
|
||||
yield "tool_functions", json.dumps(tool_functions)
|
||||
|
||||
conversation_history = input_data.conversation_history or []
|
||||
@@ -1161,9 +1129,8 @@ class SmartDecisionMakerBlock(Block):
|
||||
original_field_name = field_mapping.get(clean_arg_name, clean_arg_name)
|
||||
arg_value = tool_args.get(clean_arg_name)
|
||||
|
||||
# Use original_field_name directly (not sanitized) to match link sink_name
|
||||
# The field_mapping already translates from LLM's cleaned names to original names
|
||||
emit_key = f"tools_^_{sink_node_id}_~_{original_field_name}"
|
||||
sanitized_arg_name = self.cleanup(original_field_name)
|
||||
emit_key = f"tools_^_{sink_node_id}_~_{sanitized_arg_name}"
|
||||
|
||||
logger.debug(
|
||||
"[SmartDecisionMakerBlock|geid:%s|neid:%s] emit %s",
|
||||
|
||||
@@ -1057,153 +1057,3 @@ async def test_smart_decision_maker_traditional_mode_default():
|
||||
) # Should yield individual tool parameters
|
||||
assert "tools_^_test-sink-node-id_~_max_keyword_difficulty" in outputs
|
||||
assert "conversations" in outputs
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_smart_decision_maker_uses_customized_name_for_blocks():
|
||||
"""Test that SmartDecisionMakerBlock uses customized_name from node metadata for tool names."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from backend.blocks.basic import StoreValueBlock
|
||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||
from backend.data.graph import Link, Node
|
||||
|
||||
# Create a mock node with customized_name in metadata
|
||||
mock_node = MagicMock(spec=Node)
|
||||
mock_node.id = "test-node-id"
|
||||
mock_node.block_id = StoreValueBlock().id
|
||||
mock_node.metadata = {"customized_name": "My Custom Tool Name"}
|
||||
mock_node.block = StoreValueBlock()
|
||||
|
||||
# Create a mock link
|
||||
mock_link = MagicMock(spec=Link)
|
||||
mock_link.sink_name = "input"
|
||||
|
||||
# Call the function directly
|
||||
result = await SmartDecisionMakerBlock._create_block_function_signature(
|
||||
mock_node, [mock_link]
|
||||
)
|
||||
|
||||
# Verify the tool name uses the customized name (cleaned up)
|
||||
assert result["type"] == "function"
|
||||
assert result["function"]["name"] == "my_custom_tool_name" # Cleaned version
|
||||
assert result["function"]["_sink_node_id"] == "test-node-id"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_smart_decision_maker_falls_back_to_block_name():
|
||||
"""Test that SmartDecisionMakerBlock falls back to block.name when no customized_name."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from backend.blocks.basic import StoreValueBlock
|
||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||
from backend.data.graph import Link, Node
|
||||
|
||||
# Create a mock node without customized_name
|
||||
mock_node = MagicMock(spec=Node)
|
||||
mock_node.id = "test-node-id"
|
||||
mock_node.block_id = StoreValueBlock().id
|
||||
mock_node.metadata = {} # No customized_name
|
||||
mock_node.block = StoreValueBlock()
|
||||
|
||||
# Create a mock link
|
||||
mock_link = MagicMock(spec=Link)
|
||||
mock_link.sink_name = "input"
|
||||
|
||||
# Call the function directly
|
||||
result = await SmartDecisionMakerBlock._create_block_function_signature(
|
||||
mock_node, [mock_link]
|
||||
)
|
||||
|
||||
# Verify the tool name uses the block's default name
|
||||
assert result["type"] == "function"
|
||||
assert result["function"]["name"] == "storevalueblock" # Default block name cleaned
|
||||
assert result["function"]["_sink_node_id"] == "test-node-id"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_smart_decision_maker_uses_customized_name_for_agents():
|
||||
"""Test that SmartDecisionMakerBlock uses customized_name from metadata for agent nodes."""
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||
from backend.data.graph import Link, Node
|
||||
|
||||
# Create a mock node with customized_name in metadata
|
||||
mock_node = MagicMock(spec=Node)
|
||||
mock_node.id = "test-agent-node-id"
|
||||
mock_node.metadata = {"customized_name": "My Custom Agent"}
|
||||
mock_node.input_default = {
|
||||
"graph_id": "test-graph-id",
|
||||
"graph_version": 1,
|
||||
"input_schema": {"properties": {"test_input": {"description": "Test input"}}},
|
||||
}
|
||||
|
||||
# Create a mock link
|
||||
mock_link = MagicMock(spec=Link)
|
||||
mock_link.sink_name = "test_input"
|
||||
|
||||
# Mock the database client
|
||||
mock_graph_meta = MagicMock()
|
||||
mock_graph_meta.name = "Original Agent Name"
|
||||
mock_graph_meta.description = "Agent description"
|
||||
|
||||
mock_db_client = AsyncMock()
|
||||
mock_db_client.get_graph_metadata.return_value = mock_graph_meta
|
||||
|
||||
with patch(
|
||||
"backend.blocks.smart_decision_maker.get_database_manager_async_client",
|
||||
return_value=mock_db_client,
|
||||
):
|
||||
result = await SmartDecisionMakerBlock._create_agent_function_signature(
|
||||
mock_node, [mock_link]
|
||||
)
|
||||
|
||||
# Verify the tool name uses the customized name (cleaned up)
|
||||
assert result["type"] == "function"
|
||||
assert result["function"]["name"] == "my_custom_agent" # Cleaned version
|
||||
assert result["function"]["_sink_node_id"] == "test-agent-node-id"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_smart_decision_maker_agent_falls_back_to_graph_name():
|
||||
"""Test that agent node falls back to graph name when no customized_name."""
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||
from backend.data.graph import Link, Node
|
||||
|
||||
# Create a mock node without customized_name
|
||||
mock_node = MagicMock(spec=Node)
|
||||
mock_node.id = "test-agent-node-id"
|
||||
mock_node.metadata = {} # No customized_name
|
||||
mock_node.input_default = {
|
||||
"graph_id": "test-graph-id",
|
||||
"graph_version": 1,
|
||||
"input_schema": {"properties": {"test_input": {"description": "Test input"}}},
|
||||
}
|
||||
|
||||
# Create a mock link
|
||||
mock_link = MagicMock(spec=Link)
|
||||
mock_link.sink_name = "test_input"
|
||||
|
||||
# Mock the database client
|
||||
mock_graph_meta = MagicMock()
|
||||
mock_graph_meta.name = "Original Agent Name"
|
||||
mock_graph_meta.description = "Agent description"
|
||||
|
||||
mock_db_client = AsyncMock()
|
||||
mock_db_client.get_graph_metadata.return_value = mock_graph_meta
|
||||
|
||||
with patch(
|
||||
"backend.blocks.smart_decision_maker.get_database_manager_async_client",
|
||||
return_value=mock_db_client,
|
||||
):
|
||||
result = await SmartDecisionMakerBlock._create_agent_function_signature(
|
||||
mock_node, [mock_link]
|
||||
)
|
||||
|
||||
# Verify the tool name uses the graph's default name
|
||||
assert result["type"] == "function"
|
||||
assert result["function"]["name"] == "original_agent_name" # Graph name cleaned
|
||||
assert result["function"]["_sink_node_id"] == "test-agent-node-id"
|
||||
|
||||
@@ -15,7 +15,6 @@ async def test_smart_decision_maker_handles_dynamic_dict_fields():
|
||||
mock_node.block = CreateDictionaryBlock()
|
||||
mock_node.block_id = CreateDictionaryBlock().id
|
||||
mock_node.input_default = {}
|
||||
mock_node.metadata = {}
|
||||
|
||||
# Create mock links with dynamic dictionary fields
|
||||
mock_links = [
|
||||
@@ -78,7 +77,6 @@ async def test_smart_decision_maker_handles_dynamic_list_fields():
|
||||
mock_node.block = AddToListBlock()
|
||||
mock_node.block_id = AddToListBlock().id
|
||||
mock_node.input_default = {}
|
||||
mock_node.metadata = {}
|
||||
|
||||
# Create mock links with dynamic list fields
|
||||
mock_links = [
|
||||
|
||||
@@ -44,7 +44,6 @@ async def test_create_block_function_signature_with_dict_fields():
|
||||
mock_node.block = CreateDictionaryBlock()
|
||||
mock_node.block_id = CreateDictionaryBlock().id
|
||||
mock_node.input_default = {}
|
||||
mock_node.metadata = {}
|
||||
|
||||
# Create mock links with dynamic dictionary fields (source sanitized, sink original)
|
||||
mock_links = [
|
||||
@@ -107,7 +106,6 @@ async def test_create_block_function_signature_with_list_fields():
|
||||
mock_node.block = AddToListBlock()
|
||||
mock_node.block_id = AddToListBlock().id
|
||||
mock_node.input_default = {}
|
||||
mock_node.metadata = {}
|
||||
|
||||
# Create mock links with dynamic list fields
|
||||
mock_links = [
|
||||
@@ -161,7 +159,6 @@ async def test_create_block_function_signature_with_object_fields():
|
||||
mock_node.block = MatchTextPatternBlock()
|
||||
mock_node.block_id = MatchTextPatternBlock().id
|
||||
mock_node.input_default = {}
|
||||
mock_node.metadata = {}
|
||||
|
||||
# Create mock links with dynamic object fields
|
||||
mock_links = [
|
||||
@@ -211,13 +208,11 @@ async def test_create_tool_node_signatures():
|
||||
mock_dict_node.block = CreateDictionaryBlock()
|
||||
mock_dict_node.block_id = CreateDictionaryBlock().id
|
||||
mock_dict_node.input_default = {}
|
||||
mock_dict_node.metadata = {}
|
||||
|
||||
mock_list_node = Mock()
|
||||
mock_list_node.block = AddToListBlock()
|
||||
mock_list_node.block_id = AddToListBlock().id
|
||||
mock_list_node.input_default = {}
|
||||
mock_list_node.metadata = {}
|
||||
|
||||
# Mock links with dynamic fields
|
||||
dict_link1 = Mock(
|
||||
@@ -428,7 +423,6 @@ async def test_mixed_regular_and_dynamic_fields():
|
||||
mock_node.block.name = "TestBlock"
|
||||
mock_node.block.description = "A test block"
|
||||
mock_node.block.input_schema = Mock()
|
||||
mock_node.metadata = {}
|
||||
|
||||
# Mock the get_field_schema to return a proper schema for regular fields
|
||||
def get_field_schema(field_name):
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
from .blog import WordPressCreatePostBlock, WordPressGetAllPostsBlock
|
||||
from .blog import WordPressCreatePostBlock
|
||||
|
||||
__all__ = ["WordPressCreatePostBlock", "WordPressGetAllPostsBlock"]
|
||||
__all__ = ["WordPressCreatePostBlock"]
|
||||
|
||||
@@ -161,7 +161,7 @@ async def oauth_exchange_code_for_tokens(
|
||||
grant_type="authorization_code",
|
||||
).model_dump(exclude_none=True)
|
||||
|
||||
response = await Requests(raise_for_status=False).post(
|
||||
response = await Requests().post(
|
||||
f"{WORDPRESS_BASE_URL}oauth2/token",
|
||||
headers=headers,
|
||||
data=data,
|
||||
@@ -205,7 +205,7 @@ async def oauth_refresh_tokens(
|
||||
grant_type="refresh_token",
|
||||
).model_dump(exclude_none=True)
|
||||
|
||||
response = await Requests(raise_for_status=False).post(
|
||||
response = await Requests().post(
|
||||
f"{WORDPRESS_BASE_URL}oauth2/token",
|
||||
headers=headers,
|
||||
data=data,
|
||||
@@ -252,7 +252,7 @@ async def validate_token(
|
||||
"token": token,
|
||||
}
|
||||
|
||||
response = await Requests(raise_for_status=False).get(
|
||||
response = await Requests().get(
|
||||
f"{WORDPRESS_BASE_URL}oauth2/token-info",
|
||||
params=params,
|
||||
)
|
||||
@@ -296,7 +296,7 @@ async def make_api_request(
|
||||
|
||||
url = f"{WORDPRESS_BASE_URL.rstrip('/')}{endpoint}"
|
||||
|
||||
request_method = getattr(Requests(raise_for_status=False), method.lower())
|
||||
request_method = getattr(Requests(), method.lower())
|
||||
response = await request_method(
|
||||
url,
|
||||
headers=headers,
|
||||
@@ -476,7 +476,6 @@ async def create_post(
|
||||
data["tags"] = ",".join(str(t) for t in data["tags"])
|
||||
|
||||
# Make the API request
|
||||
site = normalize_site(site)
|
||||
endpoint = f"/rest/v1.1/sites/{site}/posts/new"
|
||||
|
||||
headers = {
|
||||
@@ -484,7 +483,7 @@ async def create_post(
|
||||
"Content-Type": "application/x-www-form-urlencoded",
|
||||
}
|
||||
|
||||
response = await Requests(raise_for_status=False).post(
|
||||
response = await Requests().post(
|
||||
f"{WORDPRESS_BASE_URL.rstrip('/')}{endpoint}",
|
||||
headers=headers,
|
||||
data=data,
|
||||
@@ -500,132 +499,3 @@ async def create_post(
|
||||
)
|
||||
error_message = error_data.get("message", response.text)
|
||||
raise ValueError(f"Failed to create post: {response.status} - {error_message}")
|
||||
|
||||
|
||||
class Post(BaseModel):
|
||||
"""Response model for individual posts in a posts list response.
|
||||
|
||||
This is a simplified version compared to PostResponse, as the list endpoint
|
||||
returns less detailed information than the create/get single post endpoints.
|
||||
"""
|
||||
|
||||
ID: int
|
||||
site_ID: int
|
||||
author: PostAuthor
|
||||
date: datetime
|
||||
modified: datetime
|
||||
title: str
|
||||
URL: str
|
||||
short_URL: str
|
||||
content: str | None = None
|
||||
excerpt: str | None = None
|
||||
slug: str
|
||||
guid: str
|
||||
status: str
|
||||
sticky: bool
|
||||
password: str | None = ""
|
||||
parent: Union[Dict[str, Any], bool, None] = None
|
||||
type: str
|
||||
discussion: Dict[str, Union[str, bool, int]] | None = None
|
||||
likes_enabled: bool | None = None
|
||||
sharing_enabled: bool | None = None
|
||||
like_count: int | None = None
|
||||
i_like: bool | None = None
|
||||
is_reblogged: bool | None = None
|
||||
is_following: bool | None = None
|
||||
global_ID: str | None = None
|
||||
featured_image: str | None = None
|
||||
post_thumbnail: Dict[str, Any] | None = None
|
||||
format: str | None = None
|
||||
geo: Union[Dict[str, Any], bool, None] = None
|
||||
menu_order: int | None = None
|
||||
page_template: str | None = None
|
||||
publicize_URLs: List[str] | None = None
|
||||
terms: Dict[str, Dict[str, Any]] | None = None
|
||||
tags: Dict[str, Dict[str, Any]] | None = None
|
||||
categories: Dict[str, Dict[str, Any]] | None = None
|
||||
attachments: Dict[str, Dict[str, Any]] | None = None
|
||||
attachment_count: int | None = None
|
||||
metadata: List[Dict[str, Any]] | None = None
|
||||
meta: Dict[str, Any] | None = None
|
||||
capabilities: Dict[str, bool] | None = None
|
||||
revisions: List[int] | None = None
|
||||
other_URLs: Dict[str, Any] | None = None
|
||||
|
||||
|
||||
class PostsResponse(BaseModel):
|
||||
"""Response model for WordPress posts list."""
|
||||
|
||||
found: int
|
||||
posts: List[Post]
|
||||
meta: Dict[str, Any]
|
||||
|
||||
|
||||
def normalize_site(site: str) -> str:
|
||||
"""
|
||||
Normalize a site identifier by stripping protocol and trailing slashes.
|
||||
|
||||
Args:
|
||||
site: Site URL, domain, or ID (e.g., "https://myblog.wordpress.com/", "myblog.wordpress.com", "123456789")
|
||||
|
||||
Returns:
|
||||
Normalized site identifier (domain or ID only)
|
||||
"""
|
||||
site = site.strip()
|
||||
if site.startswith("https://"):
|
||||
site = site[8:]
|
||||
elif site.startswith("http://"):
|
||||
site = site[7:]
|
||||
return site.rstrip("/")
|
||||
|
||||
|
||||
async def get_posts(
|
||||
credentials: Credentials,
|
||||
site: str,
|
||||
status: PostStatus | None = None,
|
||||
number: int = 100,
|
||||
offset: int = 0,
|
||||
) -> PostsResponse:
|
||||
"""
|
||||
Get posts from a WordPress site.
|
||||
|
||||
Args:
|
||||
credentials: OAuth credentials
|
||||
site: Site ID or domain (e.g., "myblog.wordpress.com" or "123456789")
|
||||
status: Filter by post status using PostStatus enum, or None for all
|
||||
number: Number of posts to retrieve (max 100)
|
||||
offset: Number of posts to skip (for pagination)
|
||||
|
||||
Returns:
|
||||
PostsResponse with the list of posts
|
||||
"""
|
||||
site = normalize_site(site)
|
||||
endpoint = f"/rest/v1.1/sites/{site}/posts"
|
||||
|
||||
headers = {
|
||||
"Authorization": credentials.auth_header(),
|
||||
}
|
||||
|
||||
params: Dict[str, Any] = {
|
||||
"number": max(1, min(number, 100)), # 1–100 posts per request
|
||||
"offset": offset,
|
||||
}
|
||||
|
||||
if status:
|
||||
params["status"] = status.value
|
||||
response = await Requests(raise_for_status=False).get(
|
||||
f"{WORDPRESS_BASE_URL.rstrip('/')}{endpoint}",
|
||||
headers=headers,
|
||||
params=params,
|
||||
)
|
||||
|
||||
if response.ok:
|
||||
return PostsResponse.model_validate(response.json())
|
||||
|
||||
error_data = (
|
||||
response.json()
|
||||
if response.headers.get("content-type", "").startswith("application/json")
|
||||
else {}
|
||||
)
|
||||
error_message = error_data.get("message", response.text)
|
||||
raise ValueError(f"Failed to get posts: {response.status} - {error_message}")
|
||||
|
||||
@@ -9,15 +9,7 @@ from backend.sdk import (
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._api import (
|
||||
CreatePostRequest,
|
||||
Post,
|
||||
PostResponse,
|
||||
PostsResponse,
|
||||
PostStatus,
|
||||
create_post,
|
||||
get_posts,
|
||||
)
|
||||
from ._api import CreatePostRequest, PostResponse, PostStatus, create_post
|
||||
from ._config import wordpress
|
||||
|
||||
|
||||
@@ -57,15 +49,8 @@ class WordPressCreatePostBlock(Block):
|
||||
media_urls: list[str] = SchemaField(
|
||||
description="URLs of images to sideload and attach to the post", default=[]
|
||||
)
|
||||
publish_as_draft: bool = SchemaField(
|
||||
description="If True, publishes the post as a draft. If False, publishes it publicly.",
|
||||
default=False,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
site: str = SchemaField(
|
||||
description="The site ID or domain (pass-through for chaining with other blocks)"
|
||||
)
|
||||
post_id: int = SchemaField(description="The ID of the created post")
|
||||
post_url: str = SchemaField(description="The full URL of the created post")
|
||||
short_url: str = SchemaField(description="The shortened wp.me URL")
|
||||
@@ -93,9 +78,7 @@ class WordPressCreatePostBlock(Block):
|
||||
tags=input_data.tags,
|
||||
featured_image=input_data.featured_image,
|
||||
media_urls=input_data.media_urls,
|
||||
status=(
|
||||
PostStatus.DRAFT if input_data.publish_as_draft else PostStatus.PUBLISH
|
||||
),
|
||||
status=PostStatus.PUBLISH,
|
||||
)
|
||||
|
||||
post_response: PostResponse = await create_post(
|
||||
@@ -104,69 +87,7 @@ class WordPressCreatePostBlock(Block):
|
||||
post_data=post_request,
|
||||
)
|
||||
|
||||
yield "site", input_data.site
|
||||
yield "post_id", post_response.ID
|
||||
yield "post_url", post_response.URL
|
||||
yield "short_url", post_response.short_URL
|
||||
yield "post_data", post_response.model_dump()
|
||||
|
||||
|
||||
class WordPressGetAllPostsBlock(Block):
|
||||
"""
|
||||
Fetches all posts from a WordPress.com site or Jetpack-enabled site.
|
||||
Supports filtering by status and pagination.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = wordpress.credentials_field()
|
||||
site: str = SchemaField(
|
||||
description="Site ID or domain (e.g., 'myblog.wordpress.com' or '123456789')"
|
||||
)
|
||||
status: PostStatus | None = SchemaField(
|
||||
description="Filter by post status, or None for all",
|
||||
default=None,
|
||||
)
|
||||
number: int = SchemaField(
|
||||
description="Number of posts to retrieve (max 100 per request)", default=20
|
||||
)
|
||||
offset: int = SchemaField(
|
||||
description="Number of posts to skip (for pagination)", default=0
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
site: str = SchemaField(
|
||||
description="The site ID or domain (pass-through for chaining with other blocks)"
|
||||
)
|
||||
found: int = SchemaField(description="Total number of posts found")
|
||||
posts: list[Post] = SchemaField(
|
||||
description="List of post objects with their details"
|
||||
)
|
||||
post: Post = SchemaField(
|
||||
description="Individual post object (yielded for each post)"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="97728fa7-7f6f-4789-ba0c-f2c114119536",
|
||||
description="Fetch all posts from WordPress.com or Jetpack sites",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: Credentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
posts_response: PostsResponse = await get_posts(
|
||||
credentials=credentials,
|
||||
site=input_data.site,
|
||||
status=input_data.status,
|
||||
number=input_data.number,
|
||||
offset=input_data.offset,
|
||||
)
|
||||
|
||||
yield "site", input_data.site
|
||||
yield "found", posts_response.found
|
||||
yield "posts", posts_response.posts
|
||||
for post in posts_response.posts:
|
||||
yield "post", post
|
||||
|
||||
@@ -50,8 +50,6 @@ from .model import (
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.data.execution import ExecutionContext
|
||||
|
||||
from .graph import Link
|
||||
|
||||
app_config = Config()
|
||||
@@ -474,7 +472,6 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||
self.block_type = block_type
|
||||
self.webhook_config = webhook_config
|
||||
self.execution_stats: NodeExecutionStats = NodeExecutionStats()
|
||||
self.requires_human_review: bool = False
|
||||
|
||||
if self.webhook_config:
|
||||
if isinstance(self.webhook_config, BlockWebhookConfig):
|
||||
@@ -617,77 +614,7 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||
block_id=self.id,
|
||||
) from ex
|
||||
|
||||
async def is_block_exec_need_review(
|
||||
self,
|
||||
input_data: BlockInput,
|
||||
*,
|
||||
user_id: str,
|
||||
node_exec_id: str,
|
||||
graph_exec_id: str,
|
||||
graph_id: str,
|
||||
graph_version: int,
|
||||
execution_context: "ExecutionContext",
|
||||
**kwargs,
|
||||
) -> tuple[bool, BlockInput]:
|
||||
"""
|
||||
Check if this block execution needs human review and handle the review process.
|
||||
|
||||
Returns:
|
||||
Tuple of (should_pause, input_data_to_use)
|
||||
- should_pause: True if execution should be paused for review
|
||||
- input_data_to_use: The input data to use (may be modified by reviewer)
|
||||
"""
|
||||
# Skip review if not required or safe mode is disabled
|
||||
if not self.requires_human_review or not execution_context.safe_mode:
|
||||
return False, input_data
|
||||
|
||||
from backend.blocks.helpers.review import HITLReviewHelper
|
||||
|
||||
# Handle the review request and get decision
|
||||
decision = await HITLReviewHelper.handle_review_decision(
|
||||
input_data=input_data,
|
||||
user_id=user_id,
|
||||
node_exec_id=node_exec_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
graph_id=graph_id,
|
||||
graph_version=graph_version,
|
||||
execution_context=execution_context,
|
||||
block_name=self.name,
|
||||
editable=True,
|
||||
)
|
||||
|
||||
if decision is None:
|
||||
# We're awaiting review - pause execution
|
||||
return True, input_data
|
||||
|
||||
if not decision.should_proceed:
|
||||
# Review was rejected, raise an error to stop execution
|
||||
raise BlockExecutionError(
|
||||
message=f"Block execution rejected by reviewer: {decision.message}",
|
||||
block_name=self.name,
|
||||
block_id=self.id,
|
||||
)
|
||||
|
||||
# Review was approved - use the potentially modified data
|
||||
# ReviewResult.data must be a dict for block inputs
|
||||
reviewed_data = decision.review_result.data
|
||||
if not isinstance(reviewed_data, dict):
|
||||
raise BlockExecutionError(
|
||||
message=f"Review data must be a dict for block input, got {type(reviewed_data).__name__}",
|
||||
block_name=self.name,
|
||||
block_id=self.id,
|
||||
)
|
||||
return False, reviewed_data
|
||||
|
||||
async def _execute(self, input_data: BlockInput, **kwargs) -> BlockOutput:
|
||||
# Check for review requirement and get potentially modified input data
|
||||
should_pause, input_data = await self.is_block_exec_need_review(
|
||||
input_data, **kwargs
|
||||
)
|
||||
if should_pause:
|
||||
return
|
||||
|
||||
# Validate the input data (original or reviewer-modified) once
|
||||
if error := self.input_schema.validate_data(input_data):
|
||||
raise BlockInputError(
|
||||
message=f"Unable to execute block with invalid input data: {error}",
|
||||
@@ -695,7 +622,6 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||
block_id=self.id,
|
||||
)
|
||||
|
||||
# Use the validated input data
|
||||
async for output_name, output_data in self.run(
|
||||
self.input_schema(**{k: v for k, v in input_data.items() if v is not None}),
|
||||
**kwargs,
|
||||
|
||||
@@ -383,7 +383,6 @@ class GraphExecutionWithNodes(GraphExecution):
|
||||
self,
|
||||
execution_context: ExecutionContext,
|
||||
compiled_nodes_input_masks: Optional[NodesInputMasks] = None,
|
||||
nodes_to_skip: Optional[set[str]] = None,
|
||||
):
|
||||
return GraphExecutionEntry(
|
||||
user_id=self.user_id,
|
||||
@@ -391,7 +390,6 @@ class GraphExecutionWithNodes(GraphExecution):
|
||||
graph_version=self.graph_version or 0,
|
||||
graph_exec_id=self.id,
|
||||
nodes_input_masks=compiled_nodes_input_masks,
|
||||
nodes_to_skip=nodes_to_skip or set(),
|
||||
execution_context=execution_context,
|
||||
)
|
||||
|
||||
@@ -1147,8 +1145,6 @@ class GraphExecutionEntry(BaseModel):
|
||||
graph_id: str
|
||||
graph_version: int
|
||||
nodes_input_masks: Optional[NodesInputMasks] = None
|
||||
nodes_to_skip: set[str] = Field(default_factory=set)
|
||||
"""Node IDs that should be skipped due to optional credentials not being configured."""
|
||||
execution_context: ExecutionContext = Field(default_factory=ExecutionContext)
|
||||
|
||||
|
||||
|
||||
@@ -94,15 +94,6 @@ class Node(BaseDbModel):
|
||||
input_links: list[Link] = []
|
||||
output_links: list[Link] = []
|
||||
|
||||
@property
|
||||
def credentials_optional(self) -> bool:
|
||||
"""
|
||||
Whether credentials are optional for this node.
|
||||
When True and credentials are not configured, the node will be skipped
|
||||
during execution rather than causing a validation error.
|
||||
"""
|
||||
return self.metadata.get("credentials_optional", False)
|
||||
|
||||
@property
|
||||
def block(self) -> AnyBlockSchema | "_UnknownBlockBase":
|
||||
"""Get the block for this node. Returns UnknownBlock if block is deleted/missing."""
|
||||
@@ -244,10 +235,7 @@ class BaseGraph(BaseDbModel):
|
||||
return any(
|
||||
node.block_id
|
||||
for node in self.nodes
|
||||
if (
|
||||
node.block.block_type == BlockType.HUMAN_IN_THE_LOOP
|
||||
or node.block.requires_human_review
|
||||
)
|
||||
if node.block.block_type == BlockType.HUMAN_IN_THE_LOOP
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -338,35 +326,7 @@ class Graph(BaseGraph):
|
||||
@computed_field
|
||||
@property
|
||||
def credentials_input_schema(self) -> dict[str, Any]:
|
||||
schema = self._credentials_input_schema.jsonschema()
|
||||
|
||||
# Determine which credential fields are required based on credentials_optional metadata
|
||||
graph_credentials_inputs = self.aggregate_credentials_inputs()
|
||||
required_fields = []
|
||||
|
||||
# Build a map of node_id -> node for quick lookup
|
||||
all_nodes = {node.id: node for node in self.nodes}
|
||||
for sub_graph in self.sub_graphs:
|
||||
for node in sub_graph.nodes:
|
||||
all_nodes[node.id] = node
|
||||
|
||||
for field_key, (
|
||||
_field_info,
|
||||
node_field_pairs,
|
||||
) in graph_credentials_inputs.items():
|
||||
# A field is required if ANY node using it has credentials_optional=False
|
||||
is_required = False
|
||||
for node_id, _field_name in node_field_pairs:
|
||||
node = all_nodes.get(node_id)
|
||||
if node and not node.credentials_optional:
|
||||
is_required = True
|
||||
break
|
||||
|
||||
if is_required:
|
||||
required_fields.append(field_key)
|
||||
|
||||
schema["required"] = required_fields
|
||||
return schema
|
||||
return self._credentials_input_schema.jsonschema()
|
||||
|
||||
@property
|
||||
def _credentials_input_schema(self) -> type[BlockSchema]:
|
||||
|
||||
@@ -396,58 +396,3 @@ async def test_access_store_listing_graph(server: SpinTestServer):
|
||||
created_graph.id, created_graph.version, "3e53486c-cf57-477e-ba2a-cb02dc828e1b"
|
||||
)
|
||||
assert got_graph is not None
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Tests for Optional Credentials Feature
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def test_node_credentials_optional_default():
|
||||
"""Test that credentials_optional defaults to False when not set in metadata."""
|
||||
node = Node(
|
||||
id="test_node",
|
||||
block_id=StoreValueBlock().id,
|
||||
input_default={},
|
||||
metadata={},
|
||||
)
|
||||
assert node.credentials_optional is False
|
||||
|
||||
|
||||
def test_node_credentials_optional_true():
|
||||
"""Test that credentials_optional returns True when explicitly set."""
|
||||
node = Node(
|
||||
id="test_node",
|
||||
block_id=StoreValueBlock().id,
|
||||
input_default={},
|
||||
metadata={"credentials_optional": True},
|
||||
)
|
||||
assert node.credentials_optional is True
|
||||
|
||||
|
||||
def test_node_credentials_optional_false():
|
||||
"""Test that credentials_optional returns False when explicitly set to False."""
|
||||
node = Node(
|
||||
id="test_node",
|
||||
block_id=StoreValueBlock().id,
|
||||
input_default={},
|
||||
metadata={"credentials_optional": False},
|
||||
)
|
||||
assert node.credentials_optional is False
|
||||
|
||||
|
||||
def test_node_credentials_optional_with_other_metadata():
|
||||
"""Test that credentials_optional works correctly with other metadata present."""
|
||||
node = Node(
|
||||
id="test_node",
|
||||
block_id=StoreValueBlock().id,
|
||||
input_default={},
|
||||
metadata={
|
||||
"position": {"x": 100, "y": 200},
|
||||
"customized_name": "My Custom Node",
|
||||
"credentials_optional": True,
|
||||
},
|
||||
)
|
||||
assert node.credentials_optional is True
|
||||
assert node.metadata["position"] == {"x": 100, "y": 200}
|
||||
assert node.metadata["customized_name"] == "My Custom Node"
|
||||
|
||||
@@ -178,7 +178,6 @@ async def execute_node(
|
||||
execution_processor: "ExecutionProcessor",
|
||||
execution_stats: NodeExecutionStats | None = None,
|
||||
nodes_input_masks: Optional[NodesInputMasks] = None,
|
||||
nodes_to_skip: Optional[set[str]] = None,
|
||||
) -> BlockOutput:
|
||||
"""
|
||||
Execute a node in the graph. This will trigger a block execution on a node,
|
||||
@@ -246,7 +245,6 @@ async def execute_node(
|
||||
"user_id": user_id,
|
||||
"execution_context": execution_context,
|
||||
"execution_processor": execution_processor,
|
||||
"nodes_to_skip": nodes_to_skip or set(),
|
||||
}
|
||||
|
||||
# Last-minute fetch credentials + acquire a system-wide read-write lock to prevent
|
||||
@@ -544,7 +542,6 @@ class ExecutionProcessor:
|
||||
node_exec_progress: NodeExecutionProgress,
|
||||
nodes_input_masks: Optional[NodesInputMasks],
|
||||
graph_stats_pair: tuple[GraphExecutionStats, threading.Lock],
|
||||
nodes_to_skip: Optional[set[str]] = None,
|
||||
) -> NodeExecutionStats:
|
||||
log_metadata = LogMetadata(
|
||||
logger=_logger,
|
||||
@@ -567,7 +564,6 @@ class ExecutionProcessor:
|
||||
db_client=db_client,
|
||||
log_metadata=log_metadata,
|
||||
nodes_input_masks=nodes_input_masks,
|
||||
nodes_to_skip=nodes_to_skip,
|
||||
)
|
||||
if isinstance(status, BaseException):
|
||||
raise status
|
||||
@@ -613,7 +609,6 @@ class ExecutionProcessor:
|
||||
db_client: "DatabaseManagerAsyncClient",
|
||||
log_metadata: LogMetadata,
|
||||
nodes_input_masks: Optional[NodesInputMasks] = None,
|
||||
nodes_to_skip: Optional[set[str]] = None,
|
||||
) -> ExecutionStatus:
|
||||
status = ExecutionStatus.RUNNING
|
||||
|
||||
@@ -650,7 +645,6 @@ class ExecutionProcessor:
|
||||
execution_processor=self,
|
||||
execution_stats=stats,
|
||||
nodes_input_masks=nodes_input_masks,
|
||||
nodes_to_skip=nodes_to_skip,
|
||||
):
|
||||
await persist_output(output_name, output_data)
|
||||
|
||||
@@ -962,21 +956,6 @@ class ExecutionProcessor:
|
||||
|
||||
queued_node_exec = execution_queue.get()
|
||||
|
||||
# Check if this node should be skipped due to optional credentials
|
||||
if queued_node_exec.node_id in graph_exec.nodes_to_skip:
|
||||
log_metadata.info(
|
||||
f"Skipping node execution {queued_node_exec.node_exec_id} "
|
||||
f"for node {queued_node_exec.node_id} - optional credentials not configured"
|
||||
)
|
||||
# Mark the node as completed without executing
|
||||
# No outputs will be produced, so downstream nodes won't trigger
|
||||
update_node_execution_status(
|
||||
db_client=db_client,
|
||||
exec_id=queued_node_exec.node_exec_id,
|
||||
status=ExecutionStatus.COMPLETED,
|
||||
)
|
||||
continue
|
||||
|
||||
log_metadata.debug(
|
||||
f"Dispatching node execution {queued_node_exec.node_exec_id} "
|
||||
f"for node {queued_node_exec.node_id}",
|
||||
@@ -1037,7 +1016,6 @@ class ExecutionProcessor:
|
||||
execution_stats,
|
||||
execution_stats_lock,
|
||||
),
|
||||
nodes_to_skip=graph_exec.nodes_to_skip,
|
||||
),
|
||||
self.node_execution_loop,
|
||||
)
|
||||
|
||||
@@ -239,19 +239,14 @@ async def _validate_node_input_credentials(
|
||||
graph: GraphModel,
|
||||
user_id: str,
|
||||
nodes_input_masks: Optional[NodesInputMasks] = None,
|
||||
) -> tuple[dict[str, dict[str, str]], set[str]]:
|
||||
) -> dict[str, dict[str, str]]:
|
||||
"""
|
||||
Checks all credentials for all nodes of the graph and returns structured errors
|
||||
and a set of nodes that should be skipped due to optional missing credentials.
|
||||
Checks all credentials for all nodes of the graph and returns structured errors.
|
||||
|
||||
Returns:
|
||||
tuple[
|
||||
dict[node_id, dict[field_name, error_message]]: Credential validation errors per node,
|
||||
set[node_id]: Nodes that should be skipped (optional credentials not configured)
|
||||
]
|
||||
dict[node_id, dict[field_name, error_message]]: Credential validation errors per node
|
||||
"""
|
||||
credential_errors: dict[str, dict[str, str]] = defaultdict(dict)
|
||||
nodes_to_skip: set[str] = set()
|
||||
|
||||
for node in graph.nodes:
|
||||
block = node.block
|
||||
@@ -261,46 +256,27 @@ async def _validate_node_input_credentials(
|
||||
if not credentials_fields:
|
||||
continue
|
||||
|
||||
# Track if any credential field is missing for this node
|
||||
has_missing_credentials = False
|
||||
|
||||
for field_name, credentials_meta_type in credentials_fields.items():
|
||||
try:
|
||||
# Check nodes_input_masks first, then input_default
|
||||
field_value = None
|
||||
if (
|
||||
nodes_input_masks
|
||||
and (node_input_mask := nodes_input_masks.get(node.id))
|
||||
and field_name in node_input_mask
|
||||
):
|
||||
field_value = node_input_mask[field_name]
|
||||
credentials_meta = credentials_meta_type.model_validate(
|
||||
node_input_mask[field_name]
|
||||
)
|
||||
elif field_name in node.input_default:
|
||||
# For optional credentials, don't use input_default - treat as missing
|
||||
# This prevents stale credential IDs from failing validation
|
||||
if node.credentials_optional:
|
||||
field_value = None
|
||||
else:
|
||||
field_value = node.input_default[field_name]
|
||||
|
||||
# Check if credentials are missing (None, empty, or not present)
|
||||
if field_value is None or (
|
||||
isinstance(field_value, dict) and not field_value.get("id")
|
||||
):
|
||||
has_missing_credentials = True
|
||||
# If node has credentials_optional flag, mark for skipping instead of error
|
||||
if node.credentials_optional:
|
||||
continue # Don't add error, will be marked for skip after loop
|
||||
else:
|
||||
credential_errors[node.id][
|
||||
field_name
|
||||
] = "These credentials are required"
|
||||
continue
|
||||
|
||||
credentials_meta = credentials_meta_type.model_validate(field_value)
|
||||
|
||||
credentials_meta = credentials_meta_type.model_validate(
|
||||
node.input_default[field_name]
|
||||
)
|
||||
else:
|
||||
# Missing credentials
|
||||
credential_errors[node.id][
|
||||
field_name
|
||||
] = "These credentials are required"
|
||||
continue
|
||||
except ValidationError as e:
|
||||
# Validation error means credentials were provided but invalid
|
||||
# This should always be an error, even if optional
|
||||
credential_errors[node.id][field_name] = f"Invalid credentials: {e}"
|
||||
continue
|
||||
|
||||
@@ -311,7 +287,6 @@ async def _validate_node_input_credentials(
|
||||
)
|
||||
except Exception as e:
|
||||
# Handle any errors fetching credentials
|
||||
# If credentials were explicitly configured but unavailable, it's an error
|
||||
credential_errors[node.id][
|
||||
field_name
|
||||
] = f"Credentials not available: {e}"
|
||||
@@ -338,19 +313,7 @@ async def _validate_node_input_credentials(
|
||||
] = "Invalid credentials: type/provider mismatch"
|
||||
continue
|
||||
|
||||
# If node has optional credentials and any are missing, mark for skipping
|
||||
# But only if there are no other errors for this node
|
||||
if (
|
||||
has_missing_credentials
|
||||
and node.credentials_optional
|
||||
and node.id not in credential_errors
|
||||
):
|
||||
nodes_to_skip.add(node.id)
|
||||
logger.info(
|
||||
f"Node #{node.id} will be skipped: optional credentials not configured"
|
||||
)
|
||||
|
||||
return credential_errors, nodes_to_skip
|
||||
return credential_errors
|
||||
|
||||
|
||||
def make_node_credentials_input_map(
|
||||
@@ -392,25 +355,21 @@ async def validate_graph_with_credentials(
|
||||
graph: GraphModel,
|
||||
user_id: str,
|
||||
nodes_input_masks: Optional[NodesInputMasks] = None,
|
||||
) -> tuple[Mapping[str, Mapping[str, str]], set[str]]:
|
||||
) -> Mapping[str, Mapping[str, str]]:
|
||||
"""
|
||||
Validate graph including credentials and return structured errors per node,
|
||||
along with a set of nodes that should be skipped due to optional missing credentials.
|
||||
Validate graph including credentials and return structured errors per node.
|
||||
|
||||
Returns:
|
||||
tuple[
|
||||
dict[node_id, dict[field_name, error_message]]: Validation errors per node,
|
||||
set[node_id]: Nodes that should be skipped (optional credentials not configured)
|
||||
]
|
||||
dict[node_id, dict[field_name, error_message]]: Validation errors per node
|
||||
"""
|
||||
# Get input validation errors
|
||||
node_input_errors = GraphModel.validate_graph_get_errors(
|
||||
graph, for_run=True, nodes_input_masks=nodes_input_masks
|
||||
)
|
||||
|
||||
# Get credential input/availability/validation errors and nodes to skip
|
||||
node_credential_input_errors, nodes_to_skip = (
|
||||
await _validate_node_input_credentials(graph, user_id, nodes_input_masks)
|
||||
# Get credential input/availability/validation errors
|
||||
node_credential_input_errors = await _validate_node_input_credentials(
|
||||
graph, user_id, nodes_input_masks
|
||||
)
|
||||
|
||||
# Merge credential errors with structural errors
|
||||
@@ -419,7 +378,7 @@ async def validate_graph_with_credentials(
|
||||
node_input_errors[node_id] = {}
|
||||
node_input_errors[node_id].update(field_errors)
|
||||
|
||||
return node_input_errors, nodes_to_skip
|
||||
return node_input_errors
|
||||
|
||||
|
||||
async def _construct_starting_node_execution_input(
|
||||
@@ -427,7 +386,7 @@ async def _construct_starting_node_execution_input(
|
||||
user_id: str,
|
||||
graph_inputs: BlockInput,
|
||||
nodes_input_masks: Optional[NodesInputMasks] = None,
|
||||
) -> tuple[list[tuple[str, BlockInput]], set[str]]:
|
||||
) -> list[tuple[str, BlockInput]]:
|
||||
"""
|
||||
Validates and prepares the input data for executing a graph.
|
||||
This function checks the graph for starting nodes, validates the input data
|
||||
@@ -441,14 +400,11 @@ async def _construct_starting_node_execution_input(
|
||||
node_credentials_map: `dict[node_id, dict[input_name, CredentialsMetaInput]]`
|
||||
|
||||
Returns:
|
||||
tuple[
|
||||
list[tuple[str, BlockInput]]: A list of tuples, each containing the node ID
|
||||
and the corresponding input data for that node.
|
||||
set[str]: Node IDs that should be skipped (optional credentials not configured)
|
||||
]
|
||||
list[tuple[str, BlockInput]]: A list of tuples, each containing the node ID and
|
||||
the corresponding input data for that node.
|
||||
"""
|
||||
# Use new validation function that includes credentials
|
||||
validation_errors, nodes_to_skip = await validate_graph_with_credentials(
|
||||
validation_errors = await validate_graph_with_credentials(
|
||||
graph, user_id, nodes_input_masks
|
||||
)
|
||||
n_error_nodes = len(validation_errors)
|
||||
@@ -489,7 +445,7 @@ async def _construct_starting_node_execution_input(
|
||||
"No starting nodes found for the graph, make sure an AgentInput or blocks with no inbound links are present as starting nodes."
|
||||
)
|
||||
|
||||
return nodes_input, nodes_to_skip
|
||||
return nodes_input
|
||||
|
||||
|
||||
async def validate_and_construct_node_execution_input(
|
||||
@@ -500,7 +456,7 @@ async def validate_and_construct_node_execution_input(
|
||||
graph_credentials_inputs: Optional[Mapping[str, CredentialsMetaInput]] = None,
|
||||
nodes_input_masks: Optional[NodesInputMasks] = None,
|
||||
is_sub_graph: bool = False,
|
||||
) -> tuple[GraphModel, list[tuple[str, BlockInput]], NodesInputMasks, set[str]]:
|
||||
) -> tuple[GraphModel, list[tuple[str, BlockInput]], NodesInputMasks]:
|
||||
"""
|
||||
Public wrapper that handles graph fetching, credential mapping, and validation+construction.
|
||||
This centralizes the logic used by both scheduler validation and actual execution.
|
||||
@@ -517,7 +473,6 @@ async def validate_and_construct_node_execution_input(
|
||||
GraphModel: Full graph object for the given `graph_id`.
|
||||
list[tuple[node_id, BlockInput]]: Starting node IDs with corresponding inputs.
|
||||
dict[str, BlockInput]: Node input masks including all passed-in credentials.
|
||||
set[str]: Node IDs that should be skipped (optional credentials not configured).
|
||||
|
||||
Raises:
|
||||
NotFoundError: If the graph is not found.
|
||||
@@ -559,16 +514,14 @@ async def validate_and_construct_node_execution_input(
|
||||
nodes_input_masks or {},
|
||||
)
|
||||
|
||||
starting_nodes_input, nodes_to_skip = (
|
||||
await _construct_starting_node_execution_input(
|
||||
graph=graph,
|
||||
user_id=user_id,
|
||||
graph_inputs=graph_inputs,
|
||||
nodes_input_masks=nodes_input_masks,
|
||||
)
|
||||
starting_nodes_input = await _construct_starting_node_execution_input(
|
||||
graph=graph,
|
||||
user_id=user_id,
|
||||
graph_inputs=graph_inputs,
|
||||
nodes_input_masks=nodes_input_masks,
|
||||
)
|
||||
|
||||
return graph, starting_nodes_input, nodes_input_masks, nodes_to_skip
|
||||
return graph, starting_nodes_input, nodes_input_masks
|
||||
|
||||
|
||||
def _merge_nodes_input_masks(
|
||||
@@ -826,9 +779,6 @@ async def add_graph_execution(
|
||||
|
||||
# Use existing execution's compiled input masks
|
||||
compiled_nodes_input_masks = graph_exec.nodes_input_masks or {}
|
||||
# For resumed executions, nodes_to_skip was already determined at creation time
|
||||
# TODO: Consider storing nodes_to_skip in DB if we need to preserve it across resumes
|
||||
nodes_to_skip: set[str] = set()
|
||||
|
||||
logger.info(f"Resuming graph execution #{graph_exec.id} for graph #{graph_id}")
|
||||
else:
|
||||
@@ -837,7 +787,7 @@ async def add_graph_execution(
|
||||
)
|
||||
|
||||
# Create new execution
|
||||
graph, starting_nodes_input, compiled_nodes_input_masks, nodes_to_skip = (
|
||||
graph, starting_nodes_input, compiled_nodes_input_masks = (
|
||||
await validate_and_construct_node_execution_input(
|
||||
graph_id=graph_id,
|
||||
user_id=user_id,
|
||||
@@ -886,7 +836,6 @@ async def add_graph_execution(
|
||||
try:
|
||||
graph_exec_entry = graph_exec.to_graph_execution_entry(
|
||||
compiled_nodes_input_masks=compiled_nodes_input_masks,
|
||||
nodes_to_skip=nodes_to_skip,
|
||||
execution_context=execution_context,
|
||||
)
|
||||
logger.info(f"Publishing execution {graph_exec.id} to execution queue")
|
||||
|
||||
@@ -367,13 +367,10 @@ async def test_add_graph_execution_is_repeatable(mocker: MockerFixture):
|
||||
)
|
||||
|
||||
# Setup mock returns
|
||||
# The function returns (graph, starting_nodes_input, compiled_nodes_input_masks, nodes_to_skip)
|
||||
nodes_to_skip: set[str] = set()
|
||||
mock_validate.return_value = (
|
||||
mock_graph,
|
||||
starting_nodes_input,
|
||||
compiled_nodes_input_masks,
|
||||
nodes_to_skip,
|
||||
)
|
||||
mock_prisma.is_connected.return_value = True
|
||||
mock_edb.create_graph_execution = mocker.AsyncMock(return_value=mock_graph_exec)
|
||||
@@ -459,212 +456,3 @@ async def test_add_graph_execution_is_repeatable(mocker: MockerFixture):
|
||||
# Both executions should succeed (though they create different objects)
|
||||
assert result1 == mock_graph_exec
|
||||
assert result2 == mock_graph_exec_2
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Tests for Optional Credentials Feature
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_node_input_credentials_returns_nodes_to_skip(
|
||||
mocker: MockerFixture,
|
||||
):
|
||||
"""
|
||||
Test that _validate_node_input_credentials returns nodes_to_skip set
|
||||
for nodes with credentials_optional=True and missing credentials.
|
||||
"""
|
||||
from backend.executor.utils import _validate_node_input_credentials
|
||||
|
||||
# Create a mock node with credentials_optional=True
|
||||
mock_node = mocker.MagicMock()
|
||||
mock_node.id = "node-with-optional-creds"
|
||||
mock_node.credentials_optional = True
|
||||
mock_node.input_default = {} # No credentials configured
|
||||
|
||||
# Create a mock block with credentials field
|
||||
mock_block = mocker.MagicMock()
|
||||
mock_credentials_field_type = mocker.MagicMock()
|
||||
mock_block.input_schema.get_credentials_fields.return_value = {
|
||||
"credentials": mock_credentials_field_type
|
||||
}
|
||||
mock_node.block = mock_block
|
||||
|
||||
# Create mock graph
|
||||
mock_graph = mocker.MagicMock()
|
||||
mock_graph.nodes = [mock_node]
|
||||
|
||||
# Call the function
|
||||
errors, nodes_to_skip = await _validate_node_input_credentials(
|
||||
graph=mock_graph,
|
||||
user_id="test-user-id",
|
||||
nodes_input_masks=None,
|
||||
)
|
||||
|
||||
# Node should be in nodes_to_skip, not in errors
|
||||
assert mock_node.id in nodes_to_skip
|
||||
assert mock_node.id not in errors
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_node_input_credentials_required_missing_creds_error(
|
||||
mocker: MockerFixture,
|
||||
):
|
||||
"""
|
||||
Test that _validate_node_input_credentials returns errors
|
||||
for nodes with credentials_optional=False and missing credentials.
|
||||
"""
|
||||
from backend.executor.utils import _validate_node_input_credentials
|
||||
|
||||
# Create a mock node with credentials_optional=False (required)
|
||||
mock_node = mocker.MagicMock()
|
||||
mock_node.id = "node-with-required-creds"
|
||||
mock_node.credentials_optional = False
|
||||
mock_node.input_default = {} # No credentials configured
|
||||
|
||||
# Create a mock block with credentials field
|
||||
mock_block = mocker.MagicMock()
|
||||
mock_credentials_field_type = mocker.MagicMock()
|
||||
mock_block.input_schema.get_credentials_fields.return_value = {
|
||||
"credentials": mock_credentials_field_type
|
||||
}
|
||||
mock_node.block = mock_block
|
||||
|
||||
# Create mock graph
|
||||
mock_graph = mocker.MagicMock()
|
||||
mock_graph.nodes = [mock_node]
|
||||
|
||||
# Call the function
|
||||
errors, nodes_to_skip = await _validate_node_input_credentials(
|
||||
graph=mock_graph,
|
||||
user_id="test-user-id",
|
||||
nodes_input_masks=None,
|
||||
)
|
||||
|
||||
# Node should be in errors, not in nodes_to_skip
|
||||
assert mock_node.id in errors
|
||||
assert "credentials" in errors[mock_node.id]
|
||||
assert "required" in errors[mock_node.id]["credentials"].lower()
|
||||
assert mock_node.id not in nodes_to_skip
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_graph_with_credentials_returns_nodes_to_skip(
|
||||
mocker: MockerFixture,
|
||||
):
|
||||
"""
|
||||
Test that validate_graph_with_credentials returns nodes_to_skip set
|
||||
from _validate_node_input_credentials.
|
||||
"""
|
||||
from backend.executor.utils import validate_graph_with_credentials
|
||||
|
||||
# Mock _validate_node_input_credentials to return specific values
|
||||
mock_validate = mocker.patch(
|
||||
"backend.executor.utils._validate_node_input_credentials"
|
||||
)
|
||||
expected_errors = {"node1": {"field": "error"}}
|
||||
expected_nodes_to_skip = {"node2", "node3"}
|
||||
mock_validate.return_value = (expected_errors, expected_nodes_to_skip)
|
||||
|
||||
# Mock GraphModel with validate_graph_get_errors method
|
||||
mock_graph = mocker.MagicMock()
|
||||
mock_graph.validate_graph_get_errors.return_value = {}
|
||||
|
||||
# Call the function
|
||||
errors, nodes_to_skip = await validate_graph_with_credentials(
|
||||
graph=mock_graph,
|
||||
user_id="test-user-id",
|
||||
nodes_input_masks=None,
|
||||
)
|
||||
|
||||
# Verify nodes_to_skip is passed through
|
||||
assert nodes_to_skip == expected_nodes_to_skip
|
||||
assert "node1" in errors
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_graph_execution_with_nodes_to_skip(mocker: MockerFixture):
|
||||
"""
|
||||
Test that add_graph_execution properly passes nodes_to_skip
|
||||
to the graph execution entry.
|
||||
"""
|
||||
from backend.data.execution import GraphExecutionWithNodes
|
||||
from backend.executor.utils import add_graph_execution
|
||||
|
||||
# Mock data
|
||||
graph_id = "test-graph-id"
|
||||
user_id = "test-user-id"
|
||||
inputs = {"test_input": "test_value"}
|
||||
graph_version = 1
|
||||
|
||||
# Mock the graph object
|
||||
mock_graph = mocker.MagicMock()
|
||||
mock_graph.version = graph_version
|
||||
|
||||
# Starting nodes and masks
|
||||
starting_nodes_input = [("node1", {"input1": "value1"})]
|
||||
compiled_nodes_input_masks = {}
|
||||
nodes_to_skip = {"skipped-node-1", "skipped-node-2"}
|
||||
|
||||
# Mock the graph execution object
|
||||
mock_graph_exec = mocker.MagicMock(spec=GraphExecutionWithNodes)
|
||||
mock_graph_exec.id = "execution-id-123"
|
||||
mock_graph_exec.node_executions = []
|
||||
|
||||
# Track what's passed to to_graph_execution_entry
|
||||
captured_kwargs = {}
|
||||
|
||||
def capture_to_entry(**kwargs):
|
||||
captured_kwargs.update(kwargs)
|
||||
return mocker.MagicMock()
|
||||
|
||||
mock_graph_exec.to_graph_execution_entry.side_effect = capture_to_entry
|
||||
|
||||
# Setup mocks
|
||||
mock_validate = mocker.patch(
|
||||
"backend.executor.utils.validate_and_construct_node_execution_input"
|
||||
)
|
||||
mock_edb = mocker.patch("backend.executor.utils.execution_db")
|
||||
mock_prisma = mocker.patch("backend.executor.utils.prisma")
|
||||
mock_udb = mocker.patch("backend.executor.utils.user_db")
|
||||
mock_gdb = mocker.patch("backend.executor.utils.graph_db")
|
||||
mock_get_queue = mocker.patch("backend.executor.utils.get_async_execution_queue")
|
||||
mock_get_event_bus = mocker.patch(
|
||||
"backend.executor.utils.get_async_execution_event_bus"
|
||||
)
|
||||
|
||||
# Setup returns - include nodes_to_skip in the tuple
|
||||
mock_validate.return_value = (
|
||||
mock_graph,
|
||||
starting_nodes_input,
|
||||
compiled_nodes_input_masks,
|
||||
nodes_to_skip, # This should be passed through
|
||||
)
|
||||
mock_prisma.is_connected.return_value = True
|
||||
mock_edb.create_graph_execution = mocker.AsyncMock(return_value=mock_graph_exec)
|
||||
mock_edb.update_graph_execution_stats = mocker.AsyncMock(
|
||||
return_value=mock_graph_exec
|
||||
)
|
||||
mock_edb.update_node_execution_status_batch = mocker.AsyncMock()
|
||||
|
||||
mock_user = mocker.MagicMock()
|
||||
mock_user.timezone = "UTC"
|
||||
mock_settings = mocker.MagicMock()
|
||||
mock_settings.human_in_the_loop_safe_mode = True
|
||||
|
||||
mock_udb.get_user_by_id = mocker.AsyncMock(return_value=mock_user)
|
||||
mock_gdb.get_graph_settings = mocker.AsyncMock(return_value=mock_settings)
|
||||
mock_get_queue.return_value = mocker.AsyncMock()
|
||||
mock_get_event_bus.return_value = mocker.MagicMock(publish=mocker.AsyncMock())
|
||||
|
||||
# Call the function
|
||||
await add_graph_execution(
|
||||
graph_id=graph_id,
|
||||
user_id=user_id,
|
||||
inputs=inputs,
|
||||
graph_version=graph_version,
|
||||
)
|
||||
|
||||
# Verify nodes_to_skip was passed to to_graph_execution_entry
|
||||
assert "nodes_to_skip" in captured_kwargs
|
||||
assert captured_kwargs["nodes_to_skip"] == nodes_to_skip
|
||||
|
||||
@@ -8,7 +8,6 @@ from .discord import DiscordOAuthHandler
|
||||
from .github import GitHubOAuthHandler
|
||||
from .google import GoogleOAuthHandler
|
||||
from .notion import NotionOAuthHandler
|
||||
from .reddit import RedditOAuthHandler
|
||||
from .twitter import TwitterOAuthHandler
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -21,7 +20,6 @@ _ORIGINAL_HANDLERS = [
|
||||
GitHubOAuthHandler,
|
||||
GoogleOAuthHandler,
|
||||
NotionOAuthHandler,
|
||||
RedditOAuthHandler,
|
||||
TwitterOAuthHandler,
|
||||
TodoistOAuthHandler,
|
||||
]
|
||||
|
||||
@@ -1,208 +0,0 @@
|
||||
import time
|
||||
import urllib.parse
|
||||
from typing import ClassVar, Optional
|
||||
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.model import OAuth2Credentials
|
||||
from backend.integrations.oauth.base import BaseOAuthHandler
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.request import Requests
|
||||
from backend.util.settings import Settings
|
||||
|
||||
settings = Settings()
|
||||
|
||||
|
||||
class RedditOAuthHandler(BaseOAuthHandler):
|
||||
"""
|
||||
Reddit OAuth 2.0 handler.
|
||||
|
||||
Based on the documentation at:
|
||||
- https://github.com/reddit-archive/reddit/wiki/OAuth2
|
||||
|
||||
Notes:
|
||||
- Reddit requires `duration=permanent` to get refresh tokens
|
||||
- Access tokens expire after 1 hour (3600 seconds)
|
||||
- Reddit requires HTTP Basic Auth for token requests
|
||||
- Reddit requires a unique User-Agent header
|
||||
"""
|
||||
|
||||
PROVIDER_NAME = ProviderName.REDDIT
|
||||
DEFAULT_SCOPES: ClassVar[list[str]] = [
|
||||
"identity", # Get username, verify auth
|
||||
"read", # Access posts and comments
|
||||
"submit", # Submit new posts and comments
|
||||
"edit", # Edit own posts and comments
|
||||
"history", # Access user's post history
|
||||
"privatemessages", # Access inbox and send private messages
|
||||
"flair", # Access and set flair on posts/subreddits
|
||||
]
|
||||
|
||||
AUTHORIZE_URL = "https://www.reddit.com/api/v1/authorize"
|
||||
TOKEN_URL = "https://www.reddit.com/api/v1/access_token"
|
||||
USERNAME_URL = "https://oauth.reddit.com/api/v1/me"
|
||||
REVOKE_URL = "https://www.reddit.com/api/v1/revoke_token"
|
||||
|
||||
def __init__(self, client_id: str, client_secret: str, redirect_uri: str):
|
||||
self.client_id = client_id
|
||||
self.client_secret = client_secret
|
||||
self.redirect_uri = redirect_uri
|
||||
|
||||
def get_login_url(
|
||||
self, scopes: list[str], state: str, code_challenge: Optional[str]
|
||||
) -> str:
|
||||
"""Generate Reddit OAuth 2.0 authorization URL"""
|
||||
scopes = self.handle_default_scopes(scopes)
|
||||
|
||||
params = {
|
||||
"response_type": "code",
|
||||
"client_id": self.client_id,
|
||||
"redirect_uri": self.redirect_uri,
|
||||
"scope": " ".join(scopes),
|
||||
"state": state,
|
||||
"duration": "permanent", # Required for refresh tokens
|
||||
}
|
||||
|
||||
return f"{self.AUTHORIZE_URL}?{urllib.parse.urlencode(params)}"
|
||||
|
||||
async def exchange_code_for_tokens(
|
||||
self, code: str, scopes: list[str], code_verifier: Optional[str]
|
||||
) -> OAuth2Credentials:
|
||||
"""Exchange authorization code for access tokens"""
|
||||
scopes = self.handle_default_scopes(scopes)
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/x-www-form-urlencoded",
|
||||
"User-Agent": settings.config.reddit_user_agent,
|
||||
}
|
||||
|
||||
data = {
|
||||
"grant_type": "authorization_code",
|
||||
"code": code,
|
||||
"redirect_uri": self.redirect_uri,
|
||||
}
|
||||
|
||||
# Reddit requires HTTP Basic Auth for token requests
|
||||
auth = (self.client_id, self.client_secret)
|
||||
|
||||
response = await Requests().post(
|
||||
self.TOKEN_URL, headers=headers, data=data, auth=auth
|
||||
)
|
||||
|
||||
if not response.ok:
|
||||
error_text = response.text()
|
||||
raise ValueError(
|
||||
f"Reddit token exchange failed: {response.status} - {error_text}"
|
||||
)
|
||||
|
||||
tokens = response.json()
|
||||
|
||||
if "error" in tokens:
|
||||
raise ValueError(f"Reddit OAuth error: {tokens.get('error')}")
|
||||
|
||||
username = await self._get_username(tokens["access_token"])
|
||||
|
||||
return OAuth2Credentials(
|
||||
provider=self.PROVIDER_NAME,
|
||||
title=None,
|
||||
username=username,
|
||||
access_token=tokens["access_token"],
|
||||
refresh_token=tokens.get("refresh_token"),
|
||||
access_token_expires_at=int(time.time()) + tokens.get("expires_in", 3600),
|
||||
refresh_token_expires_at=None, # Reddit refresh tokens don't expire
|
||||
scopes=scopes,
|
||||
)
|
||||
|
||||
async def _get_username(self, access_token: str) -> str:
|
||||
"""Get the username from the access token"""
|
||||
headers = {
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
"User-Agent": settings.config.reddit_user_agent,
|
||||
}
|
||||
|
||||
response = await Requests().get(self.USERNAME_URL, headers=headers)
|
||||
|
||||
if not response.ok:
|
||||
raise ValueError(f"Failed to get Reddit username: {response.status}")
|
||||
|
||||
data = response.json()
|
||||
return data.get("name", "unknown")
|
||||
|
||||
async def _refresh_tokens(
|
||||
self, credentials: OAuth2Credentials
|
||||
) -> OAuth2Credentials:
|
||||
"""Refresh access tokens using refresh token"""
|
||||
if not credentials.refresh_token:
|
||||
raise ValueError("No refresh token available")
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/x-www-form-urlencoded",
|
||||
"User-Agent": settings.config.reddit_user_agent,
|
||||
}
|
||||
|
||||
data = {
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": credentials.refresh_token.get_secret_value(),
|
||||
}
|
||||
|
||||
auth = (self.client_id, self.client_secret)
|
||||
|
||||
response = await Requests().post(
|
||||
self.TOKEN_URL, headers=headers, data=data, auth=auth
|
||||
)
|
||||
|
||||
if not response.ok:
|
||||
error_text = response.text()
|
||||
raise ValueError(
|
||||
f"Reddit token refresh failed: {response.status} - {error_text}"
|
||||
)
|
||||
|
||||
tokens = response.json()
|
||||
|
||||
if "error" in tokens:
|
||||
raise ValueError(f"Reddit OAuth error: {tokens.get('error')}")
|
||||
|
||||
username = await self._get_username(tokens["access_token"])
|
||||
|
||||
# Reddit may or may not return a new refresh token
|
||||
new_refresh_token = tokens.get("refresh_token")
|
||||
if new_refresh_token:
|
||||
refresh_token: SecretStr | None = SecretStr(new_refresh_token)
|
||||
elif credentials.refresh_token:
|
||||
# Keep the existing refresh token
|
||||
refresh_token = credentials.refresh_token
|
||||
else:
|
||||
refresh_token = None
|
||||
|
||||
return OAuth2Credentials(
|
||||
id=credentials.id,
|
||||
provider=self.PROVIDER_NAME,
|
||||
title=credentials.title,
|
||||
username=username,
|
||||
access_token=tokens["access_token"],
|
||||
refresh_token=refresh_token,
|
||||
access_token_expires_at=int(time.time()) + tokens.get("expires_in", 3600),
|
||||
refresh_token_expires_at=None,
|
||||
scopes=credentials.scopes,
|
||||
)
|
||||
|
||||
async def revoke_tokens(self, credentials: OAuth2Credentials) -> bool:
|
||||
"""Revoke the access token"""
|
||||
headers = {
|
||||
"Content-Type": "application/x-www-form-urlencoded",
|
||||
"User-Agent": settings.config.reddit_user_agent,
|
||||
}
|
||||
|
||||
data = {
|
||||
"token": credentials.access_token.get_secret_value(),
|
||||
"token_type_hint": "access_token",
|
||||
}
|
||||
|
||||
auth = (self.client_id, self.client_secret)
|
||||
|
||||
response = await Requests().post(
|
||||
self.REVOKE_URL, headers=headers, data=data, auth=auth
|
||||
)
|
||||
|
||||
# Reddit returns 204 No Content on successful revocation
|
||||
return response.ok
|
||||
@@ -264,7 +264,7 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
|
||||
)
|
||||
|
||||
reddit_user_agent: str = Field(
|
||||
default="web:AutoGPT:v0.6.0 (by /u/autogpt)",
|
||||
default="AutoGPT:1.0 (by /u/autogpt)",
|
||||
description="The user agent for the Reddit API",
|
||||
)
|
||||
|
||||
|
||||
@@ -1,227 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Generate a lightweight stub for prisma/types.py that collapses all exported
|
||||
symbols to Any. This prevents Pyright from spending time/budget on Prisma's
|
||||
query DSL types while keeping runtime behavior unchanged.
|
||||
|
||||
Usage:
|
||||
poetry run gen-prisma-stub
|
||||
|
||||
This script automatically finds the prisma package location and generates
|
||||
the types.pyi stub file in the same directory as types.py.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
import importlib.util
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Iterable, Set
|
||||
|
||||
|
||||
def _iter_assigned_names(target: ast.expr) -> Iterable[str]:
|
||||
"""Extract names from assignment targets (handles tuple unpacking)."""
|
||||
if isinstance(target, ast.Name):
|
||||
yield target.id
|
||||
elif isinstance(target, (ast.Tuple, ast.List)):
|
||||
for elt in target.elts:
|
||||
yield from _iter_assigned_names(elt)
|
||||
|
||||
|
||||
def _is_private(name: str) -> bool:
|
||||
"""Check if a name is private (starts with _ but not __)."""
|
||||
return name.startswith("_") and not name.startswith("__")
|
||||
|
||||
|
||||
def _is_safe_type_alias(node: ast.Assign) -> bool:
|
||||
"""Check if an assignment is a safe type alias that shouldn't be stubbed.
|
||||
|
||||
Safe types are:
|
||||
- Literal types (don't cause type budget issues)
|
||||
- Simple type references (SortMode, SortOrder, etc.)
|
||||
- TypeVar definitions
|
||||
"""
|
||||
if not node.value:
|
||||
return False
|
||||
|
||||
# Check if it's a Subscript (like Literal[...], Union[...], TypeVar[...])
|
||||
if isinstance(node.value, ast.Subscript):
|
||||
# Get the base type name
|
||||
if isinstance(node.value.value, ast.Name):
|
||||
base_name = node.value.value.id
|
||||
# Literal types are safe
|
||||
if base_name == "Literal":
|
||||
return True
|
||||
# TypeVar is safe
|
||||
if base_name == "TypeVar":
|
||||
return True
|
||||
elif isinstance(node.value.value, ast.Attribute):
|
||||
# Handle typing_extensions.Literal etc.
|
||||
if node.value.value.attr == "Literal":
|
||||
return True
|
||||
|
||||
# Check if it's a simple Name reference (like SortMode = _types.SortMode)
|
||||
if isinstance(node.value, ast.Attribute):
|
||||
return True
|
||||
|
||||
# Check if it's a Call (like TypeVar(...))
|
||||
if isinstance(node.value, ast.Call):
|
||||
if isinstance(node.value.func, ast.Name):
|
||||
if node.value.func.id == "TypeVar":
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def collect_top_level_symbols(
|
||||
tree: ast.Module, source_lines: list[str]
|
||||
) -> tuple[Set[str], Set[str], list[str], Set[str]]:
|
||||
"""Collect all top-level symbols from an AST module.
|
||||
|
||||
Returns:
|
||||
Tuple of (class_names, function_names, safe_variable_sources, unsafe_variable_names)
|
||||
safe_variable_sources contains the actual source code lines for safe variables
|
||||
"""
|
||||
classes: Set[str] = set()
|
||||
functions: Set[str] = set()
|
||||
safe_variable_sources: list[str] = []
|
||||
unsafe_variables: Set[str] = set()
|
||||
|
||||
for node in tree.body:
|
||||
if isinstance(node, ast.ClassDef):
|
||||
if not _is_private(node.name):
|
||||
classes.add(node.name)
|
||||
elif isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
|
||||
if not _is_private(node.name):
|
||||
functions.add(node.name)
|
||||
elif isinstance(node, ast.Assign):
|
||||
is_safe = _is_safe_type_alias(node)
|
||||
names = []
|
||||
for t in node.targets:
|
||||
for n in _iter_assigned_names(t):
|
||||
if not _is_private(n):
|
||||
names.append(n)
|
||||
if names:
|
||||
if is_safe:
|
||||
# Extract the source code for this assignment
|
||||
start_line = node.lineno - 1 # 0-indexed
|
||||
end_line = node.end_lineno if node.end_lineno else node.lineno
|
||||
source = "\n".join(source_lines[start_line:end_line])
|
||||
safe_variable_sources.append(source)
|
||||
else:
|
||||
unsafe_variables.update(names)
|
||||
elif isinstance(node, ast.AnnAssign) and node.target:
|
||||
# Annotated assignments are always stubbed
|
||||
for n in _iter_assigned_names(node.target):
|
||||
if not _is_private(n):
|
||||
unsafe_variables.add(n)
|
||||
|
||||
return classes, functions, safe_variable_sources, unsafe_variables
|
||||
|
||||
|
||||
def find_prisma_types_path() -> Path:
|
||||
"""Find the prisma types.py file in the installed package."""
|
||||
spec = importlib.util.find_spec("prisma")
|
||||
if spec is None or spec.origin is None:
|
||||
raise RuntimeError("Could not find prisma package. Is it installed?")
|
||||
|
||||
prisma_dir = Path(spec.origin).parent
|
||||
types_path = prisma_dir / "types.py"
|
||||
|
||||
if not types_path.exists():
|
||||
raise RuntimeError(f"prisma/types.py not found at {types_path}")
|
||||
|
||||
return types_path
|
||||
|
||||
|
||||
def generate_stub(src_path: Path, stub_path: Path) -> int:
|
||||
"""Generate the .pyi stub file from the source types.py."""
|
||||
code = src_path.read_text(encoding="utf-8", errors="ignore")
|
||||
source_lines = code.splitlines()
|
||||
tree = ast.parse(code, filename=str(src_path))
|
||||
classes, functions, safe_variable_sources, unsafe_variables = (
|
||||
collect_top_level_symbols(tree, source_lines)
|
||||
)
|
||||
|
||||
header = """\
|
||||
# -*- coding: utf-8 -*-
|
||||
# Auto-generated stub file - DO NOT EDIT
|
||||
# Generated by gen_prisma_types_stub.py
|
||||
#
|
||||
# This stub intentionally collapses complex Prisma query DSL types to Any.
|
||||
# Prisma's generated types can explode Pyright's type inference budgets
|
||||
# on large schemas. We collapse them to Any so the rest of the codebase
|
||||
# can remain strongly typed while keeping runtime behavior unchanged.
|
||||
#
|
||||
# Safe types (Literal, TypeVar, simple references) are preserved from the
|
||||
# original types.py to maintain proper type checking where possible.
|
||||
|
||||
from __future__ import annotations
|
||||
from typing import Any
|
||||
from typing_extensions import Literal
|
||||
|
||||
# Re-export commonly used typing constructs that may be imported from this module
|
||||
from typing import TYPE_CHECKING, TypeVar, Generic, Union, Optional, List, Dict
|
||||
|
||||
# Base type alias for stubbed Prisma types - allows any dict structure
|
||||
_PrismaDict = dict[str, Any]
|
||||
|
||||
"""
|
||||
|
||||
lines = [header]
|
||||
|
||||
# Include safe variable definitions (Literal types, TypeVars, etc.)
|
||||
lines.append("# Safe type definitions preserved from original types.py")
|
||||
for source in safe_variable_sources:
|
||||
lines.append(source)
|
||||
lines.append("")
|
||||
|
||||
# Stub all classes and unsafe variables uniformly as dict[str, Any] aliases
|
||||
# This allows:
|
||||
# 1. Use in type annotations: x: SomeType
|
||||
# 2. Constructor calls: SomeType(...)
|
||||
# 3. Dict literal assignments: x: SomeType = {...}
|
||||
lines.append(
|
||||
"# Stubbed types (collapsed to dict[str, Any] to prevent type budget exhaustion)"
|
||||
)
|
||||
all_stubbed = sorted(classes | unsafe_variables)
|
||||
for name in all_stubbed:
|
||||
lines.append(f"{name} = _PrismaDict")
|
||||
|
||||
lines.append("")
|
||||
|
||||
# Stub functions
|
||||
for name in sorted(functions):
|
||||
lines.append(f"def {name}(*args: Any, **kwargs: Any) -> Any: ...")
|
||||
|
||||
lines.append("")
|
||||
|
||||
stub_path.write_text("\n".join(lines), encoding="utf-8")
|
||||
return (
|
||||
len(classes)
|
||||
+ len(functions)
|
||||
+ len(safe_variable_sources)
|
||||
+ len(unsafe_variables)
|
||||
)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""Main entry point."""
|
||||
try:
|
||||
types_path = find_prisma_types_path()
|
||||
stub_path = types_path.with_suffix(".pyi")
|
||||
|
||||
print(f"Found prisma types.py at: {types_path}")
|
||||
print(f"Generating stub at: {stub_path}")
|
||||
|
||||
num_symbols = generate_stub(types_path, stub_path)
|
||||
print(f"Generated {stub_path.name} with {num_symbols} Any-typed symbols")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error: {e}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -25,9 +25,6 @@ def run(*command: str) -> None:
|
||||
|
||||
|
||||
def lint():
|
||||
# Generate Prisma types stub before running pyright to prevent type budget exhaustion
|
||||
run("gen-prisma-stub")
|
||||
|
||||
lint_step_args: list[list[str]] = [
|
||||
["ruff", "check", *TARGET_DIRS, "--exit-zero"],
|
||||
["ruff", "format", "--diff", "--check", LIBS_DIR],
|
||||
@@ -52,6 +49,4 @@ def format():
|
||||
run("ruff", "format", LIBS_DIR)
|
||||
run("isort", "--profile", "black", BACKEND_DIR)
|
||||
run("black", BACKEND_DIR)
|
||||
# Generate Prisma types stub before running pyright to prevent type budget exhaustion
|
||||
run("gen-prisma-stub")
|
||||
run("pyright", *TARGET_DIRS)
|
||||
|
||||
@@ -117,7 +117,6 @@ lint = "linter:lint"
|
||||
test = "run_tests:test"
|
||||
load-store-agents = "test.load_store_agents:run"
|
||||
export-api-schema = "backend.cli.generate_openapi_json:main"
|
||||
gen-prisma-stub = "gen_prisma_types_stub:main"
|
||||
oauth-tool = "backend.cli.oauth_tool:cli"
|
||||
|
||||
[tool.isort]
|
||||
@@ -135,9 +134,6 @@ ignore_patterns = []
|
||||
[tool.pytest.ini_options]
|
||||
asyncio_mode = "auto"
|
||||
asyncio_default_fixture_loop_scope = "session"
|
||||
# Disable syrupy plugin to avoid conflict with pytest-snapshot
|
||||
# Both provide --snapshot-update argument causing ArgumentError
|
||||
addopts = "-p no:syrupy"
|
||||
filterwarnings = [
|
||||
"ignore:'audioop' is deprecated:DeprecationWarning:discord.player",
|
||||
"ignore:invalid escape sequence:DeprecationWarning:tweepy.api",
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
"created_at": "2025-09-04T13:37:00",
|
||||
"credentials_input_schema": {
|
||||
"properties": {},
|
||||
"required": [],
|
||||
"title": "TestGraphCredentialsInputSchema",
|
||||
"type": "object"
|
||||
},
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
{
|
||||
"credentials_input_schema": {
|
||||
"properties": {},
|
||||
"required": [],
|
||||
"title": "TestGraphCredentialsInputSchema",
|
||||
"type": "object"
|
||||
},
|
||||
|
||||
@@ -4,7 +4,6 @@
|
||||
"id": "test-agent-1",
|
||||
"graph_id": "test-agent-1",
|
||||
"graph_version": 1,
|
||||
"owner_user_id": "3e53486c-cf57-477e-ba2a-cb02dc828e1a",
|
||||
"image_url": null,
|
||||
"creator_name": "Test Creator",
|
||||
"creator_image_url": "",
|
||||
@@ -42,7 +41,6 @@
|
||||
"id": "test-agent-2",
|
||||
"graph_id": "test-agent-2",
|
||||
"graph_version": 1,
|
||||
"owner_user_id": "3e53486c-cf57-477e-ba2a-cb02dc828e1a",
|
||||
"image_url": null,
|
||||
"creator_name": "Test Creator",
|
||||
"creator_image_url": "",
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
{
|
||||
"submissions": [
|
||||
{
|
||||
"listing_id": "test-listing-id",
|
||||
"agent_id": "test-agent-id",
|
||||
"agent_version": 1,
|
||||
"name": "Test Agent",
|
||||
|
||||
@@ -37,7 +37,7 @@ services:
|
||||
context: ../
|
||||
dockerfile: autogpt_platform/backend/Dockerfile
|
||||
target: migrate
|
||||
command: ["sh", "-c", "poetry run prisma generate && poetry run gen-prisma-stub && poetry run prisma migrate deploy"]
|
||||
command: ["sh", "-c", "poetry run prisma generate && poetry run prisma migrate deploy"]
|
||||
develop:
|
||||
watch:
|
||||
- path: ./
|
||||
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 2.6 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 16 KiB |
@@ -66,7 +66,6 @@ export const RunInputDialog = ({
|
||||
formContext={{
|
||||
showHandles: false,
|
||||
size: "large",
|
||||
showOptionalToggle: false,
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
|
||||
@@ -66,7 +66,7 @@ export const useRunInputDialog = ({
|
||||
if (isCredentialFieldSchema(fieldSchema)) {
|
||||
dynamicUiSchema[fieldName] = {
|
||||
...dynamicUiSchema[fieldName],
|
||||
"ui:field": "custom/credential_field",
|
||||
"ui:field": "credentials",
|
||||
};
|
||||
}
|
||||
});
|
||||
@@ -76,18 +76,12 @@ export const useRunInputDialog = ({
|
||||
}, [credentialsSchema]);
|
||||
|
||||
const handleManualRun = async () => {
|
||||
// Filter out incomplete credentials (those without a valid id)
|
||||
// RJSF auto-populates const values (provider, type) but not id field
|
||||
const validCredentials = Object.fromEntries(
|
||||
Object.entries(credentialValues).filter(([_, cred]) => cred && cred.id),
|
||||
);
|
||||
|
||||
await executeGraph({
|
||||
graphId: flowID ?? "",
|
||||
graphVersion: flowVersion || null,
|
||||
data: {
|
||||
inputs: inputValues,
|
||||
credentials_inputs: validCredentials,
|
||||
credentials_inputs: credentialValues,
|
||||
source: "builder",
|
||||
},
|
||||
});
|
||||
|
||||
@@ -68,10 +68,7 @@ export const NodeHeader = ({ data, nodeId }: Props) => {
|
||||
<Tooltip>
|
||||
<TooltipTrigger asChild>
|
||||
<div>
|
||||
<Text
|
||||
variant="large-semibold"
|
||||
className="line-clamp-1 hover:cursor-text"
|
||||
>
|
||||
<Text variant="large-semibold" className="line-clamp-1">
|
||||
{beautifyString(title).replace("Block", "").trim()}
|
||||
</Text>
|
||||
</div>
|
||||
|
||||
@@ -151,7 +151,7 @@ export const NodeDataViewer: FC<NodeDataViewerProps> = ({
|
||||
</div>
|
||||
|
||||
<div className="flex justify-end pt-4">
|
||||
{outputItems.length > 1 && (
|
||||
{outputItems.length > 0 && (
|
||||
<OutputActions
|
||||
items={outputItems.map((item) => ({
|
||||
value: item.value,
|
||||
|
||||
@@ -89,18 +89,6 @@ export function extractOptions(
|
||||
|
||||
// get display type and color for schema types [need for type display next to field name]
|
||||
export const getTypeDisplayInfo = (schema: any) => {
|
||||
if (
|
||||
schema?.type === "array" &&
|
||||
"format" in schema &&
|
||||
schema.format === "table"
|
||||
) {
|
||||
return {
|
||||
displayType: "table",
|
||||
colorClass: "!text-indigo-500",
|
||||
hexColor: "#6366f1",
|
||||
};
|
||||
}
|
||||
|
||||
if (schema?.type === "string" && schema?.format) {
|
||||
const formatMap: Record<
|
||||
string,
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
export const uiSchema = {
|
||||
credentials: {
|
||||
"ui:field": "custom/credential_field",
|
||||
"ui:field": "credentials",
|
||||
provider: { "ui:widget": "hidden" },
|
||||
type: { "ui:widget": "hidden" },
|
||||
id: { "ui:autofocus": true },
|
||||
|
||||
@@ -68,9 +68,6 @@ type NodeStore = {
|
||||
clearAllNodeErrors: () => void; // Add this
|
||||
|
||||
syncHardcodedValuesWithHandleIds: (nodeId: string) => void;
|
||||
|
||||
// Credentials optional helpers
|
||||
setCredentialsOptional: (nodeId: string, optional: boolean) => void;
|
||||
};
|
||||
|
||||
export const useNodeStore = create<NodeStore>((set, get) => ({
|
||||
@@ -229,9 +226,6 @@ export const useNodeStore = create<NodeStore>((set, get) => ({
|
||||
...(node.data.metadata?.customized_name !== undefined && {
|
||||
customized_name: node.data.metadata.customized_name,
|
||||
}),
|
||||
...(node.data.metadata?.credentials_optional !== undefined && {
|
||||
credentials_optional: node.data.metadata.credentials_optional,
|
||||
}),
|
||||
},
|
||||
};
|
||||
},
|
||||
@@ -348,30 +342,4 @@ export const useNodeStore = create<NodeStore>((set, get) => ({
|
||||
}));
|
||||
}
|
||||
},
|
||||
|
||||
setCredentialsOptional: (nodeId: string, optional: boolean) => {
|
||||
set((state) => ({
|
||||
nodes: state.nodes.map((n) =>
|
||||
n.id === nodeId
|
||||
? {
|
||||
...n,
|
||||
data: {
|
||||
...n.data,
|
||||
metadata: {
|
||||
...n.data.metadata,
|
||||
credentials_optional: optional,
|
||||
},
|
||||
},
|
||||
}
|
||||
: n,
|
||||
),
|
||||
}));
|
||||
|
||||
const newState = {
|
||||
nodes: get().nodes,
|
||||
edges: useEdgeStore.getState().edges,
|
||||
};
|
||||
|
||||
useHistoryStore.getState().pushState(newState);
|
||||
},
|
||||
}));
|
||||
|
||||
@@ -34,9 +34,7 @@ type Props = {
|
||||
onSelectCredentials: (newValue?: CredentialsMetaInput) => void;
|
||||
onLoaded?: (loaded: boolean) => void;
|
||||
readOnly?: boolean;
|
||||
isOptional?: boolean;
|
||||
showTitle?: boolean;
|
||||
variant?: "default" | "node";
|
||||
};
|
||||
|
||||
export function CredentialsInput({
|
||||
@@ -47,9 +45,7 @@ export function CredentialsInput({
|
||||
siblingInputs,
|
||||
onLoaded,
|
||||
readOnly = false,
|
||||
isOptional = false,
|
||||
showTitle = true,
|
||||
variant = "default",
|
||||
}: Props) {
|
||||
const hookData = useCredentialsInput({
|
||||
schema,
|
||||
@@ -58,7 +54,6 @@ export function CredentialsInput({
|
||||
siblingInputs,
|
||||
onLoaded,
|
||||
readOnly,
|
||||
isOptional,
|
||||
});
|
||||
|
||||
if (!isLoaded(hookData)) {
|
||||
@@ -99,14 +94,7 @@ export function CredentialsInput({
|
||||
<div className={cn("mb-6", className)}>
|
||||
{showTitle && (
|
||||
<div className="mb-2 flex items-center gap-2">
|
||||
<Text variant="large-medium">
|
||||
{displayName} credentials
|
||||
{isOptional && (
|
||||
<span className="ml-1 text-sm font-normal text-gray-500">
|
||||
(optional)
|
||||
</span>
|
||||
)}
|
||||
</Text>
|
||||
<Text variant="large-medium">{displayName} credentials</Text>
|
||||
{schema.description && (
|
||||
<InformationTooltip description={schema.description} />
|
||||
)}
|
||||
@@ -115,17 +103,14 @@ export function CredentialsInput({
|
||||
|
||||
{hasCredentialsToShow ? (
|
||||
<>
|
||||
{(credentialsToShow.length > 1 || isOptional) && !readOnly ? (
|
||||
{credentialsToShow.length > 1 && !readOnly ? (
|
||||
<CredentialsSelect
|
||||
credentials={credentialsToShow}
|
||||
provider={provider}
|
||||
displayName={displayName}
|
||||
selectedCredentials={selectedCredential}
|
||||
onSelectCredential={handleCredentialSelect}
|
||||
onClearCredential={() => onSelectCredential(undefined)}
|
||||
readOnly={readOnly}
|
||||
allowNone={isOptional}
|
||||
variant={variant}
|
||||
/>
|
||||
) : (
|
||||
<div className="mb-4 space-y-2">
|
||||
|
||||
@@ -30,8 +30,6 @@ type CredentialRowProps = {
|
||||
readOnly?: boolean;
|
||||
showCaret?: boolean;
|
||||
asSelectTrigger?: boolean;
|
||||
/** When "node", applies compact styling for node context */
|
||||
variant?: "default" | "node";
|
||||
};
|
||||
|
||||
export function CredentialRow({
|
||||
@@ -43,22 +41,14 @@ export function CredentialRow({
|
||||
readOnly = false,
|
||||
showCaret = false,
|
||||
asSelectTrigger = false,
|
||||
variant = "default",
|
||||
}: CredentialRowProps) {
|
||||
const ProviderIcon = providerIcons[provider] || fallbackIcon;
|
||||
const isNodeVariant = variant === "node";
|
||||
|
||||
return (
|
||||
<div
|
||||
className={cn(
|
||||
"flex items-center gap-3 rounded-medium border border-zinc-200 bg-white p-3 transition-colors",
|
||||
asSelectTrigger && isNodeVariant
|
||||
? "min-w-0 flex-1 overflow-hidden border-0 bg-transparent"
|
||||
: asSelectTrigger
|
||||
? "border-0 bg-transparent"
|
||||
: readOnly
|
||||
? "w-fit"
|
||||
: "",
|
||||
asSelectTrigger ? "border-0 bg-transparent" : readOnly ? "w-fit" : "",
|
||||
)}
|
||||
onClick={readOnly || showCaret || asSelectTrigger ? undefined : onSelect}
|
||||
style={
|
||||
@@ -71,31 +61,19 @@ export function CredentialRow({
|
||||
<ProviderIcon className="h-3 w-3 text-white" />
|
||||
</div>
|
||||
<IconKey className="h-5 w-5 shrink-0 text-zinc-800" />
|
||||
<div
|
||||
className={cn(
|
||||
"flex min-w-0 flex-1 flex-nowrap items-center gap-4",
|
||||
isNodeVariant && "overflow-hidden",
|
||||
)}
|
||||
>
|
||||
<div className="flex min-w-0 flex-1 flex-nowrap items-center gap-4">
|
||||
<Text
|
||||
variant="body"
|
||||
className={cn(
|
||||
"tracking-tight",
|
||||
isNodeVariant
|
||||
? "truncate"
|
||||
: "line-clamp-1 flex-[0_0_50%] text-ellipsis",
|
||||
)}
|
||||
className="line-clamp-1 flex-[0_0_50%] text-ellipsis tracking-tight"
|
||||
>
|
||||
{getCredentialDisplayName(credential, displayName)}
|
||||
</Text>
|
||||
{!(asSelectTrigger && isNodeVariant) && (
|
||||
<Text
|
||||
variant="large"
|
||||
className="relative top-1 hidden overflow-hidden whitespace-nowrap font-mono tracking-tight md:block"
|
||||
>
|
||||
{"*".repeat(MASKED_KEY_LENGTH)}
|
||||
</Text>
|
||||
)}
|
||||
<Text
|
||||
variant="large"
|
||||
className="lex-[0_0_40%] relative top-1 hidden overflow-hidden whitespace-nowrap font-mono tracking-tight md:block"
|
||||
>
|
||||
{"*".repeat(MASKED_KEY_LENGTH)}
|
||||
</Text>
|
||||
</div>
|
||||
{showCaret && !asSelectTrigger && (
|
||||
<CaretDown className="h-4 w-4 shrink-0 text-gray-400" />
|
||||
|
||||
@@ -7,7 +7,6 @@ import {
|
||||
} from "@/components/__legacy__/ui/select";
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import { CredentialsMetaInput } from "@/lib/autogpt-server-api/types";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { useEffect } from "react";
|
||||
import { getCredentialDisplayName } from "../../helpers";
|
||||
import { CredentialRow } from "../CredentialRow/CredentialRow";
|
||||
@@ -24,11 +23,7 @@ interface Props {
|
||||
displayName: string;
|
||||
selectedCredentials?: CredentialsMetaInput;
|
||||
onSelectCredential: (credentialId: string) => void;
|
||||
onClearCredential?: () => void;
|
||||
readOnly?: boolean;
|
||||
allowNone?: boolean;
|
||||
/** When "node", applies compact styling for node context */
|
||||
variant?: "default" | "node";
|
||||
}
|
||||
|
||||
export function CredentialsSelect({
|
||||
@@ -37,38 +32,22 @@ export function CredentialsSelect({
|
||||
displayName,
|
||||
selectedCredentials,
|
||||
onSelectCredential,
|
||||
onClearCredential,
|
||||
readOnly = false,
|
||||
allowNone = true,
|
||||
variant = "default",
|
||||
}: Props) {
|
||||
// Auto-select first credential if none is selected (only if allowNone is false)
|
||||
// Auto-select first credential if none is selected
|
||||
useEffect(() => {
|
||||
if (!allowNone && !selectedCredentials && credentials.length > 0) {
|
||||
if (!selectedCredentials && credentials.length > 0) {
|
||||
onSelectCredential(credentials[0].id);
|
||||
}
|
||||
}, [allowNone, selectedCredentials, credentials, onSelectCredential]);
|
||||
|
||||
const handleValueChange = (value: string) => {
|
||||
if (value === "__none__") {
|
||||
onClearCredential?.();
|
||||
} else {
|
||||
onSelectCredential(value);
|
||||
}
|
||||
};
|
||||
}, [selectedCredentials, credentials, onSelectCredential]);
|
||||
|
||||
return (
|
||||
<div className="mb-4 w-full">
|
||||
<Select
|
||||
value={selectedCredentials?.id || (allowNone ? "__none__" : "")}
|
||||
onValueChange={handleValueChange}
|
||||
value={selectedCredentials?.id || ""}
|
||||
onValueChange={(value) => onSelectCredential(value)}
|
||||
>
|
||||
<SelectTrigger
|
||||
className={cn(
|
||||
"h-auto min-h-12 w-full rounded-medium border-zinc-200 p-0 pr-4 shadow-none",
|
||||
variant === "node" && "overflow-hidden",
|
||||
)}
|
||||
>
|
||||
<SelectTrigger className="h-auto min-h-12 w-full rounded-medium border-zinc-200 p-0 pr-4 shadow-none">
|
||||
{selectedCredentials ? (
|
||||
<SelectValue key={selectedCredentials.id} asChild>
|
||||
<CredentialRow
|
||||
@@ -84,7 +63,6 @@ export function CredentialsSelect({
|
||||
onDelete={() => {}}
|
||||
readOnly={readOnly}
|
||||
asSelectTrigger={true}
|
||||
variant={variant}
|
||||
/>
|
||||
</SelectValue>
|
||||
) : (
|
||||
@@ -92,15 +70,6 @@ export function CredentialsSelect({
|
||||
)}
|
||||
</SelectTrigger>
|
||||
<SelectContent>
|
||||
{allowNone && (
|
||||
<SelectItem key="__none__" value="__none__">
|
||||
<div className="flex items-center gap-2">
|
||||
<Text variant="body" className="tracking-tight text-gray-500">
|
||||
None (skip this credential)
|
||||
</Text>
|
||||
</div>
|
||||
</SelectItem>
|
||||
)}
|
||||
{credentials.map((credential) => (
|
||||
<SelectItem key={credential.id} value={credential.id}>
|
||||
<div className="flex items-center gap-2">
|
||||
|
||||
@@ -22,7 +22,6 @@ type Params = {
|
||||
siblingInputs?: Record<string, any>;
|
||||
onLoaded?: (loaded: boolean) => void;
|
||||
readOnly?: boolean;
|
||||
isOptional?: boolean;
|
||||
};
|
||||
|
||||
export function useCredentialsInput({
|
||||
@@ -32,7 +31,6 @@ export function useCredentialsInput({
|
||||
siblingInputs,
|
||||
onLoaded,
|
||||
readOnly = false,
|
||||
isOptional = false,
|
||||
}: Params) {
|
||||
const [isAPICredentialsModalOpen, setAPICredentialsModalOpen] =
|
||||
useState(false);
|
||||
@@ -101,20 +99,13 @@ export function useCredentialsInput({
|
||||
: null;
|
||||
}, [credentials]);
|
||||
|
||||
// Auto-select the one available credential (only if not optional)
|
||||
// Auto-select the one available credential
|
||||
useEffect(() => {
|
||||
if (readOnly) return;
|
||||
if (isOptional) return; // Don't auto-select when credential is optional
|
||||
if (singleCredential && !selectedCredential) {
|
||||
onSelectCredential(singleCredential);
|
||||
}
|
||||
}, [
|
||||
singleCredential,
|
||||
selectedCredential,
|
||||
onSelectCredential,
|
||||
readOnly,
|
||||
isOptional,
|
||||
]);
|
||||
}, [singleCredential, selectedCredential, onSelectCredential, readOnly]);
|
||||
|
||||
if (
|
||||
!credentials ||
|
||||
|
||||
@@ -8,7 +8,6 @@ import { WebhookTriggerBanner } from "../WebhookTriggerBanner/WebhookTriggerBann
|
||||
|
||||
export function ModalRunSection() {
|
||||
const {
|
||||
agent,
|
||||
defaultRunType,
|
||||
presetName,
|
||||
setPresetName,
|
||||
@@ -25,11 +24,6 @@ export function ModalRunSection() {
|
||||
const inputFields = Object.entries(agentInputFields || {});
|
||||
const credentialFields = Object.entries(agentCredentialsInputFields || {});
|
||||
|
||||
// Get the list of required credentials from the schema
|
||||
const requiredCredentials = new Set(
|
||||
(agent.credentials_input_schema?.required as string[]) || [],
|
||||
);
|
||||
|
||||
return (
|
||||
<div className="flex flex-col gap-4">
|
||||
{defaultRunType === "automatic-trigger" ||
|
||||
@@ -105,12 +99,14 @@ export function ModalRunSection() {
|
||||
schema={
|
||||
{ ...inputSubSchema, discriminator: undefined } as any
|
||||
}
|
||||
selectedCredentials={inputCredentials?.[key]}
|
||||
selectedCredentials={
|
||||
(inputCredentials && inputCredentials[key]) ??
|
||||
inputSubSchema.default
|
||||
}
|
||||
onSelectCredentials={(value) =>
|
||||
setInputCredentialsValue(key, value)
|
||||
}
|
||||
siblingInputs={inputValues}
|
||||
isOptional={!requiredCredentials.has(key)}
|
||||
/>
|
||||
),
|
||||
)}
|
||||
|
||||
@@ -163,21 +163,15 @@ export function useAgentRunModal(
|
||||
}, [agentInputSchema.required, inputValues]);
|
||||
|
||||
const [allCredentialsAreSet, missingCredentials] = useMemo(() => {
|
||||
// Only check required credentials from schema, not all properties
|
||||
// Credentials marked as optional in node metadata won't be in the required array
|
||||
const requiredCredentials = new Set(
|
||||
(agent.credentials_input_schema?.required as string[]) || [],
|
||||
const availableCredentials = new Set(Object.keys(inputCredentials));
|
||||
const allCredentials = new Set(
|
||||
Object.keys(agentCredentialsInputFields || {}) ?? [],
|
||||
);
|
||||
const missing = [...allCredentials].filter(
|
||||
(key) => !availableCredentials.has(key),
|
||||
);
|
||||
|
||||
// Check if required credentials have valid id (not just key existence)
|
||||
// A credential is valid only if it has an id field set
|
||||
const missing = [...requiredCredentials].filter((key) => {
|
||||
const cred = inputCredentials[key];
|
||||
return !cred || !cred.id;
|
||||
});
|
||||
|
||||
return [missing.length === 0, missing];
|
||||
}, [agent.credentials_input_schema, inputCredentials]);
|
||||
}, [agentCredentialsInputFields, inputCredentials]);
|
||||
|
||||
const credentialsRequired = useMemo(
|
||||
() => Object.keys(agentCredentialsInputFields || {}).length > 0,
|
||||
@@ -245,18 +239,12 @@ export function useAgentRunModal(
|
||||
});
|
||||
} else {
|
||||
// Manual execution
|
||||
// Filter out incomplete credentials (optional ones not selected)
|
||||
// Only send credentials that have a valid id field
|
||||
const validCredentials = Object.fromEntries(
|
||||
Object.entries(inputCredentials).filter(([_, cred]) => cred && cred.id),
|
||||
);
|
||||
|
||||
executeGraphMutation.mutate({
|
||||
graphId: agent.graph_id,
|
||||
graphVersion: agent.graph_version,
|
||||
data: {
|
||||
inputs: inputValues,
|
||||
credentials_inputs: validCredentials,
|
||||
credentials_inputs: inputCredentials,
|
||||
source: "library",
|
||||
},
|
||||
});
|
||||
|
||||
@@ -83,9 +83,7 @@ function renderCode(
|
||||
</div>
|
||||
)}
|
||||
<pre className="overflow-x-auto rounded-md bg-muted p-3">
|
||||
<code className="whitespace-pre-wrap break-words font-mono text-sm">
|
||||
{codeValue}
|
||||
</code>
|
||||
<code className="font-mono text-sm">{codeValue}</code>
|
||||
</pre>
|
||||
</div>
|
||||
);
|
||||
|
||||
@@ -40,17 +40,15 @@ export function useMarketplaceUpdate({ agent }: UseMarketplaceUpdateProps) {
|
||||
},
|
||||
);
|
||||
|
||||
// Get user's submissions - only fetch if user is the creator
|
||||
const { data: submissionsData, isLoading: isSubmissionsLoading } =
|
||||
useGetV2ListMySubmissions(
|
||||
{ page: 1, page_size: 50 },
|
||||
{
|
||||
query: {
|
||||
// Only fetch if user is the creator
|
||||
enabled: !!(user?.id && agent?.owner_user_id === user.id),
|
||||
},
|
||||
// Get user's submissions to check for pending submissions
|
||||
const { data: submissionsData } = useGetV2ListMySubmissions(
|
||||
{ page: 1, page_size: 50 }, // Get enough to cover recent submissions
|
||||
{
|
||||
query: {
|
||||
enabled: !!user?.id, // Only fetch if user is authenticated
|
||||
},
|
||||
);
|
||||
},
|
||||
);
|
||||
|
||||
const updateToLatestMutation = usePatchV2UpdateLibraryAgent({
|
||||
mutation: {
|
||||
@@ -80,45 +78,16 @@ export function useMarketplaceUpdate({ agent }: UseMarketplaceUpdateProps) {
|
||||
// Check if marketplace has a newer version than user's current version
|
||||
const marketplaceUpdateInfo = React.useMemo(() => {
|
||||
const storeAgent = okData(storeAgentData) as any;
|
||||
|
||||
if (!agent || isSubmissionsLoading) {
|
||||
if (!agent || !storeAgent) {
|
||||
return {
|
||||
hasUpdate: false,
|
||||
latestVersion: undefined,
|
||||
isUserCreator: false,
|
||||
hasPublishUpdate: false,
|
||||
};
|
||||
}
|
||||
|
||||
const isUserCreator = agent?.owner_user_id === user?.id;
|
||||
|
||||
const submissionsResponse = okData(submissionsData) as any;
|
||||
const agentSubmissions =
|
||||
submissionsResponse?.submissions?.filter(
|
||||
(submission: StoreSubmission) => submission.agent_id === agent.graph_id,
|
||||
) || [];
|
||||
|
||||
const highestSubmittedVersion =
|
||||
agentSubmissions.length > 0
|
||||
? Math.max(
|
||||
...agentSubmissions.map(
|
||||
(submission: StoreSubmission) => submission.agent_version,
|
||||
),
|
||||
)
|
||||
: 0;
|
||||
|
||||
const hasUnpublishedChanges =
|
||||
isUserCreator && agent.graph_version > highestSubmittedVersion;
|
||||
|
||||
if (!storeAgent) {
|
||||
return {
|
||||
hasUpdate: false,
|
||||
latestVersion: undefined,
|
||||
isUserCreator,
|
||||
hasPublishUpdate: agentSubmissions.length > 0 && hasUnpublishedChanges,
|
||||
};
|
||||
}
|
||||
|
||||
// Get the latest version from the marketplace
|
||||
// agentGraphVersions array contains graph version numbers as strings, get the highest one
|
||||
const latestMarketplaceVersion =
|
||||
storeAgent.agentGraphVersions?.length > 0
|
||||
? Math.max(
|
||||
@@ -128,11 +97,32 @@ export function useMarketplaceUpdate({ agent }: UseMarketplaceUpdateProps) {
|
||||
)
|
||||
: undefined;
|
||||
|
||||
// Determine if the user is the creator of this agent
|
||||
// Compare current user ID with the marketplace listing creator ID
|
||||
const isUserCreator =
|
||||
user?.id && agent.marketplace_listing?.creator.id === user.id;
|
||||
|
||||
// Check if there's a pending submission for this specific agent version
|
||||
const submissionsResponse = okData(submissionsData) as any;
|
||||
const hasPendingSubmissionForCurrentVersion =
|
||||
isUserCreator &&
|
||||
submissionsResponse?.submissions?.some(
|
||||
(submission: StoreSubmission) =>
|
||||
submission.agent_id === agent.graph_id &&
|
||||
submission.agent_version === agent.graph_version &&
|
||||
submission.status === "PENDING",
|
||||
);
|
||||
|
||||
// If user is creator and their version is newer than marketplace, show publish update banner
|
||||
// BUT only if there's no pending submission for this version
|
||||
const hasPublishUpdate =
|
||||
isUserCreator &&
|
||||
agent.graph_version >
|
||||
Math.max(latestMarketplaceVersion || 0, highestSubmittedVersion);
|
||||
!hasPendingSubmissionForCurrentVersion &&
|
||||
latestMarketplaceVersion !== undefined &&
|
||||
agent.graph_version > latestMarketplaceVersion;
|
||||
|
||||
// If marketplace version is newer than user's version, show update banner
|
||||
// This applies to both creators and non-creators
|
||||
const hasMarketplaceUpdate =
|
||||
latestMarketplaceVersion !== undefined &&
|
||||
latestMarketplaceVersion > agent.graph_version;
|
||||
@@ -143,7 +133,7 @@ export function useMarketplaceUpdate({ agent }: UseMarketplaceUpdateProps) {
|
||||
isUserCreator,
|
||||
hasPublishUpdate,
|
||||
};
|
||||
}, [agent, storeAgentData, user, submissionsData, isSubmissionsLoading]);
|
||||
}, [agent, storeAgentData, user, submissionsData]);
|
||||
|
||||
const handlePublishUpdate = () => {
|
||||
setModalOpen(true);
|
||||
|
||||
@@ -1,17 +1,15 @@
|
||||
"use client";
|
||||
|
||||
import { LibraryAgent } from "@/app/api/__generated__/models/libraryAgent";
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import { Skeleton } from "@/components/__legacy__/ui/skeleton";
|
||||
import { InfiniteScroll } from "@/components/contextual/InfiniteScroll/InfiniteScroll";
|
||||
import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag";
|
||||
import { HeartIcon } from "@phosphor-icons/react";
|
||||
import { useFavoriteAgents } from "../../hooks/useFavoriteAgents";
|
||||
import { LibraryAgentCard } from "../LibraryAgentCard/LibraryAgentCard";
|
||||
|
||||
interface Props {
|
||||
searchTerm: string;
|
||||
}
|
||||
|
||||
export function FavoritesSection({ searchTerm }: Props) {
|
||||
export function FavoritesSection() {
|
||||
const isAgentFavoritingEnabled = useGetFlag(Flag.AGENT_FAVORITING);
|
||||
const {
|
||||
allAgents: favoriteAgents,
|
||||
agentLoading: isLoading,
|
||||
@@ -19,50 +17,60 @@ export function FavoritesSection({ searchTerm }: Props) {
|
||||
hasNextPage,
|
||||
fetchNextPage,
|
||||
isFetchingNextPage,
|
||||
} = useFavoriteAgents({ searchTerm });
|
||||
} = useFavoriteAgents();
|
||||
|
||||
if (isLoading || favoriteAgents.length === 0) {
|
||||
// Only show this section if the feature flag is enabled
|
||||
if (!isAgentFavoritingEnabled) {
|
||||
return null;
|
||||
}
|
||||
|
||||
// Don't show the section if there are no favorites
|
||||
if (!isLoading && favoriteAgents.length === 0) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="!mb-8">
|
||||
<div className="mb-3 flex items-center gap-2 p-2">
|
||||
<HeartIcon className="h-5 w-5" weight="fill" />
|
||||
<div className="flex items-baseline gap-2">
|
||||
<Text variant="h4">Favorites</Text>
|
||||
{!isLoading && (
|
||||
<Text
|
||||
variant="body"
|
||||
data-testid="agents-count"
|
||||
className="relative bottom-px text-zinc-500"
|
||||
>
|
||||
{agentCount}
|
||||
</Text>
|
||||
)}
|
||||
</div>
|
||||
<div className="mb-8">
|
||||
<div className="flex items-center gap-[10px] p-2 pb-[10px]">
|
||||
<HeartIcon className="h-5 w-5 fill-red-500 text-red-500" />
|
||||
<span className="font-poppin text-[18px] font-semibold leading-[28px] text-neutral-800">
|
||||
Favorites
|
||||
</span>
|
||||
{!isLoading && (
|
||||
<span className="font-sans text-[14px] font-normal leading-6">
|
||||
{agentCount} {agentCount === 1 ? "agent" : "agents"}
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
|
||||
<div className="relative">
|
||||
<InfiniteScroll
|
||||
isFetchingNextPage={isFetchingNextPage}
|
||||
fetchNextPage={fetchNextPage}
|
||||
hasNextPage={hasNextPage}
|
||||
loader={
|
||||
<div className="flex h-8 w-full items-center justify-center">
|
||||
<div className="h-6 w-6 animate-spin rounded-full border-b-2 border-t-2 border-neutral-800" />
|
||||
</div>
|
||||
}
|
||||
>
|
||||
{isLoading ? (
|
||||
<div className="grid grid-cols-1 gap-4 sm:grid-cols-2 lg:grid-cols-3 xl:grid-cols-4">
|
||||
{favoriteAgents.map((agent: LibraryAgent) => (
|
||||
<LibraryAgentCard key={agent.id} agent={agent} />
|
||||
{[...Array(4)].map((_, i) => (
|
||||
<Skeleton key={i} className="h-48 w-full rounded-lg" />
|
||||
))}
|
||||
</div>
|
||||
</InfiniteScroll>
|
||||
) : (
|
||||
<InfiniteScroll
|
||||
isFetchingNextPage={isFetchingNextPage}
|
||||
fetchNextPage={fetchNextPage}
|
||||
hasNextPage={hasNextPage}
|
||||
loader={
|
||||
<div className="flex h-8 w-full items-center justify-center">
|
||||
<div className="h-6 w-6 animate-spin rounded-full border-b-2 border-t-2 border-neutral-800" />
|
||||
</div>
|
||||
}
|
||||
>
|
||||
<div className="grid grid-cols-1 gap-4 sm:grid-cols-2 lg:grid-cols-3 xl:grid-cols-4">
|
||||
{favoriteAgents.map((agent: LibraryAgent) => (
|
||||
<LibraryAgentCard key={agent.id} agent={agent} />
|
||||
))}
|
||||
</div>
|
||||
</InfiniteScroll>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{favoriteAgents.length > 0 && <div className="!mt-10 border-t" />}
|
||||
{favoriteAgents.length > 0 && <div className="mt-6 border-t pt-6" />}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -24,6 +24,7 @@ export function LibraryAgentCard({ agent }: Props) {
|
||||
|
||||
const {
|
||||
isFromMarketplace,
|
||||
isAgentFavoritingEnabled,
|
||||
isFavorite,
|
||||
profile,
|
||||
creator_image_url,
|
||||
@@ -36,8 +37,9 @@ export function LibraryAgentCard({ agent }: Props) {
|
||||
data-agent-id={id}
|
||||
className="group relative inline-flex h-[10.625rem] w-full max-w-[25rem] flex-col items-start justify-start gap-2.5 rounded-medium border border-zinc-100 bg-white transition-all duration-300 hover:shadow-md"
|
||||
>
|
||||
<NextLink href={`/library/agents/${id}`} className="flex-shrink-0">
|
||||
<div className="relative flex items-center gap-2 px-4 pt-3">
|
||||
<AgentCardMenu agent={agent} />
|
||||
<NextLink href={`/library/agents/${id}`} className="w-full flex-shrink-0">
|
||||
<div className="flex items-center gap-2 px-4 pt-3">
|
||||
<Avatar className="h-4 w-4 rounded-full">
|
||||
<AvatarImage
|
||||
src={
|
||||
@@ -56,13 +58,13 @@ export function LibraryAgentCard({ agent }: Props) {
|
||||
{isFromMarketplace ? "FROM MARKETPLACE" : "Built by you"}
|
||||
</Text>
|
||||
</div>
|
||||
{isAgentFavoritingEnabled && (
|
||||
<FavoriteButton
|
||||
isFavorite={isFavorite}
|
||||
onClick={handleToggleFavorite}
|
||||
/>
|
||||
)}
|
||||
</NextLink>
|
||||
<FavoriteButton
|
||||
isFavorite={isFavorite}
|
||||
onClick={handleToggleFavorite}
|
||||
className="absolute right-10 top-0"
|
||||
/>
|
||||
<AgentCardMenu agent={agent} />
|
||||
|
||||
<div className="flex w-full flex-1 flex-col px-4 pb-2">
|
||||
<Link
|
||||
|
||||
@@ -2,27 +2,21 @@
|
||||
|
||||
import { cn } from "@/lib/utils";
|
||||
import { HeartIcon } from "@phosphor-icons/react";
|
||||
import type { MouseEvent } from "react";
|
||||
|
||||
interface FavoriteButtonProps {
|
||||
isFavorite: boolean;
|
||||
onClick: (e: MouseEvent<HTMLButtonElement>) => void;
|
||||
className?: string;
|
||||
onClick: (e: React.MouseEvent) => void;
|
||||
}
|
||||
|
||||
export function FavoriteButton({
|
||||
isFavorite,
|
||||
onClick,
|
||||
className,
|
||||
}: FavoriteButtonProps) {
|
||||
export function FavoriteButton({ isFavorite, onClick }: FavoriteButtonProps) {
|
||||
return (
|
||||
<button
|
||||
onClick={onClick}
|
||||
className={cn(
|
||||
"rounded-full p-2 transition-all duration-200",
|
||||
"hover:scale-110",
|
||||
"rounded-full bg-white/90 p-2 backdrop-blur-sm transition-all duration-200",
|
||||
"hover:scale-110 hover:bg-white",
|
||||
"focus:outline-none focus:ring-2 focus:ring-red-500 focus:ring-offset-2",
|
||||
!isFavorite && "opacity-0 group-hover:opacity-100",
|
||||
className,
|
||||
)}
|
||||
aria-label={isFavorite ? "Remove from favorites" : "Add to favorites"}
|
||||
>
|
||||
|
||||
@@ -8,6 +8,7 @@ import { useGetV2GetUserProfile } from "@/app/api/__generated__/endpoints/store/
|
||||
import { LibraryAgent } from "@/app/api/__generated__/models/libraryAgent";
|
||||
import { okData } from "@/app/api/helpers";
|
||||
import { useToast } from "@/components/molecules/Toast/use-toast";
|
||||
import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag";
|
||||
import { updateFavoriteInQueries } from "./helpers";
|
||||
|
||||
interface Props {
|
||||
@@ -19,6 +20,7 @@ export function useLibraryAgentCard({ agent }: Props) {
|
||||
agent;
|
||||
|
||||
const isFromMarketplace = Boolean(marketplace_listing);
|
||||
const isAgentFavoritingEnabled = useGetFlag(Flag.AGENT_FAVORITING);
|
||||
const [isFavorite, setIsFavorite] = useState(is_favorite);
|
||||
const { toast } = useToast();
|
||||
const queryClient = getQueryClient();
|
||||
@@ -47,6 +49,8 @@ export function useLibraryAgentCard({ agent }: Props) {
|
||||
e.preventDefault();
|
||||
e.stopPropagation();
|
||||
|
||||
if (!isAgentFavoritingEnabled) return;
|
||||
|
||||
const newIsFavorite = !isFavorite;
|
||||
|
||||
setIsFavorite(newIsFavorite);
|
||||
@@ -76,6 +80,7 @@ export function useLibraryAgentCard({ agent }: Props) {
|
||||
|
||||
return {
|
||||
isFromMarketplace,
|
||||
isAgentFavoritingEnabled,
|
||||
isFavorite,
|
||||
profile,
|
||||
creator_image_url,
|
||||
|
||||
@@ -1,15 +1,13 @@
|
||||
"use client";
|
||||
|
||||
import {
|
||||
getPaginatedTotalCount,
|
||||
getPaginationNextPageNumber,
|
||||
unpaginate,
|
||||
} from "@/app/api/helpers";
|
||||
import { useGetV2ListFavoriteLibraryAgentsInfinite } from "@/app/api/__generated__/endpoints/library/library";
|
||||
import { getPaginationNextPageNumber, unpaginate } from "@/app/api/helpers";
|
||||
import { useMemo } from "react";
|
||||
import { filterAgents } from "../components/LibraryAgentList/helpers";
|
||||
|
||||
interface Props {
|
||||
searchTerm: string;
|
||||
}
|
||||
|
||||
export function useFavoriteAgents({ searchTerm }: Props) {
|
||||
export function useFavoriteAgents() {
|
||||
const {
|
||||
data: agentsQueryData,
|
||||
fetchNextPage,
|
||||
@@ -29,16 +27,10 @@ export function useFavoriteAgents({ searchTerm }: Props) {
|
||||
const allAgents = agentsQueryData
|
||||
? unpaginate(agentsQueryData, "agents")
|
||||
: [];
|
||||
|
||||
const filteredAgents = useMemo(
|
||||
() => filterAgents(allAgents, searchTerm),
|
||||
[allAgents, searchTerm],
|
||||
);
|
||||
|
||||
const agentCount = filteredAgents.length;
|
||||
const agentCount = getPaginatedTotalCount(agentsQueryData);
|
||||
|
||||
return {
|
||||
allAgents: filteredAgents,
|
||||
allAgents,
|
||||
agentLoading,
|
||||
hasNextPage,
|
||||
agentCount,
|
||||
|
||||
@@ -17,7 +17,7 @@ export default function LibraryPage() {
|
||||
return (
|
||||
<main className="pt-160 container min-h-screen space-y-4 pb-20 pt-16 sm:px-8 md:px-12">
|
||||
<LibraryActionHeader setSearchTerm={setSearchTerm} />
|
||||
<FavoritesSection searchTerm={searchTerm} />
|
||||
<FavoritesSection />
|
||||
<LibraryAgentList
|
||||
searchTerm={searchTerm}
|
||||
librarySort={librarySort}
|
||||
|
||||
@@ -12,7 +12,6 @@ import type { GetV2GetSpecificAgentParams } from "@/app/api/__generated__/models
|
||||
import { useAgentInfo } from "./useAgentInfo";
|
||||
import { useGetV2GetSpecificAgent } from "@/app/api/__generated__/endpoints/store/store";
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import { formatTimeAgo } from "@/lib/utils/time";
|
||||
import * as React from "react";
|
||||
|
||||
interface AgentInfoProps {
|
||||
@@ -259,29 +258,23 @@ export const AgentInfo = ({
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Version history */}
|
||||
{/* Changelog */}
|
||||
<div className="flex w-full flex-col gap-1.5 sm:gap-2">
|
||||
<div className="decoration-skip-ink-none text-base font-medium leading-6 text-neutral-800 dark:text-neutral-200">
|
||||
Version history
|
||||
<div className="decoration-skip-ink-none mb-1.5 text-base font-medium leading-6 text-neutral-800 dark:text-neutral-200 sm:mb-2">
|
||||
Changelog
|
||||
</div>
|
||||
<div className="decoration-skip-ink-none text-sm font-normal leading-6 text-neutral-600 underline-offset-[from-font] dark:text-neutral-400">
|
||||
Last updated {formatTimeAgo(lastUpdated)}
|
||||
</div>
|
||||
<div className="decoration-skip-ink-none text-xs text-neutral-600 dark:text-neutral-400 sm:text-sm">
|
||||
Version {version}.0
|
||||
<div className="decoration-skip-ink-none text-base font-normal leading-6 text-neutral-600 underline-offset-[from-font] dark:text-neutral-400">
|
||||
Last updated {lastUpdated}
|
||||
</div>
|
||||
|
||||
{/* Version List */}
|
||||
{agentVersions.length > 0 ? (
|
||||
<div className="mt-3">
|
||||
<div className="decoration-skip-ink-none mb-1.5 text-base font-medium leading-6 text-neutral-900 dark:text-neutral-200 sm:mb-2">
|
||||
Changelog
|
||||
</div>
|
||||
<div className="mt-4">
|
||||
{agentVersions.map(renderVersionItem)}
|
||||
{hasMoreVersions && (
|
||||
<button
|
||||
onClick={() => setVisibleVersionCount((prev) => prev + 3)}
|
||||
className="mt-2 flex items-center gap-1 text-sm font-medium text-neutral-700 hover:text-neutral-700 dark:text-neutral-100 dark:hover:text-neutral-300"
|
||||
className="mt-2 flex items-center gap-1 text-sm font-medium text-neutral-900 hover:text-neutral-700 dark:text-neutral-100 dark:hover:text-neutral-300"
|
||||
>
|
||||
<svg
|
||||
width="16"
|
||||
@@ -304,7 +297,7 @@ export const AgentInfo = ({
|
||||
</div>
|
||||
) : (
|
||||
<div className="text-xs text-neutral-600 dark:text-neutral-400 sm:text-sm">
|
||||
Version {version}.0
|
||||
Version {version}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
@@ -18,7 +18,6 @@ export interface AgentTableCardProps {
|
||||
runs: number;
|
||||
rating: number;
|
||||
id: number;
|
||||
listing_id?: string;
|
||||
onViewSubmission: (submission: StoreSubmission) => void;
|
||||
}
|
||||
|
||||
@@ -33,12 +32,10 @@ export const AgentTableCard = ({
|
||||
status,
|
||||
runs,
|
||||
rating,
|
||||
listing_id,
|
||||
onViewSubmission,
|
||||
}: AgentTableCardProps) => {
|
||||
const onView = () => {
|
||||
onViewSubmission({
|
||||
listing_id: listing_id || "",
|
||||
agent_id,
|
||||
agent_version,
|
||||
slug: "",
|
||||
@@ -65,14 +62,9 @@ export const AgentTableCard = ({
|
||||
/>
|
||||
</div>
|
||||
<div className="flex-1">
|
||||
<div className="flex items-center gap-2">
|
||||
<h3 className="text-[15px] font-medium text-neutral-800 dark:text-neutral-200">
|
||||
{agentName}
|
||||
</h3>
|
||||
<span className="text-[13px] text-neutral-500 dark:text-neutral-400">
|
||||
v{agent_version}
|
||||
</span>
|
||||
</div>
|
||||
<h3 className="text-[15px] font-medium text-neutral-800 dark:text-neutral-200">
|
||||
{agentName}
|
||||
</h3>
|
||||
<p className="line-clamp-2 text-sm text-neutral-600 dark:text-neutral-400">
|
||||
{description}
|
||||
</p>
|
||||
|
||||
@@ -9,11 +9,11 @@ import { useAgentTableRow } from "./useAgentTableRow";
|
||||
import { StoreSubmission } from "@/app/api/__generated__/models/storeSubmission";
|
||||
import {
|
||||
DotsThreeVerticalIcon,
|
||||
EyeIcon,
|
||||
Eye,
|
||||
ImageBroken,
|
||||
StarIcon,
|
||||
TrashIcon,
|
||||
PencilIcon,
|
||||
Star,
|
||||
Trash,
|
||||
PencilSimple,
|
||||
} from "@phosphor-icons/react/dist/ssr";
|
||||
import { SubmissionStatus } from "@/app/api/__generated__/models/submissionStatus";
|
||||
import { StoreSubmissionEditRequest } from "@/app/api/__generated__/models/storeSubmissionEditRequest";
|
||||
@@ -34,7 +34,6 @@ export interface AgentTableRowProps {
|
||||
categories?: string[];
|
||||
store_listing_version_id?: string;
|
||||
changes_summary?: string;
|
||||
listing_id?: string;
|
||||
onViewSubmission: (submission: StoreSubmission) => void;
|
||||
onDeleteSubmission: (submission_id: string) => void;
|
||||
onEditSubmission: (
|
||||
@@ -61,7 +60,6 @@ export const AgentTableRow = ({
|
||||
categories,
|
||||
store_listing_version_id,
|
||||
changes_summary,
|
||||
listing_id,
|
||||
onViewSubmission,
|
||||
onDeleteSubmission,
|
||||
onEditSubmission,
|
||||
@@ -85,10 +83,11 @@ export const AgentTableRow = ({
|
||||
categories,
|
||||
store_listing_version_id,
|
||||
changes_summary,
|
||||
listing_id,
|
||||
});
|
||||
|
||||
const canModify = status === SubmissionStatus.PENDING;
|
||||
// Determine if we should show Edit or View button
|
||||
const canEdit =
|
||||
status === SubmissionStatus.APPROVED || status === SubmissionStatus.PENDING;
|
||||
|
||||
return (
|
||||
<div
|
||||
@@ -115,22 +114,13 @@ export const AgentTableRow = ({
|
||||
</div>
|
||||
)}
|
||||
<div className="flex flex-col">
|
||||
<div className="flex items-center gap-2">
|
||||
<Text
|
||||
variant="h3"
|
||||
className="line-clamp-1 text-ellipsis text-neutral-800 dark:text-neutral-200"
|
||||
size="large-medium"
|
||||
>
|
||||
{agentName}
|
||||
</Text>
|
||||
<Text
|
||||
variant="body"
|
||||
size="small"
|
||||
className="text-neutral-500 dark:text-neutral-400"
|
||||
>
|
||||
v{agent_version}
|
||||
</Text>
|
||||
</div>
|
||||
<Text
|
||||
variant="h3"
|
||||
className="line-clamp-1 text-ellipsis text-neutral-800 dark:text-neutral-200"
|
||||
size="large-medium"
|
||||
>
|
||||
{agentName}
|
||||
</Text>
|
||||
<Text
|
||||
variant="body"
|
||||
className="line-clamp-1 text-ellipsis text-neutral-600 dark:text-neutral-400"
|
||||
@@ -160,7 +150,7 @@ export const AgentTableRow = ({
|
||||
{rating ? (
|
||||
<div className="flex items-center justify-end gap-1">
|
||||
<span className="text-sm font-medium">{rating.toFixed(1)}</span>
|
||||
<StarIcon weight="fill" className="h-2 w-2" />
|
||||
<Star weight="fill" className="h-2 w-2" />
|
||||
</div>
|
||||
) : (
|
||||
<span className="text-sm text-neutral-600 dark:text-neutral-400">
|
||||
@@ -176,12 +166,12 @@ export const AgentTableRow = ({
|
||||
<DotsThreeVerticalIcon className="h-5 w-5 text-neutral-800" />
|
||||
</DropdownMenu.Trigger>
|
||||
<DropdownMenu.Content className="z-10 rounded-xl border bg-white p-1 shadow-md dark:bg-gray-800">
|
||||
{canModify ? (
|
||||
{canEdit ? (
|
||||
<DropdownMenu.Item
|
||||
onSelect={handleEdit}
|
||||
className="flex cursor-pointer items-center rounded-md px-3 py-2 hover:bg-gray-100 dark:hover:bg-gray-700"
|
||||
>
|
||||
<PencilIcon className="mr-2 h-4 w-4 dark:text-gray-100" />
|
||||
<PencilSimple className="mr-2 h-4 w-4 dark:text-gray-100" />
|
||||
<span className="dark:text-gray-100">Edit</span>
|
||||
</DropdownMenu.Item>
|
||||
) : (
|
||||
@@ -189,22 +179,18 @@ export const AgentTableRow = ({
|
||||
onSelect={handleView}
|
||||
className="flex cursor-pointer items-center rounded-md px-3 py-2 hover:bg-gray-100 dark:hover:bg-gray-700"
|
||||
>
|
||||
<EyeIcon className="mr-2 h-4 w-4 dark:text-gray-100" />
|
||||
<Eye className="mr-2 h-4 w-4 dark:text-gray-100" />
|
||||
<span className="dark:text-gray-100">View</span>
|
||||
</DropdownMenu.Item>
|
||||
)}
|
||||
{canModify && (
|
||||
<>
|
||||
<DropdownMenu.Separator className="my-1 h-px bg-gray-300 dark:bg-gray-600" />
|
||||
<DropdownMenu.Item
|
||||
onSelect={handleDelete}
|
||||
className="flex cursor-pointer items-center rounded-md px-3 py-2 text-red-500 hover:bg-gray-100 dark:hover:bg-gray-700"
|
||||
>
|
||||
<TrashIcon className="mr-2 h-4 w-4 text-red-500 dark:text-red-400" />
|
||||
<span className="dark:text-red-400">Delete</span>
|
||||
</DropdownMenu.Item>
|
||||
</>
|
||||
)}
|
||||
<DropdownMenu.Separator className="my-1 h-px bg-gray-300 dark:bg-gray-600" />
|
||||
<DropdownMenu.Item
|
||||
onSelect={handleDelete}
|
||||
className="flex cursor-pointer items-center rounded-md px-3 py-2 text-red-500 hover:bg-gray-100 dark:hover:bg-gray-700"
|
||||
>
|
||||
<Trash className="mr-2 h-4 w-4 text-red-500 dark:text-red-400" />
|
||||
<span className="dark:text-red-400">Delete</span>
|
||||
</DropdownMenu.Item>
|
||||
</DropdownMenu.Content>
|
||||
</DropdownMenu.Root>
|
||||
</div>
|
||||
|
||||
@@ -26,7 +26,6 @@ interface useAgentTableRowProps {
|
||||
categories?: string[];
|
||||
store_listing_version_id?: string;
|
||||
changes_summary?: string;
|
||||
listing_id?: string;
|
||||
}
|
||||
|
||||
export const useAgentTableRow = ({
|
||||
@@ -47,11 +46,9 @@ export const useAgentTableRow = ({
|
||||
categories,
|
||||
store_listing_version_id,
|
||||
changes_summary,
|
||||
listing_id,
|
||||
}: useAgentTableRowProps) => {
|
||||
const handleView = () => {
|
||||
onViewSubmission({
|
||||
listing_id: listing_id || "",
|
||||
agent_id,
|
||||
agent_version,
|
||||
slug: "",
|
||||
@@ -84,14 +81,7 @@ export const useAgentTableRow = ({
|
||||
};
|
||||
|
||||
const handleDelete = () => {
|
||||
// Backend only accepts StoreListingVersion IDs for deletion
|
||||
if (!store_listing_version_id) {
|
||||
console.error(
|
||||
"Cannot delete submission: store_listing_version_id is required",
|
||||
);
|
||||
return;
|
||||
}
|
||||
onDeleteSubmission(store_listing_version_id);
|
||||
onDeleteSubmission(agent_id);
|
||||
};
|
||||
|
||||
return { handleView, handleDelete, handleEdit };
|
||||
|
||||
@@ -99,7 +99,6 @@ export const MainDashboardPage = () => {
|
||||
store_listing_version_id:
|
||||
submission.store_listing_version_id || undefined,
|
||||
changes_summary: submission.changes_summary || undefined,
|
||||
listing_id: submission.listing_id,
|
||||
}))}
|
||||
onViewSubmission={onViewSubmission}
|
||||
onDeleteSubmission={onDeleteSubmission}
|
||||
|
||||
@@ -7665,7 +7665,6 @@
|
||||
"id": { "type": "string", "title": "Id" },
|
||||
"graph_id": { "type": "string", "title": "Graph Id" },
|
||||
"graph_version": { "type": "integer", "title": "Graph Version" },
|
||||
"owner_user_id": { "type": "string", "title": "Owner User Id" },
|
||||
"image_url": {
|
||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||
"title": "Image Url"
|
||||
@@ -7748,7 +7747,6 @@
|
||||
"id",
|
||||
"graph_id",
|
||||
"graph_version",
|
||||
"owner_user_id",
|
||||
"image_url",
|
||||
"creator_name",
|
||||
"creator_image_url",
|
||||
@@ -9688,7 +9686,6 @@
|
||||
},
|
||||
"StoreSubmission": {
|
||||
"properties": {
|
||||
"listing_id": { "type": "string", "title": "Listing Id" },
|
||||
"agent_id": { "type": "string", "title": "Agent Id" },
|
||||
"agent_version": { "type": "integer", "title": "Agent Version" },
|
||||
"name": { "type": "string", "title": "Name" },
|
||||
@@ -9760,7 +9757,6 @@
|
||||
},
|
||||
"type": "object",
|
||||
"required": [
|
||||
"listing_id",
|
||||
"agent_id",
|
||||
"agent_version",
|
||||
"name",
|
||||
@@ -9823,18 +9819,8 @@
|
||||
},
|
||||
"StoreSubmissionRequest": {
|
||||
"properties": {
|
||||
"agent_id": {
|
||||
"type": "string",
|
||||
"minLength": 1,
|
||||
"title": "Agent Id",
|
||||
"description": "Agent ID cannot be empty"
|
||||
},
|
||||
"agent_version": {
|
||||
"type": "integer",
|
||||
"exclusiveMinimum": 0.0,
|
||||
"title": "Agent Version",
|
||||
"description": "Agent version must be greater than 0"
|
||||
},
|
||||
"agent_id": { "type": "string", "title": "Agent Id" },
|
||||
"agent_version": { "type": "integer", "title": "Agent Version" },
|
||||
"slug": { "type": "string", "title": "Slug" },
|
||||
"name": { "type": "string", "title": "Name" },
|
||||
"sub_heading": { "type": "string", "title": "Sub Heading" },
|
||||
|
||||
@@ -27,7 +27,6 @@ export function EditAgentModal({
|
||||
|
||||
return (
|
||||
<Dialog
|
||||
title="Edit Agent"
|
||||
styling={{
|
||||
maxWidth: "45rem",
|
||||
}}
|
||||
|
||||
@@ -8,6 +8,7 @@ import { Form, FormField } from "@/components/__legacy__/ui/form";
|
||||
import { StoreSubmission } from "@/app/api/__generated__/models/storeSubmission";
|
||||
import { ThumbnailImages } from "../../PublishAgentModal/components/AgentInfoStep/components/ThumbnailImages";
|
||||
import { StoreSubmissionEditRequest } from "@/app/api/__generated__/models/storeSubmissionEditRequest";
|
||||
import { StepHeader } from "../../PublishAgentModal/components/StepHeader";
|
||||
import { useEditAgentForm } from "./useEditAgentForm";
|
||||
|
||||
interface EditAgentFormProps {
|
||||
@@ -30,10 +31,12 @@ export function EditAgentForm({
|
||||
isSubmitting,
|
||||
handleFormSubmit,
|
||||
handleImagesChange,
|
||||
} = useEditAgentForm({ submission, onSuccess, onClose });
|
||||
} = useEditAgentForm({ submission, onSuccess });
|
||||
|
||||
return (
|
||||
<div className="mx-auto flex w-full flex-col rounded-3xl">
|
||||
<StepHeader title="Edit Agent" description="Update your agent details" />
|
||||
|
||||
<Form {...form}>
|
||||
<form
|
||||
onSubmit={form.handleSubmit(handleFormSubmit)}
|
||||
@@ -72,7 +75,7 @@ export function EditAgentForm({
|
||||
<ThumbnailImages
|
||||
agentId={submission.agent_id}
|
||||
onImagesChange={handleImagesChange}
|
||||
initialImages={Array.from(new Set(submission.image_urls || []))}
|
||||
initialImages={submission.image_urls || []}
|
||||
initialSelectedImage={submission.image_urls?.[0] || null}
|
||||
errorMessage={form.formState.errors.root?.message}
|
||||
/>
|
||||
@@ -133,7 +136,7 @@ export function EditAgentForm({
|
||||
<Input
|
||||
id={field.name}
|
||||
label="Changes Summary"
|
||||
type="textarea"
|
||||
type="text"
|
||||
placeholder="Briefly describe what you changed"
|
||||
error={form.formState.errors.changes_summary?.message}
|
||||
{...field}
|
||||
|
||||
@@ -19,13 +19,11 @@ interface useEditAgentFormProps {
|
||||
agent_id: string;
|
||||
};
|
||||
onSuccess: (submission: StoreSubmission) => void;
|
||||
onClose: () => void;
|
||||
}
|
||||
|
||||
export const useEditAgentForm = ({
|
||||
submission,
|
||||
onSuccess,
|
||||
onClose,
|
||||
}: useEditAgentFormProps) => {
|
||||
const editAgentSchema = z.object({
|
||||
title: z
|
||||
@@ -47,7 +45,7 @@ export const useEditAgentForm = ({
|
||||
changes_summary: z
|
||||
.string()
|
||||
.min(1, "Changes summary is required")
|
||||
.max(500, "Changes summary must be less than 500 characters"),
|
||||
.max(200, "Changes summary must be less than 200 characters"),
|
||||
agentOutputDemo: z
|
||||
.string()
|
||||
.refine(validateYouTubeUrl, "Please enter a valid YouTube URL"),
|
||||
@@ -56,11 +54,19 @@ export const useEditAgentForm = ({
|
||||
type EditAgentFormData = z.infer<typeof editAgentSchema>;
|
||||
|
||||
const [images, setImages] = React.useState<string[]>(
|
||||
Array.from(new Set(submission.image_urls || [])), // Remove duplicates
|
||||
submission.image_urls || [],
|
||||
);
|
||||
const [isSubmitting, setIsSubmitting] = React.useState(false);
|
||||
|
||||
const { mutateAsync: editSubmission } = usePutV2EditStoreSubmission();
|
||||
const { mutateAsync: editSubmission } = usePutV2EditStoreSubmission({
|
||||
mutation: {
|
||||
onSuccess: () => {
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: getGetV2ListMySubmissionsQueryKey(),
|
||||
});
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
const queryClient = useQueryClient();
|
||||
const { toast } = useToast();
|
||||
@@ -126,20 +132,7 @@ export const useEditAgentForm = ({
|
||||
|
||||
// Extract the StoreSubmission from the response
|
||||
if (response.status === 200 && response.data) {
|
||||
toast({
|
||||
title: "Agent Updated",
|
||||
description: "Your agent submission has been updated successfully.",
|
||||
duration: 3000,
|
||||
variant: "default",
|
||||
});
|
||||
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: getGetV2ListMySubmissionsQueryKey(),
|
||||
});
|
||||
|
||||
// Call onSuccess and explicitly close the modal
|
||||
onSuccess(response.data);
|
||||
onClose();
|
||||
} else {
|
||||
throw new Error("Failed to update submission");
|
||||
}
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
import { useEffect, useCallback, useState } from "react";
|
||||
import { useForm } from "react-hook-form";
|
||||
import { zodResolver } from "@hookform/resolvers/zod";
|
||||
import { useQueryClient } from "@tanstack/react-query";
|
||||
import { useToast } from "@/components/molecules/Toast/use-toast";
|
||||
import { useBackendAPI } from "@/lib/autogpt-server-api/context";
|
||||
import { getGetV2ListMySubmissionsQueryKey } from "@/app/api/__generated__/endpoints/store/store";
|
||||
import * as Sentry from "@sentry/nextjs";
|
||||
import {
|
||||
PublishAgentFormData,
|
||||
@@ -31,6 +33,7 @@ export function useAgentInfoStep({
|
||||
const [images, setImages] = useState<string[]>([]);
|
||||
const [isSubmitting, setIsSubmitting] = useState(false);
|
||||
|
||||
const queryClient = useQueryClient();
|
||||
const { toast } = useToast();
|
||||
const api = useBackendAPI();
|
||||
|
||||
@@ -51,26 +54,23 @@ export function useAgentInfoStep({
|
||||
});
|
||||
|
||||
useEffect(() => {
|
||||
if (initialData?.agent_id) {
|
||||
if (initialData) {
|
||||
setAgentId(initialData.agent_id);
|
||||
setImages(
|
||||
Array.from(
|
||||
new Set([
|
||||
...(initialData?.thumbnailSrc ? [initialData.thumbnailSrc] : []),
|
||||
...(initialData.additionalImages || []),
|
||||
]),
|
||||
),
|
||||
);
|
||||
const initialImages = [
|
||||
...(initialData?.thumbnailSrc ? [initialData.thumbnailSrc] : []),
|
||||
...(initialData.additionalImages || []),
|
||||
];
|
||||
setImages(initialImages);
|
||||
|
||||
// Update form with initial data
|
||||
form.reset({
|
||||
changesSummary: isMarketplaceUpdate
|
||||
? ""
|
||||
: initialData.changesSummary || "",
|
||||
changesSummary: initialData.changesSummary || "",
|
||||
title: initialData.title,
|
||||
subheader: initialData.subheader,
|
||||
slug: initialData.slug.toLocaleLowerCase().trim(),
|
||||
youtubeLink: initialData.youtubeLink,
|
||||
category: initialData.category,
|
||||
description: isMarketplaceUpdate ? "" : initialData.description,
|
||||
description: initialData.description,
|
||||
recommendedScheduleCron: initialData.recommendedScheduleCron || "",
|
||||
instructions: initialData.instructions || "",
|
||||
agentOutputDemo: initialData.agentOutputDemo || "",
|
||||
@@ -78,13 +78,6 @@ export function useAgentInfoStep({
|
||||
}
|
||||
}, [initialData, form]);
|
||||
|
||||
// Ensure agentId is set from selectedAgentId if initialData doesn't have it
|
||||
useEffect(() => {
|
||||
if (selectedAgentId && !agentId) {
|
||||
setAgentId(selectedAgentId);
|
||||
}
|
||||
}, [selectedAgentId, agentId]);
|
||||
|
||||
const handleImagesChange = useCallback((newImages: string[]) => {
|
||||
setImages(newImages);
|
||||
}, []);
|
||||
@@ -99,16 +92,6 @@ export function useAgentInfoStep({
|
||||
return;
|
||||
}
|
||||
|
||||
// Validate that an agent is selected before submission
|
||||
if (!selectedAgentId || !selectedAgentVersion) {
|
||||
toast({
|
||||
title: "Agent Selection Required",
|
||||
description: "Please select an agent before submitting to the store.",
|
||||
variant: "destructive",
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
const categories = data.category ? [data.category] : [];
|
||||
const filteredCategories = categories.filter(Boolean);
|
||||
|
||||
@@ -123,14 +106,18 @@ export function useAgentInfoStep({
|
||||
image_urls: images,
|
||||
video_url: data.youtubeLink || "",
|
||||
agent_output_demo_url: data.agentOutputDemo || "",
|
||||
agent_id: selectedAgentId,
|
||||
agent_version: selectedAgentVersion,
|
||||
agent_id: selectedAgentId || "",
|
||||
agent_version: selectedAgentVersion || 0,
|
||||
slug: (data.slug || "").replace(/\s+/g, "-"),
|
||||
categories: filteredCategories,
|
||||
recommended_schedule_cron: data.recommendedScheduleCron || null,
|
||||
changes_summary: data.changesSummary || null,
|
||||
} as any);
|
||||
|
||||
await queryClient.invalidateQueries({
|
||||
queryKey: getGetV2ListMySubmissionsQueryKey(),
|
||||
});
|
||||
|
||||
onSuccess(response);
|
||||
} catch (error) {
|
||||
Sentry.captureException(error);
|
||||
@@ -152,7 +139,12 @@ export function useAgentInfoStep({
|
||||
agentId,
|
||||
images,
|
||||
isSubmitting,
|
||||
initialImages: images,
|
||||
initialImages: initialData
|
||||
? [
|
||||
...(initialData?.thumbnailSrc ? [initialData.thumbnailSrc] : []),
|
||||
...(initialData.additionalImages || []),
|
||||
]
|
||||
: [],
|
||||
initialSelectedImage: initialData?.thumbnailSrc || null,
|
||||
handleImagesChange,
|
||||
handleSubmit: form.handleSubmit(handleFormSubmit),
|
||||
|
||||
@@ -6,11 +6,9 @@ import { emptyModalState } from "./helpers";
|
||||
import {
|
||||
useGetV2GetMyAgents,
|
||||
useGetV2ListMySubmissions,
|
||||
getGetV2ListMySubmissionsQueryKey,
|
||||
} from "@/app/api/__generated__/endpoints/store/store";
|
||||
import { okData } from "@/app/api/helpers";
|
||||
import type { MyAgent } from "@/app/api/__generated__/models/myAgent";
|
||||
import { useQueryClient } from "@tanstack/react-query";
|
||||
|
||||
const defaultTargetState: PublishState = {
|
||||
isOpen: false,
|
||||
@@ -67,7 +65,6 @@ export function usePublishAgentModal({
|
||||
>(preSelectedAgentVersion || null);
|
||||
|
||||
const router = useRouter();
|
||||
const queryClient = useQueryClient();
|
||||
|
||||
// Fetch agent data for pre-populating form when agent is pre-selected
|
||||
const { data: myAgents } = useGetV2GetMyAgents();
|
||||
@@ -80,18 +77,14 @@ export function usePublishAgentModal({
|
||||
}
|
||||
}, [targetState]);
|
||||
|
||||
// Reset internal state when modal opens (only on initial open, not on every targetState change)
|
||||
const [hasOpened, setHasOpened] = useState(false);
|
||||
// Reset internal state when modal opens
|
||||
useEffect(() => {
|
||||
if (!targetState) return;
|
||||
if (targetState.isOpen && !hasOpened) {
|
||||
if (targetState.isOpen) {
|
||||
setSelectedAgent(null);
|
||||
setSelectedAgentId(preSelectedAgentId || null);
|
||||
setSelectedAgentVersion(preSelectedAgentVersion || null);
|
||||
setInitialData(emptyModalState);
|
||||
setHasOpened(true);
|
||||
} else if (!targetState.isOpen && hasOpened) {
|
||||
setHasOpened(false);
|
||||
}
|
||||
}, [targetState, preSelectedAgentId, preSelectedAgentVersion]);
|
||||
|
||||
@@ -179,11 +172,6 @@ export function usePublishAgentModal({
|
||||
setSelectedAgentVersion(null);
|
||||
setInitialData(emptyModalState);
|
||||
|
||||
// Invalidate submissions query to refresh the data after modal closes
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: getGetV2ListMySubmissionsQueryKey(),
|
||||
});
|
||||
|
||||
// Update parent with clean closed state
|
||||
const newState = {
|
||||
isOpen: false,
|
||||
|
||||
@@ -1,116 +0,0 @@
|
||||
import type { Meta, StoryObj } from "@storybook/nextjs";
|
||||
import { TooltipProvider } from "@/components/atoms/Tooltip/BaseTooltip";
|
||||
import { Table } from "./Table";
|
||||
|
||||
const meta = {
|
||||
title: "Molecules/Table",
|
||||
component: Table,
|
||||
decorators: [
|
||||
(Story) => (
|
||||
<TooltipProvider>
|
||||
<Story />
|
||||
</TooltipProvider>
|
||||
),
|
||||
],
|
||||
parameters: {
|
||||
layout: "centered",
|
||||
},
|
||||
tags: ["autodocs"],
|
||||
argTypes: {
|
||||
allowAddRow: {
|
||||
control: "boolean",
|
||||
description: "Whether to show the Add row button",
|
||||
},
|
||||
allowDeleteRow: {
|
||||
control: "boolean",
|
||||
description: "Whether to show delete buttons for each row",
|
||||
},
|
||||
readOnly: {
|
||||
control: "boolean",
|
||||
description:
|
||||
"Whether the table is read-only (renders text instead of inputs)",
|
||||
},
|
||||
addRowLabel: {
|
||||
control: "text",
|
||||
description: "Label for the Add row button",
|
||||
},
|
||||
},
|
||||
} satisfies Meta<typeof Table>;
|
||||
|
||||
export default meta;
|
||||
type Story = StoryObj<typeof meta>;
|
||||
|
||||
export const Default: Story = {
|
||||
args: {
|
||||
columns: ["name", "email", "role"],
|
||||
allowAddRow: true,
|
||||
allowDeleteRow: true,
|
||||
},
|
||||
};
|
||||
|
||||
export const WithDefaultValues: Story = {
|
||||
args: {
|
||||
columns: ["name", "email", "role"],
|
||||
defaultValues: [
|
||||
{ name: "John Doe", email: "john@example.com", role: "Admin" },
|
||||
{ name: "Jane Smith", email: "jane@example.com", role: "User" },
|
||||
{ name: "Bob Wilson", email: "bob@example.com", role: "Editor" },
|
||||
],
|
||||
allowAddRow: true,
|
||||
allowDeleteRow: true,
|
||||
},
|
||||
};
|
||||
|
||||
export const ReadOnly: Story = {
|
||||
args: {
|
||||
columns: ["name", "email"],
|
||||
defaultValues: [
|
||||
{ name: "John Doe", email: "john@example.com" },
|
||||
{ name: "Jane Smith", email: "jane@example.com" },
|
||||
],
|
||||
readOnly: true,
|
||||
},
|
||||
};
|
||||
|
||||
export const NoAddOrDelete: Story = {
|
||||
args: {
|
||||
columns: ["name", "email"],
|
||||
defaultValues: [
|
||||
{ name: "John Doe", email: "john@example.com" },
|
||||
{ name: "Jane Smith", email: "jane@example.com" },
|
||||
],
|
||||
allowAddRow: false,
|
||||
allowDeleteRow: false,
|
||||
},
|
||||
};
|
||||
|
||||
export const SingleColumn: Story = {
|
||||
args: {
|
||||
columns: ["item"],
|
||||
allowAddRow: true,
|
||||
allowDeleteRow: true,
|
||||
addRowLabel: "Add item",
|
||||
},
|
||||
};
|
||||
|
||||
export const CustomAddLabel: Story = {
|
||||
args: {
|
||||
columns: ["key", "value"],
|
||||
allowAddRow: true,
|
||||
allowDeleteRow: true,
|
||||
addRowLabel: "Add new entry",
|
||||
},
|
||||
};
|
||||
|
||||
export const KeyValuePairs: Story = {
|
||||
args: {
|
||||
columns: ["key", "value"],
|
||||
defaultValues: [
|
||||
{ key: "API_KEY", value: "sk-..." },
|
||||
{ key: "DATABASE_URL", value: "postgres://..." },
|
||||
],
|
||||
allowAddRow: true,
|
||||
allowDeleteRow: true,
|
||||
addRowLabel: "Add variable",
|
||||
},
|
||||
};
|
||||
@@ -1,133 +0,0 @@
|
||||
import * as React from "react";
|
||||
import {
|
||||
Table as BaseTable,
|
||||
TableBody,
|
||||
TableCell,
|
||||
TableHead,
|
||||
TableHeader,
|
||||
TableRow,
|
||||
} from "@/components/__legacy__/ui/table";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { Input } from "@/components/atoms/Input/Input";
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import { Plus, Trash2 } from "lucide-react";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { useTable, RowData } from "./useTable";
|
||||
import { formatColumnTitle, formatPlaceholder } from "./helpers";
|
||||
|
||||
export interface TableProps {
|
||||
columns: string[];
|
||||
defaultValues?: RowData[];
|
||||
onChange?: (rows: RowData[]) => void;
|
||||
allowAddRow?: boolean;
|
||||
allowDeleteRow?: boolean;
|
||||
addRowLabel?: string;
|
||||
className?: string;
|
||||
readOnly?: boolean;
|
||||
}
|
||||
|
||||
export function Table({
|
||||
columns,
|
||||
defaultValues,
|
||||
onChange,
|
||||
allowAddRow = true,
|
||||
allowDeleteRow = true,
|
||||
addRowLabel = "Add row",
|
||||
className,
|
||||
readOnly = false,
|
||||
}: TableProps) {
|
||||
const { rows, handleAddRow, handleDeleteRow, handleCellChange } = useTable({
|
||||
columns,
|
||||
defaultValues,
|
||||
onChange,
|
||||
});
|
||||
|
||||
const showDeleteColumn = allowDeleteRow && !readOnly;
|
||||
const showAddButton = allowAddRow && !readOnly;
|
||||
|
||||
return (
|
||||
<div className={cn("flex flex-col gap-3", className)}>
|
||||
<div className="overflow-hidden rounded-xl border border-zinc-200 bg-white">
|
||||
<BaseTable>
|
||||
<TableHeader>
|
||||
<TableRow className="border-b border-zinc-100 bg-zinc-50/50">
|
||||
{columns.map((column) => (
|
||||
<TableHead
|
||||
key={column}
|
||||
className="h-10 px-3 text-sm font-medium text-zinc-600"
|
||||
>
|
||||
{formatColumnTitle(column)}
|
||||
</TableHead>
|
||||
))}
|
||||
{showDeleteColumn && <TableHead className="w-[50px]" />}
|
||||
</TableRow>
|
||||
</TableHeader>
|
||||
<TableBody>
|
||||
{rows.map((row, rowIndex) => (
|
||||
<TableRow key={rowIndex} className="border-none">
|
||||
{columns.map((column) => (
|
||||
<TableCell key={`${rowIndex}-${column}`} className="p-2">
|
||||
{readOnly ? (
|
||||
<Text
|
||||
variant="body"
|
||||
className="px-3 py-2 text-sm text-zinc-800"
|
||||
>
|
||||
{row[column] || "-"}
|
||||
</Text>
|
||||
) : (
|
||||
<Input
|
||||
id={`table-${rowIndex}-${column}`}
|
||||
label={formatColumnTitle(column)}
|
||||
hideLabel
|
||||
value={row[column] ?? ""}
|
||||
onChange={(e) =>
|
||||
handleCellChange(rowIndex, column, e.target.value)
|
||||
}
|
||||
placeholder={formatPlaceholder(column)}
|
||||
size="small"
|
||||
wrapperClassName="mb-0"
|
||||
/>
|
||||
)}
|
||||
</TableCell>
|
||||
))}
|
||||
{showDeleteColumn && (
|
||||
<TableCell className="p-2">
|
||||
<Button
|
||||
variant="icon"
|
||||
size="icon"
|
||||
onClick={() => handleDeleteRow(rowIndex)}
|
||||
aria-label="Delete row"
|
||||
className="text-zinc-400 transition-colors hover:text-red-500"
|
||||
>
|
||||
<Trash2 className="h-4 w-4" />
|
||||
</Button>
|
||||
</TableCell>
|
||||
)}
|
||||
</TableRow>
|
||||
))}
|
||||
{showAddButton && (
|
||||
<TableRow>
|
||||
<TableCell
|
||||
colSpan={columns.length + (showDeleteColumn ? 1 : 0)}
|
||||
className="p-2"
|
||||
>
|
||||
<Button
|
||||
variant="outline"
|
||||
size="small"
|
||||
onClick={handleAddRow}
|
||||
leftIcon={<Plus className="h-4 w-4" />}
|
||||
className="w-fit"
|
||||
>
|
||||
{addRowLabel}
|
||||
</Button>
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
)}
|
||||
</TableBody>
|
||||
</BaseTable>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
export { type RowData } from "./useTable";
|
||||
@@ -1,7 +0,0 @@
|
||||
export const formatColumnTitle = (key: string): string => {
|
||||
return key.charAt(0).toUpperCase() + key.slice(1);
|
||||
};
|
||||
|
||||
export const formatPlaceholder = (key: string): string => {
|
||||
return `Enter ${key.toLowerCase()}`;
|
||||
};
|
||||
@@ -1,81 +0,0 @@
|
||||
import { useState, useEffect } from "react";
|
||||
|
||||
export type RowData = Record<string, string>;
|
||||
|
||||
interface UseTableOptions {
|
||||
columns: string[];
|
||||
defaultValues?: RowData[];
|
||||
onChange?: (rows: RowData[]) => void;
|
||||
}
|
||||
|
||||
export function useTable({
|
||||
columns,
|
||||
defaultValues,
|
||||
onChange,
|
||||
}: UseTableOptions) {
|
||||
const createEmptyRow = (): RowData => {
|
||||
const emptyRow: RowData = {};
|
||||
columns.forEach((column) => {
|
||||
emptyRow[column] = "";
|
||||
});
|
||||
return emptyRow;
|
||||
};
|
||||
|
||||
const [rows, setRows] = useState<RowData[]>(() => {
|
||||
if (defaultValues && defaultValues.length > 0) {
|
||||
return defaultValues;
|
||||
}
|
||||
return [];
|
||||
});
|
||||
|
||||
useEffect(() => {
|
||||
if (defaultValues !== undefined) {
|
||||
setRows(defaultValues);
|
||||
}
|
||||
}, [defaultValues]);
|
||||
|
||||
const updateRows = (newRows: RowData[]) => {
|
||||
setRows(newRows);
|
||||
onChange?.(newRows);
|
||||
};
|
||||
|
||||
const handleAddRow = () => {
|
||||
const newRows = [...rows, createEmptyRow()];
|
||||
updateRows(newRows);
|
||||
};
|
||||
|
||||
const handleDeleteRow = (rowIndex: number) => {
|
||||
const newRows = rows.filter((_, index) => index !== rowIndex);
|
||||
updateRows(newRows);
|
||||
};
|
||||
|
||||
const handleCellChange = (
|
||||
rowIndex: number,
|
||||
columnKey: string,
|
||||
value: string,
|
||||
) => {
|
||||
const newRows = rows.map((row, index) => {
|
||||
if (index === rowIndex) {
|
||||
return {
|
||||
...row,
|
||||
[columnKey]: value,
|
||||
};
|
||||
}
|
||||
return row;
|
||||
});
|
||||
updateRows(newRows);
|
||||
};
|
||||
|
||||
const clearAll = () => {
|
||||
updateRows([]);
|
||||
};
|
||||
|
||||
return {
|
||||
rows,
|
||||
handleAddRow,
|
||||
handleDeleteRow,
|
||||
handleCellChange,
|
||||
clearAll,
|
||||
createEmptyRow,
|
||||
};
|
||||
}
|
||||
@@ -37,3 +37,11 @@ html body .toastDescription {
|
||||
font-size: 0.75rem !important;
|
||||
line-height: 1.25rem !important;
|
||||
}
|
||||
|
||||
/* Position close button on the right */
|
||||
/* stylelint-disable-next-line selector-pseudo-class-no-unknown */
|
||||
#root [data-sonner-toast] [data-close-button="true"] {
|
||||
left: unset !important;
|
||||
right: -18px !important;
|
||||
top: -3px !important;
|
||||
}
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"use client";
|
||||
|
||||
import { CheckCircle, Info, Warning, XCircle } from "@phosphor-icons/react";
|
||||
import { Toaster as SonnerToaster } from "sonner";
|
||||
import styles from "./styles.module.css";
|
||||
|
||||
@@ -22,10 +23,10 @@ export function Toaster() {
|
||||
}}
|
||||
className="custom__toast"
|
||||
icons={{
|
||||
success: null,
|
||||
error: null,
|
||||
warning: null,
|
||||
info: null,
|
||||
success: <CheckCircle className="h-4 w-4" color="#fff" weight="fill" />,
|
||||
error: <XCircle className="h-4 w-4" color="#fff" weight="fill" />,
|
||||
warning: <Warning className="h-4 w-4" color="#fff" weight="fill" />,
|
||||
info: <Info className="h-4 w-4" color="#fff" weight="fill" />,
|
||||
}}
|
||||
/>
|
||||
);
|
||||
|
||||
@@ -30,8 +30,6 @@ export const FormRenderer = ({
|
||||
return generateUiSchemaForCustomFields(preprocessedSchema, uiSchema);
|
||||
}, [preprocessedSchema, uiSchema]);
|
||||
|
||||
console.log("preprocessedSchema", preprocessedSchema);
|
||||
|
||||
return (
|
||||
<div className={"mb-6 mt-4"}>
|
||||
<Form
|
||||
|
||||
@@ -5,14 +5,19 @@ import { useAnyOfField } from "./useAnyOfField";
|
||||
import { getHandleId, updateUiOption } from "../../helpers";
|
||||
import { useEdgeStore } from "@/app/(platform)/build/stores/edgeStore";
|
||||
import { ANY_OF_FLAG } from "../../constants";
|
||||
import { findCustomFieldId } from "../../registry";
|
||||
|
||||
export const AnyOfField = (props: FieldProps) => {
|
||||
const { registry, schema } = props;
|
||||
const { fields } = registry;
|
||||
const { SchemaField: _SchemaField } = fields;
|
||||
const { nodeId } = registry.formContext;
|
||||
|
||||
const { isInputConnected } = useEdgeStore();
|
||||
|
||||
const uiOptions = getUiOptions(props.uiSchema, props.globalUiOptions);
|
||||
|
||||
const Widget = getWidget({ type: "string" }, "select", registry.widgets);
|
||||
|
||||
const {
|
||||
handleOptionChange,
|
||||
enumOptions,
|
||||
@@ -21,15 +26,6 @@ export const AnyOfField = (props: FieldProps) => {
|
||||
field_id,
|
||||
} = useAnyOfField(props);
|
||||
|
||||
const parentCustomFieldId = findCustomFieldId(schema);
|
||||
if (parentCustomFieldId) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const uiOptions = getUiOptions(props.uiSchema, props.globalUiOptions);
|
||||
|
||||
const Widget = getWidget({ type: "string" }, "select", registry.widgets);
|
||||
|
||||
const handleId = getHandleId({
|
||||
uiOptions,
|
||||
id: field_id + ANY_OF_FLAG,
|
||||
@@ -44,21 +40,12 @@ export const AnyOfField = (props: FieldProps) => {
|
||||
|
||||
const isHandleConnected = isInputConnected(nodeId, handleId);
|
||||
|
||||
// Now anyOf can render - custom fields if the option schema matches a custom field
|
||||
const optionCustomFieldId = optionSchema
|
||||
? findCustomFieldId(optionSchema)
|
||||
: null;
|
||||
|
||||
const optionUiSchema = optionCustomFieldId
|
||||
? { ...updatedUiSchema, "ui:field": optionCustomFieldId }
|
||||
: updatedUiSchema;
|
||||
|
||||
const optionsSchemaField =
|
||||
(optionSchema && optionSchema.type !== "null" && (
|
||||
<_SchemaField
|
||||
{...props}
|
||||
schema={optionSchema}
|
||||
uiSchema={optionUiSchema}
|
||||
uiSchema={updatedUiSchema}
|
||||
/>
|
||||
)) ||
|
||||
null;
|
||||
|
||||
@@ -17,7 +17,6 @@ interface InputExpanderModalProps {
|
||||
defaultValue: string;
|
||||
description?: string;
|
||||
placeholder?: string;
|
||||
inputType?: "text" | "json";
|
||||
}
|
||||
|
||||
export const InputExpanderModal: FC<InputExpanderModalProps> = ({
|
||||
@@ -28,7 +27,6 @@ export const InputExpanderModal: FC<InputExpanderModalProps> = ({
|
||||
defaultValue,
|
||||
description,
|
||||
placeholder,
|
||||
inputType = "text",
|
||||
}) => {
|
||||
const [tempValue, setTempValue] = useState(defaultValue);
|
||||
const [isCopied, setIsCopied] = useState(false);
|
||||
@@ -80,10 +78,7 @@ export const InputExpanderModal: FC<InputExpanderModalProps> = ({
|
||||
hideLabel
|
||||
id="input-expander-modal"
|
||||
value={tempValue}
|
||||
className={cn(
|
||||
"!min-h-[300px] rounded-2xlarge",
|
||||
inputType === "json" && "font-mono text-sm",
|
||||
)}
|
||||
className="!min-h-[300px] rounded-2xlarge"
|
||||
onChange={(e) => setTempValue(e.target.value)}
|
||||
placeholder={placeholder || "Enter text..."}
|
||||
autoFocus
|
||||
|
||||
@@ -8,34 +8,19 @@ import { CredentialsInput } from "@/app/(platform)/library/agents/[id]/component
|
||||
import { useNodeStore } from "@/app/(platform)/build/stores/nodeStore";
|
||||
import { useShallow } from "zustand/react/shallow";
|
||||
import { CredentialFieldTitle } from "./components/CredentialFieldTitle";
|
||||
import { Switch } from "@/components/atoms/Switch/Switch";
|
||||
|
||||
export const CredentialsField = (props: FieldProps) => {
|
||||
const { formData, onChange, schema, registry, fieldPathId, required } = props;
|
||||
const { formData, onChange, schema, registry, fieldPathId } = props;
|
||||
|
||||
const formContext = registry.formContext;
|
||||
const uiOptions = getUiOptions(props.uiSchema);
|
||||
const nodeId = formContext?.nodeId;
|
||||
|
||||
// Get sibling inputs (hardcoded values) and credentials optional state from the node store
|
||||
// Note: We select the node data directly instead of using getter functions to avoid
|
||||
// creating new object references that would cause infinite re-render loops with useShallow
|
||||
const { node, setCredentialsOptional } = useNodeStore(
|
||||
useShallow((state) => ({
|
||||
node: nodeId ? state.nodes.find((n) => n.id === nodeId) : undefined,
|
||||
setCredentialsOptional: state.setCredentialsOptional,
|
||||
})),
|
||||
// Get sibling inputs (hardcoded values) from the node store
|
||||
const hardcodedValues = useNodeStore(
|
||||
useShallow((state) => (nodeId ? state.getHardCodedValues(nodeId) : {})),
|
||||
);
|
||||
|
||||
const hardcodedValues = useMemo(
|
||||
() => node?.data?.hardcodedValues || {},
|
||||
[node?.data?.hardcodedValues],
|
||||
);
|
||||
const credentialsOptional = useMemo(() => {
|
||||
const value = node?.data?.metadata?.credentials_optional;
|
||||
return typeof value === "boolean" ? value : false;
|
||||
}, [node?.data?.metadata?.credentials_optional]);
|
||||
|
||||
const handleChange = (newValue: any) => {
|
||||
onChange(newValue, fieldPathId?.path);
|
||||
};
|
||||
@@ -67,10 +52,6 @@ export const CredentialsField = (props: FieldProps) => {
|
||||
[formData?.id, formData?.provider, formData?.title, formData?.type],
|
||||
);
|
||||
|
||||
// In builder canvas (nodeId exists): show star based on credentialsOptional toggle
|
||||
// In run dialogs (no nodeId): show star based on schema's required array
|
||||
const isRequired = nodeId ? !credentialsOptional : required;
|
||||
|
||||
return (
|
||||
<div className="flex flex-col gap-2">
|
||||
<CredentialFieldTitle
|
||||
@@ -78,7 +59,6 @@ export const CredentialsField = (props: FieldProps) => {
|
||||
registry={registry}
|
||||
uiOptions={uiOptions}
|
||||
schema={schema}
|
||||
required={isRequired}
|
||||
/>
|
||||
<CredentialsInput
|
||||
schema={schema as BlockIOCredentialsSubSchema}
|
||||
@@ -87,31 +67,7 @@ export const CredentialsField = (props: FieldProps) => {
|
||||
siblingInputs={hardcodedValues}
|
||||
showTitle={false}
|
||||
readOnly={formContext?.readOnly}
|
||||
isOptional={!isRequired}
|
||||
className="w-full"
|
||||
variant="node"
|
||||
/>
|
||||
|
||||
{/* Optional credentials toggle - only show in builder canvas, not run dialogs */}
|
||||
{nodeId &&
|
||||
!formContext?.readOnly &&
|
||||
formContext?.showOptionalToggle !== false && (
|
||||
<div className="mt-1 flex items-center gap-2">
|
||||
<Switch
|
||||
id={`credentials-optional-${nodeId}`}
|
||||
checked={credentialsOptional}
|
||||
onCheckedChange={(checked) =>
|
||||
setCredentialsOptional(nodeId, checked)
|
||||
}
|
||||
/>
|
||||
<label
|
||||
htmlFor={`credentials-optional-${nodeId}`}
|
||||
className="cursor-pointer text-xs text-gray-500"
|
||||
>
|
||||
Optional - skip block if not configured
|
||||
</label>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -18,9 +18,8 @@ export const CredentialFieldTitle = (props: {
|
||||
uiOptions: UiSchema;
|
||||
schema: RJSFSchema;
|
||||
fieldPathId: FieldPathId;
|
||||
required?: boolean;
|
||||
}) => {
|
||||
const { registry, uiOptions, schema, fieldPathId, required = false } = props;
|
||||
const { registry, uiOptions, schema, fieldPathId } = props;
|
||||
const { nodeId } = registry.formContext;
|
||||
|
||||
const TitleFieldTemplate = getTemplate(
|
||||
@@ -51,7 +50,7 @@ export const CredentialFieldTitle = (props: {
|
||||
<TitleFieldTemplate
|
||||
id={titleId(fieldPathId ?? "")}
|
||||
title={credentialProvider ?? ""}
|
||||
required={required}
|
||||
required={true}
|
||||
schema={schema}
|
||||
registry={registry}
|
||||
uiSchema={updatedUiSchema}
|
||||
|
||||
@@ -1,124 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { FieldProps, getTemplate, getUiOptions } from "@rjsf/utils";
|
||||
import { Input } from "@/components/atoms/Input/Input";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import {
|
||||
Tooltip,
|
||||
TooltipContent,
|
||||
TooltipTrigger,
|
||||
} from "@/components/atoms/Tooltip/BaseTooltip";
|
||||
import { ArrowsOutIcon } from "@phosphor-icons/react";
|
||||
import { InputExpanderModal } from "../../base/standard/widgets/TextInput/TextInputExpanderModal";
|
||||
import { getHandleId, updateUiOption } from "../../helpers";
|
||||
import { useJsonTextField } from "./useJsonTextField";
|
||||
import { getPlaceholder } from "./helpers";
|
||||
|
||||
export const JsonTextField = (props: FieldProps) => {
|
||||
const {
|
||||
formData,
|
||||
onChange,
|
||||
schema,
|
||||
registry,
|
||||
uiSchema,
|
||||
required,
|
||||
name,
|
||||
fieldPathId,
|
||||
} = props;
|
||||
|
||||
const uiOptions = getUiOptions(uiSchema);
|
||||
|
||||
const TitleFieldTemplate = getTemplate(
|
||||
"TitleFieldTemplate",
|
||||
registry,
|
||||
uiOptions,
|
||||
);
|
||||
|
||||
const fieldId = fieldPathId?.$id ?? props.id ?? "json-field";
|
||||
|
||||
const handleId = getHandleId({
|
||||
uiOptions,
|
||||
id: fieldId,
|
||||
schema: schema,
|
||||
});
|
||||
|
||||
const updatedUiSchema = updateUiOption(uiSchema, {
|
||||
handleId: handleId,
|
||||
});
|
||||
|
||||
const {
|
||||
textValue,
|
||||
isModalOpen,
|
||||
handleChange,
|
||||
handleModalOpen,
|
||||
handleModalClose,
|
||||
handleModalSave,
|
||||
} = useJsonTextField({
|
||||
formData,
|
||||
onChange,
|
||||
path: fieldPathId?.path,
|
||||
});
|
||||
|
||||
const placeholder = getPlaceholder(schema);
|
||||
const title = schema.title || name || "JSON Value";
|
||||
|
||||
return (
|
||||
<div className="flex flex-col gap-2">
|
||||
<TitleFieldTemplate
|
||||
id={fieldId}
|
||||
title={title}
|
||||
required={required}
|
||||
schema={schema}
|
||||
uiSchema={updatedUiSchema}
|
||||
registry={registry}
|
||||
/>
|
||||
<div className="nodrag relative flex items-center gap-2">
|
||||
<Input
|
||||
id={fieldId}
|
||||
hideLabel={true}
|
||||
type="textarea"
|
||||
label=""
|
||||
size="small"
|
||||
wrapperClassName="mb-0 flex-1 "
|
||||
value={textValue}
|
||||
onChange={handleChange}
|
||||
placeholder={placeholder}
|
||||
required={required}
|
||||
disabled={props.disabled}
|
||||
className="min-h-[60px] pr-8 font-mono text-xs"
|
||||
/>
|
||||
|
||||
<Tooltip delayDuration={0}>
|
||||
<TooltipTrigger asChild>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
onClick={handleModalOpen}
|
||||
type="button"
|
||||
className="p-1"
|
||||
>
|
||||
<ArrowsOutIcon className="size-4" />
|
||||
</Button>
|
||||
</TooltipTrigger>
|
||||
<TooltipContent>Expand input</TooltipContent>
|
||||
</Tooltip>
|
||||
</div>
|
||||
{schema.description && (
|
||||
<span className="text-xs text-gray-500">{schema.description}</span>
|
||||
)}
|
||||
|
||||
<InputExpanderModal
|
||||
isOpen={isModalOpen}
|
||||
onClose={handleModalClose}
|
||||
onSave={handleModalSave}
|
||||
title={`Edit ${title}`}
|
||||
description={schema.description || "Enter valid JSON"}
|
||||
defaultValue={textValue}
|
||||
placeholder={placeholder}
|
||||
inputType="json"
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export default JsonTextField;
|
||||
@@ -1,67 +0,0 @@
|
||||
import { RJSFSchema } from "@rjsf/utils";
|
||||
|
||||
/**
|
||||
* Converts form data to a JSON string for display
|
||||
* @param formData - The data to stringify
|
||||
* @returns JSON string or empty string if data is null/undefined
|
||||
*/
|
||||
export function stringifyFormData(formData: unknown): string {
|
||||
if (formData === undefined || formData === null) {
|
||||
return "";
|
||||
}
|
||||
try {
|
||||
return JSON.stringify(formData, null, 2);
|
||||
} catch {
|
||||
return "";
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Parses a JSON string into an object/array
|
||||
* @param value - The JSON string to parse
|
||||
* @returns Parsed value or undefined if parsing fails or empty
|
||||
*/
|
||||
export function parseJsonValue(value: string): unknown | undefined {
|
||||
const trimmed = value.trim();
|
||||
if (trimmed === "") {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
try {
|
||||
return JSON.parse(trimmed);
|
||||
} catch {
|
||||
return undefined;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the appropriate placeholder text based on schema type
|
||||
* @param schema - The JSON schema
|
||||
* @returns Placeholder string
|
||||
*/
|
||||
export function getPlaceholder(schema: RJSFSchema): string {
|
||||
if (schema.type === "array") {
|
||||
return '["item1", "item2"] or [{"key": "value"}]';
|
||||
}
|
||||
if (schema.type === "object") {
|
||||
return '{"key": "value"}';
|
||||
}
|
||||
return "Enter JSON value...";
|
||||
}
|
||||
|
||||
/**
|
||||
* Checks if a JSON string is valid
|
||||
* @param value - The JSON string to validate
|
||||
* @returns true if valid JSON, false otherwise
|
||||
*/
|
||||
export function isValidJson(value: string): boolean {
|
||||
if (value.trim() === "") {
|
||||
return true; // Empty is considered valid (will be undefined)
|
||||
}
|
||||
try {
|
||||
JSON.parse(value);
|
||||
return true;
|
||||
} catch {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@@ -1,107 +0,0 @@
|
||||
import { useState, useEffect, useCallback } from "react";
|
||||
import { FieldProps } from "@rjsf/utils";
|
||||
import { stringifyFormData, parseJsonValue, isValidJson } from "./helpers";
|
||||
|
||||
type FieldOnChange = FieldProps["onChange"];
|
||||
type FieldPathId = FieldProps["fieldPathId"];
|
||||
|
||||
interface UseJsonTextFieldOptions {
|
||||
formData: unknown;
|
||||
onChange: FieldOnChange;
|
||||
path?: FieldPathId["path"];
|
||||
}
|
||||
|
||||
interface UseJsonTextFieldReturn {
|
||||
textValue: string;
|
||||
isModalOpen: boolean;
|
||||
hasError: boolean;
|
||||
handleChange: (
|
||||
e: React.ChangeEvent<HTMLInputElement | HTMLTextAreaElement>,
|
||||
) => void;
|
||||
handleModalOpen: () => void;
|
||||
handleModalClose: () => void;
|
||||
handleModalSave: (value: string) => void;
|
||||
}
|
||||
|
||||
/**
|
||||
* Custom hook for managing JSON text field state and handlers
|
||||
*/
|
||||
export function useJsonTextField({
|
||||
formData,
|
||||
onChange,
|
||||
path,
|
||||
}: UseJsonTextFieldOptions): UseJsonTextFieldReturn {
|
||||
const [textValue, setTextValue] = useState(() => stringifyFormData(formData));
|
||||
const [isModalOpen, setIsModalOpen] = useState(false);
|
||||
const [hasError, setHasError] = useState(false);
|
||||
|
||||
// Update text value when formData changes externally
|
||||
useEffect(() => {
|
||||
const newValue = stringifyFormData(formData);
|
||||
setTextValue(newValue);
|
||||
setHasError(false);
|
||||
}, [formData]);
|
||||
|
||||
const handleChange = useCallback(
|
||||
(e: React.ChangeEvent<HTMLInputElement | HTMLTextAreaElement>) => {
|
||||
const value = e.target.value;
|
||||
setTextValue(value);
|
||||
|
||||
// Validate JSON and update error state
|
||||
const valid = isValidJson(value);
|
||||
setHasError(!valid);
|
||||
|
||||
// Try to parse and update formData
|
||||
if (value.trim() === "") {
|
||||
onChange(undefined, path ?? []);
|
||||
return;
|
||||
}
|
||||
|
||||
const parsed = parseJsonValue(value);
|
||||
if (parsed !== undefined) {
|
||||
onChange(parsed, path ?? []);
|
||||
}
|
||||
},
|
||||
[onChange, path],
|
||||
);
|
||||
|
||||
const handleModalOpen = useCallback(() => {
|
||||
setIsModalOpen(true);
|
||||
}, []);
|
||||
|
||||
const handleModalClose = useCallback(() => {
|
||||
setIsModalOpen(false);
|
||||
}, []);
|
||||
|
||||
const handleModalSave = useCallback(
|
||||
(value: string) => {
|
||||
setTextValue(value);
|
||||
setIsModalOpen(false);
|
||||
|
||||
// Validate and update
|
||||
const valid = isValidJson(value);
|
||||
setHasError(!valid);
|
||||
|
||||
if (value.trim() === "") {
|
||||
onChange(undefined, path ?? []);
|
||||
return;
|
||||
}
|
||||
|
||||
const parsed = parseJsonValue(value);
|
||||
if (parsed !== undefined) {
|
||||
onChange(parsed, path ?? []);
|
||||
}
|
||||
},
|
||||
[onChange, path],
|
||||
);
|
||||
|
||||
return {
|
||||
textValue,
|
||||
isModalOpen,
|
||||
hasError,
|
||||
handleChange,
|
||||
handleModalOpen,
|
||||
handleModalClose,
|
||||
handleModalSave,
|
||||
};
|
||||
}
|
||||
@@ -1,57 +0,0 @@
|
||||
import React from "react";
|
||||
import { FieldProps, getUiOptions } from "@rjsf/utils";
|
||||
import { BlockIOObjectSubSchema } from "@/lib/autogpt-server-api/types";
|
||||
import {
|
||||
MultiSelector,
|
||||
MultiSelectorContent,
|
||||
MultiSelectorInput,
|
||||
MultiSelectorItem,
|
||||
MultiSelectorList,
|
||||
MultiSelectorTrigger,
|
||||
} from "@/components/__legacy__/ui/multiselect";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { useMultiSelectField } from "./useMultiSelectField";
|
||||
|
||||
export const MultiSelectField = (props: FieldProps) => {
|
||||
const { schema, formData, onChange, fieldPathId } = props;
|
||||
const uiOptions = getUiOptions(props.uiSchema);
|
||||
|
||||
const { optionSchema, options, selection, createChangeHandler } =
|
||||
useMultiSelectField({
|
||||
schema: schema as BlockIOObjectSubSchema,
|
||||
formData,
|
||||
});
|
||||
|
||||
const handleValuesChange = createChangeHandler(onChange, fieldPathId);
|
||||
|
||||
const displayName = schema.title || "options";
|
||||
|
||||
return (
|
||||
<div className={cn("flex flex-col", uiOptions.className)}>
|
||||
<MultiSelector
|
||||
className="nodrag"
|
||||
values={selection}
|
||||
onValuesChange={handleValuesChange}
|
||||
>
|
||||
<MultiSelectorTrigger className="rounded-3xl border border-zinc-200 bg-white px-2 shadow-none">
|
||||
<MultiSelectorInput
|
||||
placeholder={
|
||||
(schema as any).placeholder ?? `Select ${displayName}...`
|
||||
}
|
||||
/>
|
||||
</MultiSelectorTrigger>
|
||||
<MultiSelectorContent className="nowheel">
|
||||
<MultiSelectorList>
|
||||
{options
|
||||
.map((key) => ({ ...optionSchema[key], key }))
|
||||
.map(({ key, title, description }) => (
|
||||
<MultiSelectorItem key={key} value={key} title={description}>
|
||||
{title ?? key}
|
||||
</MultiSelectorItem>
|
||||
))}
|
||||
</MultiSelectorList>
|
||||
</MultiSelectorContent>
|
||||
</MultiSelector>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
@@ -1 +0,0 @@
|
||||
export { MultiSelectField } from "./MultiSelectField";
|
||||
@@ -1,65 +0,0 @@
|
||||
import { FieldProps } from "@rjsf/utils";
|
||||
import { BlockIOObjectSubSchema } from "@/lib/autogpt-server-api/types";
|
||||
|
||||
type FormData = Record<string, boolean> | null | undefined;
|
||||
|
||||
interface UseMultiSelectFieldOptions {
|
||||
schema: BlockIOObjectSubSchema;
|
||||
formData: FormData;
|
||||
}
|
||||
|
||||
export function useMultiSelectField({
|
||||
schema,
|
||||
formData,
|
||||
}: UseMultiSelectFieldOptions) {
|
||||
const getOptionSchema = (): Record<string, BlockIOObjectSubSchema> => {
|
||||
if (schema.properties) {
|
||||
return schema.properties as Record<string, BlockIOObjectSubSchema>;
|
||||
}
|
||||
if (
|
||||
"anyOf" in schema &&
|
||||
Array.isArray(schema.anyOf) &&
|
||||
schema.anyOf.length > 0 &&
|
||||
"properties" in schema.anyOf[0]
|
||||
) {
|
||||
return (schema.anyOf[0] as BlockIOObjectSubSchema).properties as Record<
|
||||
string,
|
||||
BlockIOObjectSubSchema
|
||||
>;
|
||||
}
|
||||
return {};
|
||||
};
|
||||
|
||||
const optionSchema = getOptionSchema();
|
||||
const options = Object.keys(optionSchema);
|
||||
|
||||
const getSelection = (): string[] => {
|
||||
if (!formData || typeof formData !== "object") {
|
||||
return [];
|
||||
}
|
||||
return Object.entries(formData)
|
||||
.filter(([_, value]) => value === true)
|
||||
.map(([key]) => key);
|
||||
};
|
||||
|
||||
const selection = getSelection();
|
||||
|
||||
const createChangeHandler =
|
||||
(
|
||||
onChange: FieldProps["onChange"],
|
||||
fieldPathId: FieldProps["fieldPathId"],
|
||||
) =>
|
||||
(values: string[]) => {
|
||||
const newValue = Object.fromEntries(
|
||||
options.map((opt) => [opt, values.includes(opt)]),
|
||||
);
|
||||
onChange(newValue, fieldPathId?.path);
|
||||
};
|
||||
|
||||
return {
|
||||
optionSchema,
|
||||
options,
|
||||
selection,
|
||||
createChangeHandler,
|
||||
};
|
||||
}
|
||||
@@ -1,52 +0,0 @@
|
||||
import { descriptionId, FieldProps, getTemplate, titleId } from "@rjsf/utils";
|
||||
import { Table, RowData } from "@/components/molecules/Table/Table";
|
||||
import { useMemo } from "react";
|
||||
|
||||
export const TableField = (props: FieldProps) => {
|
||||
const { schema, formData, onChange, fieldPathId, registry, uiSchema } = props;
|
||||
|
||||
const itemSchema = schema.items as any;
|
||||
const properties = itemSchema?.properties || {};
|
||||
|
||||
const columns: string[] = useMemo(() => {
|
||||
return Object.keys(properties);
|
||||
}, [properties]);
|
||||
|
||||
const handleChange = (rows: RowData[]) => {
|
||||
onChange(rows, fieldPathId?.path.slice(0, -1));
|
||||
};
|
||||
|
||||
const TitleFieldTemplate = getTemplate("TitleFieldTemplate", registry);
|
||||
const DescriptionFieldTemplate = getTemplate(
|
||||
"DescriptionFieldTemplate",
|
||||
registry,
|
||||
);
|
||||
|
||||
return (
|
||||
<div className="flex flex-col gap-2">
|
||||
<TitleFieldTemplate
|
||||
id={titleId(fieldPathId)}
|
||||
title={schema.title || ""}
|
||||
required={true}
|
||||
schema={schema}
|
||||
uiSchema={uiSchema}
|
||||
registry={registry}
|
||||
/>
|
||||
<DescriptionFieldTemplate
|
||||
id={descriptionId(fieldPathId)}
|
||||
description={schema.description || ""}
|
||||
schema={schema}
|
||||
registry={registry}
|
||||
/>
|
||||
|
||||
<Table
|
||||
columns={columns}
|
||||
defaultValues={formData}
|
||||
onChange={handleChange}
|
||||
allowAddRow={true}
|
||||
allowDeleteRow={true}
|
||||
addRowLabel="Add row"
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
@@ -1,10 +1,6 @@
|
||||
import { FieldProps, RJSFSchema, RegistryFieldsType } from "@rjsf/utils";
|
||||
import { CredentialsField } from "./CredentialField/CredentialField";
|
||||
import { GoogleDrivePickerField } from "./GoogleDrivePickerField/GoogleDrivePickerField";
|
||||
import { JsonTextField } from "./JsonTextField/JsonTextField";
|
||||
import { MultiSelectField } from "./MultiSelectField/MultiSelectField";
|
||||
import { isMultiSelectSchema } from "../utils/schema-utils";
|
||||
import { TableField } from "./TableField/TableField";
|
||||
|
||||
export interface CustomFieldDefinition {
|
||||
id: string;
|
||||
@@ -12,9 +8,6 @@ export interface CustomFieldDefinition {
|
||||
component: (props: FieldProps<any, RJSFSchema, any>) => JSX.Element | null;
|
||||
}
|
||||
|
||||
/** Field ID for JsonTextField - used to render nested complex types as text input */
|
||||
export const JSON_TEXT_FIELD_ID = "custom/json_text_field";
|
||||
|
||||
export const CUSTOM_FIELDS: CustomFieldDefinition[] = [
|
||||
{
|
||||
id: "custom/credential_field",
|
||||
@@ -37,28 +30,6 @@ export const CUSTOM_FIELDS: CustomFieldDefinition[] = [
|
||||
},
|
||||
component: GoogleDrivePickerField,
|
||||
},
|
||||
{
|
||||
id: "custom/json_text_field",
|
||||
// Not matched by schema - assigned via uiSchema for nested complex types
|
||||
matcher: () => false,
|
||||
component: JsonTextField,
|
||||
},
|
||||
{
|
||||
id: "custom/multi_select_field",
|
||||
matcher: isMultiSelectSchema,
|
||||
component: MultiSelectField,
|
||||
},
|
||||
{
|
||||
id: "custom/table_field",
|
||||
matcher: (schema: any) => {
|
||||
return (
|
||||
schema.type === "array" &&
|
||||
"format" in schema &&
|
||||
schema.format === "table"
|
||||
);
|
||||
},
|
||||
component: TableField,
|
||||
},
|
||||
];
|
||||
|
||||
export function findCustomFieldId(schema: any): string | null {
|
||||
|
||||
@@ -6,7 +6,6 @@ export interface ExtendedFormContextType extends FormContextType {
|
||||
uiType?: BlockUIType;
|
||||
showHandles?: boolean;
|
||||
size?: "small" | "medium" | "large";
|
||||
showOptionalToggle?: boolean;
|
||||
}
|
||||
|
||||
export type PathSegment = {
|
||||
|
||||
@@ -1,46 +1,19 @@
|
||||
import { RJSFSchema, UiSchema } from "@rjsf/utils";
|
||||
import {
|
||||
findCustomFieldId,
|
||||
JSON_TEXT_FIELD_ID,
|
||||
} from "../custom/custom-registry";
|
||||
|
||||
function isComplexType(schema: RJSFSchema): boolean {
|
||||
return schema.type === "object" || schema.type === "array";
|
||||
}
|
||||
|
||||
function hasComplexAnyOfOptions(schema: RJSFSchema): boolean {
|
||||
const options = schema.anyOf || schema.oneOf;
|
||||
if (!Array.isArray(options)) return false;
|
||||
return options.some(
|
||||
(opt: any) =>
|
||||
opt &&
|
||||
typeof opt === "object" &&
|
||||
(opt.type === "object" || opt.type === "array"),
|
||||
);
|
||||
}
|
||||
import { findCustomFieldId } from "../custom/custom-registry";
|
||||
|
||||
/**
|
||||
* Generates uiSchema with ui:field settings for custom fields based on schema matchers.
|
||||
* This is the standard RJSF way to route fields to custom components.
|
||||
*
|
||||
* Nested complex types (arrays/objects inside arrays/objects) are rendered as JsonTextField
|
||||
* to avoid deeply nested form UIs. Users can enter raw JSON for these fields.
|
||||
*
|
||||
* @param schema - The JSON schema
|
||||
* @param existingUiSchema - Existing uiSchema to merge with
|
||||
* @param insideComplexType - Whether we're already inside a complex type (object/array)
|
||||
*/
|
||||
export function generateUiSchemaForCustomFields(
|
||||
schema: RJSFSchema,
|
||||
existingUiSchema: UiSchema = {},
|
||||
insideComplexType: boolean = false,
|
||||
): UiSchema {
|
||||
const uiSchema: UiSchema = { ...existingUiSchema };
|
||||
|
||||
if (schema.properties) {
|
||||
for (const [key, propSchema] of Object.entries(schema.properties)) {
|
||||
if (propSchema && typeof propSchema === "object") {
|
||||
// First check for custom field matchers (credentials, google drive, etc.)
|
||||
const customFieldId = findCustomFieldId(propSchema);
|
||||
|
||||
if (customFieldId) {
|
||||
@@ -48,33 +21,8 @@ export function generateUiSchemaForCustomFields(
|
||||
...(uiSchema[key] as object),
|
||||
"ui:field": customFieldId,
|
||||
};
|
||||
// Skip further processing for custom fields
|
||||
continue;
|
||||
}
|
||||
|
||||
// Handle nested complex types - render as JsonTextField
|
||||
if (insideComplexType && isComplexType(propSchema as RJSFSchema)) {
|
||||
uiSchema[key] = {
|
||||
...(uiSchema[key] as object),
|
||||
"ui:field": JSON_TEXT_FIELD_ID,
|
||||
};
|
||||
// Don't recurse further - this field is now a text input
|
||||
continue;
|
||||
}
|
||||
|
||||
// Handle anyOf/oneOf inside complex types
|
||||
if (
|
||||
insideComplexType &&
|
||||
hasComplexAnyOfOptions(propSchema as RJSFSchema)
|
||||
) {
|
||||
uiSchema[key] = {
|
||||
...(uiSchema[key] as object),
|
||||
"ui:field": JSON_TEXT_FIELD_ID,
|
||||
};
|
||||
continue;
|
||||
}
|
||||
|
||||
// Recurse into object properties
|
||||
if (
|
||||
propSchema.type === "object" &&
|
||||
propSchema.properties &&
|
||||
@@ -83,7 +31,6 @@ export function generateUiSchemaForCustomFields(
|
||||
const nestedUiSchema = generateUiSchemaForCustomFields(
|
||||
propSchema as RJSFSchema,
|
||||
(uiSchema[key] as UiSchema) || {},
|
||||
true, // Now inside a complex type
|
||||
);
|
||||
uiSchema[key] = {
|
||||
...(uiSchema[key] as object),
|
||||
@@ -91,11 +38,9 @@ export function generateUiSchemaForCustomFields(
|
||||
};
|
||||
}
|
||||
|
||||
// Handle arrays
|
||||
if (propSchema.type === "array" && propSchema.items) {
|
||||
const itemsSchema = propSchema.items as RJSFSchema;
|
||||
if (itemsSchema && typeof itemsSchema === "object") {
|
||||
// Check for custom field on array items
|
||||
const itemsCustomFieldId = findCustomFieldId(itemsSchema);
|
||||
if (itemsCustomFieldId) {
|
||||
uiSchema[key] = {
|
||||
@@ -104,28 +49,10 @@ export function generateUiSchemaForCustomFields(
|
||||
"ui:field": itemsCustomFieldId,
|
||||
},
|
||||
};
|
||||
} else if (isComplexType(itemsSchema)) {
|
||||
// Array items that are complex types become JsonTextField
|
||||
uiSchema[key] = {
|
||||
...(uiSchema[key] as object),
|
||||
items: {
|
||||
"ui:field": JSON_TEXT_FIELD_ID,
|
||||
},
|
||||
};
|
||||
} else if (hasComplexAnyOfOptions(itemsSchema)) {
|
||||
// Array items with anyOf containing complex types become JsonTextField
|
||||
uiSchema[key] = {
|
||||
...(uiSchema[key] as object),
|
||||
items: {
|
||||
"ui:field": JSON_TEXT_FIELD_ID,
|
||||
},
|
||||
};
|
||||
} else if (itemsSchema.properties) {
|
||||
// Recurse into object items (but they're now inside a complex type)
|
||||
const itemsUiSchema = generateUiSchemaForCustomFields(
|
||||
itemsSchema,
|
||||
((uiSchema[key] as UiSchema)?.items as UiSchema) || {},
|
||||
true, // Inside complex type (array)
|
||||
);
|
||||
if (Object.keys(itemsUiSchema).length > 0) {
|
||||
uiSchema[key] = {
|
||||
@@ -136,61 +63,6 @@ export function generateUiSchemaForCustomFields(
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Handle anyOf/oneOf at root level - process complex options
|
||||
if (!insideComplexType) {
|
||||
const anyOfOptions = propSchema.anyOf || propSchema.oneOf;
|
||||
|
||||
if (Array.isArray(anyOfOptions)) {
|
||||
for (let i = 0; i < anyOfOptions.length; i++) {
|
||||
const option = anyOfOptions[i] as RJSFSchema;
|
||||
if (option && typeof option === "object") {
|
||||
// Handle anyOf array options with complex items
|
||||
if (option.type === "array" && option.items) {
|
||||
const itemsSchema = option.items as RJSFSchema;
|
||||
if (itemsSchema && typeof itemsSchema === "object") {
|
||||
// Array items that are complex types become JsonTextField
|
||||
if (isComplexType(itemsSchema)) {
|
||||
uiSchema[key] = {
|
||||
...(uiSchema[key] as object),
|
||||
items: {
|
||||
"ui:field": JSON_TEXT_FIELD_ID,
|
||||
},
|
||||
};
|
||||
} else if (hasComplexAnyOfOptions(itemsSchema)) {
|
||||
uiSchema[key] = {
|
||||
...(uiSchema[key] as object),
|
||||
items: {
|
||||
"ui:field": JSON_TEXT_FIELD_ID,
|
||||
},
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Recurse into anyOf object options with properties
|
||||
if (
|
||||
option.type === "object" &&
|
||||
option.properties &&
|
||||
typeof option.properties === "object"
|
||||
) {
|
||||
const optionUiSchema = generateUiSchemaForCustomFields(
|
||||
option,
|
||||
{},
|
||||
true, // Inside complex type (anyOf object option)
|
||||
);
|
||||
if (Object.keys(optionUiSchema).length > 0) {
|
||||
// Store under the property key - RJSF will apply it
|
||||
uiSchema[key] = {
|
||||
...(uiSchema[key] as object),
|
||||
...optionUiSchema,
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,11 +1,7 @@
|
||||
import { getUiOptions, RJSFSchema, UiSchema } from "@rjsf/utils";
|
||||
|
||||
export function isAnyOfSchema(schema: RJSFSchema | undefined): boolean {
|
||||
return (
|
||||
Array.isArray(schema?.anyOf) &&
|
||||
schema!.anyOf.length > 0 &&
|
||||
schema?.enum === undefined
|
||||
);
|
||||
return Array.isArray(schema?.anyOf) && schema!.anyOf.length > 0;
|
||||
}
|
||||
|
||||
export const isAnyOfChild = (
|
||||
@@ -37,21 +33,3 @@ export function isOptionalType(schema: RJSFSchema | undefined): {
|
||||
export function isAnyOfSelector(name: string) {
|
||||
return name.includes("anyof_select");
|
||||
}
|
||||
|
||||
export function isMultiSelectSchema(schema: RJSFSchema | undefined): boolean {
|
||||
if (typeof schema !== "object" || schema === null) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if ("anyOf" in schema || "oneOf" in schema) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return !!(
|
||||
schema.type === "object" &&
|
||||
schema.properties &&
|
||||
Object.values(schema.properties).every(
|
||||
(prop: any) => prop.type === "boolean",
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user