Compare commits

..

5 Commits

Author SHA1 Message Date
Swifty
402bec4595 pr comments 2026-01-08 14:43:07 +01:00
Swifty
1c8cba9c5f fix more linting issues 2026-01-08 13:30:51 +01:00
Swifty
072c647baa fix addintal formatting issues 2026-01-08 13:24:30 +01:00
Swifty
d5f490b85d Merge branch 'dev' into swiftyos/fix-linting-errors 2026-01-08 13:07:05 +01:00
Swifty
6686de1701 fix linter errors 2026-01-08 13:04:02 +01:00
173 changed files with 1428 additions and 6190 deletions

View File

@@ -1,37 +0,0 @@
{
"worktreeCopyPatterns": [
".env*",
".vscode/**",
".auth/**",
".claude/**",
"autogpt_platform/.env*",
"autogpt_platform/backend/.env*",
"autogpt_platform/frontend/.env*",
"autogpt_platform/frontend/.auth/**",
"autogpt_platform/db/docker/.env*"
],
"worktreeCopyIgnores": [
"**/node_modules/**",
"**/dist/**",
"**/.git/**",
"**/Thumbs.db",
"**/.DS_Store",
"**/.next/**",
"**/__pycache__/**",
"**/.ruff_cache/**",
"**/.pytest_cache/**",
"**/*.pyc",
"**/playwright-report/**",
"**/logs/**",
"**/site/**"
],
"worktreePathTemplate": "$BASE_PATH.worktree",
"postCreateCmd": [
"cd autogpt_platform/autogpt_libs && poetry install",
"cd autogpt_platform/backend && poetry install && poetry run prisma generate",
"cd autogpt_platform/frontend && pnpm install",
"cd docs && pip install -r requirements.txt"
],
"terminalCommand": "code .",
"deleteBranchWithWorktree": false
}

View File

@@ -16,7 +16,6 @@
!autogpt_platform/backend/poetry.lock
!autogpt_platform/backend/README.md
!autogpt_platform/backend/.env
!autogpt_platform/backend/gen_prisma_types_stub.py
# Platform - Market
!autogpt_platform/market/market/

View File

@@ -74,7 +74,7 @@ jobs:
- name: Generate Prisma Client
working-directory: autogpt_platform/backend
run: poetry run prisma generate && poetry run gen-prisma-stub
run: poetry run prisma generate
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
- name: Set up Node.js

View File

@@ -90,7 +90,7 @@ jobs:
- name: Generate Prisma Client
working-directory: autogpt_platform/backend
run: poetry run prisma generate && poetry run gen-prisma-stub
run: poetry run prisma generate
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
- name: Set up Node.js

View File

@@ -72,7 +72,7 @@ jobs:
- name: Generate Prisma Client
working-directory: autogpt_platform/backend
run: poetry run prisma generate && poetry run gen-prisma-stub
run: poetry run prisma generate
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
- name: Set up Node.js
@@ -108,16 +108,6 @@ jobs:
# run: pnpm playwright install --with-deps chromium
# Docker setup for development environment
- name: Free up disk space
run: |
# Remove large unused tools to free disk space for Docker builds
sudo rm -rf /usr/share/dotnet
sudo rm -rf /usr/local/lib/android
sudo rm -rf /opt/ghc
sudo rm -rf /opt/hostedtoolcache/CodeQL
sudo docker system prune -af
df -h
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3

View File

@@ -134,7 +134,7 @@ jobs:
run: poetry install
- name: Generate Prisma Client
run: poetry run prisma generate && poetry run gen-prisma-stub
run: poetry run prisma generate
- id: supabase
name: Start Supabase

View File

@@ -12,7 +12,6 @@ reset-db:
rm -rf db/docker/volumes/db/data
cd backend && poetry run prisma migrate deploy
cd backend && poetry run prisma generate
cd backend && poetry run gen-prisma-stub
# View logs for core services
logs-core:
@@ -34,7 +33,6 @@ init-env:
migrate:
cd backend && poetry run prisma migrate deploy
cd backend && poetry run prisma generate
cd backend && poetry run gen-prisma-stub
run-backend:
cd backend && poetry run app

View File

@@ -48,8 +48,7 @@ RUN poetry install --no-ansi --no-root
# Generate Prisma client
COPY autogpt_platform/backend/schema.prisma ./
COPY autogpt_platform/backend/backend/data/partial_types.py ./backend/data/partial_types.py
COPY autogpt_platform/backend/gen_prisma_types_stub.py ./
RUN poetry run prisma generate && poetry run gen-prisma-stub
RUN poetry run prisma generate
FROM debian:13-slim AS server_dependencies

View File

@@ -3,6 +3,7 @@ from datetime import UTC, datetime
from os import getenv
import pytest
from prisma.types import ProfileCreateInput
from pydantic import SecretStr
from backend.api.features.chat.model import ChatSession
@@ -49,13 +50,13 @@ async def setup_test_data():
# 1b. Create a profile with username for the user (required for store agent lookup)
username = user.email.split("@")[0]
await prisma.profile.create(
data={
"userId": user.id,
"username": username,
"name": f"Test User {username}",
"description": "Test user profile",
"links": [], # Required field - empty array for test profiles
}
data=ProfileCreateInput(
userId=user.id,
username=username,
name=f"Test User {username}",
description="Test user profile",
links=[], # Required field - empty array for test profiles
)
)
# 2. Create a test graph with agent input -> agent output
@@ -172,13 +173,13 @@ async def setup_llm_test_data():
# 1b. Create a profile with username for the user (required for store agent lookup)
username = user.email.split("@")[0]
await prisma.profile.create(
data={
"userId": user.id,
"username": username,
"name": f"Test User {username}",
"description": "Test user profile for LLM tests",
"links": [], # Required field - empty array for test profiles
}
data=ProfileCreateInput(
userId=user.id,
username=username,
name=f"Test User {username}",
description="Test user profile for LLM tests",
links=[], # Required field - empty array for test profiles
)
)
# 2. Create test OpenAI credentials for the user
@@ -332,13 +333,13 @@ async def setup_firecrawl_test_data():
# 1b. Create a profile with username for the user (required for store agent lookup)
username = user.email.split("@")[0]
await prisma.profile.create(
data={
"userId": user.id,
"username": username,
"name": f"Test User {username}",
"description": "Test user profile for Firecrawl tests",
"links": [], # Required field - empty array for test profiles
}
data=ProfileCreateInput(
userId=user.id,
username=username,
name=f"Test User {username}",
description="Test user profile for Firecrawl tests",
links=[], # Required field - empty array for test profiles
)
)
# NOTE: We deliberately do NOT create Firecrawl credentials for this user

View File

@@ -489,7 +489,7 @@ async def update_agent_version_in_library(
agent_graph_version: int,
) -> library_model.LibraryAgent:
"""
Updates the agent version in the library for any agent owned by the user.
Updates the agent version in the library if useGraphIsActiveVersion is True.
Args:
user_id: Owner of the LibraryAgent.
@@ -498,31 +498,20 @@ async def update_agent_version_in_library(
Raises:
DatabaseError: If there's an error with the update.
NotFoundError: If no library agent is found for this user and agent.
"""
logger.debug(
f"Updating agent version in library for user #{user_id}, "
f"agent #{agent_graph_id} v{agent_graph_version}"
)
async with transaction() as tx:
library_agent = await prisma.models.LibraryAgent.prisma(tx).find_first_or_raise(
try:
library_agent = await prisma.models.LibraryAgent.prisma().find_first_or_raise(
where={
"userId": user_id,
"agentGraphId": agent_graph_id,
"useGraphIsActiveVersion": True,
},
)
# Delete any conflicting LibraryAgent for the target version
await prisma.models.LibraryAgent.prisma(tx).delete_many(
where={
"userId": user_id,
"agentGraphId": agent_graph_id,
"agentGraphVersion": agent_graph_version,
"id": {"not": library_agent.id},
}
)
lib = await prisma.models.LibraryAgent.prisma(tx).update(
lib = await prisma.models.LibraryAgent.prisma().update(
where={"id": library_agent.id},
data={
"AgentGraph": {
@@ -536,13 +525,13 @@ async def update_agent_version_in_library(
},
include={"AgentGraph": True},
)
if lib is None:
raise NotFoundError(f"Library agent {library_agent.id} not found")
if lib is None:
raise NotFoundError(
f"Failed to update library agent for {agent_graph_id} v{agent_graph_version}"
)
return library_model.LibraryAgent.from_db(lib)
return library_model.LibraryAgent.from_db(lib)
except prisma.errors.PrismaError as e:
logger.error(f"Database error updating agent version in library: {e}")
raise DatabaseError("Failed to update agent version in library") from e
async def update_library_agent(
@@ -828,19 +817,16 @@ async def add_store_agent_to_library(
# Create LibraryAgent entry
added_agent = await prisma.models.LibraryAgent.prisma().create(
data={
"User": {"connect": {"id": user_id}},
"AgentGraph": {
data=prisma.types.LibraryAgentCreateInput(
User={"connect": {"id": user_id}},
AgentGraph={
"connect": {
"graphVersionId": {"id": graph.id, "version": graph.version}
}
},
"isCreatedByUser": False,
"useGraphIsActiveVersion": False,
"settings": SafeJson(
_initialize_graph_settings(graph_model).model_dump()
),
},
isCreatedByUser=False,
settings=SafeJson(_initialize_graph_settings(graph_model).model_dump()),
),
include=library_agent_include(
user_id, include_nodes=False, include_executions=False
),

View File

@@ -48,7 +48,6 @@ class LibraryAgent(pydantic.BaseModel):
id: str
graph_id: str
graph_version: int
owner_user_id: str # ID of user who owns/created this agent graph
image_url: str | None
@@ -164,7 +163,6 @@ class LibraryAgent(pydantic.BaseModel):
id=agent.id,
graph_id=agent.agentGraphId,
graph_version=agent.agentGraphVersion,
owner_user_id=agent.userId,
image_url=agent.imageUrl,
creator_name=creator_name,
creator_image_url=creator_image_url,

View File

@@ -42,7 +42,6 @@ async def test_get_library_agents_success(
id="test-agent-1",
graph_id="test-agent-1",
graph_version=1,
owner_user_id=test_user_id,
name="Test Agent 1",
description="Test Description 1",
image_url=None,
@@ -65,7 +64,6 @@ async def test_get_library_agents_success(
id="test-agent-2",
graph_id="test-agent-2",
graph_version=1,
owner_user_id=test_user_id,
name="Test Agent 2",
description="Test Description 2",
image_url=None,
@@ -140,7 +138,6 @@ async def test_get_favorite_library_agents_success(
id="test-agent-1",
graph_id="test-agent-1",
graph_version=1,
owner_user_id=test_user_id,
name="Favorite Agent 1",
description="Test Favorite Description 1",
image_url=None,
@@ -208,7 +205,6 @@ def test_add_agent_to_library_success(
id="test-library-agent-id",
graph_id="test-agent-1",
graph_version=1,
owner_user_id=test_user_id,
name="Test Agent 1",
description="Test Description 1",
image_url=None,

View File

@@ -27,6 +27,13 @@ from prisma.models import OAuthApplication as PrismaOAuthApplication
from prisma.models import OAuthAuthorizationCode as PrismaOAuthAuthorizationCode
from prisma.models import OAuthRefreshToken as PrismaOAuthRefreshToken
from prisma.models import User as PrismaUser
from prisma.types import (
OAuthAccessTokenCreateInput,
OAuthApplicationCreateInput,
OAuthAuthorizationCodeCreateInput,
OAuthRefreshTokenCreateInput,
UserCreateInput,
)
from backend.api.rest_api import app
@@ -48,11 +55,11 @@ def test_user_id() -> str:
async def test_user(server, test_user_id: str):
"""Create a test user in the database."""
await PrismaUser.prisma().create(
data={
"id": test_user_id,
"email": f"oauth-test-{test_user_id}@example.com",
"name": "OAuth Test User",
}
data=UserCreateInput(
id=test_user_id,
email=f"oauth-test-{test_user_id}@example.com",
name="OAuth Test User",
)
)
yield test_user_id
@@ -77,22 +84,22 @@ async def test_oauth_app(test_user: str):
client_secret_hash, client_secret_salt = keysmith.hash_key(client_secret_plaintext)
await PrismaOAuthApplication.prisma().create(
data={
"id": app_id,
"name": "Test OAuth App",
"description": "Test application for integration tests",
"clientId": client_id,
"clientSecret": client_secret_hash,
"clientSecretSalt": client_secret_salt,
"redirectUris": [
data=OAuthApplicationCreateInput(
id=app_id,
name="Test OAuth App",
description="Test application for integration tests",
clientId=client_id,
clientSecret=client_secret_hash,
clientSecretSalt=client_secret_salt,
redirectUris=[
"https://example.com/callback",
"http://localhost:3000/callback",
],
"grantTypes": ["authorization_code", "refresh_token"],
"scopes": [APIKeyPermission.EXECUTE_GRAPH, APIKeyPermission.READ_GRAPH],
"ownerId": test_user,
"isActive": True,
}
grantTypes=["authorization_code", "refresh_token"],
scopes=[APIKeyPermission.EXECUTE_GRAPH, APIKeyPermission.READ_GRAPH],
ownerId=test_user,
isActive=True,
)
)
yield {
@@ -296,19 +303,19 @@ async def inactive_oauth_app(test_user: str):
client_secret_hash, client_secret_salt = keysmith.hash_key(client_secret_plaintext)
await PrismaOAuthApplication.prisma().create(
data={
"id": app_id,
"name": "Inactive OAuth App",
"description": "Inactive test application",
"clientId": client_id,
"clientSecret": client_secret_hash,
"clientSecretSalt": client_secret_salt,
"redirectUris": ["https://example.com/callback"],
"grantTypes": ["authorization_code", "refresh_token"],
"scopes": [APIKeyPermission.EXECUTE_GRAPH],
"ownerId": test_user,
"isActive": False, # Inactive!
}
data=OAuthApplicationCreateInput(
id=app_id,
name="Inactive OAuth App",
description="Inactive test application",
clientId=client_id,
clientSecret=client_secret_hash,
clientSecretSalt=client_secret_salt,
redirectUris=["https://example.com/callback"],
grantTypes=["authorization_code", "refresh_token"],
scopes=[APIKeyPermission.EXECUTE_GRAPH],
ownerId=test_user,
isActive=False, # Inactive!
)
)
yield {
@@ -699,14 +706,14 @@ async def test_token_authorization_code_expired(
now = datetime.now(timezone.utc)
await PrismaOAuthAuthorizationCode.prisma().create(
data={
"code": expired_code,
"applicationId": test_oauth_app["id"],
"userId": test_user,
"scopes": [APIKeyPermission.EXECUTE_GRAPH],
"redirectUri": test_oauth_app["redirect_uri"],
"expiresAt": now - timedelta(hours=1), # Already expired
}
data=OAuthAuthorizationCodeCreateInput(
code=expired_code,
applicationId=test_oauth_app["id"],
userId=test_user,
scopes=[APIKeyPermission.EXECUTE_GRAPH],
redirectUri=test_oauth_app["redirect_uri"],
expiresAt=now - timedelta(hours=1), # Already expired
)
)
response = await client.post(
@@ -942,13 +949,13 @@ async def test_token_refresh_expired(
now = datetime.now(timezone.utc)
await PrismaOAuthRefreshToken.prisma().create(
data={
"token": expired_token_hash,
"applicationId": test_oauth_app["id"],
"userId": test_user,
"scopes": [APIKeyPermission.EXECUTE_GRAPH],
"expiresAt": now - timedelta(days=1), # Already expired
}
data=OAuthRefreshTokenCreateInput(
token=expired_token_hash,
applicationId=test_oauth_app["id"],
userId=test_user,
scopes=[APIKeyPermission.EXECUTE_GRAPH],
expiresAt=now - timedelta(days=1), # Already expired
)
)
response = await client.post(
@@ -980,14 +987,14 @@ async def test_token_refresh_revoked(
now = datetime.now(timezone.utc)
await PrismaOAuthRefreshToken.prisma().create(
data={
"token": revoked_token_hash,
"applicationId": test_oauth_app["id"],
"userId": test_user,
"scopes": [APIKeyPermission.EXECUTE_GRAPH],
"expiresAt": now + timedelta(days=30), # Not expired
"revokedAt": now - timedelta(hours=1), # But revoked
}
data=OAuthRefreshTokenCreateInput(
token=revoked_token_hash,
applicationId=test_oauth_app["id"],
userId=test_user,
scopes=[APIKeyPermission.EXECUTE_GRAPH],
expiresAt=now + timedelta(days=30), # Not expired
revokedAt=now - timedelta(hours=1), # But revoked
)
)
response = await client.post(
@@ -1013,19 +1020,19 @@ async def other_oauth_app(test_user: str):
client_secret_hash, client_secret_salt = keysmith.hash_key(client_secret_plaintext)
await PrismaOAuthApplication.prisma().create(
data={
"id": app_id,
"name": "Other OAuth App",
"description": "Second test application",
"clientId": client_id,
"clientSecret": client_secret_hash,
"clientSecretSalt": client_secret_salt,
"redirectUris": ["https://other.example.com/callback"],
"grantTypes": ["authorization_code", "refresh_token"],
"scopes": [APIKeyPermission.EXECUTE_GRAPH],
"ownerId": test_user,
"isActive": True,
}
data=OAuthApplicationCreateInput(
id=app_id,
name="Other OAuth App",
description="Second test application",
clientId=client_id,
clientSecret=client_secret_hash,
clientSecretSalt=client_secret_salt,
redirectUris=["https://other.example.com/callback"],
grantTypes=["authorization_code", "refresh_token"],
scopes=[APIKeyPermission.EXECUTE_GRAPH],
ownerId=test_user,
isActive=True,
)
)
yield {
@@ -1052,13 +1059,13 @@ async def test_token_refresh_wrong_application(
now = datetime.now(timezone.utc)
await PrismaOAuthRefreshToken.prisma().create(
data={
"token": token_hash,
"applicationId": test_oauth_app["id"], # Belongs to test_oauth_app
"userId": test_user,
"scopes": [APIKeyPermission.EXECUTE_GRAPH],
"expiresAt": now + timedelta(days=30),
}
data=OAuthRefreshTokenCreateInput(
token=token_hash,
applicationId=test_oauth_app["id"], # Belongs to test_oauth_app
userId=test_user,
scopes=[APIKeyPermission.EXECUTE_GRAPH],
expiresAt=now + timedelta(days=30),
)
)
# Try to use it with `other_oauth_app`
@@ -1267,19 +1274,19 @@ async def test_validate_access_token_fails_when_app_disabled(
client_secret_hash, client_secret_salt = keysmith.hash_key(client_secret_plaintext)
await PrismaOAuthApplication.prisma().create(
data={
"id": app_id,
"name": "App To Be Disabled",
"description": "Test app for disabled validation",
"clientId": client_id,
"clientSecret": client_secret_hash,
"clientSecretSalt": client_secret_salt,
"redirectUris": ["https://example.com/callback"],
"grantTypes": ["authorization_code"],
"scopes": [APIKeyPermission.EXECUTE_GRAPH],
"ownerId": test_user,
"isActive": True,
}
data=OAuthApplicationCreateInput(
id=app_id,
name="App To Be Disabled",
description="Test app for disabled validation",
clientId=client_id,
clientSecret=client_secret_hash,
clientSecretSalt=client_secret_salt,
redirectUris=["https://example.com/callback"],
grantTypes=["authorization_code"],
scopes=[APIKeyPermission.EXECUTE_GRAPH],
ownerId=test_user,
isActive=True,
)
)
# Create an access token directly in the database
@@ -1288,13 +1295,13 @@ async def test_validate_access_token_fails_when_app_disabled(
now = datetime.now(timezone.utc)
await PrismaOAuthAccessToken.prisma().create(
data={
"token": token_hash,
"applicationId": app_id,
"userId": test_user,
"scopes": [APIKeyPermission.EXECUTE_GRAPH],
"expiresAt": now + timedelta(hours=1),
}
data=OAuthAccessTokenCreateInput(
token=token_hash,
applicationId=app_id,
userId=test_user,
scopes=[APIKeyPermission.EXECUTE_GRAPH],
expiresAt=now + timedelta(hours=1),
)
)
# Token should be valid while app is active
@@ -1561,19 +1568,19 @@ async def test_revoke_token_from_different_app_fails_silently(
)
await PrismaOAuthApplication.prisma().create(
data={
"id": app2_id,
"name": "Second Test OAuth App",
"description": "Second test application for cross-app revocation test",
"clientId": app2_client_id,
"clientSecret": app2_client_secret_hash,
"clientSecretSalt": app2_client_secret_salt,
"redirectUris": ["https://other-app.com/callback"],
"grantTypes": ["authorization_code", "refresh_token"],
"scopes": [APIKeyPermission.EXECUTE_GRAPH, APIKeyPermission.READ_GRAPH],
"ownerId": test_user,
"isActive": True,
}
data=OAuthApplicationCreateInput(
id=app2_id,
name="Second Test OAuth App",
description="Second test application for cross-app revocation test",
clientId=app2_client_id,
clientSecret=app2_client_secret_hash,
clientSecretSalt=app2_client_secret_salt,
redirectUris=["https://other-app.com/callback"],
grantTypes=["authorization_code", "refresh_token"],
scopes=[APIKeyPermission.EXECUTE_GRAPH, APIKeyPermission.READ_GRAPH],
ownerId=test_user,
isActive=True,
)
)
# App 2 tries to revoke App 1's access token

View File

@@ -249,7 +249,9 @@ async def log_search_term(search_query: str):
date = datetime.now(timezone.utc).replace(hour=0, minute=0, second=0, microsecond=0)
try:
await prisma.models.SearchTerms.prisma().create(
data={"searchTerm": search_query, "createdDate": date}
data=prisma.types.SearchTermsCreateInput(
searchTerm=search_query, createdDate=date
)
)
except Exception as e:
# Fail silently here so that logging search terms doesn't break the app
@@ -614,7 +616,6 @@ async def get_store_submissions(
submission_models = []
for sub in submissions:
submission_model = store_model.StoreSubmission(
listing_id=sub.listing_id,
agent_id=sub.agent_id,
agent_version=sub.agent_version,
name=sub.name,
@@ -668,48 +669,35 @@ async def delete_store_submission(
submission_id: str,
) -> bool:
"""
Delete a store submission version as the submitting user.
Delete a store listing submission as the submitting user.
Args:
user_id: ID of the authenticated user
submission_id: StoreListingVersion ID to delete
submission_id: ID of the submission to be deleted
Returns:
bool: True if successfully deleted
bool: True if the submission was successfully deleted, False otherwise
"""
logger.debug(f"Deleting store submission {submission_id} for user {user_id}")
try:
# Find the submission version with ownership check
version = await prisma.models.StoreListingVersion.prisma().find_first(
where={"id": submission_id}, include={"StoreListing": True}
# Verify the submission belongs to this user
submission = await prisma.models.StoreListing.prisma().find_first(
where={"agentGraphId": submission_id, "owningUserId": user_id}
)
if (
not version
or not version.StoreListing
or version.StoreListing.owningUserId != user_id
):
raise store_exceptions.SubmissionNotFoundError("Submission not found")
# Prevent deletion of approved submissions
if version.submissionStatus == prisma.enums.SubmissionStatus.APPROVED:
raise store_exceptions.InvalidOperationError(
"Cannot delete approved submissions"
if not submission:
logger.warning(f"Submission not found for user {user_id}: {submission_id}")
raise store_exceptions.SubmissionNotFoundError(
f"Submission not found for this user. User ID: {user_id}, Submission ID: {submission_id}"
)
# Delete the version
await prisma.models.StoreListingVersion.prisma().delete(
where={"id": version.id}
)
# Delete the submission
await prisma.models.StoreListing.prisma().delete(where={"id": submission.id})
# Clean up empty listing if this was the last version
remaining = await prisma.models.StoreListingVersion.prisma().count(
where={"storeListingId": version.storeListingId}
logger.debug(
f"Successfully deleted submission {submission_id} for user {user_id}"
)
if remaining == 0:
await prisma.models.StoreListing.prisma().delete(
where={"id": version.storeListingId}
)
return True
except Exception as e:
@@ -773,15 +761,9 @@ async def create_store_submission(
logger.warning(
f"Agent not found for user {user_id}: {agent_id} v{agent_version}"
)
# Provide more user-friendly error message when agent_id is empty
if not agent_id or agent_id.strip() == "":
raise store_exceptions.AgentNotFoundError(
"No agent selected. Please select an agent before submitting to the store."
)
else:
raise store_exceptions.AgentNotFoundError(
f"Agent not found for this user. User ID: {user_id}, Agent ID: {agent_id}, Version: {agent_version}"
)
raise store_exceptions.AgentNotFoundError(
f"Agent not found for this user. User ID: {user_id}, Agent ID: {agent_id}, Version: {agent_version}"
)
# Check if listing already exists for this agent
existing_listing = await prisma.models.StoreListing.prisma().find_first(
@@ -853,7 +835,6 @@ async def create_store_submission(
logger.debug(f"Created store listing for agent {agent_id}")
# Return submission details
return store_model.StoreSubmission(
listing_id=listing.id,
agent_id=agent_id,
agent_version=agent_version,
name=name,
@@ -965,56 +946,81 @@ async def edit_store_submission(
# Currently we are not allowing user to update the agent associated with a submission
# If we allow it in future, then we need a check here to verify the agent belongs to this user.
# Only allow editing of PENDING submissions
if current_version.submissionStatus != prisma.enums.SubmissionStatus.PENDING:
# Check if we can edit this submission
if current_version.submissionStatus == prisma.enums.SubmissionStatus.REJECTED:
raise store_exceptions.InvalidOperationError(
f"Cannot edit a {current_version.submissionStatus.value.lower()} submission. Only pending submissions can be edited."
"Cannot edit a rejected submission"
)
# For APPROVED submissions, we need to create a new version
if current_version.submissionStatus == prisma.enums.SubmissionStatus.APPROVED:
# Create a new version for the existing listing
return await create_store_version(
user_id=user_id,
agent_id=current_version.agentGraphId,
agent_version=current_version.agentGraphVersion,
store_listing_id=current_version.storeListingId,
name=name,
video_url=video_url,
agent_output_demo_url=agent_output_demo_url,
image_urls=image_urls,
description=description,
sub_heading=sub_heading,
categories=categories,
changes_summary=changes_summary,
recommended_schedule_cron=recommended_schedule_cron,
instructions=instructions,
)
# For PENDING submissions, we can update the existing version
# Update the existing version
updated_version = await prisma.models.StoreListingVersion.prisma().update(
where={"id": store_listing_version_id},
data=prisma.types.StoreListingVersionUpdateInput(
elif current_version.submissionStatus == prisma.enums.SubmissionStatus.PENDING:
# Update the existing version
updated_version = await prisma.models.StoreListingVersion.prisma().update(
where={"id": store_listing_version_id},
data=prisma.types.StoreListingVersionUpdateInput(
name=name,
videoUrl=video_url,
agentOutputDemoUrl=agent_output_demo_url,
imageUrls=image_urls,
description=description,
categories=categories,
subHeading=sub_heading,
changesSummary=changes_summary,
recommendedScheduleCron=recommended_schedule_cron,
instructions=instructions,
),
)
logger.debug(
f"Updated existing version {store_listing_version_id} for agent {current_version.agentGraphId}"
)
if not updated_version:
raise DatabaseError("Failed to update store listing version")
return store_model.StoreSubmission(
agent_id=current_version.agentGraphId,
agent_version=current_version.agentGraphVersion,
name=name,
videoUrl=video_url,
agentOutputDemoUrl=agent_output_demo_url,
imageUrls=image_urls,
sub_heading=sub_heading,
slug=current_version.StoreListing.slug,
description=description,
categories=categories,
subHeading=sub_heading,
changesSummary=changes_summary,
recommendedScheduleCron=recommended_schedule_cron,
instructions=instructions,
),
)
image_urls=image_urls,
date_submitted=updated_version.submittedAt or updated_version.createdAt,
status=updated_version.submissionStatus,
runs=0,
rating=0.0,
store_listing_version_id=updated_version.id,
changes_summary=changes_summary,
video_url=video_url,
categories=categories,
version=updated_version.version,
)
logger.debug(
f"Updated existing version {store_listing_version_id} for agent {current_version.agentGraphId}"
)
if not updated_version:
raise DatabaseError("Failed to update store listing version")
return store_model.StoreSubmission(
listing_id=current_version.StoreListing.id,
agent_id=current_version.agentGraphId,
agent_version=current_version.agentGraphVersion,
name=name,
sub_heading=sub_heading,
slug=current_version.StoreListing.slug,
description=description,
instructions=instructions,
image_urls=image_urls,
date_submitted=updated_version.submittedAt or updated_version.createdAt,
status=updated_version.submissionStatus,
runs=0,
rating=0.0,
store_listing_version_id=updated_version.id,
changes_summary=changes_summary,
video_url=video_url,
categories=categories,
version=updated_version.version,
)
else:
raise store_exceptions.InvalidOperationError(
f"Cannot edit submission with status: {current_version.submissionStatus}"
)
except (
store_exceptions.SubmissionNotFoundError,
@@ -1093,78 +1099,38 @@ async def create_store_version(
f"Agent not found for this user. User ID: {user_id}, Agent ID: {agent_id}, Version: {agent_version}"
)
# Check if there's already a PENDING submission for this agent (any version)
existing_pending_submission = (
await prisma.models.StoreListingVersion.prisma().find_first(
where=prisma.types.StoreListingVersionWhereInput(
storeListingId=store_listing_id,
agentGraphId=agent_id,
submissionStatus=prisma.enums.SubmissionStatus.PENDING,
isDeleted=False,
)
# Get the latest version number
latest_version = listing.Versions[0] if listing.Versions else None
next_version = (latest_version.version + 1) if latest_version else 1
# Create a new version for the existing listing
new_version = await prisma.models.StoreListingVersion.prisma().create(
data=prisma.types.StoreListingVersionCreateInput(
version=next_version,
agentGraphId=agent_id,
agentGraphVersion=agent_version,
name=name,
videoUrl=video_url,
agentOutputDemoUrl=agent_output_demo_url,
imageUrls=image_urls,
description=description,
instructions=instructions,
categories=categories,
subHeading=sub_heading,
submissionStatus=prisma.enums.SubmissionStatus.PENDING,
submittedAt=datetime.now(),
changesSummary=changes_summary,
recommendedScheduleCron=recommended_schedule_cron,
storeListingId=store_listing_id,
)
)
# Handle existing pending submission and create new one atomically
async with transaction() as tx:
# Get the latest version number first
latest_listing = await prisma.models.StoreListing.prisma(tx).find_first(
where=prisma.types.StoreListingWhereInput(
id=store_listing_id, owningUserId=user_id
),
include={"Versions": {"order_by": {"version": "desc"}, "take": 1}},
)
if not latest_listing:
raise store_exceptions.ListingNotFoundError(
f"Store listing not found. User ID: {user_id}, Listing ID: {store_listing_id}"
)
latest_version = (
latest_listing.Versions[0] if latest_listing.Versions else None
)
next_version = (latest_version.version + 1) if latest_version else 1
# If there's an existing pending submission, delete it atomically before creating new one
if existing_pending_submission:
logger.info(
f"Found existing PENDING submission for agent {agent_id} (was v{existing_pending_submission.agentGraphVersion}, now v{agent_version}), replacing existing submission instead of creating duplicate"
)
await prisma.models.StoreListingVersion.prisma(tx).delete(
where={"id": existing_pending_submission.id}
)
logger.debug(
f"Deleted existing pending submission {existing_pending_submission.id}"
)
# Create a new version for the existing listing
new_version = await prisma.models.StoreListingVersion.prisma(tx).create(
data=prisma.types.StoreListingVersionCreateInput(
version=next_version,
agentGraphId=agent_id,
agentGraphVersion=agent_version,
name=name,
videoUrl=video_url,
agentOutputDemoUrl=agent_output_demo_url,
imageUrls=image_urls,
description=description,
instructions=instructions,
categories=categories,
subHeading=sub_heading,
submissionStatus=prisma.enums.SubmissionStatus.PENDING,
submittedAt=datetime.now(),
changesSummary=changes_summary,
recommendedScheduleCron=recommended_schedule_cron,
storeListingId=store_listing_id,
)
)
logger.debug(
f"Created new version for listing {store_listing_id} of agent {agent_id}"
)
# Return submission details
return store_model.StoreSubmission(
listing_id=listing.id,
agent_id=agent_id,
agent_version=agent_version,
name=name,
@@ -1492,11 +1458,9 @@ async def _approve_sub_agent(
# Create new version if no matching version found
next_version = max((v.version for v in listing.Versions or []), default=0) + 1
await prisma.models.StoreListingVersion.prisma(tx).create(
data={
**_create_sub_agent_version_data(sub_graph, heading, main_agent_name),
"version": next_version,
"storeListingId": listing.id,
}
data=_create_sub_agent_version_data(
sub_graph, heading, main_agent_name, next_version, listing.id
)
)
await prisma.models.StoreListing.prisma(tx).update(
where={"id": listing.id}, data={"hasApprovedVersion": True}
@@ -1504,10 +1468,14 @@ async def _approve_sub_agent(
def _create_sub_agent_version_data(
sub_graph: prisma.models.AgentGraph, heading: str, main_agent_name: str
sub_graph: prisma.models.AgentGraph,
heading: str,
main_agent_name: str,
version: typing.Optional[int] = None,
store_listing_id: typing.Optional[str] = None,
) -> prisma.types.StoreListingVersionCreateInput:
"""Create store listing version data for a sub-agent"""
return prisma.types.StoreListingVersionCreateInput(
data = prisma.types.StoreListingVersionCreateInput(
agentGraphId=sub_graph.id,
agentGraphVersion=sub_graph.version,
name=sub_graph.name or heading,
@@ -1522,6 +1490,11 @@ def _create_sub_agent_version_data(
imageUrls=[], # Sub-agents don't need images
categories=[], # Sub-agents don't need categories
)
if version is not None:
data["version"] = version
if store_listing_id is not None:
data["storeListingId"] = store_listing_id
return data
async def review_store_submission(
@@ -1744,12 +1717,15 @@ async def review_store_submission(
# Convert to Pydantic model for consistency
return store_model.StoreSubmission(
listing_id=(submission.StoreListing.id if submission.StoreListing else ""),
agent_id=submission.agentGraphId,
agent_version=submission.agentGraphVersion,
name=submission.name,
sub_heading=submission.subHeading,
slug=(submission.StoreListing.slug if submission.StoreListing else ""),
slug=(
submission.StoreListing.slug
if hasattr(submission, "storeListing") and submission.StoreListing
else ""
),
description=submission.description,
instructions=submission.instructions,
image_urls=submission.imageUrls or [],
@@ -1851,7 +1827,9 @@ async def get_admin_listings_with_versions(
where = prisma.types.StoreListingWhereInput(**where_dict)
include = prisma.types.StoreListingInclude(
Versions=prisma.types.FindManyStoreListingVersionArgsFromStoreListing(
order_by={"version": "desc"}
order_by=prisma.types._StoreListingVersion_version_OrderByInput(
version="desc"
)
),
OwningUser=True,
)
@@ -1876,7 +1854,6 @@ async def get_admin_listings_with_versions(
# If we have versions, turn them into StoreSubmission models
for version in listing.Versions or []:
version_model = store_model.StoreSubmission(
listing_id=listing.id,
agent_id=version.agentGraphId,
agent_version=version.agentGraphVersion,
name=version.name,

View File

@@ -110,7 +110,6 @@ class Profile(pydantic.BaseModel):
class StoreSubmission(pydantic.BaseModel):
listing_id: str
agent_id: str
agent_version: int
name: str
@@ -165,12 +164,8 @@ class StoreListingsWithVersionsResponse(pydantic.BaseModel):
class StoreSubmissionRequest(pydantic.BaseModel):
agent_id: str = pydantic.Field(
..., min_length=1, description="Agent ID cannot be empty"
)
agent_version: int = pydantic.Field(
..., gt=0, description="Agent version must be greater than 0"
)
agent_id: str
agent_version: int
slug: str
name: str
sub_heading: str

View File

@@ -138,7 +138,6 @@ def test_creator_details():
def test_store_submission():
submission = store_model.StoreSubmission(
listing_id="listing123",
agent_id="agent123",
agent_version=1,
sub_heading="Test subheading",
@@ -160,7 +159,6 @@ def test_store_submissions_response():
response = store_model.StoreSubmissionsResponse(
submissions=[
store_model.StoreSubmission(
listing_id="listing123",
agent_id="agent123",
agent_version=1,
sub_heading="Test subheading",

View File

@@ -521,7 +521,6 @@ def test_get_submissions_success(
mocked_value = store_model.StoreSubmissionsResponse(
submissions=[
store_model.StoreSubmission(
listing_id="test-listing-id",
name="Test Agent",
description="Test agent description",
image_urls=["test.jpg"],

View File

@@ -6,9 +6,6 @@ import hashlib
import hmac
import logging
from enum import Enum
from typing import cast
from prisma.types import Serializable
from backend.sdk import (
BaseWebhooksManager,
@@ -87,9 +84,7 @@ class AirtableWebhookManager(BaseWebhooksManager):
# update webhook config
await update_webhook(
webhook.id,
config=cast(
dict[str, Serializable], {"base_id": base_id, "cursor": response.cursor}
),
config={"base_id": base_id, "cursor": response.cursor},
)
event_type = "notification"

View File

@@ -1,184 +0,0 @@
"""
Shared helpers for Human-In-The-Loop (HITL) review functionality.
Used by both the dedicated HumanInTheLoopBlock and blocks that require human review.
"""
import logging
from typing import Any, Optional
from prisma.enums import ReviewStatus
from pydantic import BaseModel
from backend.data.execution import ExecutionContext, ExecutionStatus
from backend.data.human_review import ReviewResult
from backend.executor.manager import async_update_node_execution_status
from backend.util.clients import get_database_manager_async_client
logger = logging.getLogger(__name__)
class ReviewDecision(BaseModel):
"""Result of a review decision."""
should_proceed: bool
message: str
review_result: ReviewResult
class HITLReviewHelper:
"""Helper class for Human-In-The-Loop review operations."""
@staticmethod
async def get_or_create_human_review(**kwargs) -> Optional[ReviewResult]:
"""Create or retrieve a human review from the database."""
return await get_database_manager_async_client().get_or_create_human_review(
**kwargs
)
@staticmethod
async def update_node_execution_status(**kwargs) -> None:
"""Update the execution status of a node."""
await async_update_node_execution_status(
db_client=get_database_manager_async_client(), **kwargs
)
@staticmethod
async def update_review_processed_status(
node_exec_id: str, processed: bool
) -> None:
"""Update the processed status of a review."""
return await get_database_manager_async_client().update_review_processed_status(
node_exec_id, processed
)
@staticmethod
async def _handle_review_request(
input_data: Any,
user_id: str,
node_exec_id: str,
graph_exec_id: str,
graph_id: str,
graph_version: int,
execution_context: ExecutionContext,
block_name: str = "Block",
editable: bool = False,
) -> Optional[ReviewResult]:
"""
Handle a review request for a block that requires human review.
Args:
input_data: The input data to be reviewed
user_id: ID of the user requesting the review
node_exec_id: ID of the node execution
graph_exec_id: ID of the graph execution
graph_id: ID of the graph
graph_version: Version of the graph
execution_context: Current execution context
block_name: Name of the block requesting review
editable: Whether the reviewer can edit the data
Returns:
ReviewResult if review is complete, None if waiting for human input
Raises:
Exception: If review creation or status update fails
"""
# Skip review if safe mode is disabled - return auto-approved result
if not execution_context.safe_mode:
logger.info(
f"Block {block_name} skipping review for node {node_exec_id} - safe mode disabled"
)
return ReviewResult(
data=input_data,
status=ReviewStatus.APPROVED,
message="Auto-approved (safe mode disabled)",
processed=True,
node_exec_id=node_exec_id,
)
result = await HITLReviewHelper.get_or_create_human_review(
user_id=user_id,
node_exec_id=node_exec_id,
graph_exec_id=graph_exec_id,
graph_id=graph_id,
graph_version=graph_version,
input_data=input_data,
message=f"Review required for {block_name} execution",
editable=editable,
)
if result is None:
logger.info(
f"Block {block_name} pausing execution for node {node_exec_id} - awaiting human review"
)
await HITLReviewHelper.update_node_execution_status(
exec_id=node_exec_id,
status=ExecutionStatus.REVIEW,
)
return None # Signal that execution should pause
# Mark review as processed if not already done
if not result.processed:
await HITLReviewHelper.update_review_processed_status(
node_exec_id=node_exec_id, processed=True
)
return result
@staticmethod
async def handle_review_decision(
input_data: Any,
user_id: str,
node_exec_id: str,
graph_exec_id: str,
graph_id: str,
graph_version: int,
execution_context: ExecutionContext,
block_name: str = "Block",
editable: bool = False,
) -> Optional[ReviewDecision]:
"""
Handle a review request and return the decision in a single call.
Args:
input_data: The input data to be reviewed
user_id: ID of the user requesting the review
node_exec_id: ID of the node execution
graph_exec_id: ID of the graph execution
graph_id: ID of the graph
graph_version: Version of the graph
execution_context: Current execution context
block_name: Name of the block requesting review
editable: Whether the reviewer can edit the data
Returns:
ReviewDecision if review is complete (approved/rejected),
None if execution should pause (awaiting review)
"""
review_result = await HITLReviewHelper._handle_review_request(
input_data=input_data,
user_id=user_id,
node_exec_id=node_exec_id,
graph_exec_id=graph_exec_id,
graph_id=graph_id,
graph_version=graph_version,
execution_context=execution_context,
block_name=block_name,
editable=editable,
)
if review_result is None:
# Still awaiting review - return None to pause execution
return None
# Review is complete, determine outcome
should_proceed = review_result.status == ReviewStatus.APPROVED
message = review_result.message or (
"Execution approved by reviewer"
if should_proceed
else "Execution rejected by reviewer"
)
return ReviewDecision(
should_proceed=should_proceed, message=message, review_result=review_result
)

View File

@@ -3,7 +3,6 @@ from typing import Any
from prisma.enums import ReviewStatus
from backend.blocks.helpers.review import HITLReviewHelper
from backend.data.block import (
Block,
BlockCategory,
@@ -12,9 +11,11 @@ from backend.data.block import (
BlockSchemaOutput,
BlockType,
)
from backend.data.execution import ExecutionContext
from backend.data.execution import ExecutionContext, ExecutionStatus
from backend.data.human_review import ReviewResult
from backend.data.model import SchemaField
from backend.executor.manager import async_update_node_execution_status
from backend.util.clients import get_database_manager_async_client
logger = logging.getLogger(__name__)
@@ -71,26 +72,32 @@ class HumanInTheLoopBlock(Block):
("approved_data", {"name": "John Doe", "age": 30}),
],
test_mock={
"handle_review_decision": lambda **kwargs: type(
"ReviewDecision",
(),
{
"should_proceed": True,
"message": "Test approval message",
"review_result": ReviewResult(
data={"name": "John Doe", "age": 30},
status=ReviewStatus.APPROVED,
message="",
processed=False,
node_exec_id="test-node-exec-id",
),
},
)(),
"get_or_create_human_review": lambda *_args, **_kwargs: ReviewResult(
data={"name": "John Doe", "age": 30},
status=ReviewStatus.APPROVED,
message="",
processed=False,
node_exec_id="test-node-exec-id",
),
"update_node_execution_status": lambda *_args, **_kwargs: None,
"update_review_processed_status": lambda *_args, **_kwargs: None,
},
)
async def handle_review_decision(self, **kwargs):
return await HITLReviewHelper.handle_review_decision(**kwargs)
async def get_or_create_human_review(self, **kwargs):
return await get_database_manager_async_client().get_or_create_human_review(
**kwargs
)
async def update_node_execution_status(self, **kwargs):
return await async_update_node_execution_status(
db_client=get_database_manager_async_client(), **kwargs
)
async def update_review_processed_status(self, node_exec_id: str, processed: bool):
return await get_database_manager_async_client().update_review_processed_status(
node_exec_id, processed
)
async def run(
self,
@@ -102,7 +109,7 @@ class HumanInTheLoopBlock(Block):
graph_id: str,
graph_version: int,
execution_context: ExecutionContext,
**_kwargs,
**kwargs,
) -> BlockOutput:
if not execution_context.safe_mode:
logger.info(
@@ -112,28 +119,48 @@ class HumanInTheLoopBlock(Block):
yield "review_message", "Auto-approved (safe mode disabled)"
return
decision = await self.handle_review_decision(
input_data=input_data.data,
user_id=user_id,
node_exec_id=node_exec_id,
graph_exec_id=graph_exec_id,
graph_id=graph_id,
graph_version=graph_version,
execution_context=execution_context,
block_name=self.name,
editable=input_data.editable,
)
try:
result = await self.get_or_create_human_review(
user_id=user_id,
node_exec_id=node_exec_id,
graph_exec_id=graph_exec_id,
graph_id=graph_id,
graph_version=graph_version,
input_data=input_data.data,
message=input_data.name,
editable=input_data.editable,
)
except Exception as e:
logger.error(f"Error in HITL block for node {node_exec_id}: {str(e)}")
raise
if decision is None:
return
if result is None:
logger.info(
f"HITL block pausing execution for node {node_exec_id} - awaiting human review"
)
try:
await self.update_node_execution_status(
exec_id=node_exec_id,
status=ExecutionStatus.REVIEW,
)
return
except Exception as e:
logger.error(
f"Failed to update node status for HITL block {node_exec_id}: {str(e)}"
)
raise
status = decision.review_result.status
if status == ReviewStatus.APPROVED:
yield "approved_data", decision.review_result.data
elif status == ReviewStatus.REJECTED:
yield "rejected_data", decision.review_result.data
else:
raise RuntimeError(f"Unexpected review status: {status}")
if not result.processed:
await self.update_review_processed_status(
node_exec_id=node_exec_id, processed=True
)
if decision.message:
yield "review_message", decision.message
if result.status == ReviewStatus.APPROVED:
yield "approved_data", result.data
if result.message:
yield "review_message", result.message
elif result.status == ReviewStatus.REJECTED:
yield "rejected_data", result.data
if result.message:
yield "review_message", result.message

File diff suppressed because it is too large Load Diff

View File

@@ -391,12 +391,8 @@ class SmartDecisionMakerBlock(Block):
"""
block = sink_node.block
# Use custom name from node metadata if set, otherwise fall back to block.name
custom_name = sink_node.metadata.get("customized_name")
tool_name = custom_name if custom_name else block.name
tool_function: dict[str, Any] = {
"name": SmartDecisionMakerBlock.cleanup(tool_name),
"name": SmartDecisionMakerBlock.cleanup(block.name),
"description": block.description,
}
sink_block_input_schema = block.input_schema
@@ -493,24 +489,14 @@ class SmartDecisionMakerBlock(Block):
f"Sink graph metadata not found: {graph_id} {graph_version}"
)
# Use custom name from node metadata if set, otherwise fall back to graph name
custom_name = sink_node.metadata.get("customized_name")
tool_name = custom_name if custom_name else sink_graph_meta.name
tool_function: dict[str, Any] = {
"name": SmartDecisionMakerBlock.cleanup(tool_name),
"name": SmartDecisionMakerBlock.cleanup(sink_graph_meta.name),
"description": sink_graph_meta.description,
}
properties = {}
field_mapping = {}
for link in links:
field_name = link.sink_name
clean_field_name = SmartDecisionMakerBlock.cleanup(field_name)
field_mapping[clean_field_name] = field_name
sink_block_input_schema = sink_node.input_default["input_schema"]
sink_block_properties = sink_block_input_schema.get("properties", {}).get(
link.sink_name, {}
@@ -520,7 +506,7 @@ class SmartDecisionMakerBlock(Block):
if "description" in sink_block_properties
else f"The {link.sink_name} of the tool"
)
properties[clean_field_name] = {
properties[link.sink_name] = {
"type": "string",
"description": description,
"default": json.dumps(sink_block_properties.get("default", None)),
@@ -533,7 +519,7 @@ class SmartDecisionMakerBlock(Block):
"strict": True,
}
tool_function["_field_mapping"] = field_mapping
# Store node info for later use in output processing
tool_function["_sink_node_id"] = sink_node.id
return {"type": "function", "function": tool_function}
@@ -989,28 +975,10 @@ class SmartDecisionMakerBlock(Block):
graph_version: int,
execution_context: ExecutionContext,
execution_processor: "ExecutionProcessor",
nodes_to_skip: set[str] | None = None,
**kwargs,
) -> BlockOutput:
tool_functions = await self._create_tool_node_signatures(node_id)
original_tool_count = len(tool_functions)
# Filter out tools for nodes that should be skipped (e.g., missing optional credentials)
if nodes_to_skip:
tool_functions = [
tf
for tf in tool_functions
if tf.get("function", {}).get("_sink_node_id") not in nodes_to_skip
]
# Only raise error if we had tools but they were all filtered out
if original_tool_count > 0 and not tool_functions:
raise ValueError(
"No available tools to execute - all downstream nodes are unavailable "
"(possibly due to missing optional credentials)"
)
yield "tool_functions", json.dumps(tool_functions)
conversation_history = input_data.conversation_history or []
@@ -1161,9 +1129,8 @@ class SmartDecisionMakerBlock(Block):
original_field_name = field_mapping.get(clean_arg_name, clean_arg_name)
arg_value = tool_args.get(clean_arg_name)
# Use original_field_name directly (not sanitized) to match link sink_name
# The field_mapping already translates from LLM's cleaned names to original names
emit_key = f"tools_^_{sink_node_id}_~_{original_field_name}"
sanitized_arg_name = self.cleanup(original_field_name)
emit_key = f"tools_^_{sink_node_id}_~_{sanitized_arg_name}"
logger.debug(
"[SmartDecisionMakerBlock|geid:%s|neid:%s] emit %s",

View File

@@ -1057,153 +1057,3 @@ async def test_smart_decision_maker_traditional_mode_default():
) # Should yield individual tool parameters
assert "tools_^_test-sink-node-id_~_max_keyword_difficulty" in outputs
assert "conversations" in outputs
@pytest.mark.asyncio
async def test_smart_decision_maker_uses_customized_name_for_blocks():
"""Test that SmartDecisionMakerBlock uses customized_name from node metadata for tool names."""
from unittest.mock import MagicMock
from backend.blocks.basic import StoreValueBlock
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
from backend.data.graph import Link, Node
# Create a mock node with customized_name in metadata
mock_node = MagicMock(spec=Node)
mock_node.id = "test-node-id"
mock_node.block_id = StoreValueBlock().id
mock_node.metadata = {"customized_name": "My Custom Tool Name"}
mock_node.block = StoreValueBlock()
# Create a mock link
mock_link = MagicMock(spec=Link)
mock_link.sink_name = "input"
# Call the function directly
result = await SmartDecisionMakerBlock._create_block_function_signature(
mock_node, [mock_link]
)
# Verify the tool name uses the customized name (cleaned up)
assert result["type"] == "function"
assert result["function"]["name"] == "my_custom_tool_name" # Cleaned version
assert result["function"]["_sink_node_id"] == "test-node-id"
@pytest.mark.asyncio
async def test_smart_decision_maker_falls_back_to_block_name():
"""Test that SmartDecisionMakerBlock falls back to block.name when no customized_name."""
from unittest.mock import MagicMock
from backend.blocks.basic import StoreValueBlock
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
from backend.data.graph import Link, Node
# Create a mock node without customized_name
mock_node = MagicMock(spec=Node)
mock_node.id = "test-node-id"
mock_node.block_id = StoreValueBlock().id
mock_node.metadata = {} # No customized_name
mock_node.block = StoreValueBlock()
# Create a mock link
mock_link = MagicMock(spec=Link)
mock_link.sink_name = "input"
# Call the function directly
result = await SmartDecisionMakerBlock._create_block_function_signature(
mock_node, [mock_link]
)
# Verify the tool name uses the block's default name
assert result["type"] == "function"
assert result["function"]["name"] == "storevalueblock" # Default block name cleaned
assert result["function"]["_sink_node_id"] == "test-node-id"
@pytest.mark.asyncio
async def test_smart_decision_maker_uses_customized_name_for_agents():
"""Test that SmartDecisionMakerBlock uses customized_name from metadata for agent nodes."""
from unittest.mock import AsyncMock, MagicMock, patch
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
from backend.data.graph import Link, Node
# Create a mock node with customized_name in metadata
mock_node = MagicMock(spec=Node)
mock_node.id = "test-agent-node-id"
mock_node.metadata = {"customized_name": "My Custom Agent"}
mock_node.input_default = {
"graph_id": "test-graph-id",
"graph_version": 1,
"input_schema": {"properties": {"test_input": {"description": "Test input"}}},
}
# Create a mock link
mock_link = MagicMock(spec=Link)
mock_link.sink_name = "test_input"
# Mock the database client
mock_graph_meta = MagicMock()
mock_graph_meta.name = "Original Agent Name"
mock_graph_meta.description = "Agent description"
mock_db_client = AsyncMock()
mock_db_client.get_graph_metadata.return_value = mock_graph_meta
with patch(
"backend.blocks.smart_decision_maker.get_database_manager_async_client",
return_value=mock_db_client,
):
result = await SmartDecisionMakerBlock._create_agent_function_signature(
mock_node, [mock_link]
)
# Verify the tool name uses the customized name (cleaned up)
assert result["type"] == "function"
assert result["function"]["name"] == "my_custom_agent" # Cleaned version
assert result["function"]["_sink_node_id"] == "test-agent-node-id"
@pytest.mark.asyncio
async def test_smart_decision_maker_agent_falls_back_to_graph_name():
"""Test that agent node falls back to graph name when no customized_name."""
from unittest.mock import AsyncMock, MagicMock, patch
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
from backend.data.graph import Link, Node
# Create a mock node without customized_name
mock_node = MagicMock(spec=Node)
mock_node.id = "test-agent-node-id"
mock_node.metadata = {} # No customized_name
mock_node.input_default = {
"graph_id": "test-graph-id",
"graph_version": 1,
"input_schema": {"properties": {"test_input": {"description": "Test input"}}},
}
# Create a mock link
mock_link = MagicMock(spec=Link)
mock_link.sink_name = "test_input"
# Mock the database client
mock_graph_meta = MagicMock()
mock_graph_meta.name = "Original Agent Name"
mock_graph_meta.description = "Agent description"
mock_db_client = AsyncMock()
mock_db_client.get_graph_metadata.return_value = mock_graph_meta
with patch(
"backend.blocks.smart_decision_maker.get_database_manager_async_client",
return_value=mock_db_client,
):
result = await SmartDecisionMakerBlock._create_agent_function_signature(
mock_node, [mock_link]
)
# Verify the tool name uses the graph's default name
assert result["type"] == "function"
assert result["function"]["name"] == "original_agent_name" # Graph name cleaned
assert result["function"]["_sink_node_id"] == "test-agent-node-id"

View File

@@ -15,7 +15,6 @@ async def test_smart_decision_maker_handles_dynamic_dict_fields():
mock_node.block = CreateDictionaryBlock()
mock_node.block_id = CreateDictionaryBlock().id
mock_node.input_default = {}
mock_node.metadata = {}
# Create mock links with dynamic dictionary fields
mock_links = [
@@ -78,7 +77,6 @@ async def test_smart_decision_maker_handles_dynamic_list_fields():
mock_node.block = AddToListBlock()
mock_node.block_id = AddToListBlock().id
mock_node.input_default = {}
mock_node.metadata = {}
# Create mock links with dynamic list fields
mock_links = [

View File

@@ -44,7 +44,6 @@ async def test_create_block_function_signature_with_dict_fields():
mock_node.block = CreateDictionaryBlock()
mock_node.block_id = CreateDictionaryBlock().id
mock_node.input_default = {}
mock_node.metadata = {}
# Create mock links with dynamic dictionary fields (source sanitized, sink original)
mock_links = [
@@ -107,7 +106,6 @@ async def test_create_block_function_signature_with_list_fields():
mock_node.block = AddToListBlock()
mock_node.block_id = AddToListBlock().id
mock_node.input_default = {}
mock_node.metadata = {}
# Create mock links with dynamic list fields
mock_links = [
@@ -161,7 +159,6 @@ async def test_create_block_function_signature_with_object_fields():
mock_node.block = MatchTextPatternBlock()
mock_node.block_id = MatchTextPatternBlock().id
mock_node.input_default = {}
mock_node.metadata = {}
# Create mock links with dynamic object fields
mock_links = [
@@ -211,13 +208,11 @@ async def test_create_tool_node_signatures():
mock_dict_node.block = CreateDictionaryBlock()
mock_dict_node.block_id = CreateDictionaryBlock().id
mock_dict_node.input_default = {}
mock_dict_node.metadata = {}
mock_list_node = Mock()
mock_list_node.block = AddToListBlock()
mock_list_node.block_id = AddToListBlock().id
mock_list_node.input_default = {}
mock_list_node.metadata = {}
# Mock links with dynamic fields
dict_link1 = Mock(
@@ -428,7 +423,6 @@ async def test_mixed_regular_and_dynamic_fields():
mock_node.block.name = "TestBlock"
mock_node.block.description = "A test block"
mock_node.block.input_schema = Mock()
mock_node.metadata = {}
# Mock the get_field_schema to return a proper schema for regular fields
def get_field_schema(field_name):

View File

@@ -42,6 +42,7 @@ from urllib.parse import urlparse
import click
from autogpt_libs.api_key.keysmith import APIKeySmith
from prisma.enums import APIKeyPermission
from prisma.types import OAuthApplicationCreateInput
keysmith = APIKeySmith()
@@ -147,7 +148,7 @@ def format_sql_insert(creds: dict) -> str:
sql = f"""
-- ============================================================
-- OAuth Application: {creds['name']}
-- OAuth Application: {creds["name"]}
-- Generated: {now_iso} UTC
-- ============================================================
@@ -167,14 +168,14 @@ INSERT INTO "OAuthApplication" (
"isActive"
)
VALUES (
'{creds['id']}',
'{creds["id"]}',
NOW(),
NOW(),
'{creds['name']}',
{f"'{creds['description']}'" if creds['description'] else 'NULL'},
'{creds['client_id']}',
'{creds['client_secret_hash']}',
'{creds['client_secret_salt']}',
'{creds["name"]}',
{f"'{creds['description']}'" if creds["description"] else "NULL"},
'{creds["client_id"]}',
'{creds["client_secret_hash"]}',
'{creds["client_secret_salt"]}',
ARRAY{redirect_uris_pg}::TEXT[],
ARRAY{grant_types_pg}::TEXT[],
ARRAY{scopes_pg}::"APIKeyPermission"[],
@@ -186,8 +187,8 @@ VALUES (
-- ⚠️ IMPORTANT: Save these credentials securely!
-- ============================================================
--
-- Client ID: {creds['client_id']}
-- Client Secret: {creds['client_secret_plaintext']}
-- Client ID: {creds["client_id"]}
-- Client Secret: {creds["client_secret_plaintext"]}
--
-- ⚠️ The client secret is shown ONLY ONCE!
-- ⚠️ Store it securely and share only with the application developer.
@@ -200,7 +201,7 @@ VALUES (
-- To verify the application was created:
-- SELECT "clientId", name, scopes, "redirectUris", "isActive"
-- FROM "OAuthApplication"
-- WHERE "clientId" = '{creds['client_id']}';
-- WHERE "clientId" = '{creds["client_id"]}';
"""
return sql
@@ -834,19 +835,19 @@ async def create_test_app_in_db(
# Insert into database
app = await OAuthApplication.prisma().create(
data={
"id": creds["id"],
"name": creds["name"],
"description": creds["description"],
"clientId": creds["client_id"],
"clientSecret": creds["client_secret_hash"],
"clientSecretSalt": creds["client_secret_salt"],
"redirectUris": creds["redirect_uris"],
"grantTypes": creds["grant_types"],
"scopes": creds["scopes"],
"ownerId": owner_id,
"isActive": True,
}
data=OAuthApplicationCreateInput(
id=creds["id"],
name=creds["name"],
description=creds["description"],
clientId=creds["client_id"],
clientSecret=creds["client_secret_hash"],
clientSecretSalt=creds["client_secret_salt"],
redirectUris=creds["redirect_uris"],
grantTypes=creds["grant_types"],
scopes=creds["scopes"],
ownerId=owner_id,
isActive=True,
)
)
click.echo(f"✓ Created test OAuth application: {app.clientId}")

View File

@@ -6,7 +6,7 @@ from typing import Literal, Optional
from autogpt_libs.api_key.keysmith import APIKeySmith
from prisma.enums import APIKeyPermission, APIKeyStatus
from prisma.models import APIKey as PrismaAPIKey
from prisma.types import APIKeyWhereUniqueInput
from prisma.types import APIKeyCreateInput, APIKeyWhereUniqueInput
from pydantic import Field
from backend.data.includes import MAX_USER_API_KEYS_FETCH
@@ -82,17 +82,17 @@ async def create_api_key(
generated_key = keysmith.generate_key()
saved_key_obj = await PrismaAPIKey.prisma().create(
data={
"id": str(uuid.uuid4()),
"name": name,
"head": generated_key.head,
"tail": generated_key.tail,
"hash": generated_key.hash,
"salt": generated_key.salt,
"permissions": [p for p in permissions],
"description": description,
"userId": user_id,
}
data=APIKeyCreateInput(
id=str(uuid.uuid4()),
name=name,
head=generated_key.head,
tail=generated_key.tail,
hash=generated_key.hash,
salt=generated_key.salt,
permissions=[p for p in permissions],
description=description,
userId=user_id,
)
)
return APIKeyInfo.from_db(saved_key_obj), generated_key.key

View File

@@ -22,7 +22,12 @@ from prisma.models import OAuthAccessToken as PrismaOAuthAccessToken
from prisma.models import OAuthApplication as PrismaOAuthApplication
from prisma.models import OAuthAuthorizationCode as PrismaOAuthAuthorizationCode
from prisma.models import OAuthRefreshToken as PrismaOAuthRefreshToken
from prisma.types import OAuthApplicationUpdateInput
from prisma.types import (
OAuthAccessTokenCreateInput,
OAuthApplicationUpdateInput,
OAuthAuthorizationCodeCreateInput,
OAuthRefreshTokenCreateInput,
)
from pydantic import BaseModel, Field, SecretStr
from .base import APIAuthorizationInfo
@@ -359,17 +364,17 @@ async def create_authorization_code(
expires_at = now + AUTHORIZATION_CODE_TTL
saved_code = await PrismaOAuthAuthorizationCode.prisma().create(
data={
"id": str(uuid.uuid4()),
"code": code,
"expiresAt": expires_at,
"applicationId": application_id,
"userId": user_id,
"scopes": [s for s in scopes],
"redirectUri": redirect_uri,
"codeChallenge": code_challenge,
"codeChallengeMethod": code_challenge_method,
}
data=OAuthAuthorizationCodeCreateInput(
id=str(uuid.uuid4()),
code=code,
expiresAt=expires_at,
applicationId=application_id,
userId=user_id,
scopes=[s for s in scopes],
redirectUri=redirect_uri,
codeChallenge=code_challenge,
codeChallengeMethod=code_challenge_method,
)
)
return OAuthAuthorizationCodeInfo.from_db(saved_code)
@@ -490,14 +495,14 @@ async def create_access_token(
expires_at = now + ACCESS_TOKEN_TTL
saved_token = await PrismaOAuthAccessToken.prisma().create(
data={
"id": str(uuid.uuid4()),
"token": token_hash, # SHA256 hash for direct lookup
"expiresAt": expires_at,
"applicationId": application_id,
"userId": user_id,
"scopes": [s for s in scopes],
}
data=OAuthAccessTokenCreateInput(
id=str(uuid.uuid4()),
token=token_hash, # SHA256 hash for direct lookup
expiresAt=expires_at,
applicationId=application_id,
userId=user_id,
scopes=[s for s in scopes],
)
)
return OAuthAccessToken.from_db(saved_token, plaintext_token=plaintext_token)
@@ -607,14 +612,14 @@ async def create_refresh_token(
expires_at = now + REFRESH_TOKEN_TTL
saved_token = await PrismaOAuthRefreshToken.prisma().create(
data={
"id": str(uuid.uuid4()),
"token": token_hash, # SHA256 hash for direct lookup
"expiresAt": expires_at,
"applicationId": application_id,
"userId": user_id,
"scopes": [s for s in scopes],
}
data=OAuthRefreshTokenCreateInput(
id=str(uuid.uuid4()),
token=token_hash, # SHA256 hash for direct lookup
expiresAt=expires_at,
applicationId=application_id,
userId=user_id,
scopes=[s for s in scopes],
)
)
return OAuthRefreshToken.from_db(saved_token, plaintext_token=plaintext_token)

View File

@@ -50,8 +50,6 @@ from .model import (
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from backend.data.execution import ExecutionContext
from .graph import Link
app_config = Config()
@@ -474,7 +472,6 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
self.block_type = block_type
self.webhook_config = webhook_config
self.execution_stats: NodeExecutionStats = NodeExecutionStats()
self.requires_human_review: bool = False
if self.webhook_config:
if isinstance(self.webhook_config, BlockWebhookConfig):
@@ -617,77 +614,7 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
block_id=self.id,
) from ex
async def is_block_exec_need_review(
self,
input_data: BlockInput,
*,
user_id: str,
node_exec_id: str,
graph_exec_id: str,
graph_id: str,
graph_version: int,
execution_context: "ExecutionContext",
**kwargs,
) -> tuple[bool, BlockInput]:
"""
Check if this block execution needs human review and handle the review process.
Returns:
Tuple of (should_pause, input_data_to_use)
- should_pause: True if execution should be paused for review
- input_data_to_use: The input data to use (may be modified by reviewer)
"""
# Skip review if not required or safe mode is disabled
if not self.requires_human_review or not execution_context.safe_mode:
return False, input_data
from backend.blocks.helpers.review import HITLReviewHelper
# Handle the review request and get decision
decision = await HITLReviewHelper.handle_review_decision(
input_data=input_data,
user_id=user_id,
node_exec_id=node_exec_id,
graph_exec_id=graph_exec_id,
graph_id=graph_id,
graph_version=graph_version,
execution_context=execution_context,
block_name=self.name,
editable=True,
)
if decision is None:
# We're awaiting review - pause execution
return True, input_data
if not decision.should_proceed:
# Review was rejected, raise an error to stop execution
raise BlockExecutionError(
message=f"Block execution rejected by reviewer: {decision.message}",
block_name=self.name,
block_id=self.id,
)
# Review was approved - use the potentially modified data
# ReviewResult.data must be a dict for block inputs
reviewed_data = decision.review_result.data
if not isinstance(reviewed_data, dict):
raise BlockExecutionError(
message=f"Review data must be a dict for block input, got {type(reviewed_data).__name__}",
block_name=self.name,
block_id=self.id,
)
return False, reviewed_data
async def _execute(self, input_data: BlockInput, **kwargs) -> BlockOutput:
# Check for review requirement and get potentially modified input data
should_pause, input_data = await self.is_block_exec_need_review(
input_data, **kwargs
)
if should_pause:
return
# Validate the input data (original or reviewer-modified) once
if error := self.input_schema.validate_data(input_data):
raise BlockInputError(
message=f"Unable to execute block with invalid input data: {error}",
@@ -695,7 +622,6 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
block_id=self.id,
)
# Use the validated input data
async for output_name, output_data in self.run(
self.input_schema(**{k: v for k, v in input_data.items() if v is not None}),
**kwargs,

View File

@@ -11,6 +11,7 @@ import pytest
from prisma.enums import CreditTransactionType
from prisma.errors import UniqueViolationError
from prisma.models import CreditTransaction, User, UserBalance
from prisma.types import UserBalanceCreateInput, UserBalanceUpsertInput, UserCreateInput
from backend.data.credit import UserCredit
from backend.util.json import SafeJson
@@ -21,11 +22,11 @@ async def create_test_user(user_id: str) -> None:
"""Create a test user for ceiling tests."""
try:
await User.prisma().create(
data={
"id": user_id,
"email": f"test-{user_id}@example.com",
"name": f"Test User {user_id[:8]}",
}
data=UserCreateInput(
id=user_id,
email=f"test-{user_id}@example.com",
name=f"Test User {user_id[:8]}",
)
)
except UniqueViolationError:
# User already exists, continue
@@ -33,7 +34,10 @@ async def create_test_user(user_id: str) -> None:
await UserBalance.prisma().upsert(
where={"userId": user_id},
data={"create": {"userId": user_id, "balance": 0}, "update": {"balance": 0}},
data=UserBalanceUpsertInput(
create=UserBalanceCreateInput(userId=user_id, balance=0),
update={"balance": 0},
),
)

View File

@@ -14,6 +14,7 @@ import pytest
from prisma.enums import CreditTransactionType
from prisma.errors import UniqueViolationError
from prisma.models import CreditTransaction, User, UserBalance
from prisma.types import UserBalanceCreateInput, UserBalanceUpsertInput, UserCreateInput
from backend.data.credit import POSTGRES_INT_MAX, UsageTransactionMetadata, UserCredit
from backend.util.exceptions import InsufficientBalanceError
@@ -28,11 +29,11 @@ async def create_test_user(user_id: str) -> None:
"""Create a test user with initial balance."""
try:
await User.prisma().create(
data={
"id": user_id,
"email": f"test-{user_id}@example.com",
"name": f"Test User {user_id[:8]}",
}
data=UserCreateInput(
id=user_id,
email=f"test-{user_id}@example.com",
name=f"Test User {user_id[:8]}",
)
)
except UniqueViolationError:
# User already exists, continue
@@ -41,7 +42,10 @@ async def create_test_user(user_id: str) -> None:
# Ensure UserBalance record exists
await UserBalance.prisma().upsert(
where={"userId": user_id},
data={"create": {"userId": user_id, "balance": 0}, "update": {"balance": 0}},
data=UserBalanceUpsertInput(
create=UserBalanceCreateInput(userId=user_id, balance=0),
update={"balance": 0},
),
)
@@ -342,10 +346,10 @@ async def test_integer_overflow_protection(server: SpinTestServer):
# First, set balance near max
await UserBalance.prisma().upsert(
where={"userId": user_id},
data={
"create": {"userId": user_id, "balance": max_int - 100},
"update": {"balance": max_int - 100},
},
data=UserBalanceUpsertInput(
create=UserBalanceCreateInput(userId=user_id, balance=max_int - 100),
update={"balance": max_int - 100},
),
)
# Try to add more than possible - should clamp to POSTGRES_INT_MAX
@@ -507,7 +511,7 @@ async def test_concurrent_multiple_spends_sufficient_balance(server: SpinTestSer
sorted_timings = sorted(timings.items(), key=lambda x: x[1]["start"])
print("\nExecution order by start time:")
for i, (label, timing) in enumerate(sorted_timings):
print(f" {i+1}. {label}: {timing['start']:.4f} -> {timing['end']:.4f}")
print(f" {i + 1}. {label}: {timing['start']:.4f} -> {timing['end']:.4f}")
# Check for overlap (true concurrency) vs serialization
overlaps = []
@@ -546,7 +550,7 @@ async def test_concurrent_multiple_spends_sufficient_balance(server: SpinTestSer
print("\nDatabase transaction order (by createdAt):")
for i, tx in enumerate(transactions):
print(
f" {i+1}. Amount {tx.amount}, Running balance: {tx.runningBalance}, Created: {tx.createdAt}"
f" {i + 1}. Amount {tx.amount}, Running balance: {tx.runningBalance}, Created: {tx.createdAt}"
)
# Verify running balances are chronologically consistent (ordered by createdAt)
@@ -707,7 +711,7 @@ async def test_prove_database_locking_behavior(server: SpinTestServer):
for i, result in enumerate(sorted_results):
print(
f" {i+1}. {result['label']}: DB operation took {result['db_duration']:.4f}s"
f" {i + 1}. {result['label']}: DB operation took {result['db_duration']:.4f}s"
)
# Check if any operations overlapped at the database level

View File

@@ -8,6 +8,7 @@ which would have caught the CreditTransactionType enum casting bug.
import pytest
from prisma.enums import CreditTransactionType
from prisma.models import CreditTransaction, User, UserBalance
from prisma.types import UserCreateInput
from backend.data.credit import (
AutoTopUpConfig,
@@ -29,12 +30,12 @@ async def cleanup_test_user():
# Create the user first
try:
await User.prisma().create(
data={
"id": user_id,
"email": f"test-{user_id}@example.com",
"topUpConfig": SafeJson({}),
"timezone": "UTC",
}
data=UserCreateInput(
id=user_id,
email=f"test-{user_id}@example.com",
topUpConfig=SafeJson({}),
timezone="UTC",
)
)
except Exception:
# User might already exist, that's fine

View File

@@ -12,6 +12,12 @@ import pytest
import stripe
from prisma.enums import CreditTransactionType
from prisma.models import CreditRefundRequest, CreditTransaction, User, UserBalance
from prisma.types import (
CreditRefundRequestCreateInput,
CreditTransactionCreateInput,
UserBalanceCreateInput,
UserCreateInput,
)
from backend.data.credit import UserCredit
from backend.util.json import SafeJson
@@ -35,32 +41,32 @@ async def setup_test_user_with_topup():
# Create user
await User.prisma().create(
data={
"id": REFUND_TEST_USER_ID,
"email": f"{REFUND_TEST_USER_ID}@example.com",
"name": "Refund Test User",
}
data=UserCreateInput(
id=REFUND_TEST_USER_ID,
email=f"{REFUND_TEST_USER_ID}@example.com",
name="Refund Test User",
)
)
# Create user balance
await UserBalance.prisma().create(
data={
"userId": REFUND_TEST_USER_ID,
"balance": 1000, # $10
}
data=UserBalanceCreateInput(
userId=REFUND_TEST_USER_ID,
balance=1000, # $10
)
)
# Create a top-up transaction that can be refunded
topup_tx = await CreditTransaction.prisma().create(
data={
"userId": REFUND_TEST_USER_ID,
"amount": 1000,
"type": CreditTransactionType.TOP_UP,
"transactionKey": "pi_test_12345",
"runningBalance": 1000,
"isActive": True,
"metadata": SafeJson({"stripe_payment_intent": "pi_test_12345"}),
}
data=CreditTransactionCreateInput(
userId=REFUND_TEST_USER_ID,
amount=1000,
type=CreditTransactionType.TOP_UP,
transactionKey="pi_test_12345",
runningBalance=1000,
isActive=True,
metadata=SafeJson({"stripe_payment_intent": "pi_test_12345"}),
)
)
return topup_tx
@@ -93,12 +99,12 @@ async def test_deduct_credits_atomic(server: SpinTestServer):
# Create refund request record (simulating webhook flow)
await CreditRefundRequest.prisma().create(
data={
"userId": REFUND_TEST_USER_ID,
"amount": 500,
"transactionKey": topup_tx.transactionKey, # Should match the original transaction
"reason": "Test refund",
}
data=CreditRefundRequestCreateInput(
userId=REFUND_TEST_USER_ID,
amount=500,
transactionKey=topup_tx.transactionKey, # Should match the original transaction
reason="Test refund",
)
)
# Call deduct_credits
@@ -286,12 +292,12 @@ async def test_concurrent_refunds(server: SpinTestServer):
refund_requests = []
for i in range(5):
req = await CreditRefundRequest.prisma().create(
data={
"userId": REFUND_TEST_USER_ID,
"amount": 100, # $1 each
"transactionKey": topup_tx.transactionKey,
"reason": f"Test refund {i}",
}
data=CreditRefundRequestCreateInput(
userId=REFUND_TEST_USER_ID,
amount=100, # $1 each
transactionKey=topup_tx.transactionKey,
reason=f"Test refund {i}",
)
)
refund_requests.append(req)

View File

@@ -3,6 +3,11 @@ from datetime import datetime, timedelta, timezone
import pytest
from prisma.enums import CreditTransactionType
from prisma.models import CreditTransaction, UserBalance
from prisma.types import (
CreditTransactionCreateInput,
UserBalanceCreateInput,
UserBalanceUpsertInput,
)
from backend.blocks.llm import AITextGeneratorBlock
from backend.data.block import get_block
@@ -23,10 +28,10 @@ async def disable_test_user_transactions():
old_date = datetime.now(timezone.utc) - timedelta(days=35) # More than a month ago
await UserBalance.prisma().upsert(
where={"userId": DEFAULT_USER_ID},
data={
"create": {"userId": DEFAULT_USER_ID, "balance": 0},
"update": {"balance": 0, "updatedAt": old_date},
},
data=UserBalanceUpsertInput(
create=UserBalanceCreateInput(userId=DEFAULT_USER_ID, balance=0),
update={"balance": 0, "updatedAt": old_date},
),
)
@@ -140,23 +145,23 @@ async def test_block_credit_reset(server: SpinTestServer):
# Manually create a transaction with month 1 timestamp to establish history
await CreditTransaction.prisma().create(
data={
"userId": DEFAULT_USER_ID,
"amount": 100,
"type": CreditTransactionType.TOP_UP,
"runningBalance": 1100,
"isActive": True,
"createdAt": month1, # Set specific timestamp
}
data=CreditTransactionCreateInput(
userId=DEFAULT_USER_ID,
amount=100,
type=CreditTransactionType.TOP_UP,
runningBalance=1100,
isActive=True,
createdAt=month1, # Set specific timestamp
)
)
# Update user balance to match
await UserBalance.prisma().upsert(
where={"userId": DEFAULT_USER_ID},
data={
"create": {"userId": DEFAULT_USER_ID, "balance": 1100},
"update": {"balance": 1100},
},
data=UserBalanceUpsertInput(
create=UserBalanceCreateInput(userId=DEFAULT_USER_ID, balance=1100),
update={"balance": 1100},
),
)
# Now test month 2 behavior
@@ -175,14 +180,14 @@ async def test_block_credit_reset(server: SpinTestServer):
# Create a month 2 transaction to update the last transaction time
await CreditTransaction.prisma().create(
data={
"userId": DEFAULT_USER_ID,
"amount": -700, # Spent 700 to get to 400
"type": CreditTransactionType.USAGE,
"runningBalance": 400,
"isActive": True,
"createdAt": month2,
}
data=CreditTransactionCreateInput(
userId=DEFAULT_USER_ID,
amount=-700, # Spent 700 to get to 400
type=CreditTransactionType.USAGE,
runningBalance=400,
isActive=True,
createdAt=month2,
)
)
# Move to month 3

View File

@@ -12,6 +12,7 @@ import pytest
from prisma.enums import CreditTransactionType
from prisma.errors import UniqueViolationError
from prisma.models import CreditTransaction, User, UserBalance
from prisma.types import UserBalanceCreateInput, UserBalanceUpsertInput, UserCreateInput
from backend.data.credit import POSTGRES_INT_MIN, UserCredit
from backend.util.test import SpinTestServer
@@ -21,11 +22,11 @@ async def create_test_user(user_id: str) -> None:
"""Create a test user for underflow tests."""
try:
await User.prisma().create(
data={
"id": user_id,
"email": f"test-{user_id}@example.com",
"name": f"Test User {user_id[:8]}",
}
data=UserCreateInput(
id=user_id,
email=f"test-{user_id}@example.com",
name=f"Test User {user_id[:8]}",
)
)
except UniqueViolationError:
# User already exists, continue
@@ -33,7 +34,10 @@ async def create_test_user(user_id: str) -> None:
await UserBalance.prisma().upsert(
where={"userId": user_id},
data={"create": {"userId": user_id, "balance": 0}, "update": {"balance": 0}},
data=UserBalanceUpsertInput(
create=UserBalanceCreateInput(userId=user_id, balance=0),
update={"balance": 0},
),
)
@@ -66,14 +70,14 @@ async def test_debug_underflow_step_by_step(server: SpinTestServer):
initial_balance_target = POSTGRES_INT_MIN + 100
# Use direct database update to set the balance close to underflow
from prisma.models import UserBalance
await UserBalance.prisma().upsert(
where={"userId": user_id},
data={
"create": {"userId": user_id, "balance": initial_balance_target},
"update": {"balance": initial_balance_target},
},
data=UserBalanceUpsertInput(
create=UserBalanceCreateInput(
userId=user_id, balance=initial_balance_target
),
update={"balance": initial_balance_target},
),
)
current_balance = await credit_system.get_credits(user_id)
@@ -110,10 +114,10 @@ async def test_debug_underflow_step_by_step(server: SpinTestServer):
# Set balance to exactly POSTGRES_INT_MIN
await UserBalance.prisma().upsert(
where={"userId": user_id},
data={
"create": {"userId": user_id, "balance": POSTGRES_INT_MIN},
"update": {"balance": POSTGRES_INT_MIN},
},
data=UserBalanceUpsertInput(
create=UserBalanceCreateInput(userId=user_id, balance=POSTGRES_INT_MIN),
update={"balance": POSTGRES_INT_MIN},
),
)
edge_balance = await credit_system.get_credits(user_id)
@@ -147,15 +151,13 @@ async def test_underflow_protection_large_refunds(server: SpinTestServer):
# Set up balance close to underflow threshold to test the protection
# Set balance to POSTGRES_INT_MIN + 1000, then try to subtract 2000
# This should trigger underflow protection
from prisma.models import UserBalance
test_balance = POSTGRES_INT_MIN + 1000
await UserBalance.prisma().upsert(
where={"userId": user_id},
data={
"create": {"userId": user_id, "balance": test_balance},
"update": {"balance": test_balance},
},
data=UserBalanceUpsertInput(
create=UserBalanceCreateInput(userId=user_id, balance=test_balance),
update={"balance": test_balance},
),
)
current_balance = await credit_system.get_credits(user_id)
@@ -212,15 +214,13 @@ async def test_multiple_large_refunds_cumulative_underflow(server: SpinTestServe
try:
# Set up balance close to underflow threshold
from prisma.models import UserBalance
initial_balance = POSTGRES_INT_MIN + 500 # Close to minimum but with some room
await UserBalance.prisma().upsert(
where={"userId": user_id},
data={
"create": {"userId": user_id, "balance": initial_balance},
"update": {"balance": initial_balance},
},
data=UserBalanceUpsertInput(
create=UserBalanceCreateInput(userId=user_id, balance=initial_balance),
update={"balance": initial_balance},
),
)
# Apply multiple refunds that would cumulatively underflow
@@ -295,10 +295,10 @@ async def test_concurrent_large_refunds_no_underflow(server: SpinTestServer):
initial_balance = POSTGRES_INT_MIN + 1000 # Close to minimum
await UserBalance.prisma().upsert(
where={"userId": user_id},
data={
"create": {"userId": user_id, "balance": initial_balance},
"update": {"balance": initial_balance},
},
data=UserBalanceUpsertInput(
create=UserBalanceCreateInput(userId=user_id, balance=initial_balance),
update={"balance": initial_balance},
),
)
async def large_refund(amount: int, label: str):

View File

@@ -14,6 +14,7 @@ import pytest
from prisma.enums import CreditTransactionType
from prisma.errors import UniqueViolationError
from prisma.models import CreditTransaction, User, UserBalance
from prisma.types import UserBalanceCreateInput, UserCreateInput
from backend.data.credit import UsageTransactionMetadata, UserCredit
from backend.util.json import SafeJson
@@ -24,11 +25,11 @@ async def create_test_user(user_id: str) -> None:
"""Create a test user for migration tests."""
try:
await User.prisma().create(
data={
"id": user_id,
"email": f"test-{user_id}@example.com",
"name": f"Test User {user_id[:8]}",
}
data=UserCreateInput(
id=user_id,
email=f"test-{user_id}@example.com",
name=f"Test User {user_id[:8]}",
)
)
except UniqueViolationError:
# User already exists, continue
@@ -121,7 +122,7 @@ async def test_detect_stale_user_balance_queries(server: SpinTestServer):
try:
# Create UserBalance with specific value
await UserBalance.prisma().create(
data={"userId": user_id, "balance": 5000} # $50
data=UserBalanceCreateInput(userId=user_id, balance=5000) # $50
)
# Verify that get_credits returns UserBalance value (5000), not any stale User.balance value
@@ -160,7 +161,9 @@ async def test_concurrent_operations_use_userbalance_only(server: SpinTestServer
try:
# Set initial balance in UserBalance
await UserBalance.prisma().create(data={"userId": user_id, "balance": 1000})
await UserBalance.prisma().create(
data=UserBalanceCreateInput(userId=user_id, balance=1000)
)
# Run concurrent operations to ensure they all use UserBalance atomic operations
async def concurrent_spend(amount: int, label: str):

View File

@@ -28,6 +28,7 @@ from prisma.models import (
AgentNodeExecutionKeyValueData,
)
from prisma.types import (
AgentGraphExecutionCreateInput,
AgentGraphExecutionUpdateManyMutationInput,
AgentGraphExecutionWhereInput,
AgentNodeExecutionCreateInput,
@@ -35,7 +36,6 @@ from prisma.types import (
AgentNodeExecutionKeyValueDataCreateInput,
AgentNodeExecutionUpdateInput,
AgentNodeExecutionWhereInput,
AgentNodeExecutionWhereUniqueInput,
)
from pydantic import BaseModel, ConfigDict, JsonValue, ValidationError
from pydantic.fields import Field
@@ -383,7 +383,6 @@ class GraphExecutionWithNodes(GraphExecution):
self,
execution_context: ExecutionContext,
compiled_nodes_input_masks: Optional[NodesInputMasks] = None,
nodes_to_skip: Optional[set[str]] = None,
):
return GraphExecutionEntry(
user_id=self.user_id,
@@ -391,7 +390,6 @@ class GraphExecutionWithNodes(GraphExecution):
graph_version=self.graph_version or 0,
graph_exec_id=self.id,
nodes_input_masks=compiled_nodes_input_masks,
nodes_to_skip=nodes_to_skip or set(),
execution_context=execution_context,
)
@@ -711,18 +709,18 @@ async def create_graph_execution(
The id of the AgentGraphExecution and the list of ExecutionResult for each node.
"""
result = await AgentGraphExecution.prisma().create(
data={
"agentGraphId": graph_id,
"agentGraphVersion": graph_version,
"executionStatus": ExecutionStatus.INCOMPLETE,
"inputs": SafeJson(inputs),
"credentialInputs": (
data=AgentGraphExecutionCreateInput(
agentGraphId=graph_id,
agentGraphVersion=graph_version,
executionStatus=ExecutionStatus.INCOMPLETE,
inputs=SafeJson(inputs),
credentialInputs=(
SafeJson(credential_inputs) if credential_inputs else Json({})
),
"nodesInputMasks": (
nodesInputMasks=(
SafeJson(nodes_input_masks) if nodes_input_masks else Json({})
),
"NodeExecutions": {
NodeExecutions={
"create": [
AgentNodeExecutionCreateInput(
agentNodeId=node_id,
@@ -738,10 +736,10 @@ async def create_graph_execution(
for node_id, node_input in starting_nodes_input
]
},
"userId": user_id,
"agentPresetId": preset_id,
"parentGraphExecutionId": parent_graph_exec_id,
},
userId=user_id,
agentPresetId=preset_id,
parentGraphExecutionId=parent_graph_exec_id,
),
include=GRAPH_EXECUTION_INCLUDE_WITH_NODES,
)
@@ -833,10 +831,10 @@ async def upsert_execution_output(
"""
Insert AgentNodeExecutionInputOutput record for as one of AgentNodeExecution.Output.
"""
data: AgentNodeExecutionInputOutputCreateInput = {
"name": output_name,
"referencedByOutputExecId": node_exec_id,
}
data = AgentNodeExecutionInputOutputCreateInput(
name=output_name,
referencedByOutputExecId=node_exec_id,
)
if output_data is not None:
data["data"] = SafeJson(output_data)
await AgentNodeExecutionInputOutput.prisma().create(data=data)
@@ -966,6 +964,12 @@ async def update_node_execution_status(
execution_data: BlockInput | None = None,
stats: dict[str, Any] | None = None,
) -> NodeExecutionResult:
"""
Update a node execution's status with validation of allowed transitions.
⚠️ Internal executor use only - no user_id check. Callers (executor/manager.py)
are responsible for validating user authorization before invoking this function.
"""
if status == ExecutionStatus.QUEUED and execution_data is None:
raise ValueError("Execution data must be provided when queuing an execution.")
@@ -976,25 +980,27 @@ async def update_node_execution_status(
f"Invalid status transition: {status} has no valid source statuses"
)
if res := await AgentNodeExecution.prisma().update(
where=cast(
AgentNodeExecutionWhereUniqueInput,
{
"id": node_exec_id,
"executionStatus": {"in": [s.value for s in allowed_from]},
},
),
# Fetch current execution to validate status transition
current = await AgentNodeExecution.prisma().find_unique(
where={"id": node_exec_id}, include=EXECUTION_RESULT_INCLUDE
)
if not current:
raise ValueError(f"Execution {node_exec_id} not found.")
# Validate current status allows transition to the new status
if current.executionStatus not in allowed_from:
# Return current state without updating if transition is not allowed
return NodeExecutionResult.from_db(current)
# Perform the update with only the unique identifier
res = await AgentNodeExecution.prisma().update(
where={"id": node_exec_id},
data=_get_update_status_data(status, execution_data, stats),
include=EXECUTION_RESULT_INCLUDE,
):
return NodeExecutionResult.from_db(res)
if res := await AgentNodeExecution.prisma().find_unique(
where={"id": node_exec_id}, include=EXECUTION_RESULT_INCLUDE
):
return NodeExecutionResult.from_db(res)
raise ValueError(f"Execution {node_exec_id} not found.")
)
if not res:
raise ValueError(f"Failed to update execution {node_exec_id}.")
return NodeExecutionResult.from_db(res)
def _get_update_status_data(
@@ -1147,8 +1153,6 @@ class GraphExecutionEntry(BaseModel):
graph_id: str
graph_version: int
nodes_input_masks: Optional[NodesInputMasks] = None
nodes_to_skip: set[str] = Field(default_factory=set)
"""Node IDs that should be skipped due to optional credentials not being configured."""
execution_context: ExecutionContext = Field(default_factory=ExecutionContext)

View File

@@ -94,15 +94,6 @@ class Node(BaseDbModel):
input_links: list[Link] = []
output_links: list[Link] = []
@property
def credentials_optional(self) -> bool:
"""
Whether credentials are optional for this node.
When True and credentials are not configured, the node will be skipped
during execution rather than causing a validation error.
"""
return self.metadata.get("credentials_optional", False)
@property
def block(self) -> AnyBlockSchema | "_UnknownBlockBase":
"""Get the block for this node. Returns UnknownBlock if block is deleted/missing."""
@@ -244,10 +235,7 @@ class BaseGraph(BaseDbModel):
return any(
node.block_id
for node in self.nodes
if (
node.block.block_type == BlockType.HUMAN_IN_THE_LOOP
or node.block.requires_human_review
)
if node.block.block_type == BlockType.HUMAN_IN_THE_LOOP
)
@property
@@ -338,35 +326,7 @@ class Graph(BaseGraph):
@computed_field
@property
def credentials_input_schema(self) -> dict[str, Any]:
schema = self._credentials_input_schema.jsonschema()
# Determine which credential fields are required based on credentials_optional metadata
graph_credentials_inputs = self.aggregate_credentials_inputs()
required_fields = []
# Build a map of node_id -> node for quick lookup
all_nodes = {node.id: node for node in self.nodes}
for sub_graph in self.sub_graphs:
for node in sub_graph.nodes:
all_nodes[node.id] = node
for field_key, (
_field_info,
node_field_pairs,
) in graph_credentials_inputs.items():
# A field is required if ANY node using it has credentials_optional=False
is_required = False
for node_id, _field_name in node_field_pairs:
node = all_nodes.get(node_id)
if node and not node.credentials_optional:
is_required = True
break
if is_required:
required_fields.append(field_key)
schema["required"] = required_fields
return schema
return self._credentials_input_schema.jsonschema()
@property
def _credentials_input_schema(self) -> type[BlockSchema]:

View File

@@ -396,58 +396,3 @@ async def test_access_store_listing_graph(server: SpinTestServer):
created_graph.id, created_graph.version, "3e53486c-cf57-477e-ba2a-cb02dc828e1b"
)
assert got_graph is not None
# ============================================================================
# Tests for Optional Credentials Feature
# ============================================================================
def test_node_credentials_optional_default():
"""Test that credentials_optional defaults to False when not set in metadata."""
node = Node(
id="test_node",
block_id=StoreValueBlock().id,
input_default={},
metadata={},
)
assert node.credentials_optional is False
def test_node_credentials_optional_true():
"""Test that credentials_optional returns True when explicitly set."""
node = Node(
id="test_node",
block_id=StoreValueBlock().id,
input_default={},
metadata={"credentials_optional": True},
)
assert node.credentials_optional is True
def test_node_credentials_optional_false():
"""Test that credentials_optional returns False when explicitly set to False."""
node = Node(
id="test_node",
block_id=StoreValueBlock().id,
input_default={},
metadata={"credentials_optional": False},
)
assert node.credentials_optional is False
def test_node_credentials_optional_with_other_metadata():
"""Test that credentials_optional works correctly with other metadata present."""
node = Node(
id="test_node",
block_id=StoreValueBlock().id,
input_default={},
metadata={
"position": {"x": 100, "y": 200},
"customized_name": "My Custom Node",
"credentials_optional": True,
},
)
assert node.credentials_optional is True
assert node.metadata["position"] == {"x": 100, "y": 200}
assert node.metadata["customized_name"] == "My Custom Node"

View File

@@ -10,7 +10,11 @@ from typing import Optional
from prisma.enums import ReviewStatus
from prisma.models import PendingHumanReview
from prisma.types import PendingHumanReviewUpdateInput
from prisma.types import (
PendingHumanReviewCreateInput,
PendingHumanReviewUpdateInput,
PendingHumanReviewUpsertInput,
)
from pydantic import BaseModel
from backend.api.features.executions.review.model import (
@@ -66,20 +70,20 @@ async def get_or_create_human_review(
# Upsert - get existing or create new review
review = await PendingHumanReview.prisma().upsert(
where={"nodeExecId": node_exec_id},
data={
"create": {
"userId": user_id,
"nodeExecId": node_exec_id,
"graphExecId": graph_exec_id,
"graphId": graph_id,
"graphVersion": graph_version,
"payload": SafeJson(input_data),
"instructions": message,
"editable": editable,
"status": ReviewStatus.WAITING,
},
"update": {}, # Do nothing on update - keep existing review as is
},
data=PendingHumanReviewUpsertInput(
create=PendingHumanReviewCreateInput(
userId=user_id,
nodeExecId=node_exec_id,
graphExecId=graph_exec_id,
graphId=graph_id,
graphVersion=graph_version,
payload=SafeJson(input_data),
instructions=message,
editable=editable,
status=ReviewStatus.WAITING,
),
update={}, # Do nothing on update - keep existing review as is
),
)
logger.info(

View File

@@ -7,7 +7,11 @@ import prisma
import pydantic
from prisma.enums import OnboardingStep
from prisma.models import UserOnboarding
from prisma.types import UserOnboardingCreateInput, UserOnboardingUpdateInput
from prisma.types import (
UserOnboardingCreateInput,
UserOnboardingUpdateInput,
UserOnboardingUpsertInput,
)
from backend.api.features.store.model import StoreAgentDetails
from backend.api.model import OnboardingNotificationPayload
@@ -92,6 +96,7 @@ async def reset_user_onboarding(user_id: str):
async def update_user_onboarding(user_id: str, data: UserOnboardingUpdate):
update: UserOnboardingUpdateInput = {}
# get_user_onboarding guarantees the record exists via upsert
onboarding = await get_user_onboarding(user_id)
if data.walletShown:
update["walletShown"] = data.walletShown
@@ -110,12 +115,14 @@ async def update_user_onboarding(user_id: str, data: UserOnboardingUpdate):
if data.onboardingAgentExecutionId is not None:
update["onboardingAgentExecutionId"] = data.onboardingAgentExecutionId
# The create branch is never taken since get_user_onboarding ensures the record exists,
# but upsert requires a create payload so we provide a minimal one
return await UserOnboarding.prisma().upsert(
where={"userId": user_id},
data={
"create": {"userId": user_id, **update},
"update": update,
},
data=UserOnboardingUpsertInput(
create=UserOnboardingCreateInput(userId=user_id),
update=update,
),
)

View File

@@ -178,7 +178,6 @@ async def execute_node(
execution_processor: "ExecutionProcessor",
execution_stats: NodeExecutionStats | None = None,
nodes_input_masks: Optional[NodesInputMasks] = None,
nodes_to_skip: Optional[set[str]] = None,
) -> BlockOutput:
"""
Execute a node in the graph. This will trigger a block execution on a node,
@@ -246,7 +245,6 @@ async def execute_node(
"user_id": user_id,
"execution_context": execution_context,
"execution_processor": execution_processor,
"nodes_to_skip": nodes_to_skip or set(),
}
# Last-minute fetch credentials + acquire a system-wide read-write lock to prevent
@@ -544,7 +542,6 @@ class ExecutionProcessor:
node_exec_progress: NodeExecutionProgress,
nodes_input_masks: Optional[NodesInputMasks],
graph_stats_pair: tuple[GraphExecutionStats, threading.Lock],
nodes_to_skip: Optional[set[str]] = None,
) -> NodeExecutionStats:
log_metadata = LogMetadata(
logger=_logger,
@@ -567,7 +564,6 @@ class ExecutionProcessor:
db_client=db_client,
log_metadata=log_metadata,
nodes_input_masks=nodes_input_masks,
nodes_to_skip=nodes_to_skip,
)
if isinstance(status, BaseException):
raise status
@@ -613,7 +609,6 @@ class ExecutionProcessor:
db_client: "DatabaseManagerAsyncClient",
log_metadata: LogMetadata,
nodes_input_masks: Optional[NodesInputMasks] = None,
nodes_to_skip: Optional[set[str]] = None,
) -> ExecutionStatus:
status = ExecutionStatus.RUNNING
@@ -650,7 +645,6 @@ class ExecutionProcessor:
execution_processor=self,
execution_stats=stats,
nodes_input_masks=nodes_input_masks,
nodes_to_skip=nodes_to_skip,
):
await persist_output(output_name, output_data)
@@ -962,21 +956,6 @@ class ExecutionProcessor:
queued_node_exec = execution_queue.get()
# Check if this node should be skipped due to optional credentials
if queued_node_exec.node_id in graph_exec.nodes_to_skip:
log_metadata.info(
f"Skipping node execution {queued_node_exec.node_exec_id} "
f"for node {queued_node_exec.node_id} - optional credentials not configured"
)
# Mark the node as completed without executing
# No outputs will be produced, so downstream nodes won't trigger
update_node_execution_status(
db_client=db_client,
exec_id=queued_node_exec.node_exec_id,
status=ExecutionStatus.COMPLETED,
)
continue
log_metadata.debug(
f"Dispatching node execution {queued_node_exec.node_exec_id} "
f"for node {queued_node_exec.node_id}",
@@ -1037,7 +1016,6 @@ class ExecutionProcessor:
execution_stats,
execution_stats_lock,
),
nodes_to_skip=graph_exec.nodes_to_skip,
),
self.node_execution_loop,
)

View File

@@ -239,19 +239,14 @@ async def _validate_node_input_credentials(
graph: GraphModel,
user_id: str,
nodes_input_masks: Optional[NodesInputMasks] = None,
) -> tuple[dict[str, dict[str, str]], set[str]]:
) -> dict[str, dict[str, str]]:
"""
Checks all credentials for all nodes of the graph and returns structured errors
and a set of nodes that should be skipped due to optional missing credentials.
Checks all credentials for all nodes of the graph and returns structured errors.
Returns:
tuple[
dict[node_id, dict[field_name, error_message]]: Credential validation errors per node,
set[node_id]: Nodes that should be skipped (optional credentials not configured)
]
dict[node_id, dict[field_name, error_message]]: Credential validation errors per node
"""
credential_errors: dict[str, dict[str, str]] = defaultdict(dict)
nodes_to_skip: set[str] = set()
for node in graph.nodes:
block = node.block
@@ -261,46 +256,27 @@ async def _validate_node_input_credentials(
if not credentials_fields:
continue
# Track if any credential field is missing for this node
has_missing_credentials = False
for field_name, credentials_meta_type in credentials_fields.items():
try:
# Check nodes_input_masks first, then input_default
field_value = None
if (
nodes_input_masks
and (node_input_mask := nodes_input_masks.get(node.id))
and field_name in node_input_mask
):
field_value = node_input_mask[field_name]
credentials_meta = credentials_meta_type.model_validate(
node_input_mask[field_name]
)
elif field_name in node.input_default:
# For optional credentials, don't use input_default - treat as missing
# This prevents stale credential IDs from failing validation
if node.credentials_optional:
field_value = None
else:
field_value = node.input_default[field_name]
# Check if credentials are missing (None, empty, or not present)
if field_value is None or (
isinstance(field_value, dict) and not field_value.get("id")
):
has_missing_credentials = True
# If node has credentials_optional flag, mark for skipping instead of error
if node.credentials_optional:
continue # Don't add error, will be marked for skip after loop
else:
credential_errors[node.id][
field_name
] = "These credentials are required"
continue
credentials_meta = credentials_meta_type.model_validate(field_value)
credentials_meta = credentials_meta_type.model_validate(
node.input_default[field_name]
)
else:
# Missing credentials
credential_errors[node.id][
field_name
] = "These credentials are required"
continue
except ValidationError as e:
# Validation error means credentials were provided but invalid
# This should always be an error, even if optional
credential_errors[node.id][field_name] = f"Invalid credentials: {e}"
continue
@@ -311,7 +287,6 @@ async def _validate_node_input_credentials(
)
except Exception as e:
# Handle any errors fetching credentials
# If credentials were explicitly configured but unavailable, it's an error
credential_errors[node.id][
field_name
] = f"Credentials not available: {e}"
@@ -338,19 +313,7 @@ async def _validate_node_input_credentials(
] = "Invalid credentials: type/provider mismatch"
continue
# If node has optional credentials and any are missing, mark for skipping
# But only if there are no other errors for this node
if (
has_missing_credentials
and node.credentials_optional
and node.id not in credential_errors
):
nodes_to_skip.add(node.id)
logger.info(
f"Node #{node.id} will be skipped: optional credentials not configured"
)
return credential_errors, nodes_to_skip
return credential_errors
def make_node_credentials_input_map(
@@ -392,25 +355,21 @@ async def validate_graph_with_credentials(
graph: GraphModel,
user_id: str,
nodes_input_masks: Optional[NodesInputMasks] = None,
) -> tuple[Mapping[str, Mapping[str, str]], set[str]]:
) -> Mapping[str, Mapping[str, str]]:
"""
Validate graph including credentials and return structured errors per node,
along with a set of nodes that should be skipped due to optional missing credentials.
Validate graph including credentials and return structured errors per node.
Returns:
tuple[
dict[node_id, dict[field_name, error_message]]: Validation errors per node,
set[node_id]: Nodes that should be skipped (optional credentials not configured)
]
dict[node_id, dict[field_name, error_message]]: Validation errors per node
"""
# Get input validation errors
node_input_errors = GraphModel.validate_graph_get_errors(
graph, for_run=True, nodes_input_masks=nodes_input_masks
)
# Get credential input/availability/validation errors and nodes to skip
node_credential_input_errors, nodes_to_skip = (
await _validate_node_input_credentials(graph, user_id, nodes_input_masks)
# Get credential input/availability/validation errors
node_credential_input_errors = await _validate_node_input_credentials(
graph, user_id, nodes_input_masks
)
# Merge credential errors with structural errors
@@ -419,7 +378,7 @@ async def validate_graph_with_credentials(
node_input_errors[node_id] = {}
node_input_errors[node_id].update(field_errors)
return node_input_errors, nodes_to_skip
return node_input_errors
async def _construct_starting_node_execution_input(
@@ -427,7 +386,7 @@ async def _construct_starting_node_execution_input(
user_id: str,
graph_inputs: BlockInput,
nodes_input_masks: Optional[NodesInputMasks] = None,
) -> tuple[list[tuple[str, BlockInput]], set[str]]:
) -> list[tuple[str, BlockInput]]:
"""
Validates and prepares the input data for executing a graph.
This function checks the graph for starting nodes, validates the input data
@@ -441,14 +400,11 @@ async def _construct_starting_node_execution_input(
node_credentials_map: `dict[node_id, dict[input_name, CredentialsMetaInput]]`
Returns:
tuple[
list[tuple[str, BlockInput]]: A list of tuples, each containing the node ID
and the corresponding input data for that node.
set[str]: Node IDs that should be skipped (optional credentials not configured)
]
list[tuple[str, BlockInput]]: A list of tuples, each containing the node ID and
the corresponding input data for that node.
"""
# Use new validation function that includes credentials
validation_errors, nodes_to_skip = await validate_graph_with_credentials(
validation_errors = await validate_graph_with_credentials(
graph, user_id, nodes_input_masks
)
n_error_nodes = len(validation_errors)
@@ -489,7 +445,7 @@ async def _construct_starting_node_execution_input(
"No starting nodes found for the graph, make sure an AgentInput or blocks with no inbound links are present as starting nodes."
)
return nodes_input, nodes_to_skip
return nodes_input
async def validate_and_construct_node_execution_input(
@@ -500,7 +456,7 @@ async def validate_and_construct_node_execution_input(
graph_credentials_inputs: Optional[Mapping[str, CredentialsMetaInput]] = None,
nodes_input_masks: Optional[NodesInputMasks] = None,
is_sub_graph: bool = False,
) -> tuple[GraphModel, list[tuple[str, BlockInput]], NodesInputMasks, set[str]]:
) -> tuple[GraphModel, list[tuple[str, BlockInput]], NodesInputMasks]:
"""
Public wrapper that handles graph fetching, credential mapping, and validation+construction.
This centralizes the logic used by both scheduler validation and actual execution.
@@ -517,7 +473,6 @@ async def validate_and_construct_node_execution_input(
GraphModel: Full graph object for the given `graph_id`.
list[tuple[node_id, BlockInput]]: Starting node IDs with corresponding inputs.
dict[str, BlockInput]: Node input masks including all passed-in credentials.
set[str]: Node IDs that should be skipped (optional credentials not configured).
Raises:
NotFoundError: If the graph is not found.
@@ -559,16 +514,14 @@ async def validate_and_construct_node_execution_input(
nodes_input_masks or {},
)
starting_nodes_input, nodes_to_skip = (
await _construct_starting_node_execution_input(
graph=graph,
user_id=user_id,
graph_inputs=graph_inputs,
nodes_input_masks=nodes_input_masks,
)
starting_nodes_input = await _construct_starting_node_execution_input(
graph=graph,
user_id=user_id,
graph_inputs=graph_inputs,
nodes_input_masks=nodes_input_masks,
)
return graph, starting_nodes_input, nodes_input_masks, nodes_to_skip
return graph, starting_nodes_input, nodes_input_masks
def _merge_nodes_input_masks(
@@ -826,9 +779,6 @@ async def add_graph_execution(
# Use existing execution's compiled input masks
compiled_nodes_input_masks = graph_exec.nodes_input_masks or {}
# For resumed executions, nodes_to_skip was already determined at creation time
# TODO: Consider storing nodes_to_skip in DB if we need to preserve it across resumes
nodes_to_skip: set[str] = set()
logger.info(f"Resuming graph execution #{graph_exec.id} for graph #{graph_id}")
else:
@@ -837,7 +787,7 @@ async def add_graph_execution(
)
# Create new execution
graph, starting_nodes_input, compiled_nodes_input_masks, nodes_to_skip = (
graph, starting_nodes_input, compiled_nodes_input_masks = (
await validate_and_construct_node_execution_input(
graph_id=graph_id,
user_id=user_id,
@@ -886,7 +836,6 @@ async def add_graph_execution(
try:
graph_exec_entry = graph_exec.to_graph_execution_entry(
compiled_nodes_input_masks=compiled_nodes_input_masks,
nodes_to_skip=nodes_to_skip,
execution_context=execution_context,
)
logger.info(f"Publishing execution {graph_exec.id} to execution queue")

View File

@@ -367,13 +367,10 @@ async def test_add_graph_execution_is_repeatable(mocker: MockerFixture):
)
# Setup mock returns
# The function returns (graph, starting_nodes_input, compiled_nodes_input_masks, nodes_to_skip)
nodes_to_skip: set[str] = set()
mock_validate.return_value = (
mock_graph,
starting_nodes_input,
compiled_nodes_input_masks,
nodes_to_skip,
)
mock_prisma.is_connected.return_value = True
mock_edb.create_graph_execution = mocker.AsyncMock(return_value=mock_graph_exec)
@@ -459,212 +456,3 @@ async def test_add_graph_execution_is_repeatable(mocker: MockerFixture):
# Both executions should succeed (though they create different objects)
assert result1 == mock_graph_exec
assert result2 == mock_graph_exec_2
# ============================================================================
# Tests for Optional Credentials Feature
# ============================================================================
@pytest.mark.asyncio
async def test_validate_node_input_credentials_returns_nodes_to_skip(
mocker: MockerFixture,
):
"""
Test that _validate_node_input_credentials returns nodes_to_skip set
for nodes with credentials_optional=True and missing credentials.
"""
from backend.executor.utils import _validate_node_input_credentials
# Create a mock node with credentials_optional=True
mock_node = mocker.MagicMock()
mock_node.id = "node-with-optional-creds"
mock_node.credentials_optional = True
mock_node.input_default = {} # No credentials configured
# Create a mock block with credentials field
mock_block = mocker.MagicMock()
mock_credentials_field_type = mocker.MagicMock()
mock_block.input_schema.get_credentials_fields.return_value = {
"credentials": mock_credentials_field_type
}
mock_node.block = mock_block
# Create mock graph
mock_graph = mocker.MagicMock()
mock_graph.nodes = [mock_node]
# Call the function
errors, nodes_to_skip = await _validate_node_input_credentials(
graph=mock_graph,
user_id="test-user-id",
nodes_input_masks=None,
)
# Node should be in nodes_to_skip, not in errors
assert mock_node.id in nodes_to_skip
assert mock_node.id not in errors
@pytest.mark.asyncio
async def test_validate_node_input_credentials_required_missing_creds_error(
mocker: MockerFixture,
):
"""
Test that _validate_node_input_credentials returns errors
for nodes with credentials_optional=False and missing credentials.
"""
from backend.executor.utils import _validate_node_input_credentials
# Create a mock node with credentials_optional=False (required)
mock_node = mocker.MagicMock()
mock_node.id = "node-with-required-creds"
mock_node.credentials_optional = False
mock_node.input_default = {} # No credentials configured
# Create a mock block with credentials field
mock_block = mocker.MagicMock()
mock_credentials_field_type = mocker.MagicMock()
mock_block.input_schema.get_credentials_fields.return_value = {
"credentials": mock_credentials_field_type
}
mock_node.block = mock_block
# Create mock graph
mock_graph = mocker.MagicMock()
mock_graph.nodes = [mock_node]
# Call the function
errors, nodes_to_skip = await _validate_node_input_credentials(
graph=mock_graph,
user_id="test-user-id",
nodes_input_masks=None,
)
# Node should be in errors, not in nodes_to_skip
assert mock_node.id in errors
assert "credentials" in errors[mock_node.id]
assert "required" in errors[mock_node.id]["credentials"].lower()
assert mock_node.id not in nodes_to_skip
@pytest.mark.asyncio
async def test_validate_graph_with_credentials_returns_nodes_to_skip(
mocker: MockerFixture,
):
"""
Test that validate_graph_with_credentials returns nodes_to_skip set
from _validate_node_input_credentials.
"""
from backend.executor.utils import validate_graph_with_credentials
# Mock _validate_node_input_credentials to return specific values
mock_validate = mocker.patch(
"backend.executor.utils._validate_node_input_credentials"
)
expected_errors = {"node1": {"field": "error"}}
expected_nodes_to_skip = {"node2", "node3"}
mock_validate.return_value = (expected_errors, expected_nodes_to_skip)
# Mock GraphModel with validate_graph_get_errors method
mock_graph = mocker.MagicMock()
mock_graph.validate_graph_get_errors.return_value = {}
# Call the function
errors, nodes_to_skip = await validate_graph_with_credentials(
graph=mock_graph,
user_id="test-user-id",
nodes_input_masks=None,
)
# Verify nodes_to_skip is passed through
assert nodes_to_skip == expected_nodes_to_skip
assert "node1" in errors
@pytest.mark.asyncio
async def test_add_graph_execution_with_nodes_to_skip(mocker: MockerFixture):
"""
Test that add_graph_execution properly passes nodes_to_skip
to the graph execution entry.
"""
from backend.data.execution import GraphExecutionWithNodes
from backend.executor.utils import add_graph_execution
# Mock data
graph_id = "test-graph-id"
user_id = "test-user-id"
inputs = {"test_input": "test_value"}
graph_version = 1
# Mock the graph object
mock_graph = mocker.MagicMock()
mock_graph.version = graph_version
# Starting nodes and masks
starting_nodes_input = [("node1", {"input1": "value1"})]
compiled_nodes_input_masks = {}
nodes_to_skip = {"skipped-node-1", "skipped-node-2"}
# Mock the graph execution object
mock_graph_exec = mocker.MagicMock(spec=GraphExecutionWithNodes)
mock_graph_exec.id = "execution-id-123"
mock_graph_exec.node_executions = []
# Track what's passed to to_graph_execution_entry
captured_kwargs = {}
def capture_to_entry(**kwargs):
captured_kwargs.update(kwargs)
return mocker.MagicMock()
mock_graph_exec.to_graph_execution_entry.side_effect = capture_to_entry
# Setup mocks
mock_validate = mocker.patch(
"backend.executor.utils.validate_and_construct_node_execution_input"
)
mock_edb = mocker.patch("backend.executor.utils.execution_db")
mock_prisma = mocker.patch("backend.executor.utils.prisma")
mock_udb = mocker.patch("backend.executor.utils.user_db")
mock_gdb = mocker.patch("backend.executor.utils.graph_db")
mock_get_queue = mocker.patch("backend.executor.utils.get_async_execution_queue")
mock_get_event_bus = mocker.patch(
"backend.executor.utils.get_async_execution_event_bus"
)
# Setup returns - include nodes_to_skip in the tuple
mock_validate.return_value = (
mock_graph,
starting_nodes_input,
compiled_nodes_input_masks,
nodes_to_skip, # This should be passed through
)
mock_prisma.is_connected.return_value = True
mock_edb.create_graph_execution = mocker.AsyncMock(return_value=mock_graph_exec)
mock_edb.update_graph_execution_stats = mocker.AsyncMock(
return_value=mock_graph_exec
)
mock_edb.update_node_execution_status_batch = mocker.AsyncMock()
mock_user = mocker.MagicMock()
mock_user.timezone = "UTC"
mock_settings = mocker.MagicMock()
mock_settings.human_in_the_loop_safe_mode = True
mock_udb.get_user_by_id = mocker.AsyncMock(return_value=mock_user)
mock_gdb.get_graph_settings = mocker.AsyncMock(return_value=mock_settings)
mock_get_queue.return_value = mocker.AsyncMock()
mock_get_event_bus.return_value = mocker.MagicMock(publish=mocker.AsyncMock())
# Call the function
await add_graph_execution(
graph_id=graph_id,
user_id=user_id,
inputs=inputs,
graph_version=graph_version,
)
# Verify nodes_to_skip was passed to to_graph_execution_entry
assert "nodes_to_skip" in captured_kwargs
assert captured_kwargs["nodes_to_skip"] == nodes_to_skip

View File

@@ -8,7 +8,6 @@ from .discord import DiscordOAuthHandler
from .github import GitHubOAuthHandler
from .google import GoogleOAuthHandler
from .notion import NotionOAuthHandler
from .reddit import RedditOAuthHandler
from .twitter import TwitterOAuthHandler
if TYPE_CHECKING:
@@ -21,7 +20,6 @@ _ORIGINAL_HANDLERS = [
GitHubOAuthHandler,
GoogleOAuthHandler,
NotionOAuthHandler,
RedditOAuthHandler,
TwitterOAuthHandler,
TodoistOAuthHandler,
]

View File

@@ -1,208 +0,0 @@
import time
import urllib.parse
from typing import ClassVar, Optional
from pydantic import SecretStr
from backend.data.model import OAuth2Credentials
from backend.integrations.oauth.base import BaseOAuthHandler
from backend.integrations.providers import ProviderName
from backend.util.request import Requests
from backend.util.settings import Settings
settings = Settings()
class RedditOAuthHandler(BaseOAuthHandler):
"""
Reddit OAuth 2.0 handler.
Based on the documentation at:
- https://github.com/reddit-archive/reddit/wiki/OAuth2
Notes:
- Reddit requires `duration=permanent` to get refresh tokens
- Access tokens expire after 1 hour (3600 seconds)
- Reddit requires HTTP Basic Auth for token requests
- Reddit requires a unique User-Agent header
"""
PROVIDER_NAME = ProviderName.REDDIT
DEFAULT_SCOPES: ClassVar[list[str]] = [
"identity", # Get username, verify auth
"read", # Access posts and comments
"submit", # Submit new posts and comments
"edit", # Edit own posts and comments
"history", # Access user's post history
"privatemessages", # Access inbox and send private messages
"flair", # Access and set flair on posts/subreddits
]
AUTHORIZE_URL = "https://www.reddit.com/api/v1/authorize"
TOKEN_URL = "https://www.reddit.com/api/v1/access_token"
USERNAME_URL = "https://oauth.reddit.com/api/v1/me"
REVOKE_URL = "https://www.reddit.com/api/v1/revoke_token"
def __init__(self, client_id: str, client_secret: str, redirect_uri: str):
self.client_id = client_id
self.client_secret = client_secret
self.redirect_uri = redirect_uri
def get_login_url(
self, scopes: list[str], state: str, code_challenge: Optional[str]
) -> str:
"""Generate Reddit OAuth 2.0 authorization URL"""
scopes = self.handle_default_scopes(scopes)
params = {
"response_type": "code",
"client_id": self.client_id,
"redirect_uri": self.redirect_uri,
"scope": " ".join(scopes),
"state": state,
"duration": "permanent", # Required for refresh tokens
}
return f"{self.AUTHORIZE_URL}?{urllib.parse.urlencode(params)}"
async def exchange_code_for_tokens(
self, code: str, scopes: list[str], code_verifier: Optional[str]
) -> OAuth2Credentials:
"""Exchange authorization code for access tokens"""
scopes = self.handle_default_scopes(scopes)
headers = {
"Content-Type": "application/x-www-form-urlencoded",
"User-Agent": settings.config.reddit_user_agent,
}
data = {
"grant_type": "authorization_code",
"code": code,
"redirect_uri": self.redirect_uri,
}
# Reddit requires HTTP Basic Auth for token requests
auth = (self.client_id, self.client_secret)
response = await Requests().post(
self.TOKEN_URL, headers=headers, data=data, auth=auth
)
if not response.ok:
error_text = response.text()
raise ValueError(
f"Reddit token exchange failed: {response.status} - {error_text}"
)
tokens = response.json()
if "error" in tokens:
raise ValueError(f"Reddit OAuth error: {tokens.get('error')}")
username = await self._get_username(tokens["access_token"])
return OAuth2Credentials(
provider=self.PROVIDER_NAME,
title=None,
username=username,
access_token=tokens["access_token"],
refresh_token=tokens.get("refresh_token"),
access_token_expires_at=int(time.time()) + tokens.get("expires_in", 3600),
refresh_token_expires_at=None, # Reddit refresh tokens don't expire
scopes=scopes,
)
async def _get_username(self, access_token: str) -> str:
"""Get the username from the access token"""
headers = {
"Authorization": f"Bearer {access_token}",
"User-Agent": settings.config.reddit_user_agent,
}
response = await Requests().get(self.USERNAME_URL, headers=headers)
if not response.ok:
raise ValueError(f"Failed to get Reddit username: {response.status}")
data = response.json()
return data.get("name", "unknown")
async def _refresh_tokens(
self, credentials: OAuth2Credentials
) -> OAuth2Credentials:
"""Refresh access tokens using refresh token"""
if not credentials.refresh_token:
raise ValueError("No refresh token available")
headers = {
"Content-Type": "application/x-www-form-urlencoded",
"User-Agent": settings.config.reddit_user_agent,
}
data = {
"grant_type": "refresh_token",
"refresh_token": credentials.refresh_token.get_secret_value(),
}
auth = (self.client_id, self.client_secret)
response = await Requests().post(
self.TOKEN_URL, headers=headers, data=data, auth=auth
)
if not response.ok:
error_text = response.text()
raise ValueError(
f"Reddit token refresh failed: {response.status} - {error_text}"
)
tokens = response.json()
if "error" in tokens:
raise ValueError(f"Reddit OAuth error: {tokens.get('error')}")
username = await self._get_username(tokens["access_token"])
# Reddit may or may not return a new refresh token
new_refresh_token = tokens.get("refresh_token")
if new_refresh_token:
refresh_token: SecretStr | None = SecretStr(new_refresh_token)
elif credentials.refresh_token:
# Keep the existing refresh token
refresh_token = credentials.refresh_token
else:
refresh_token = None
return OAuth2Credentials(
id=credentials.id,
provider=self.PROVIDER_NAME,
title=credentials.title,
username=username,
access_token=tokens["access_token"],
refresh_token=refresh_token,
access_token_expires_at=int(time.time()) + tokens.get("expires_in", 3600),
refresh_token_expires_at=None,
scopes=credentials.scopes,
)
async def revoke_tokens(self, credentials: OAuth2Credentials) -> bool:
"""Revoke the access token"""
headers = {
"Content-Type": "application/x-www-form-urlencoded",
"User-Agent": settings.config.reddit_user_agent,
}
data = {
"token": credentials.access_token.get_secret_value(),
"token_type_hint": "access_token",
}
auth = (self.client_id, self.client_secret)
response = await Requests().post(
self.REVOKE_URL, headers=headers, data=data, auth=auth
)
# Reddit returns 204 No Content on successful revocation
return response.ok

View File

@@ -264,7 +264,7 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
)
reddit_user_agent: str = Field(
default="web:AutoGPT:v0.6.0 (by /u/autogpt)",
default="AutoGPT:1.0 (by /u/autogpt)",
description="The user agent for the Reddit API",
)

View File

@@ -1,227 +0,0 @@
#!/usr/bin/env python3
"""
Generate a lightweight stub for prisma/types.py that collapses all exported
symbols to Any. This prevents Pyright from spending time/budget on Prisma's
query DSL types while keeping runtime behavior unchanged.
Usage:
poetry run gen-prisma-stub
This script automatically finds the prisma package location and generates
the types.pyi stub file in the same directory as types.py.
"""
from __future__ import annotations
import ast
import importlib.util
import sys
from pathlib import Path
from typing import Iterable, Set
def _iter_assigned_names(target: ast.expr) -> Iterable[str]:
"""Extract names from assignment targets (handles tuple unpacking)."""
if isinstance(target, ast.Name):
yield target.id
elif isinstance(target, (ast.Tuple, ast.List)):
for elt in target.elts:
yield from _iter_assigned_names(elt)
def _is_private(name: str) -> bool:
"""Check if a name is private (starts with _ but not __)."""
return name.startswith("_") and not name.startswith("__")
def _is_safe_type_alias(node: ast.Assign) -> bool:
"""Check if an assignment is a safe type alias that shouldn't be stubbed.
Safe types are:
- Literal types (don't cause type budget issues)
- Simple type references (SortMode, SortOrder, etc.)
- TypeVar definitions
"""
if not node.value:
return False
# Check if it's a Subscript (like Literal[...], Union[...], TypeVar[...])
if isinstance(node.value, ast.Subscript):
# Get the base type name
if isinstance(node.value.value, ast.Name):
base_name = node.value.value.id
# Literal types are safe
if base_name == "Literal":
return True
# TypeVar is safe
if base_name == "TypeVar":
return True
elif isinstance(node.value.value, ast.Attribute):
# Handle typing_extensions.Literal etc.
if node.value.value.attr == "Literal":
return True
# Check if it's a simple Name reference (like SortMode = _types.SortMode)
if isinstance(node.value, ast.Attribute):
return True
# Check if it's a Call (like TypeVar(...))
if isinstance(node.value, ast.Call):
if isinstance(node.value.func, ast.Name):
if node.value.func.id == "TypeVar":
return True
return False
def collect_top_level_symbols(
tree: ast.Module, source_lines: list[str]
) -> tuple[Set[str], Set[str], list[str], Set[str]]:
"""Collect all top-level symbols from an AST module.
Returns:
Tuple of (class_names, function_names, safe_variable_sources, unsafe_variable_names)
safe_variable_sources contains the actual source code lines for safe variables
"""
classes: Set[str] = set()
functions: Set[str] = set()
safe_variable_sources: list[str] = []
unsafe_variables: Set[str] = set()
for node in tree.body:
if isinstance(node, ast.ClassDef):
if not _is_private(node.name):
classes.add(node.name)
elif isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
if not _is_private(node.name):
functions.add(node.name)
elif isinstance(node, ast.Assign):
is_safe = _is_safe_type_alias(node)
names = []
for t in node.targets:
for n in _iter_assigned_names(t):
if not _is_private(n):
names.append(n)
if names:
if is_safe:
# Extract the source code for this assignment
start_line = node.lineno - 1 # 0-indexed
end_line = node.end_lineno if node.end_lineno else node.lineno
source = "\n".join(source_lines[start_line:end_line])
safe_variable_sources.append(source)
else:
unsafe_variables.update(names)
elif isinstance(node, ast.AnnAssign) and node.target:
# Annotated assignments are always stubbed
for n in _iter_assigned_names(node.target):
if not _is_private(n):
unsafe_variables.add(n)
return classes, functions, safe_variable_sources, unsafe_variables
def find_prisma_types_path() -> Path:
"""Find the prisma types.py file in the installed package."""
spec = importlib.util.find_spec("prisma")
if spec is None or spec.origin is None:
raise RuntimeError("Could not find prisma package. Is it installed?")
prisma_dir = Path(spec.origin).parent
types_path = prisma_dir / "types.py"
if not types_path.exists():
raise RuntimeError(f"prisma/types.py not found at {types_path}")
return types_path
def generate_stub(src_path: Path, stub_path: Path) -> int:
"""Generate the .pyi stub file from the source types.py."""
code = src_path.read_text(encoding="utf-8", errors="ignore")
source_lines = code.splitlines()
tree = ast.parse(code, filename=str(src_path))
classes, functions, safe_variable_sources, unsafe_variables = (
collect_top_level_symbols(tree, source_lines)
)
header = """\
# -*- coding: utf-8 -*-
# Auto-generated stub file - DO NOT EDIT
# Generated by gen_prisma_types_stub.py
#
# This stub intentionally collapses complex Prisma query DSL types to Any.
# Prisma's generated types can explode Pyright's type inference budgets
# on large schemas. We collapse them to Any so the rest of the codebase
# can remain strongly typed while keeping runtime behavior unchanged.
#
# Safe types (Literal, TypeVar, simple references) are preserved from the
# original types.py to maintain proper type checking where possible.
from __future__ import annotations
from typing import Any
from typing_extensions import Literal
# Re-export commonly used typing constructs that may be imported from this module
from typing import TYPE_CHECKING, TypeVar, Generic, Union, Optional, List, Dict
# Base type alias for stubbed Prisma types - allows any dict structure
_PrismaDict = dict[str, Any]
"""
lines = [header]
# Include safe variable definitions (Literal types, TypeVars, etc.)
lines.append("# Safe type definitions preserved from original types.py")
for source in safe_variable_sources:
lines.append(source)
lines.append("")
# Stub all classes and unsafe variables uniformly as dict[str, Any] aliases
# This allows:
# 1. Use in type annotations: x: SomeType
# 2. Constructor calls: SomeType(...)
# 3. Dict literal assignments: x: SomeType = {...}
lines.append(
"# Stubbed types (collapsed to dict[str, Any] to prevent type budget exhaustion)"
)
all_stubbed = sorted(classes | unsafe_variables)
for name in all_stubbed:
lines.append(f"{name} = _PrismaDict")
lines.append("")
# Stub functions
for name in sorted(functions):
lines.append(f"def {name}(*args: Any, **kwargs: Any) -> Any: ...")
lines.append("")
stub_path.write_text("\n".join(lines), encoding="utf-8")
return (
len(classes)
+ len(functions)
+ len(safe_variable_sources)
+ len(unsafe_variables)
)
def main() -> None:
"""Main entry point."""
try:
types_path = find_prisma_types_path()
stub_path = types_path.with_suffix(".pyi")
print(f"Found prisma types.py at: {types_path}")
print(f"Generating stub at: {stub_path}")
num_symbols = generate_stub(types_path, stub_path)
print(f"Generated {stub_path.name} with {num_symbols} Any-typed symbols")
except Exception as e:
print(f"Error: {e}", file=sys.stderr)
sys.exit(1)
if __name__ == "__main__":
main()

View File

@@ -25,9 +25,6 @@ def run(*command: str) -> None:
def lint():
# Generate Prisma types stub before running pyright to prevent type budget exhaustion
run("gen-prisma-stub")
lint_step_args: list[list[str]] = [
["ruff", "check", *TARGET_DIRS, "--exit-zero"],
["ruff", "format", "--diff", "--check", LIBS_DIR],
@@ -52,6 +49,4 @@ def format():
run("ruff", "format", LIBS_DIR)
run("isort", "--profile", "black", BACKEND_DIR)
run("black", BACKEND_DIR)
# Generate Prisma types stub before running pyright to prevent type budget exhaustion
run("gen-prisma-stub")
run("pyright", *TARGET_DIRS)

View File

@@ -117,7 +117,6 @@ lint = "linter:lint"
test = "run_tests:test"
load-store-agents = "test.load_store_agents:run"
export-api-schema = "backend.cli.generate_openapi_json:main"
gen-prisma-stub = "gen_prisma_types_stub:main"
oauth-tool = "backend.cli.oauth_tool:cli"
[tool.isort]
@@ -135,9 +134,6 @@ ignore_patterns = []
[tool.pytest.ini_options]
asyncio_mode = "auto"
asyncio_default_fixture_loop_scope = "session"
# Disable syrupy plugin to avoid conflict with pytest-snapshot
# Both provide --snapshot-update argument causing ArgumentError
addopts = "-p no:syrupy"
filterwarnings = [
"ignore:'audioop' is deprecated:DeprecationWarning:discord.player",
"ignore:invalid escape sequence:DeprecationWarning:tweepy.api",

View File

@@ -2,7 +2,6 @@
"created_at": "2025-09-04T13:37:00",
"credentials_input_schema": {
"properties": {},
"required": [],
"title": "TestGraphCredentialsInputSchema",
"type": "object"
},

View File

@@ -2,7 +2,6 @@
{
"credentials_input_schema": {
"properties": {},
"required": [],
"title": "TestGraphCredentialsInputSchema",
"type": "object"
},

View File

@@ -4,7 +4,6 @@
"id": "test-agent-1",
"graph_id": "test-agent-1",
"graph_version": 1,
"owner_user_id": "3e53486c-cf57-477e-ba2a-cb02dc828e1a",
"image_url": null,
"creator_name": "Test Creator",
"creator_image_url": "",
@@ -42,7 +41,6 @@
"id": "test-agent-2",
"graph_id": "test-agent-2",
"graph_version": 1,
"owner_user_id": "3e53486c-cf57-477e-ba2a-cb02dc828e1a",
"image_url": null,
"creator_name": "Test Creator",
"creator_image_url": "",

View File

@@ -1,7 +1,6 @@
{
"submissions": [
{
"listing_id": "test-listing-id",
"agent_id": "test-agent-id",
"agent_version": 1,
"name": "Test Agent",

View File

@@ -22,6 +22,7 @@ import random
from typing import Any, Dict, List
from faker import Faker
from prisma.types import AgentBlockCreateInput
# Import API functions from the backend
from backend.api.features.library.db import create_library_agent, create_preset
@@ -179,12 +180,12 @@ class TestDataCreator:
for block in blocks_to_create:
try:
await prisma.agentblock.create(
data={
"id": block.id,
"name": block.name,
"inputSchema": "{}",
"outputSchema": "{}",
}
data=AgentBlockCreateInput(
id=block.id,
name=block.name,
inputSchema="{}",
outputSchema="{}",
)
)
except Exception as e:
print(f"Error creating block {block.name}: {e}")

View File

@@ -30,13 +30,19 @@ from prisma.types import (
AgentGraphCreateInput,
AgentNodeCreateInput,
AgentNodeLinkCreateInput,
AgentPresetCreateInput,
AnalyticsDetailsCreateInput,
AnalyticsMetricsCreateInput,
APIKeyCreateInput,
CreditTransactionCreateInput,
IntegrationWebhookCreateInput,
LibraryAgentCreateInput,
ProfileCreateInput,
StoreListingCreateInput,
StoreListingReviewCreateInput,
StoreListingVersionCreateInput,
UserCreateInput,
UserOnboardingCreateInput,
)
faker = Faker()
@@ -172,14 +178,14 @@ async def main():
for _ in range(num_presets): # Create 1 AgentPreset per user
graph = random.choice(agent_graphs)
preset = await db.agentpreset.create(
data={
"name": faker.sentence(nb_words=3),
"description": faker.text(max_nb_chars=200),
"userId": user.id,
"agentGraphId": graph.id,
"agentGraphVersion": graph.version,
"isActive": True,
}
data=AgentPresetCreateInput(
name=faker.sentence(nb_words=3),
description=faker.text(max_nb_chars=200),
userId=user.id,
agentGraphId=graph.id,
agentGraphVersion=graph.version,
isActive=True,
)
)
agent_presets.append(preset)
@@ -220,18 +226,18 @@ async def main():
)
library_agent = await db.libraryagent.create(
data={
"userId": user.id,
"agentGraphId": graph.id,
"agentGraphVersion": graph.version,
"creatorId": creator_profile.id if creator_profile else None,
"imageUrl": get_image() if random.random() < 0.5 else None,
"useGraphIsActiveVersion": random.choice([True, False]),
"isFavorite": random.choice([True, False]),
"isCreatedByUser": random.choice([True, False]),
"isArchived": random.choice([True, False]),
"isDeleted": random.choice([True, False]),
}
data=LibraryAgentCreateInput(
userId=user.id,
agentGraphId=graph.id,
agentGraphVersion=graph.version,
creatorId=creator_profile.id if creator_profile else None,
imageUrl=get_image() if random.random() < 0.5 else None,
useGraphIsActiveVersion=random.choice([True, False]),
isFavorite=random.choice([True, False]),
isCreatedByUser=random.choice([True, False]),
isArchived=random.choice([True, False]),
isDeleted=random.choice([True, False]),
)
)
library_agents.append(library_agent)
@@ -392,13 +398,13 @@ async def main():
user = random.choice(users)
slug = faker.slug()
listing = await db.storelisting.create(
data={
"agentGraphId": graph.id,
"agentGraphVersion": graph.version,
"owningUserId": user.id,
"hasApprovedVersion": random.choice([True, False]),
"slug": slug,
}
data=StoreListingCreateInput(
agentGraphId=graph.id,
agentGraphVersion=graph.version,
owningUserId=user.id,
hasApprovedVersion=random.choice([True, False]),
slug=slug,
)
)
store_listings.append(listing)
@@ -408,26 +414,26 @@ async def main():
for listing in store_listings:
graph = [g for g in agent_graphs if g.id == listing.agentGraphId][0]
version = await db.storelistingversion.create(
data={
"agentGraphId": graph.id,
"agentGraphVersion": graph.version,
"name": graph.name or faker.sentence(nb_words=3),
"subHeading": faker.sentence(),
"videoUrl": get_video_url() if random.random() < 0.3 else None,
"imageUrls": [get_image() for _ in range(3)],
"description": faker.text(),
"categories": [faker.word() for _ in range(3)],
"isFeatured": random.choice([True, False]),
"isAvailable": True,
"storeListingId": listing.id,
"submissionStatus": random.choice(
data=StoreListingVersionCreateInput(
agentGraphId=graph.id,
agentGraphVersion=graph.version,
name=graph.name or faker.sentence(nb_words=3),
subHeading=faker.sentence(),
videoUrl=get_video_url() if random.random() < 0.3 else None,
imageUrls=[get_image() for _ in range(3)],
description=faker.text(),
categories=[faker.word() for _ in range(3)],
isFeatured=random.choice([True, False]),
isAvailable=True,
storeListingId=listing.id,
submissionStatus=random.choice(
[
prisma.enums.SubmissionStatus.PENDING,
prisma.enums.SubmissionStatus.APPROVED,
prisma.enums.SubmissionStatus.REJECTED,
]
),
}
)
)
store_listing_versions.append(version)
@@ -469,51 +475,49 @@ async def main():
try:
await db.useronboarding.create(
data={
"userId": user.id,
"completedSteps": completed_steps,
"walletShown": random.choice([True, False]),
"notified": (
data=UserOnboardingCreateInput(
userId=user.id,
completedSteps=completed_steps,
walletShown=random.choice([True, False]),
notified=(
random.sample(completed_steps, k=min(3, len(completed_steps)))
if completed_steps
else []
),
"rewardedFor": (
rewardedFor=(
random.sample(completed_steps, k=min(2, len(completed_steps)))
if completed_steps
else []
),
"usageReason": (
usageReason=(
random.choice(["personal", "business", "research", "learning"])
if random.random() < 0.7
else None
),
"integrations": random.sample(
integrations=random.sample(
["github", "google", "discord", "slack"], k=random.randint(0, 2)
),
"otherIntegrations": (
faker.word() if random.random() < 0.2 else None
),
"selectedStoreListingVersionId": (
otherIntegrations=(faker.word() if random.random() < 0.2 else None),
selectedStoreListingVersionId=(
random.choice(store_listing_versions).id
if store_listing_versions and random.random() < 0.5
else None
),
"onboardingAgentExecutionId": (
onboardingAgentExecutionId=(
random.choice(agent_graph_executions).id
if agent_graph_executions and random.random() < 0.3
else None
),
"agentRuns": random.randint(0, 10),
}
agentRuns=random.randint(0, 10),
)
)
except Exception as e:
print(f"Error creating onboarding for user {user.id}: {e}")
# Try simpler version
await db.useronboarding.create(
data={
"userId": user.id,
}
data=UserOnboardingCreateInput(
userId=user.id,
)
)
# Insert IntegrationWebhooks for some users
@@ -544,20 +548,20 @@ async def main():
for user in users:
api_key = APIKeySmith().generate_key()
await db.apikey.create(
data={
"name": faker.word(),
"head": api_key.head,
"tail": api_key.tail,
"hash": api_key.hash,
"salt": api_key.salt,
"status": prisma.enums.APIKeyStatus.ACTIVE,
"permissions": [
data=APIKeyCreateInput(
name=faker.word(),
head=api_key.head,
tail=api_key.tail,
hash=api_key.hash,
salt=api_key.salt,
status=prisma.enums.APIKeyStatus.ACTIVE,
permissions=[
prisma.enums.APIKeyPermission.EXECUTE_GRAPH,
prisma.enums.APIKeyPermission.READ_GRAPH,
],
"description": faker.text(),
"userId": user.id,
}
description=faker.text(),
userId=user.id,
)
)
# Refresh materialized views

View File

@@ -16,6 +16,7 @@ from datetime import datetime, timedelta
import prisma.enums
from faker import Faker
from prisma import Json, Prisma
from prisma.types import CreditTransactionCreateInput, StoreListingReviewCreateInput
faker = Faker()
@@ -166,16 +167,16 @@ async def main():
score = random.choices([1, 2, 3, 4, 5], weights=[5, 10, 20, 40, 25])[0]
await db.storelistingreview.create(
data={
"storeListingVersionId": version.id,
"reviewByUserId": reviewer.id,
"score": score,
"comments": (
data=StoreListingReviewCreateInput(
storeListingVersionId=version.id,
reviewByUserId=reviewer.id,
score=score,
comments=(
faker.text(max_nb_chars=200)
if random.random() < 0.7
else None
),
}
)
)
new_reviews_count += 1
@@ -244,17 +245,17 @@ async def main():
)
await db.credittransaction.create(
data={
"userId": user.id,
"amount": amount,
"type": transaction_type,
"metadata": Json(
data=CreditTransactionCreateInput(
userId=user.id,
amount=amount,
type=transaction_type,
metadata=Json(
{
"source": "test_updater",
"timestamp": datetime.now().isoformat(),
}
),
}
)
)
transaction_count += 1

View File

@@ -37,7 +37,7 @@ services:
context: ../
dockerfile: autogpt_platform/backend/Dockerfile
target: migrate
command: ["sh", "-c", "poetry run prisma generate && poetry run gen-prisma-stub && poetry run prisma migrate deploy"]
command: ["sh", "-c", "poetry run prisma generate && poetry run prisma migrate deploy"]
develop:
watch:
- path: ./

View File

@@ -66,7 +66,6 @@ export const RunInputDialog = ({
formContext={{
showHandles: false,
size: "large",
showOptionalToggle: false,
}}
/>
</div>

View File

@@ -66,7 +66,7 @@ export const useRunInputDialog = ({
if (isCredentialFieldSchema(fieldSchema)) {
dynamicUiSchema[fieldName] = {
...dynamicUiSchema[fieldName],
"ui:field": "custom/credential_field",
"ui:field": "credentials",
};
}
});
@@ -76,18 +76,12 @@ export const useRunInputDialog = ({
}, [credentialsSchema]);
const handleManualRun = async () => {
// Filter out incomplete credentials (those without a valid id)
// RJSF auto-populates const values (provider, type) but not id field
const validCredentials = Object.fromEntries(
Object.entries(credentialValues).filter(([_, cred]) => cred && cred.id),
);
await executeGraph({
graphId: flowID ?? "",
graphVersion: flowVersion || null,
data: {
inputs: inputValues,
credentials_inputs: validCredentials,
credentials_inputs: credentialValues,
source: "builder",
},
});

View File

@@ -68,10 +68,7 @@ export const NodeHeader = ({ data, nodeId }: Props) => {
<Tooltip>
<TooltipTrigger asChild>
<div>
<Text
variant="large-semibold"
className="line-clamp-1 hover:cursor-text"
>
<Text variant="large-semibold" className="line-clamp-1">
{beautifyString(title).replace("Block", "").trim()}
</Text>
</div>

View File

@@ -89,18 +89,6 @@ export function extractOptions(
// get display type and color for schema types [need for type display next to field name]
export const getTypeDisplayInfo = (schema: any) => {
if (
schema?.type === "array" &&
"format" in schema &&
schema.format === "table"
) {
return {
displayType: "table",
colorClass: "!text-indigo-500",
hexColor: "#6366f1",
};
}
if (schema?.type === "string" && schema?.format) {
const formatMap: Record<
string,

View File

@@ -1,6 +1,6 @@
export const uiSchema = {
credentials: {
"ui:field": "custom/credential_field",
"ui:field": "credentials",
provider: { "ui:widget": "hidden" },
type: { "ui:widget": "hidden" },
id: { "ui:autofocus": true },

View File

@@ -68,9 +68,6 @@ type NodeStore = {
clearAllNodeErrors: () => void; // Add this
syncHardcodedValuesWithHandleIds: (nodeId: string) => void;
// Credentials optional helpers
setCredentialsOptional: (nodeId: string, optional: boolean) => void;
};
export const useNodeStore = create<NodeStore>((set, get) => ({
@@ -229,9 +226,6 @@ export const useNodeStore = create<NodeStore>((set, get) => ({
...(node.data.metadata?.customized_name !== undefined && {
customized_name: node.data.metadata.customized_name,
}),
...(node.data.metadata?.credentials_optional !== undefined && {
credentials_optional: node.data.metadata.credentials_optional,
}),
},
};
},
@@ -348,30 +342,4 @@ export const useNodeStore = create<NodeStore>((set, get) => ({
}));
}
},
setCredentialsOptional: (nodeId: string, optional: boolean) => {
set((state) => ({
nodes: state.nodes.map((n) =>
n.id === nodeId
? {
...n,
data: {
...n.data,
metadata: {
...n.data.metadata,
credentials_optional: optional,
},
},
}
: n,
),
}));
const newState = {
nodes: get().nodes,
edges: useEdgeStore.getState().edges,
};
useHistoryStore.getState().pushState(newState);
},
}));

View File

@@ -34,9 +34,7 @@ type Props = {
onSelectCredentials: (newValue?: CredentialsMetaInput) => void;
onLoaded?: (loaded: boolean) => void;
readOnly?: boolean;
isOptional?: boolean;
showTitle?: boolean;
variant?: "default" | "node";
};
export function CredentialsInput({
@@ -47,9 +45,7 @@ export function CredentialsInput({
siblingInputs,
onLoaded,
readOnly = false,
isOptional = false,
showTitle = true,
variant = "default",
}: Props) {
const hookData = useCredentialsInput({
schema,
@@ -58,7 +54,6 @@ export function CredentialsInput({
siblingInputs,
onLoaded,
readOnly,
isOptional,
});
if (!isLoaded(hookData)) {
@@ -99,14 +94,7 @@ export function CredentialsInput({
<div className={cn("mb-6", className)}>
{showTitle && (
<div className="mb-2 flex items-center gap-2">
<Text variant="large-medium">
{displayName} credentials
{isOptional && (
<span className="ml-1 text-sm font-normal text-gray-500">
(optional)
</span>
)}
</Text>
<Text variant="large-medium">{displayName} credentials</Text>
{schema.description && (
<InformationTooltip description={schema.description} />
)}
@@ -115,17 +103,14 @@ export function CredentialsInput({
{hasCredentialsToShow ? (
<>
{(credentialsToShow.length > 1 || isOptional) && !readOnly ? (
{credentialsToShow.length > 1 && !readOnly ? (
<CredentialsSelect
credentials={credentialsToShow}
provider={provider}
displayName={displayName}
selectedCredentials={selectedCredential}
onSelectCredential={handleCredentialSelect}
onClearCredential={() => onSelectCredential(undefined)}
readOnly={readOnly}
allowNone={isOptional}
variant={variant}
/>
) : (
<div className="mb-4 space-y-2">

View File

@@ -30,8 +30,6 @@ type CredentialRowProps = {
readOnly?: boolean;
showCaret?: boolean;
asSelectTrigger?: boolean;
/** When "node", applies compact styling for node context */
variant?: "default" | "node";
};
export function CredentialRow({
@@ -43,22 +41,14 @@ export function CredentialRow({
readOnly = false,
showCaret = false,
asSelectTrigger = false,
variant = "default",
}: CredentialRowProps) {
const ProviderIcon = providerIcons[provider] || fallbackIcon;
const isNodeVariant = variant === "node";
return (
<div
className={cn(
"flex items-center gap-3 rounded-medium border border-zinc-200 bg-white p-3 transition-colors",
asSelectTrigger && isNodeVariant
? "min-w-0 flex-1 overflow-hidden border-0 bg-transparent"
: asSelectTrigger
? "border-0 bg-transparent"
: readOnly
? "w-fit"
: "",
asSelectTrigger ? "border-0 bg-transparent" : readOnly ? "w-fit" : "",
)}
onClick={readOnly || showCaret || asSelectTrigger ? undefined : onSelect}
style={
@@ -71,31 +61,19 @@ export function CredentialRow({
<ProviderIcon className="h-3 w-3 text-white" />
</div>
<IconKey className="h-5 w-5 shrink-0 text-zinc-800" />
<div
className={cn(
"flex min-w-0 flex-1 flex-nowrap items-center gap-4",
isNodeVariant && "overflow-hidden",
)}
>
<div className="flex min-w-0 flex-1 flex-nowrap items-center gap-4">
<Text
variant="body"
className={cn(
"tracking-tight",
isNodeVariant
? "truncate"
: "line-clamp-1 flex-[0_0_50%] text-ellipsis",
)}
className="line-clamp-1 flex-[0_0_50%] text-ellipsis tracking-tight"
>
{getCredentialDisplayName(credential, displayName)}
</Text>
{!(asSelectTrigger && isNodeVariant) && (
<Text
variant="large"
className="relative top-1 hidden overflow-hidden whitespace-nowrap font-mono tracking-tight md:block"
>
{"*".repeat(MASKED_KEY_LENGTH)}
</Text>
)}
<Text
variant="large"
className="lex-[0_0_40%] relative top-1 hidden overflow-hidden whitespace-nowrap font-mono tracking-tight md:block"
>
{"*".repeat(MASKED_KEY_LENGTH)}
</Text>
</div>
{showCaret && !asSelectTrigger && (
<CaretDown className="h-4 w-4 shrink-0 text-gray-400" />

View File

@@ -7,7 +7,6 @@ import {
} from "@/components/__legacy__/ui/select";
import { Text } from "@/components/atoms/Text/Text";
import { CredentialsMetaInput } from "@/lib/autogpt-server-api/types";
import { cn } from "@/lib/utils";
import { useEffect } from "react";
import { getCredentialDisplayName } from "../../helpers";
import { CredentialRow } from "../CredentialRow/CredentialRow";
@@ -24,11 +23,7 @@ interface Props {
displayName: string;
selectedCredentials?: CredentialsMetaInput;
onSelectCredential: (credentialId: string) => void;
onClearCredential?: () => void;
readOnly?: boolean;
allowNone?: boolean;
/** When "node", applies compact styling for node context */
variant?: "default" | "node";
}
export function CredentialsSelect({
@@ -37,38 +32,22 @@ export function CredentialsSelect({
displayName,
selectedCredentials,
onSelectCredential,
onClearCredential,
readOnly = false,
allowNone = true,
variant = "default",
}: Props) {
// Auto-select first credential if none is selected (only if allowNone is false)
// Auto-select first credential if none is selected
useEffect(() => {
if (!allowNone && !selectedCredentials && credentials.length > 0) {
if (!selectedCredentials && credentials.length > 0) {
onSelectCredential(credentials[0].id);
}
}, [allowNone, selectedCredentials, credentials, onSelectCredential]);
const handleValueChange = (value: string) => {
if (value === "__none__") {
onClearCredential?.();
} else {
onSelectCredential(value);
}
};
}, [selectedCredentials, credentials, onSelectCredential]);
return (
<div className="mb-4 w-full">
<Select
value={selectedCredentials?.id || (allowNone ? "__none__" : "")}
onValueChange={handleValueChange}
value={selectedCredentials?.id || ""}
onValueChange={(value) => onSelectCredential(value)}
>
<SelectTrigger
className={cn(
"h-auto min-h-12 w-full rounded-medium border-zinc-200 p-0 pr-4 shadow-none",
variant === "node" && "overflow-hidden",
)}
>
<SelectTrigger className="h-auto min-h-12 w-full rounded-medium border-zinc-200 p-0 pr-4 shadow-none">
{selectedCredentials ? (
<SelectValue key={selectedCredentials.id} asChild>
<CredentialRow
@@ -84,7 +63,6 @@ export function CredentialsSelect({
onDelete={() => {}}
readOnly={readOnly}
asSelectTrigger={true}
variant={variant}
/>
</SelectValue>
) : (
@@ -92,15 +70,6 @@ export function CredentialsSelect({
)}
</SelectTrigger>
<SelectContent>
{allowNone && (
<SelectItem key="__none__" value="__none__">
<div className="flex items-center gap-2">
<Text variant="body" className="tracking-tight text-gray-500">
None (skip this credential)
</Text>
</div>
</SelectItem>
)}
{credentials.map((credential) => (
<SelectItem key={credential.id} value={credential.id}>
<div className="flex items-center gap-2">

View File

@@ -22,7 +22,6 @@ type Params = {
siblingInputs?: Record<string, any>;
onLoaded?: (loaded: boolean) => void;
readOnly?: boolean;
isOptional?: boolean;
};
export function useCredentialsInput({
@@ -32,7 +31,6 @@ export function useCredentialsInput({
siblingInputs,
onLoaded,
readOnly = false,
isOptional = false,
}: Params) {
const [isAPICredentialsModalOpen, setAPICredentialsModalOpen] =
useState(false);
@@ -101,20 +99,13 @@ export function useCredentialsInput({
: null;
}, [credentials]);
// Auto-select the one available credential (only if not optional)
// Auto-select the one available credential
useEffect(() => {
if (readOnly) return;
if (isOptional) return; // Don't auto-select when credential is optional
if (singleCredential && !selectedCredential) {
onSelectCredential(singleCredential);
}
}, [
singleCredential,
selectedCredential,
onSelectCredential,
readOnly,
isOptional,
]);
}, [singleCredential, selectedCredential, onSelectCredential, readOnly]);
if (
!credentials ||

View File

@@ -8,7 +8,6 @@ import { WebhookTriggerBanner } from "../WebhookTriggerBanner/WebhookTriggerBann
export function ModalRunSection() {
const {
agent,
defaultRunType,
presetName,
setPresetName,
@@ -25,11 +24,6 @@ export function ModalRunSection() {
const inputFields = Object.entries(agentInputFields || {});
const credentialFields = Object.entries(agentCredentialsInputFields || {});
// Get the list of required credentials from the schema
const requiredCredentials = new Set(
(agent.credentials_input_schema?.required as string[]) || [],
);
return (
<div className="flex flex-col gap-4">
{defaultRunType === "automatic-trigger" ||
@@ -105,12 +99,14 @@ export function ModalRunSection() {
schema={
{ ...inputSubSchema, discriminator: undefined } as any
}
selectedCredentials={inputCredentials?.[key]}
selectedCredentials={
(inputCredentials && inputCredentials[key]) ??
inputSubSchema.default
}
onSelectCredentials={(value) =>
setInputCredentialsValue(key, value)
}
siblingInputs={inputValues}
isOptional={!requiredCredentials.has(key)}
/>
),
)}

View File

@@ -163,21 +163,15 @@ export function useAgentRunModal(
}, [agentInputSchema.required, inputValues]);
const [allCredentialsAreSet, missingCredentials] = useMemo(() => {
// Only check required credentials from schema, not all properties
// Credentials marked as optional in node metadata won't be in the required array
const requiredCredentials = new Set(
(agent.credentials_input_schema?.required as string[]) || [],
const availableCredentials = new Set(Object.keys(inputCredentials));
const allCredentials = new Set(
Object.keys(agentCredentialsInputFields || {}) ?? [],
);
const missing = [...allCredentials].filter(
(key) => !availableCredentials.has(key),
);
// Check if required credentials have valid id (not just key existence)
// A credential is valid only if it has an id field set
const missing = [...requiredCredentials].filter((key) => {
const cred = inputCredentials[key];
return !cred || !cred.id;
});
return [missing.length === 0, missing];
}, [agent.credentials_input_schema, inputCredentials]);
}, [agentCredentialsInputFields, inputCredentials]);
const credentialsRequired = useMemo(
() => Object.keys(agentCredentialsInputFields || {}).length > 0,
@@ -245,18 +239,12 @@ export function useAgentRunModal(
});
} else {
// Manual execution
// Filter out incomplete credentials (optional ones not selected)
// Only send credentials that have a valid id field
const validCredentials = Object.fromEntries(
Object.entries(inputCredentials).filter(([_, cred]) => cred && cred.id),
);
executeGraphMutation.mutate({
graphId: agent.graph_id,
graphVersion: agent.graph_version,
data: {
inputs: inputValues,
credentials_inputs: validCredentials,
credentials_inputs: inputCredentials,
source: "library",
},
});

View File

@@ -40,17 +40,15 @@ export function useMarketplaceUpdate({ agent }: UseMarketplaceUpdateProps) {
},
);
// Get user's submissions - only fetch if user is the creator
const { data: submissionsData, isLoading: isSubmissionsLoading } =
useGetV2ListMySubmissions(
{ page: 1, page_size: 50 },
{
query: {
// Only fetch if user is the creator
enabled: !!(user?.id && agent?.owner_user_id === user.id),
},
// Get user's submissions to check for pending submissions
const { data: submissionsData } = useGetV2ListMySubmissions(
{ page: 1, page_size: 50 }, // Get enough to cover recent submissions
{
query: {
enabled: !!user?.id, // Only fetch if user is authenticated
},
);
},
);
const updateToLatestMutation = usePatchV2UpdateLibraryAgent({
mutation: {
@@ -80,45 +78,16 @@ export function useMarketplaceUpdate({ agent }: UseMarketplaceUpdateProps) {
// Check if marketplace has a newer version than user's current version
const marketplaceUpdateInfo = React.useMemo(() => {
const storeAgent = okData(storeAgentData) as any;
if (!agent || isSubmissionsLoading) {
if (!agent || !storeAgent) {
return {
hasUpdate: false,
latestVersion: undefined,
isUserCreator: false,
hasPublishUpdate: false,
};
}
const isUserCreator = agent?.owner_user_id === user?.id;
const submissionsResponse = okData(submissionsData) as any;
const agentSubmissions =
submissionsResponse?.submissions?.filter(
(submission: StoreSubmission) => submission.agent_id === agent.graph_id,
) || [];
const highestSubmittedVersion =
agentSubmissions.length > 0
? Math.max(
...agentSubmissions.map(
(submission: StoreSubmission) => submission.agent_version,
),
)
: 0;
const hasUnpublishedChanges =
isUserCreator && agent.graph_version > highestSubmittedVersion;
if (!storeAgent) {
return {
hasUpdate: false,
latestVersion: undefined,
isUserCreator,
hasPublishUpdate: agentSubmissions.length > 0 && hasUnpublishedChanges,
};
}
// Get the latest version from the marketplace
// agentGraphVersions array contains graph version numbers as strings, get the highest one
const latestMarketplaceVersion =
storeAgent.agentGraphVersions?.length > 0
? Math.max(
@@ -128,11 +97,32 @@ export function useMarketplaceUpdate({ agent }: UseMarketplaceUpdateProps) {
)
: undefined;
// Determine if the user is the creator of this agent
// Compare current user ID with the marketplace listing creator ID
const isUserCreator =
user?.id && agent.marketplace_listing?.creator.id === user.id;
// Check if there's a pending submission for this specific agent version
const submissionsResponse = okData(submissionsData) as any;
const hasPendingSubmissionForCurrentVersion =
isUserCreator &&
submissionsResponse?.submissions?.some(
(submission: StoreSubmission) =>
submission.agent_id === agent.graph_id &&
submission.agent_version === agent.graph_version &&
submission.status === "PENDING",
);
// If user is creator and their version is newer than marketplace, show publish update banner
// BUT only if there's no pending submission for this version
const hasPublishUpdate =
isUserCreator &&
agent.graph_version >
Math.max(latestMarketplaceVersion || 0, highestSubmittedVersion);
!hasPendingSubmissionForCurrentVersion &&
latestMarketplaceVersion !== undefined &&
agent.graph_version > latestMarketplaceVersion;
// If marketplace version is newer than user's version, show update banner
// This applies to both creators and non-creators
const hasMarketplaceUpdate =
latestMarketplaceVersion !== undefined &&
latestMarketplaceVersion > agent.graph_version;
@@ -143,7 +133,7 @@ export function useMarketplaceUpdate({ agent }: UseMarketplaceUpdateProps) {
isUserCreator,
hasPublishUpdate,
};
}, [agent, storeAgentData, user, submissionsData, isSubmissionsLoading]);
}, [agent, storeAgentData, user, submissionsData]);
const handlePublishUpdate = () => {
setModalOpen(true);

View File

@@ -1,17 +1,15 @@
"use client";
import { LibraryAgent } from "@/app/api/__generated__/models/libraryAgent";
import { Text } from "@/components/atoms/Text/Text";
import { Skeleton } from "@/components/__legacy__/ui/skeleton";
import { InfiniteScroll } from "@/components/contextual/InfiniteScroll/InfiniteScroll";
import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag";
import { HeartIcon } from "@phosphor-icons/react";
import { useFavoriteAgents } from "../../hooks/useFavoriteAgents";
import { LibraryAgentCard } from "../LibraryAgentCard/LibraryAgentCard";
interface Props {
searchTerm: string;
}
export function FavoritesSection({ searchTerm }: Props) {
export function FavoritesSection() {
const isAgentFavoritingEnabled = useGetFlag(Flag.AGENT_FAVORITING);
const {
allAgents: favoriteAgents,
agentLoading: isLoading,
@@ -19,50 +17,60 @@ export function FavoritesSection({ searchTerm }: Props) {
hasNextPage,
fetchNextPage,
isFetchingNextPage,
} = useFavoriteAgents({ searchTerm });
} = useFavoriteAgents();
if (isLoading || favoriteAgents.length === 0) {
// Only show this section if the feature flag is enabled
if (!isAgentFavoritingEnabled) {
return null;
}
// Don't show the section if there are no favorites
if (!isLoading && favoriteAgents.length === 0) {
return null;
}
return (
<div className="!mb-8">
<div className="mb-3 flex items-center gap-2 p-2">
<HeartIcon className="h-5 w-5" weight="fill" />
<div className="flex items-baseline gap-2">
<Text variant="h4">Favorites</Text>
{!isLoading && (
<Text
variant="body"
data-testid="agents-count"
className="relative bottom-px text-zinc-500"
>
{agentCount}
</Text>
)}
</div>
<div className="mb-8">
<div className="flex items-center gap-[10px] p-2 pb-[10px]">
<HeartIcon className="h-5 w-5 fill-red-500 text-red-500" />
<span className="font-poppin text-[18px] font-semibold leading-[28px] text-neutral-800">
Favorites
</span>
{!isLoading && (
<span className="font-sans text-[14px] font-normal leading-6">
{agentCount} {agentCount === 1 ? "agent" : "agents"}
</span>
)}
</div>
<div className="relative">
<InfiniteScroll
isFetchingNextPage={isFetchingNextPage}
fetchNextPage={fetchNextPage}
hasNextPage={hasNextPage}
loader={
<div className="flex h-8 w-full items-center justify-center">
<div className="h-6 w-6 animate-spin rounded-full border-b-2 border-t-2 border-neutral-800" />
</div>
}
>
{isLoading ? (
<div className="grid grid-cols-1 gap-4 sm:grid-cols-2 lg:grid-cols-3 xl:grid-cols-4">
{favoriteAgents.map((agent: LibraryAgent) => (
<LibraryAgentCard key={agent.id} agent={agent} />
{[...Array(4)].map((_, i) => (
<Skeleton key={i} className="h-48 w-full rounded-lg" />
))}
</div>
</InfiniteScroll>
) : (
<InfiniteScroll
isFetchingNextPage={isFetchingNextPage}
fetchNextPage={fetchNextPage}
hasNextPage={hasNextPage}
loader={
<div className="flex h-8 w-full items-center justify-center">
<div className="h-6 w-6 animate-spin rounded-full border-b-2 border-t-2 border-neutral-800" />
</div>
}
>
<div className="grid grid-cols-1 gap-4 sm:grid-cols-2 lg:grid-cols-3 xl:grid-cols-4">
{favoriteAgents.map((agent: LibraryAgent) => (
<LibraryAgentCard key={agent.id} agent={agent} />
))}
</div>
</InfiniteScroll>
)}
</div>
{favoriteAgents.length > 0 && <div className="!mt-10 border-t" />}
{favoriteAgents.length > 0 && <div className="mt-6 border-t pt-6" />}
</div>
);
}

View File

@@ -24,6 +24,7 @@ export function LibraryAgentCard({ agent }: Props) {
const {
isFromMarketplace,
isAgentFavoritingEnabled,
isFavorite,
profile,
creator_image_url,
@@ -36,8 +37,9 @@ export function LibraryAgentCard({ agent }: Props) {
data-agent-id={id}
className="group relative inline-flex h-[10.625rem] w-full max-w-[25rem] flex-col items-start justify-start gap-2.5 rounded-medium border border-zinc-100 bg-white transition-all duration-300 hover:shadow-md"
>
<NextLink href={`/library/agents/${id}`} className="flex-shrink-0">
<div className="relative flex items-center gap-2 px-4 pt-3">
<AgentCardMenu agent={agent} />
<NextLink href={`/library/agents/${id}`} className="w-full flex-shrink-0">
<div className="flex items-center gap-2 px-4 pt-3">
<Avatar className="h-4 w-4 rounded-full">
<AvatarImage
src={
@@ -56,13 +58,13 @@ export function LibraryAgentCard({ agent }: Props) {
{isFromMarketplace ? "FROM MARKETPLACE" : "Built by you"}
</Text>
</div>
{isAgentFavoritingEnabled && (
<FavoriteButton
isFavorite={isFavorite}
onClick={handleToggleFavorite}
/>
)}
</NextLink>
<FavoriteButton
isFavorite={isFavorite}
onClick={handleToggleFavorite}
className="absolute right-10 top-0"
/>
<AgentCardMenu agent={agent} />
<div className="flex w-full flex-1 flex-col px-4 pb-2">
<Link

View File

@@ -2,27 +2,21 @@
import { cn } from "@/lib/utils";
import { HeartIcon } from "@phosphor-icons/react";
import type { MouseEvent } from "react";
interface FavoriteButtonProps {
isFavorite: boolean;
onClick: (e: MouseEvent<HTMLButtonElement>) => void;
className?: string;
onClick: (e: React.MouseEvent) => void;
}
export function FavoriteButton({
isFavorite,
onClick,
className,
}: FavoriteButtonProps) {
export function FavoriteButton({ isFavorite, onClick }: FavoriteButtonProps) {
return (
<button
onClick={onClick}
className={cn(
"rounded-full p-2 transition-all duration-200",
"hover:scale-110",
"rounded-full bg-white/90 p-2 backdrop-blur-sm transition-all duration-200",
"hover:scale-110 hover:bg-white",
"focus:outline-none focus:ring-2 focus:ring-red-500 focus:ring-offset-2",
!isFavorite && "opacity-0 group-hover:opacity-100",
className,
)}
aria-label={isFavorite ? "Remove from favorites" : "Add to favorites"}
>

View File

@@ -8,6 +8,7 @@ import { useGetV2GetUserProfile } from "@/app/api/__generated__/endpoints/store/
import { LibraryAgent } from "@/app/api/__generated__/models/libraryAgent";
import { okData } from "@/app/api/helpers";
import { useToast } from "@/components/molecules/Toast/use-toast";
import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag";
import { updateFavoriteInQueries } from "./helpers";
interface Props {
@@ -19,6 +20,7 @@ export function useLibraryAgentCard({ agent }: Props) {
agent;
const isFromMarketplace = Boolean(marketplace_listing);
const isAgentFavoritingEnabled = useGetFlag(Flag.AGENT_FAVORITING);
const [isFavorite, setIsFavorite] = useState(is_favorite);
const { toast } = useToast();
const queryClient = getQueryClient();
@@ -47,6 +49,8 @@ export function useLibraryAgentCard({ agent }: Props) {
e.preventDefault();
e.stopPropagation();
if (!isAgentFavoritingEnabled) return;
const newIsFavorite = !isFavorite;
setIsFavorite(newIsFavorite);
@@ -76,6 +80,7 @@ export function useLibraryAgentCard({ agent }: Props) {
return {
isFromMarketplace,
isAgentFavoritingEnabled,
isFavorite,
profile,
creator_image_url,

View File

@@ -1,15 +1,13 @@
"use client";
import {
getPaginatedTotalCount,
getPaginationNextPageNumber,
unpaginate,
} from "@/app/api/helpers";
import { useGetV2ListFavoriteLibraryAgentsInfinite } from "@/app/api/__generated__/endpoints/library/library";
import { getPaginationNextPageNumber, unpaginate } from "@/app/api/helpers";
import { useMemo } from "react";
import { filterAgents } from "../components/LibraryAgentList/helpers";
interface Props {
searchTerm: string;
}
export function useFavoriteAgents({ searchTerm }: Props) {
export function useFavoriteAgents() {
const {
data: agentsQueryData,
fetchNextPage,
@@ -29,16 +27,10 @@ export function useFavoriteAgents({ searchTerm }: Props) {
const allAgents = agentsQueryData
? unpaginate(agentsQueryData, "agents")
: [];
const filteredAgents = useMemo(
() => filterAgents(allAgents, searchTerm),
[allAgents, searchTerm],
);
const agentCount = filteredAgents.length;
const agentCount = getPaginatedTotalCount(agentsQueryData);
return {
allAgents: filteredAgents,
allAgents,
agentLoading,
hasNextPage,
agentCount,

View File

@@ -17,7 +17,7 @@ export default function LibraryPage() {
return (
<main className="pt-160 container min-h-screen space-y-4 pb-20 pt-16 sm:px-8 md:px-12">
<LibraryActionHeader setSearchTerm={setSearchTerm} />
<FavoritesSection searchTerm={searchTerm} />
<FavoritesSection />
<LibraryAgentList
searchTerm={searchTerm}
librarySort={librarySort}

View File

@@ -12,7 +12,6 @@ import type { GetV2GetSpecificAgentParams } from "@/app/api/__generated__/models
import { useAgentInfo } from "./useAgentInfo";
import { useGetV2GetSpecificAgent } from "@/app/api/__generated__/endpoints/store/store";
import { Text } from "@/components/atoms/Text/Text";
import { formatTimeAgo } from "@/lib/utils/time";
import * as React from "react";
interface AgentInfoProps {
@@ -259,29 +258,23 @@ export const AgentInfo = ({
</div>
</div>
{/* Version history */}
{/* Changelog */}
<div className="flex w-full flex-col gap-1.5 sm:gap-2">
<div className="decoration-skip-ink-none text-base font-medium leading-6 text-neutral-800 dark:text-neutral-200">
Version history
<div className="decoration-skip-ink-none mb-1.5 text-base font-medium leading-6 text-neutral-800 dark:text-neutral-200 sm:mb-2">
Changelog
</div>
<div className="decoration-skip-ink-none text-sm font-normal leading-6 text-neutral-600 underline-offset-[from-font] dark:text-neutral-400">
Last updated {formatTimeAgo(lastUpdated)}
</div>
<div className="decoration-skip-ink-none text-xs text-neutral-600 dark:text-neutral-400 sm:text-sm">
Version {version}.0
<div className="decoration-skip-ink-none text-base font-normal leading-6 text-neutral-600 underline-offset-[from-font] dark:text-neutral-400">
Last updated {lastUpdated}
</div>
{/* Version List */}
{agentVersions.length > 0 ? (
<div className="mt-3">
<div className="decoration-skip-ink-none mb-1.5 text-base font-medium leading-6 text-neutral-900 dark:text-neutral-200 sm:mb-2">
Changelog
</div>
<div className="mt-4">
{agentVersions.map(renderVersionItem)}
{hasMoreVersions && (
<button
onClick={() => setVisibleVersionCount((prev) => prev + 3)}
className="mt-2 flex items-center gap-1 text-sm font-medium text-neutral-700 hover:text-neutral-700 dark:text-neutral-100 dark:hover:text-neutral-300"
className="mt-2 flex items-center gap-1 text-sm font-medium text-neutral-900 hover:text-neutral-700 dark:text-neutral-100 dark:hover:text-neutral-300"
>
<svg
width="16"
@@ -304,7 +297,7 @@ export const AgentInfo = ({
</div>
) : (
<div className="text-xs text-neutral-600 dark:text-neutral-400 sm:text-sm">
Version {version}.0
Version {version}
</div>
)}
</div>

View File

@@ -18,7 +18,6 @@ export interface AgentTableCardProps {
runs: number;
rating: number;
id: number;
listing_id?: string;
onViewSubmission: (submission: StoreSubmission) => void;
}
@@ -33,12 +32,10 @@ export const AgentTableCard = ({
status,
runs,
rating,
listing_id,
onViewSubmission,
}: AgentTableCardProps) => {
const onView = () => {
onViewSubmission({
listing_id: listing_id || "",
agent_id,
agent_version,
slug: "",
@@ -65,14 +62,9 @@ export const AgentTableCard = ({
/>
</div>
<div className="flex-1">
<div className="flex items-center gap-2">
<h3 className="text-[15px] font-medium text-neutral-800 dark:text-neutral-200">
{agentName}
</h3>
<span className="text-[13px] text-neutral-500 dark:text-neutral-400">
v{agent_version}
</span>
</div>
<h3 className="text-[15px] font-medium text-neutral-800 dark:text-neutral-200">
{agentName}
</h3>
<p className="line-clamp-2 text-sm text-neutral-600 dark:text-neutral-400">
{description}
</p>

View File

@@ -9,11 +9,11 @@ import { useAgentTableRow } from "./useAgentTableRow";
import { StoreSubmission } from "@/app/api/__generated__/models/storeSubmission";
import {
DotsThreeVerticalIcon,
EyeIcon,
Eye,
ImageBroken,
StarIcon,
TrashIcon,
PencilIcon,
Star,
Trash,
PencilSimple,
} from "@phosphor-icons/react/dist/ssr";
import { SubmissionStatus } from "@/app/api/__generated__/models/submissionStatus";
import { StoreSubmissionEditRequest } from "@/app/api/__generated__/models/storeSubmissionEditRequest";
@@ -34,7 +34,6 @@ export interface AgentTableRowProps {
categories?: string[];
store_listing_version_id?: string;
changes_summary?: string;
listing_id?: string;
onViewSubmission: (submission: StoreSubmission) => void;
onDeleteSubmission: (submission_id: string) => void;
onEditSubmission: (
@@ -61,7 +60,6 @@ export const AgentTableRow = ({
categories,
store_listing_version_id,
changes_summary,
listing_id,
onViewSubmission,
onDeleteSubmission,
onEditSubmission,
@@ -85,10 +83,11 @@ export const AgentTableRow = ({
categories,
store_listing_version_id,
changes_summary,
listing_id,
});
const canModify = status === SubmissionStatus.PENDING;
// Determine if we should show Edit or View button
const canEdit =
status === SubmissionStatus.APPROVED || status === SubmissionStatus.PENDING;
return (
<div
@@ -115,22 +114,13 @@ export const AgentTableRow = ({
</div>
)}
<div className="flex flex-col">
<div className="flex items-center gap-2">
<Text
variant="h3"
className="line-clamp-1 text-ellipsis text-neutral-800 dark:text-neutral-200"
size="large-medium"
>
{agentName}
</Text>
<Text
variant="body"
size="small"
className="text-neutral-500 dark:text-neutral-400"
>
v{agent_version}
</Text>
</div>
<Text
variant="h3"
className="line-clamp-1 text-ellipsis text-neutral-800 dark:text-neutral-200"
size="large-medium"
>
{agentName}
</Text>
<Text
variant="body"
className="line-clamp-1 text-ellipsis text-neutral-600 dark:text-neutral-400"
@@ -160,7 +150,7 @@ export const AgentTableRow = ({
{rating ? (
<div className="flex items-center justify-end gap-1">
<span className="text-sm font-medium">{rating.toFixed(1)}</span>
<StarIcon weight="fill" className="h-2 w-2" />
<Star weight="fill" className="h-2 w-2" />
</div>
) : (
<span className="text-sm text-neutral-600 dark:text-neutral-400">
@@ -176,12 +166,12 @@ export const AgentTableRow = ({
<DotsThreeVerticalIcon className="h-5 w-5 text-neutral-800" />
</DropdownMenu.Trigger>
<DropdownMenu.Content className="z-10 rounded-xl border bg-white p-1 shadow-md dark:bg-gray-800">
{canModify ? (
{canEdit ? (
<DropdownMenu.Item
onSelect={handleEdit}
className="flex cursor-pointer items-center rounded-md px-3 py-2 hover:bg-gray-100 dark:hover:bg-gray-700"
>
<PencilIcon className="mr-2 h-4 w-4 dark:text-gray-100" />
<PencilSimple className="mr-2 h-4 w-4 dark:text-gray-100" />
<span className="dark:text-gray-100">Edit</span>
</DropdownMenu.Item>
) : (
@@ -189,22 +179,18 @@ export const AgentTableRow = ({
onSelect={handleView}
className="flex cursor-pointer items-center rounded-md px-3 py-2 hover:bg-gray-100 dark:hover:bg-gray-700"
>
<EyeIcon className="mr-2 h-4 w-4 dark:text-gray-100" />
<Eye className="mr-2 h-4 w-4 dark:text-gray-100" />
<span className="dark:text-gray-100">View</span>
</DropdownMenu.Item>
)}
{canModify && (
<>
<DropdownMenu.Separator className="my-1 h-px bg-gray-300 dark:bg-gray-600" />
<DropdownMenu.Item
onSelect={handleDelete}
className="flex cursor-pointer items-center rounded-md px-3 py-2 text-red-500 hover:bg-gray-100 dark:hover:bg-gray-700"
>
<TrashIcon className="mr-2 h-4 w-4 text-red-500 dark:text-red-400" />
<span className="dark:text-red-400">Delete</span>
</DropdownMenu.Item>
</>
)}
<DropdownMenu.Separator className="my-1 h-px bg-gray-300 dark:bg-gray-600" />
<DropdownMenu.Item
onSelect={handleDelete}
className="flex cursor-pointer items-center rounded-md px-3 py-2 text-red-500 hover:bg-gray-100 dark:hover:bg-gray-700"
>
<Trash className="mr-2 h-4 w-4 text-red-500 dark:text-red-400" />
<span className="dark:text-red-400">Delete</span>
</DropdownMenu.Item>
</DropdownMenu.Content>
</DropdownMenu.Root>
</div>

View File

@@ -26,7 +26,6 @@ interface useAgentTableRowProps {
categories?: string[];
store_listing_version_id?: string;
changes_summary?: string;
listing_id?: string;
}
export const useAgentTableRow = ({
@@ -47,11 +46,9 @@ export const useAgentTableRow = ({
categories,
store_listing_version_id,
changes_summary,
listing_id,
}: useAgentTableRowProps) => {
const handleView = () => {
onViewSubmission({
listing_id: listing_id || "",
agent_id,
agent_version,
slug: "",
@@ -84,14 +81,7 @@ export const useAgentTableRow = ({
};
const handleDelete = () => {
// Backend only accepts StoreListingVersion IDs for deletion
if (!store_listing_version_id) {
console.error(
"Cannot delete submission: store_listing_version_id is required",
);
return;
}
onDeleteSubmission(store_listing_version_id);
onDeleteSubmission(agent_id);
};
return { handleView, handleDelete, handleEdit };

View File

@@ -99,7 +99,6 @@ export const MainDashboardPage = () => {
store_listing_version_id:
submission.store_listing_version_id || undefined,
changes_summary: submission.changes_summary || undefined,
listing_id: submission.listing_id,
}))}
onViewSubmission={onViewSubmission}
onDeleteSubmission={onDeleteSubmission}

View File

@@ -7665,7 +7665,6 @@
"id": { "type": "string", "title": "Id" },
"graph_id": { "type": "string", "title": "Graph Id" },
"graph_version": { "type": "integer", "title": "Graph Version" },
"owner_user_id": { "type": "string", "title": "Owner User Id" },
"image_url": {
"anyOf": [{ "type": "string" }, { "type": "null" }],
"title": "Image Url"
@@ -7748,7 +7747,6 @@
"id",
"graph_id",
"graph_version",
"owner_user_id",
"image_url",
"creator_name",
"creator_image_url",
@@ -9688,7 +9686,6 @@
},
"StoreSubmission": {
"properties": {
"listing_id": { "type": "string", "title": "Listing Id" },
"agent_id": { "type": "string", "title": "Agent Id" },
"agent_version": { "type": "integer", "title": "Agent Version" },
"name": { "type": "string", "title": "Name" },
@@ -9760,7 +9757,6 @@
},
"type": "object",
"required": [
"listing_id",
"agent_id",
"agent_version",
"name",
@@ -9823,18 +9819,8 @@
},
"StoreSubmissionRequest": {
"properties": {
"agent_id": {
"type": "string",
"minLength": 1,
"title": "Agent Id",
"description": "Agent ID cannot be empty"
},
"agent_version": {
"type": "integer",
"exclusiveMinimum": 0.0,
"title": "Agent Version",
"description": "Agent version must be greater than 0"
},
"agent_id": { "type": "string", "title": "Agent Id" },
"agent_version": { "type": "integer", "title": "Agent Version" },
"slug": { "type": "string", "title": "Slug" },
"name": { "type": "string", "title": "Name" },
"sub_heading": { "type": "string", "title": "Sub Heading" },

View File

@@ -27,7 +27,6 @@ export function EditAgentModal({
return (
<Dialog
title="Edit Agent"
styling={{
maxWidth: "45rem",
}}

View File

@@ -8,6 +8,7 @@ import { Form, FormField } from "@/components/__legacy__/ui/form";
import { StoreSubmission } from "@/app/api/__generated__/models/storeSubmission";
import { ThumbnailImages } from "../../PublishAgentModal/components/AgentInfoStep/components/ThumbnailImages";
import { StoreSubmissionEditRequest } from "@/app/api/__generated__/models/storeSubmissionEditRequest";
import { StepHeader } from "../../PublishAgentModal/components/StepHeader";
import { useEditAgentForm } from "./useEditAgentForm";
interface EditAgentFormProps {
@@ -30,10 +31,12 @@ export function EditAgentForm({
isSubmitting,
handleFormSubmit,
handleImagesChange,
} = useEditAgentForm({ submission, onSuccess, onClose });
} = useEditAgentForm({ submission, onSuccess });
return (
<div className="mx-auto flex w-full flex-col rounded-3xl">
<StepHeader title="Edit Agent" description="Update your agent details" />
<Form {...form}>
<form
onSubmit={form.handleSubmit(handleFormSubmit)}
@@ -72,7 +75,7 @@ export function EditAgentForm({
<ThumbnailImages
agentId={submission.agent_id}
onImagesChange={handleImagesChange}
initialImages={Array.from(new Set(submission.image_urls || []))}
initialImages={submission.image_urls || []}
initialSelectedImage={submission.image_urls?.[0] || null}
errorMessage={form.formState.errors.root?.message}
/>
@@ -133,7 +136,7 @@ export function EditAgentForm({
<Input
id={field.name}
label="Changes Summary"
type="textarea"
type="text"
placeholder="Briefly describe what you changed"
error={form.formState.errors.changes_summary?.message}
{...field}

View File

@@ -19,13 +19,11 @@ interface useEditAgentFormProps {
agent_id: string;
};
onSuccess: (submission: StoreSubmission) => void;
onClose: () => void;
}
export const useEditAgentForm = ({
submission,
onSuccess,
onClose,
}: useEditAgentFormProps) => {
const editAgentSchema = z.object({
title: z
@@ -47,7 +45,7 @@ export const useEditAgentForm = ({
changes_summary: z
.string()
.min(1, "Changes summary is required")
.max(500, "Changes summary must be less than 500 characters"),
.max(200, "Changes summary must be less than 200 characters"),
agentOutputDemo: z
.string()
.refine(validateYouTubeUrl, "Please enter a valid YouTube URL"),
@@ -56,11 +54,19 @@ export const useEditAgentForm = ({
type EditAgentFormData = z.infer<typeof editAgentSchema>;
const [images, setImages] = React.useState<string[]>(
Array.from(new Set(submission.image_urls || [])), // Remove duplicates
submission.image_urls || [],
);
const [isSubmitting, setIsSubmitting] = React.useState(false);
const { mutateAsync: editSubmission } = usePutV2EditStoreSubmission();
const { mutateAsync: editSubmission } = usePutV2EditStoreSubmission({
mutation: {
onSuccess: () => {
queryClient.invalidateQueries({
queryKey: getGetV2ListMySubmissionsQueryKey(),
});
},
},
});
const queryClient = useQueryClient();
const { toast } = useToast();
@@ -126,20 +132,7 @@ export const useEditAgentForm = ({
// Extract the StoreSubmission from the response
if (response.status === 200 && response.data) {
toast({
title: "Agent Updated",
description: "Your agent submission has been updated successfully.",
duration: 3000,
variant: "default",
});
queryClient.invalidateQueries({
queryKey: getGetV2ListMySubmissionsQueryKey(),
});
// Call onSuccess and explicitly close the modal
onSuccess(response.data);
onClose();
} else {
throw new Error("Failed to update submission");
}

View File

@@ -1,8 +1,10 @@
import { useEffect, useCallback, useState } from "react";
import { useForm } from "react-hook-form";
import { zodResolver } from "@hookform/resolvers/zod";
import { useQueryClient } from "@tanstack/react-query";
import { useToast } from "@/components/molecules/Toast/use-toast";
import { useBackendAPI } from "@/lib/autogpt-server-api/context";
import { getGetV2ListMySubmissionsQueryKey } from "@/app/api/__generated__/endpoints/store/store";
import * as Sentry from "@sentry/nextjs";
import {
PublishAgentFormData,
@@ -31,6 +33,7 @@ export function useAgentInfoStep({
const [images, setImages] = useState<string[]>([]);
const [isSubmitting, setIsSubmitting] = useState(false);
const queryClient = useQueryClient();
const { toast } = useToast();
const api = useBackendAPI();
@@ -51,26 +54,23 @@ export function useAgentInfoStep({
});
useEffect(() => {
if (initialData?.agent_id) {
if (initialData) {
setAgentId(initialData.agent_id);
setImages(
Array.from(
new Set([
...(initialData?.thumbnailSrc ? [initialData.thumbnailSrc] : []),
...(initialData.additionalImages || []),
]),
),
);
const initialImages = [
...(initialData?.thumbnailSrc ? [initialData.thumbnailSrc] : []),
...(initialData.additionalImages || []),
];
setImages(initialImages);
// Update form with initial data
form.reset({
changesSummary: isMarketplaceUpdate
? ""
: initialData.changesSummary || "",
changesSummary: initialData.changesSummary || "",
title: initialData.title,
subheader: initialData.subheader,
slug: initialData.slug.toLocaleLowerCase().trim(),
youtubeLink: initialData.youtubeLink,
category: initialData.category,
description: isMarketplaceUpdate ? "" : initialData.description,
description: initialData.description,
recommendedScheduleCron: initialData.recommendedScheduleCron || "",
instructions: initialData.instructions || "",
agentOutputDemo: initialData.agentOutputDemo || "",
@@ -78,13 +78,6 @@ export function useAgentInfoStep({
}
}, [initialData, form]);
// Ensure agentId is set from selectedAgentId if initialData doesn't have it
useEffect(() => {
if (selectedAgentId && !agentId) {
setAgentId(selectedAgentId);
}
}, [selectedAgentId, agentId]);
const handleImagesChange = useCallback((newImages: string[]) => {
setImages(newImages);
}, []);
@@ -99,16 +92,6 @@ export function useAgentInfoStep({
return;
}
// Validate that an agent is selected before submission
if (!selectedAgentId || !selectedAgentVersion) {
toast({
title: "Agent Selection Required",
description: "Please select an agent before submitting to the store.",
variant: "destructive",
});
return;
}
const categories = data.category ? [data.category] : [];
const filteredCategories = categories.filter(Boolean);
@@ -123,14 +106,18 @@ export function useAgentInfoStep({
image_urls: images,
video_url: data.youtubeLink || "",
agent_output_demo_url: data.agentOutputDemo || "",
agent_id: selectedAgentId,
agent_version: selectedAgentVersion,
agent_id: selectedAgentId || "",
agent_version: selectedAgentVersion || 0,
slug: (data.slug || "").replace(/\s+/g, "-"),
categories: filteredCategories,
recommended_schedule_cron: data.recommendedScheduleCron || null,
changes_summary: data.changesSummary || null,
} as any);
await queryClient.invalidateQueries({
queryKey: getGetV2ListMySubmissionsQueryKey(),
});
onSuccess(response);
} catch (error) {
Sentry.captureException(error);
@@ -152,7 +139,12 @@ export function useAgentInfoStep({
agentId,
images,
isSubmitting,
initialImages: images,
initialImages: initialData
? [
...(initialData?.thumbnailSrc ? [initialData.thumbnailSrc] : []),
...(initialData.additionalImages || []),
]
: [],
initialSelectedImage: initialData?.thumbnailSrc || null,
handleImagesChange,
handleSubmit: form.handleSubmit(handleFormSubmit),

View File

@@ -6,11 +6,9 @@ import { emptyModalState } from "./helpers";
import {
useGetV2GetMyAgents,
useGetV2ListMySubmissions,
getGetV2ListMySubmissionsQueryKey,
} from "@/app/api/__generated__/endpoints/store/store";
import { okData } from "@/app/api/helpers";
import type { MyAgent } from "@/app/api/__generated__/models/myAgent";
import { useQueryClient } from "@tanstack/react-query";
const defaultTargetState: PublishState = {
isOpen: false,
@@ -67,7 +65,6 @@ export function usePublishAgentModal({
>(preSelectedAgentVersion || null);
const router = useRouter();
const queryClient = useQueryClient();
// Fetch agent data for pre-populating form when agent is pre-selected
const { data: myAgents } = useGetV2GetMyAgents();
@@ -80,18 +77,14 @@ export function usePublishAgentModal({
}
}, [targetState]);
// Reset internal state when modal opens (only on initial open, not on every targetState change)
const [hasOpened, setHasOpened] = useState(false);
// Reset internal state when modal opens
useEffect(() => {
if (!targetState) return;
if (targetState.isOpen && !hasOpened) {
if (targetState.isOpen) {
setSelectedAgent(null);
setSelectedAgentId(preSelectedAgentId || null);
setSelectedAgentVersion(preSelectedAgentVersion || null);
setInitialData(emptyModalState);
setHasOpened(true);
} else if (!targetState.isOpen && hasOpened) {
setHasOpened(false);
}
}, [targetState, preSelectedAgentId, preSelectedAgentVersion]);
@@ -179,11 +172,6 @@ export function usePublishAgentModal({
setSelectedAgentVersion(null);
setInitialData(emptyModalState);
// Invalidate submissions query to refresh the data after modal closes
queryClient.invalidateQueries({
queryKey: getGetV2ListMySubmissionsQueryKey(),
});
// Update parent with clean closed state
const newState = {
isOpen: false,

View File

@@ -1,116 +0,0 @@
import type { Meta, StoryObj } from "@storybook/nextjs";
import { TooltipProvider } from "@/components/atoms/Tooltip/BaseTooltip";
import { Table } from "./Table";
const meta = {
title: "Molecules/Table",
component: Table,
decorators: [
(Story) => (
<TooltipProvider>
<Story />
</TooltipProvider>
),
],
parameters: {
layout: "centered",
},
tags: ["autodocs"],
argTypes: {
allowAddRow: {
control: "boolean",
description: "Whether to show the Add row button",
},
allowDeleteRow: {
control: "boolean",
description: "Whether to show delete buttons for each row",
},
readOnly: {
control: "boolean",
description:
"Whether the table is read-only (renders text instead of inputs)",
},
addRowLabel: {
control: "text",
description: "Label for the Add row button",
},
},
} satisfies Meta<typeof Table>;
export default meta;
type Story = StoryObj<typeof meta>;
export const Default: Story = {
args: {
columns: ["name", "email", "role"],
allowAddRow: true,
allowDeleteRow: true,
},
};
export const WithDefaultValues: Story = {
args: {
columns: ["name", "email", "role"],
defaultValues: [
{ name: "John Doe", email: "john@example.com", role: "Admin" },
{ name: "Jane Smith", email: "jane@example.com", role: "User" },
{ name: "Bob Wilson", email: "bob@example.com", role: "Editor" },
],
allowAddRow: true,
allowDeleteRow: true,
},
};
export const ReadOnly: Story = {
args: {
columns: ["name", "email"],
defaultValues: [
{ name: "John Doe", email: "john@example.com" },
{ name: "Jane Smith", email: "jane@example.com" },
],
readOnly: true,
},
};
export const NoAddOrDelete: Story = {
args: {
columns: ["name", "email"],
defaultValues: [
{ name: "John Doe", email: "john@example.com" },
{ name: "Jane Smith", email: "jane@example.com" },
],
allowAddRow: false,
allowDeleteRow: false,
},
};
export const SingleColumn: Story = {
args: {
columns: ["item"],
allowAddRow: true,
allowDeleteRow: true,
addRowLabel: "Add item",
},
};
export const CustomAddLabel: Story = {
args: {
columns: ["key", "value"],
allowAddRow: true,
allowDeleteRow: true,
addRowLabel: "Add new entry",
},
};
export const KeyValuePairs: Story = {
args: {
columns: ["key", "value"],
defaultValues: [
{ key: "API_KEY", value: "sk-..." },
{ key: "DATABASE_URL", value: "postgres://..." },
],
allowAddRow: true,
allowDeleteRow: true,
addRowLabel: "Add variable",
},
};

View File

@@ -1,133 +0,0 @@
import * as React from "react";
import {
Table as BaseTable,
TableBody,
TableCell,
TableHead,
TableHeader,
TableRow,
} from "@/components/__legacy__/ui/table";
import { Button } from "@/components/atoms/Button/Button";
import { Input } from "@/components/atoms/Input/Input";
import { Text } from "@/components/atoms/Text/Text";
import { Plus, Trash2 } from "lucide-react";
import { cn } from "@/lib/utils";
import { useTable, RowData } from "./useTable";
import { formatColumnTitle, formatPlaceholder } from "./helpers";
export interface TableProps {
columns: string[];
defaultValues?: RowData[];
onChange?: (rows: RowData[]) => void;
allowAddRow?: boolean;
allowDeleteRow?: boolean;
addRowLabel?: string;
className?: string;
readOnly?: boolean;
}
export function Table({
columns,
defaultValues,
onChange,
allowAddRow = true,
allowDeleteRow = true,
addRowLabel = "Add row",
className,
readOnly = false,
}: TableProps) {
const { rows, handleAddRow, handleDeleteRow, handleCellChange } = useTable({
columns,
defaultValues,
onChange,
});
const showDeleteColumn = allowDeleteRow && !readOnly;
const showAddButton = allowAddRow && !readOnly;
return (
<div className={cn("flex flex-col gap-3", className)}>
<div className="overflow-hidden rounded-xl border border-zinc-200 bg-white">
<BaseTable>
<TableHeader>
<TableRow className="border-b border-zinc-100 bg-zinc-50/50">
{columns.map((column) => (
<TableHead
key={column}
className="h-10 px-3 text-sm font-medium text-zinc-600"
>
{formatColumnTitle(column)}
</TableHead>
))}
{showDeleteColumn && <TableHead className="w-[50px]" />}
</TableRow>
</TableHeader>
<TableBody>
{rows.map((row, rowIndex) => (
<TableRow key={rowIndex} className="border-none">
{columns.map((column) => (
<TableCell key={`${rowIndex}-${column}`} className="p-2">
{readOnly ? (
<Text
variant="body"
className="px-3 py-2 text-sm text-zinc-800"
>
{row[column] || "-"}
</Text>
) : (
<Input
id={`table-${rowIndex}-${column}`}
label={formatColumnTitle(column)}
hideLabel
value={row[column] ?? ""}
onChange={(e) =>
handleCellChange(rowIndex, column, e.target.value)
}
placeholder={formatPlaceholder(column)}
size="small"
wrapperClassName="mb-0"
/>
)}
</TableCell>
))}
{showDeleteColumn && (
<TableCell className="p-2">
<Button
variant="icon"
size="icon"
onClick={() => handleDeleteRow(rowIndex)}
aria-label="Delete row"
className="text-zinc-400 transition-colors hover:text-red-500"
>
<Trash2 className="h-4 w-4" />
</Button>
</TableCell>
)}
</TableRow>
))}
{showAddButton && (
<TableRow>
<TableCell
colSpan={columns.length + (showDeleteColumn ? 1 : 0)}
className="p-2"
>
<Button
variant="outline"
size="small"
onClick={handleAddRow}
leftIcon={<Plus className="h-4 w-4" />}
className="w-fit"
>
{addRowLabel}
</Button>
</TableCell>
</TableRow>
)}
</TableBody>
</BaseTable>
</div>
</div>
);
}
export { type RowData } from "./useTable";

View File

@@ -1,7 +0,0 @@
export const formatColumnTitle = (key: string): string => {
return key.charAt(0).toUpperCase() + key.slice(1);
};
export const formatPlaceholder = (key: string): string => {
return `Enter ${key.toLowerCase()}`;
};

View File

@@ -1,81 +0,0 @@
import { useState, useEffect } from "react";
export type RowData = Record<string, string>;
interface UseTableOptions {
columns: string[];
defaultValues?: RowData[];
onChange?: (rows: RowData[]) => void;
}
export function useTable({
columns,
defaultValues,
onChange,
}: UseTableOptions) {
const createEmptyRow = (): RowData => {
const emptyRow: RowData = {};
columns.forEach((column) => {
emptyRow[column] = "";
});
return emptyRow;
};
const [rows, setRows] = useState<RowData[]>(() => {
if (defaultValues && defaultValues.length > 0) {
return defaultValues;
}
return [];
});
useEffect(() => {
if (defaultValues !== undefined) {
setRows(defaultValues);
}
}, [defaultValues]);
const updateRows = (newRows: RowData[]) => {
setRows(newRows);
onChange?.(newRows);
};
const handleAddRow = () => {
const newRows = [...rows, createEmptyRow()];
updateRows(newRows);
};
const handleDeleteRow = (rowIndex: number) => {
const newRows = rows.filter((_, index) => index !== rowIndex);
updateRows(newRows);
};
const handleCellChange = (
rowIndex: number,
columnKey: string,
value: string,
) => {
const newRows = rows.map((row, index) => {
if (index === rowIndex) {
return {
...row,
[columnKey]: value,
};
}
return row;
});
updateRows(newRows);
};
const clearAll = () => {
updateRows([]);
};
return {
rows,
handleAddRow,
handleDeleteRow,
handleCellChange,
clearAll,
createEmptyRow,
};
}

View File

@@ -37,3 +37,11 @@ html body .toastDescription {
font-size: 0.75rem !important;
line-height: 1.25rem !important;
}
/* Position close button on the right */
/* stylelint-disable-next-line selector-pseudo-class-no-unknown */
#root [data-sonner-toast] [data-close-button="true"] {
left: unset !important;
right: -18px !important;
top: -3px !important;
}

View File

@@ -1,5 +1,6 @@
"use client";
import { CheckCircle, Info, Warning, XCircle } from "@phosphor-icons/react";
import { Toaster as SonnerToaster } from "sonner";
import styles from "./styles.module.css";
@@ -22,10 +23,10 @@ export function Toaster() {
}}
className="custom__toast"
icons={{
success: null,
error: null,
warning: null,
info: null,
success: <CheckCircle className="h-4 w-4" color="#fff" weight="fill" />,
error: <XCircle className="h-4 w-4" color="#fff" weight="fill" />,
warning: <Warning className="h-4 w-4" color="#fff" weight="fill" />,
info: <Info className="h-4 w-4" color="#fff" weight="fill" />,
}}
/>
);

View File

@@ -30,8 +30,6 @@ export const FormRenderer = ({
return generateUiSchemaForCustomFields(preprocessedSchema, uiSchema);
}, [preprocessedSchema, uiSchema]);
console.log("preprocessedSchema", preprocessedSchema);
return (
<div className={"mb-6 mt-4"}>
<Form

View File

@@ -5,14 +5,19 @@ import { useAnyOfField } from "./useAnyOfField";
import { getHandleId, updateUiOption } from "../../helpers";
import { useEdgeStore } from "@/app/(platform)/build/stores/edgeStore";
import { ANY_OF_FLAG } from "../../constants";
import { findCustomFieldId } from "../../registry";
export const AnyOfField = (props: FieldProps) => {
const { registry, schema } = props;
const { fields } = registry;
const { SchemaField: _SchemaField } = fields;
const { nodeId } = registry.formContext;
const { isInputConnected } = useEdgeStore();
const uiOptions = getUiOptions(props.uiSchema, props.globalUiOptions);
const Widget = getWidget({ type: "string" }, "select", registry.widgets);
const {
handleOptionChange,
enumOptions,
@@ -21,15 +26,6 @@ export const AnyOfField = (props: FieldProps) => {
field_id,
} = useAnyOfField(props);
const parentCustomFieldId = findCustomFieldId(schema);
if (parentCustomFieldId) {
return null;
}
const uiOptions = getUiOptions(props.uiSchema, props.globalUiOptions);
const Widget = getWidget({ type: "string" }, "select", registry.widgets);
const handleId = getHandleId({
uiOptions,
id: field_id + ANY_OF_FLAG,
@@ -44,21 +40,12 @@ export const AnyOfField = (props: FieldProps) => {
const isHandleConnected = isInputConnected(nodeId, handleId);
// Now anyOf can render - custom fields if the option schema matches a custom field
const optionCustomFieldId = optionSchema
? findCustomFieldId(optionSchema)
: null;
const optionUiSchema = optionCustomFieldId
? { ...updatedUiSchema, "ui:field": optionCustomFieldId }
: updatedUiSchema;
const optionsSchemaField =
(optionSchema && optionSchema.type !== "null" && (
<_SchemaField
{...props}
schema={optionSchema}
uiSchema={optionUiSchema}
uiSchema={updatedUiSchema}
/>
)) ||
null;

View File

@@ -17,7 +17,6 @@ interface InputExpanderModalProps {
defaultValue: string;
description?: string;
placeholder?: string;
inputType?: "text" | "json";
}
export const InputExpanderModal: FC<InputExpanderModalProps> = ({
@@ -28,7 +27,6 @@ export const InputExpanderModal: FC<InputExpanderModalProps> = ({
defaultValue,
description,
placeholder,
inputType = "text",
}) => {
const [tempValue, setTempValue] = useState(defaultValue);
const [isCopied, setIsCopied] = useState(false);
@@ -80,10 +78,7 @@ export const InputExpanderModal: FC<InputExpanderModalProps> = ({
hideLabel
id="input-expander-modal"
value={tempValue}
className={cn(
"!min-h-[300px] rounded-2xlarge",
inputType === "json" && "font-mono text-sm",
)}
className="!min-h-[300px] rounded-2xlarge"
onChange={(e) => setTempValue(e.target.value)}
placeholder={placeholder || "Enter text..."}
autoFocus

View File

@@ -8,34 +8,19 @@ import { CredentialsInput } from "@/app/(platform)/library/agents/[id]/component
import { useNodeStore } from "@/app/(platform)/build/stores/nodeStore";
import { useShallow } from "zustand/react/shallow";
import { CredentialFieldTitle } from "./components/CredentialFieldTitle";
import { Switch } from "@/components/atoms/Switch/Switch";
export const CredentialsField = (props: FieldProps) => {
const { formData, onChange, schema, registry, fieldPathId, required } = props;
const { formData, onChange, schema, registry, fieldPathId } = props;
const formContext = registry.formContext;
const uiOptions = getUiOptions(props.uiSchema);
const nodeId = formContext?.nodeId;
// Get sibling inputs (hardcoded values) and credentials optional state from the node store
// Note: We select the node data directly instead of using getter functions to avoid
// creating new object references that would cause infinite re-render loops with useShallow
const { node, setCredentialsOptional } = useNodeStore(
useShallow((state) => ({
node: nodeId ? state.nodes.find((n) => n.id === nodeId) : undefined,
setCredentialsOptional: state.setCredentialsOptional,
})),
// Get sibling inputs (hardcoded values) from the node store
const hardcodedValues = useNodeStore(
useShallow((state) => (nodeId ? state.getHardCodedValues(nodeId) : {})),
);
const hardcodedValues = useMemo(
() => node?.data?.hardcodedValues || {},
[node?.data?.hardcodedValues],
);
const credentialsOptional = useMemo(() => {
const value = node?.data?.metadata?.credentials_optional;
return typeof value === "boolean" ? value : false;
}, [node?.data?.metadata?.credentials_optional]);
const handleChange = (newValue: any) => {
onChange(newValue, fieldPathId?.path);
};
@@ -67,10 +52,6 @@ export const CredentialsField = (props: FieldProps) => {
[formData?.id, formData?.provider, formData?.title, formData?.type],
);
// In builder canvas (nodeId exists): show star based on credentialsOptional toggle
// In run dialogs (no nodeId): show star based on schema's required array
const isRequired = nodeId ? !credentialsOptional : required;
return (
<div className="flex flex-col gap-2">
<CredentialFieldTitle
@@ -78,7 +59,6 @@ export const CredentialsField = (props: FieldProps) => {
registry={registry}
uiOptions={uiOptions}
schema={schema}
required={isRequired}
/>
<CredentialsInput
schema={schema as BlockIOCredentialsSubSchema}
@@ -87,31 +67,7 @@ export const CredentialsField = (props: FieldProps) => {
siblingInputs={hardcodedValues}
showTitle={false}
readOnly={formContext?.readOnly}
isOptional={!isRequired}
className="w-full"
variant="node"
/>
{/* Optional credentials toggle - only show in builder canvas, not run dialogs */}
{nodeId &&
!formContext?.readOnly &&
formContext?.showOptionalToggle !== false && (
<div className="mt-1 flex items-center gap-2">
<Switch
id={`credentials-optional-${nodeId}`}
checked={credentialsOptional}
onCheckedChange={(checked) =>
setCredentialsOptional(nodeId, checked)
}
/>
<label
htmlFor={`credentials-optional-${nodeId}`}
className="cursor-pointer text-xs text-gray-500"
>
Optional - skip block if not configured
</label>
</div>
)}
</div>
);
};

View File

@@ -18,9 +18,8 @@ export const CredentialFieldTitle = (props: {
uiOptions: UiSchema;
schema: RJSFSchema;
fieldPathId: FieldPathId;
required?: boolean;
}) => {
const { registry, uiOptions, schema, fieldPathId, required = false } = props;
const { registry, uiOptions, schema, fieldPathId } = props;
const { nodeId } = registry.formContext;
const TitleFieldTemplate = getTemplate(
@@ -51,7 +50,7 @@ export const CredentialFieldTitle = (props: {
<TitleFieldTemplate
id={titleId(fieldPathId ?? "")}
title={credentialProvider ?? ""}
required={required}
required={true}
schema={schema}
registry={registry}
uiSchema={updatedUiSchema}

View File

@@ -1,124 +0,0 @@
"use client";
import { FieldProps, getTemplate, getUiOptions } from "@rjsf/utils";
import { Input } from "@/components/atoms/Input/Input";
import { Button } from "@/components/atoms/Button/Button";
import {
Tooltip,
TooltipContent,
TooltipTrigger,
} from "@/components/atoms/Tooltip/BaseTooltip";
import { ArrowsOutIcon } from "@phosphor-icons/react";
import { InputExpanderModal } from "../../base/standard/widgets/TextInput/TextInputExpanderModal";
import { getHandleId, updateUiOption } from "../../helpers";
import { useJsonTextField } from "./useJsonTextField";
import { getPlaceholder } from "./helpers";
export const JsonTextField = (props: FieldProps) => {
const {
formData,
onChange,
schema,
registry,
uiSchema,
required,
name,
fieldPathId,
} = props;
const uiOptions = getUiOptions(uiSchema);
const TitleFieldTemplate = getTemplate(
"TitleFieldTemplate",
registry,
uiOptions,
);
const fieldId = fieldPathId?.$id ?? props.id ?? "json-field";
const handleId = getHandleId({
uiOptions,
id: fieldId,
schema: schema,
});
const updatedUiSchema = updateUiOption(uiSchema, {
handleId: handleId,
});
const {
textValue,
isModalOpen,
handleChange,
handleModalOpen,
handleModalClose,
handleModalSave,
} = useJsonTextField({
formData,
onChange,
path: fieldPathId?.path,
});
const placeholder = getPlaceholder(schema);
const title = schema.title || name || "JSON Value";
return (
<div className="flex flex-col gap-2">
<TitleFieldTemplate
id={fieldId}
title={title}
required={required}
schema={schema}
uiSchema={updatedUiSchema}
registry={registry}
/>
<div className="nodrag relative flex items-center gap-2">
<Input
id={fieldId}
hideLabel={true}
type="textarea"
label=""
size="small"
wrapperClassName="mb-0 flex-1 "
value={textValue}
onChange={handleChange}
placeholder={placeholder}
required={required}
disabled={props.disabled}
className="min-h-[60px] pr-8 font-mono text-xs"
/>
<Tooltip delayDuration={0}>
<TooltipTrigger asChild>
<Button
variant="ghost"
size="icon"
onClick={handleModalOpen}
type="button"
className="p-1"
>
<ArrowsOutIcon className="size-4" />
</Button>
</TooltipTrigger>
<TooltipContent>Expand input</TooltipContent>
</Tooltip>
</div>
{schema.description && (
<span className="text-xs text-gray-500">{schema.description}</span>
)}
<InputExpanderModal
isOpen={isModalOpen}
onClose={handleModalClose}
onSave={handleModalSave}
title={`Edit ${title}`}
description={schema.description || "Enter valid JSON"}
defaultValue={textValue}
placeholder={placeholder}
inputType="json"
/>
</div>
);
};
export default JsonTextField;

Some files were not shown because too many files have changed in this diff Show More