mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-03-17 03:00:27 -04:00
Compare commits
22 Commits
master
...
feat/llm-p
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
957ec038b8 | ||
|
|
9b39a662ee | ||
|
|
9b93a956b4 | ||
|
|
b236719bbf | ||
|
|
4f286f510f | ||
|
|
b1595d871d | ||
|
|
29ab7f2d9c | ||
|
|
784936b323 | ||
|
|
f2ae38a1a7 | ||
|
|
2ccfb4e4c1 | ||
|
|
c65e5c957a | ||
|
|
54355a691b | ||
|
|
3cafa49c4c | ||
|
|
ded002a406 | ||
|
|
4fdf89c3be | ||
|
|
d816bd739f | ||
|
|
6a16376323 | ||
|
|
ed7b02ffb1 | ||
|
|
d064198dd1 | ||
|
|
01ad033b2b | ||
|
|
56bcbda054 | ||
|
|
d40efc6056 |
@@ -53,7 +53,6 @@ from backend.copilot.tools.models import (
|
||||
UnderstandingUpdatedResponse,
|
||||
)
|
||||
from backend.copilot.tracking import track_user_message
|
||||
from backend.data.redis_client import get_redis_async
|
||||
from backend.data.workspace import get_or_create_workspace
|
||||
from backend.util.exceptions import NotFoundError
|
||||
|
||||
@@ -128,7 +127,6 @@ class SessionSummaryResponse(BaseModel):
|
||||
created_at: str
|
||||
updated_at: str
|
||||
title: str | None = None
|
||||
is_processing: bool
|
||||
|
||||
|
||||
class ListSessionsResponse(BaseModel):
|
||||
@@ -187,28 +185,6 @@ async def list_sessions(
|
||||
"""
|
||||
sessions, total_count = await get_user_sessions(user_id, limit, offset)
|
||||
|
||||
# Batch-check Redis for active stream status on each session
|
||||
processing_set: set[str] = set()
|
||||
if sessions:
|
||||
try:
|
||||
redis = await get_redis_async()
|
||||
pipe = redis.pipeline(transaction=False)
|
||||
for session in sessions:
|
||||
pipe.hget(
|
||||
f"{config.session_meta_prefix}{session.session_id}",
|
||||
"status",
|
||||
)
|
||||
statuses = await pipe.execute()
|
||||
processing_set = {
|
||||
session.session_id
|
||||
for session, st in zip(sessions, statuses)
|
||||
if st == "running"
|
||||
}
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to fetch processing status from Redis; " "defaulting to empty"
|
||||
)
|
||||
|
||||
return ListSessionsResponse(
|
||||
sessions=[
|
||||
SessionSummaryResponse(
|
||||
@@ -216,7 +192,6 @@ async def list_sessions(
|
||||
created_at=session.started_at.isoformat(),
|
||||
updated_at=session.updated_at.isoformat(),
|
||||
title=session.title,
|
||||
is_processing=session.session_id in processing_set,
|
||||
)
|
||||
for session in sessions
|
||||
],
|
||||
|
||||
@@ -638,7 +638,7 @@ async def test_process_review_action_auto_approve_creates_auto_approval_records(
|
||||
|
||||
# Mock get_node_executions to return node_id mapping
|
||||
mock_get_node_executions = mocker.patch(
|
||||
"backend.api.features.executions.review.routes.get_node_executions"
|
||||
"backend.data.execution.get_node_executions"
|
||||
)
|
||||
mock_node_exec = mocker.Mock(spec=NodeExecutionResult)
|
||||
mock_node_exec.node_exec_id = "test_node_123"
|
||||
@@ -936,7 +936,7 @@ async def test_process_review_action_auto_approve_only_applies_to_approved_revie
|
||||
|
||||
# Mock get_node_executions to return node_id mapping
|
||||
mock_get_node_executions = mocker.patch(
|
||||
"backend.api.features.executions.review.routes.get_node_executions"
|
||||
"backend.data.execution.get_node_executions"
|
||||
)
|
||||
mock_node_exec = mocker.Mock(spec=NodeExecutionResult)
|
||||
mock_node_exec.node_exec_id = "node_exec_approved"
|
||||
@@ -1148,7 +1148,7 @@ async def test_process_review_action_per_review_auto_approve_granularity(
|
||||
|
||||
# Mock get_node_executions to return batch node data
|
||||
mock_get_node_executions = mocker.patch(
|
||||
"backend.api.features.executions.review.routes.get_node_executions"
|
||||
"backend.data.execution.get_node_executions"
|
||||
)
|
||||
# Create mock node executions for each review
|
||||
mock_node_execs = []
|
||||
|
||||
@@ -6,15 +6,10 @@ import autogpt_libs.auth as autogpt_auth_lib
|
||||
from fastapi import APIRouter, HTTPException, Query, Security, status
|
||||
from prisma.enums import ReviewStatus
|
||||
|
||||
from backend.copilot.constants import (
|
||||
is_copilot_synthetic_id,
|
||||
parse_node_id_from_exec_id,
|
||||
)
|
||||
from backend.data.execution import (
|
||||
ExecutionContext,
|
||||
ExecutionStatus,
|
||||
get_graph_execution_meta,
|
||||
get_node_executions,
|
||||
)
|
||||
from backend.data.graph import get_graph_settings
|
||||
from backend.data.human_review import (
|
||||
@@ -41,38 +36,6 @@ router = APIRouter(
|
||||
)
|
||||
|
||||
|
||||
async def _resolve_node_ids(
|
||||
node_exec_ids: list[str],
|
||||
graph_exec_id: str,
|
||||
is_copilot: bool,
|
||||
) -> dict[str, str]:
|
||||
"""Resolve node_exec_id -> node_id for auto-approval records.
|
||||
|
||||
CoPilot synthetic IDs encode node_id in the format "{node_id}:{random}".
|
||||
Graph executions look up node_id from NodeExecution records.
|
||||
"""
|
||||
if not node_exec_ids:
|
||||
return {}
|
||||
|
||||
if is_copilot:
|
||||
return {neid: parse_node_id_from_exec_id(neid) for neid in node_exec_ids}
|
||||
|
||||
node_execs = await get_node_executions(
|
||||
graph_exec_id=graph_exec_id, include_exec_data=False
|
||||
)
|
||||
node_exec_map = {ne.node_exec_id: ne.node_id for ne in node_execs}
|
||||
|
||||
result = {}
|
||||
for neid in node_exec_ids:
|
||||
if neid in node_exec_map:
|
||||
result[neid] = node_exec_map[neid]
|
||||
else:
|
||||
logger.error(
|
||||
f"Failed to resolve node_id for {neid}: Node execution not found."
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
@router.get(
|
||||
"/pending",
|
||||
summary="Get Pending Reviews",
|
||||
@@ -147,16 +110,14 @@ async def list_pending_reviews_for_execution(
|
||||
"""
|
||||
|
||||
# Verify user owns the graph execution before returning reviews
|
||||
# (CoPilot synthetic IDs don't have graph execution records)
|
||||
if not is_copilot_synthetic_id(graph_exec_id):
|
||||
graph_exec = await get_graph_execution_meta(
|
||||
user_id=user_id, execution_id=graph_exec_id
|
||||
graph_exec = await get_graph_execution_meta(
|
||||
user_id=user_id, execution_id=graph_exec_id
|
||||
)
|
||||
if not graph_exec:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Graph execution #{graph_exec_id} not found",
|
||||
)
|
||||
if not graph_exec:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Graph execution #{graph_exec_id} not found",
|
||||
)
|
||||
|
||||
return await get_pending_reviews_for_execution(graph_exec_id, user_id)
|
||||
|
||||
@@ -199,26 +160,30 @@ async def process_review_action(
|
||||
)
|
||||
|
||||
graph_exec_id = next(iter(graph_exec_ids))
|
||||
is_copilot = is_copilot_synthetic_id(graph_exec_id)
|
||||
|
||||
# Validate execution status for graph executions (skip for CoPilot synthetic IDs)
|
||||
if not is_copilot:
|
||||
graph_exec_meta = await get_graph_execution_meta(
|
||||
user_id=user_id, execution_id=graph_exec_id
|
||||
# Validate execution status before processing reviews
|
||||
graph_exec_meta = await get_graph_execution_meta(
|
||||
user_id=user_id, execution_id=graph_exec_id
|
||||
)
|
||||
|
||||
if not graph_exec_meta:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Graph execution #{graph_exec_id} not found",
|
||||
)
|
||||
|
||||
# Only allow processing reviews if execution is paused for review
|
||||
# or incomplete (partial execution with some reviews already processed)
|
||||
if graph_exec_meta.status not in (
|
||||
ExecutionStatus.REVIEW,
|
||||
ExecutionStatus.INCOMPLETE,
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail=f"Cannot process reviews while execution status is {graph_exec_meta.status}. "
|
||||
f"Reviews can only be processed when execution is paused (REVIEW status). "
|
||||
f"Current status: {graph_exec_meta.status}",
|
||||
)
|
||||
if not graph_exec_meta:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Graph execution #{graph_exec_id} not found",
|
||||
)
|
||||
if graph_exec_meta.status not in (
|
||||
ExecutionStatus.REVIEW,
|
||||
ExecutionStatus.INCOMPLETE,
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail=f"Cannot process reviews while execution status is {graph_exec_meta.status}",
|
||||
)
|
||||
|
||||
# Build review decisions map and track which reviews requested auto-approval
|
||||
# Auto-approved reviews use original data (no modifications allowed)
|
||||
@@ -271,7 +236,7 @@ async def process_review_action(
|
||||
)
|
||||
return (node_id, False)
|
||||
|
||||
# Collect node_exec_ids that need auto-approval and resolve their node_ids
|
||||
# Collect node_exec_ids that need auto-approval
|
||||
node_exec_ids_needing_auto_approval = [
|
||||
node_exec_id
|
||||
for node_exec_id, review_result in updated_reviews.items()
|
||||
@@ -279,16 +244,29 @@ async def process_review_action(
|
||||
and auto_approve_requests.get(node_exec_id, False)
|
||||
]
|
||||
|
||||
node_id_map = await _resolve_node_ids(
|
||||
node_exec_ids_needing_auto_approval, graph_exec_id, is_copilot
|
||||
)
|
||||
|
||||
# Deduplicate by node_id — one auto-approval per node
|
||||
# Batch-fetch node executions to get node_ids
|
||||
nodes_needing_auto_approval: dict[str, Any] = {}
|
||||
for node_exec_id in node_exec_ids_needing_auto_approval:
|
||||
node_id = node_id_map.get(node_exec_id)
|
||||
if node_id and node_id not in nodes_needing_auto_approval:
|
||||
nodes_needing_auto_approval[node_id] = updated_reviews[node_exec_id]
|
||||
if node_exec_ids_needing_auto_approval:
|
||||
from backend.data.execution import get_node_executions
|
||||
|
||||
node_execs = await get_node_executions(
|
||||
graph_exec_id=graph_exec_id, include_exec_data=False
|
||||
)
|
||||
node_exec_map = {node_exec.node_exec_id: node_exec for node_exec in node_execs}
|
||||
|
||||
for node_exec_id in node_exec_ids_needing_auto_approval:
|
||||
node_exec = node_exec_map.get(node_exec_id)
|
||||
if node_exec:
|
||||
review_result = updated_reviews[node_exec_id]
|
||||
# Use the first approved review for this node (deduplicate by node_id)
|
||||
if node_exec.node_id not in nodes_needing_auto_approval:
|
||||
nodes_needing_auto_approval[node_exec.node_id] = review_result
|
||||
else:
|
||||
logger.error(
|
||||
f"Failed to create auto-approval record for {node_exec_id}: "
|
||||
f"Node execution not found. This may indicate a race condition "
|
||||
f"or data inconsistency."
|
||||
)
|
||||
|
||||
# Execute all auto-approval creations in parallel (deduplicated by node_id)
|
||||
auto_approval_results = await asyncio.gather(
|
||||
@@ -303,11 +281,13 @@ async def process_review_action(
|
||||
auto_approval_failed_count = 0
|
||||
for result in auto_approval_results:
|
||||
if isinstance(result, Exception):
|
||||
# Unexpected exception during auto-approval creation
|
||||
auto_approval_failed_count += 1
|
||||
logger.error(
|
||||
f"Unexpected exception during auto-approval creation: {result}"
|
||||
)
|
||||
elif isinstance(result, tuple) and len(result) == 2 and not result[1]:
|
||||
# Auto-approval creation failed (returned False)
|
||||
auto_approval_failed_count += 1
|
||||
|
||||
# Count results
|
||||
@@ -322,20 +302,22 @@ async def process_review_action(
|
||||
if review.status == ReviewStatus.REJECTED
|
||||
)
|
||||
|
||||
# Resume graph execution only for real graph executions (not CoPilot)
|
||||
# CoPilot sessions are resumed by the LLM retrying run_block with review_id
|
||||
if not is_copilot and updated_reviews:
|
||||
# Resume execution only if ALL pending reviews for this execution have been processed
|
||||
if updated_reviews:
|
||||
still_has_pending = await has_pending_reviews_for_graph_exec(graph_exec_id)
|
||||
|
||||
if not still_has_pending:
|
||||
# Get the graph_id from any processed review
|
||||
first_review = next(iter(updated_reviews.values()))
|
||||
|
||||
try:
|
||||
# Fetch user and settings to build complete execution context
|
||||
user = await get_user_by_id(user_id)
|
||||
settings = await get_graph_settings(
|
||||
user_id=user_id, graph_id=first_review.graph_id
|
||||
)
|
||||
|
||||
# Preserve user's timezone preference when resuming execution
|
||||
user_timezone = (
|
||||
user.timezone if user.timezone != USER_TIMEZONE_NOT_SET else "UTC"
|
||||
)
|
||||
|
||||
@@ -165,6 +165,7 @@ class LibraryAgent(pydantic.BaseModel):
|
||||
id: str
|
||||
graph_id: str
|
||||
graph_version: int
|
||||
owner_user_id: str
|
||||
|
||||
image_url: str | None
|
||||
|
||||
@@ -205,9 +206,7 @@ class LibraryAgent(pydantic.BaseModel):
|
||||
default_factory=list,
|
||||
description="List of recent executions with status, score, and summary",
|
||||
)
|
||||
can_access_graph: bool = pydantic.Field(
|
||||
description="Indicates whether the same user owns the corresponding graph"
|
||||
)
|
||||
can_access_graph: bool
|
||||
is_latest_version: bool
|
||||
is_favorite: bool
|
||||
folder_id: str | None = None
|
||||
@@ -325,6 +324,7 @@ class LibraryAgent(pydantic.BaseModel):
|
||||
id=agent.id,
|
||||
graph_id=agent.agentGraphId,
|
||||
graph_version=agent.agentGraphVersion,
|
||||
owner_user_id=agent.userId,
|
||||
image_url=agent.imageUrl,
|
||||
creator_name=creator_name,
|
||||
creator_image_url=creator_image_url,
|
||||
|
||||
@@ -42,6 +42,7 @@ async def test_get_library_agents_success(
|
||||
id="test-agent-1",
|
||||
graph_id="test-agent-1",
|
||||
graph_version=1,
|
||||
owner_user_id=test_user_id,
|
||||
name="Test Agent 1",
|
||||
description="Test Description 1",
|
||||
image_url=None,
|
||||
@@ -66,6 +67,7 @@ async def test_get_library_agents_success(
|
||||
id="test-agent-2",
|
||||
graph_id="test-agent-2",
|
||||
graph_version=1,
|
||||
owner_user_id=test_user_id,
|
||||
name="Test Agent 2",
|
||||
description="Test Description 2",
|
||||
image_url=None,
|
||||
@@ -129,6 +131,7 @@ async def test_get_favorite_library_agents_success(
|
||||
id="test-agent-1",
|
||||
graph_id="test-agent-1",
|
||||
graph_version=1,
|
||||
owner_user_id=test_user_id,
|
||||
name="Favorite Agent 1",
|
||||
description="Test Favorite Description 1",
|
||||
image_url=None,
|
||||
@@ -181,6 +184,7 @@ def test_add_agent_to_library_success(
|
||||
id="test-library-agent-id",
|
||||
graph_id="test-agent-1",
|
||||
graph_version=1,
|
||||
owner_user_id=test_user_id,
|
||||
name="Test Agent 1",
|
||||
description="Test Description 1",
|
||||
image_url=None,
|
||||
|
||||
@@ -94,8 +94,3 @@ class NotificationPayload(pydantic.BaseModel):
|
||||
|
||||
class OnboardingNotificationPayload(NotificationPayload):
|
||||
step: OnboardingStep | None
|
||||
|
||||
|
||||
class CopilotCompletionPayload(NotificationPayload):
|
||||
session_id: str
|
||||
status: Literal["completed", "failed"]
|
||||
|
||||
@@ -37,8 +37,10 @@ import backend.api.features.workspace.routes as workspace_routes
|
||||
import backend.data.block
|
||||
import backend.data.db
|
||||
import backend.data.graph
|
||||
import backend.data.llm_registry
|
||||
import backend.data.user
|
||||
import backend.integrations.webhooks.utils
|
||||
import backend.server.v2.llm
|
||||
import backend.util.service
|
||||
import backend.util.settings
|
||||
from backend.api.features.library.exceptions import (
|
||||
@@ -117,11 +119,30 @@ async def lifespan_context(app: fastapi.FastAPI):
|
||||
|
||||
AutoRegistry.patch_integrations()
|
||||
|
||||
# Refresh LLM registry before initializing blocks so blocks can use registry data
|
||||
# Note: Graceful fallback for now since no blocks consume registry yet (comes in PR #5)
|
||||
# When block integration lands, this should fail hard or skip block initialization
|
||||
try:
|
||||
await backend.data.llm_registry.refresh_llm_registry()
|
||||
logger.info("LLM registry refreshed successfully at startup")
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to refresh LLM registry at startup: {e}. "
|
||||
"Blocks will initialize with empty registry."
|
||||
)
|
||||
|
||||
await backend.data.block.initialize_blocks()
|
||||
|
||||
await backend.data.user.migrate_and_encrypt_user_integrations()
|
||||
await backend.data.graph.fix_llm_provider_credentials()
|
||||
await backend.data.graph.migrate_llm_models(DEFAULT_LLM_MODEL)
|
||||
try:
|
||||
await backend.data.graph.migrate_llm_models(DEFAULT_LLM_MODEL)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to migrate LLM models at startup: {e}. "
|
||||
"This is expected in test environments without AgentNode table."
|
||||
)
|
||||
|
||||
await backend.integrations.webhooks.utils.migrate_legacy_triggered_graphs()
|
||||
|
||||
with launch_darkly_context():
|
||||
@@ -348,6 +369,11 @@ app.include_router(
|
||||
tags=["oauth"],
|
||||
prefix="/api/oauth",
|
||||
)
|
||||
app.include_router(
|
||||
backend.server.v2.llm.router,
|
||||
tags=["v2", "llm"],
|
||||
prefix="/api",
|
||||
)
|
||||
|
||||
app.mount("/external-api", external_api)
|
||||
|
||||
|
||||
@@ -624,7 +624,6 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||
graph_id: str,
|
||||
graph_version: int,
|
||||
execution_context: "ExecutionContext",
|
||||
is_graph_execution: bool = True,
|
||||
**kwargs,
|
||||
) -> tuple[bool, BlockInput]:
|
||||
"""
|
||||
@@ -653,7 +652,6 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||
graph_version=graph_version,
|
||||
block_name=self.name,
|
||||
editable=True,
|
||||
is_graph_execution=is_graph_execution,
|
||||
)
|
||||
|
||||
if decision is None:
|
||||
|
||||
@@ -126,7 +126,7 @@ class PrintToConsoleBlock(Block):
|
||||
output_schema=PrintToConsoleBlock.Output,
|
||||
test_input={"text": "Hello, World!"},
|
||||
is_sensitive_action=True,
|
||||
disabled=True,
|
||||
disabled=True, # Disabled per Nick Tindle's request (OPEN-3000)
|
||||
test_output=[
|
||||
("output", "Hello, World!"),
|
||||
("status", "printed"),
|
||||
|
||||
@@ -96,7 +96,6 @@ class SendEmailBlock(Block):
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[("status", "Email sent successfully")],
|
||||
test_mock={"send_email": lambda *args, **kwargs: "Email sent successfully"},
|
||||
is_sensitive_action=True,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -1,3 +0,0 @@
|
||||
def github_repo_path(repo_url: str) -> str:
|
||||
"""Extract 'owner/repo' from a GitHub repository URL."""
|
||||
return repo_url.replace("https://github.com/", "")
|
||||
@@ -1,374 +0,0 @@
|
||||
import asyncio
|
||||
from enum import StrEnum
|
||||
from urllib.parse import quote
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from backend.blocks._base import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
from ._api import get_api
|
||||
from ._auth import (
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
GithubCredentials,
|
||||
GithubCredentialsField,
|
||||
GithubCredentialsInput,
|
||||
)
|
||||
from ._utils import github_repo_path
|
||||
|
||||
|
||||
class GithubListCommitsBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
branch: str = SchemaField(
|
||||
description="Branch name to list commits from",
|
||||
default="main",
|
||||
)
|
||||
per_page: int = SchemaField(
|
||||
description="Number of commits to return (max 100)",
|
||||
default=30,
|
||||
ge=1,
|
||||
le=100,
|
||||
)
|
||||
page: int = SchemaField(
|
||||
description="Page number for pagination",
|
||||
default=1,
|
||||
ge=1,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class CommitItem(TypedDict):
|
||||
sha: str
|
||||
message: str
|
||||
author: str
|
||||
date: str
|
||||
url: str
|
||||
|
||||
commit: CommitItem = SchemaField(
|
||||
title="Commit", description="A commit with its details"
|
||||
)
|
||||
commits: list[CommitItem] = SchemaField(
|
||||
description="List of commits with their details"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if listing commits failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="8b13f579-d8b6-4dc2-a140-f770428805de",
|
||||
description="This block lists commits on a branch in a GitHub repository.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubListCommitsBlock.Input,
|
||||
output_schema=GithubListCommitsBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"branch": "main",
|
||||
"per_page": 30,
|
||||
"page": 1,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
(
|
||||
"commits",
|
||||
[
|
||||
{
|
||||
"sha": "abc123",
|
||||
"message": "Initial commit",
|
||||
"author": "octocat",
|
||||
"date": "2024-01-01T00:00:00Z",
|
||||
"url": "https://github.com/owner/repo/commit/abc123",
|
||||
}
|
||||
],
|
||||
),
|
||||
(
|
||||
"commit",
|
||||
{
|
||||
"sha": "abc123",
|
||||
"message": "Initial commit",
|
||||
"author": "octocat",
|
||||
"date": "2024-01-01T00:00:00Z",
|
||||
"url": "https://github.com/owner/repo/commit/abc123",
|
||||
},
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"list_commits": lambda *args, **kwargs: [
|
||||
{
|
||||
"sha": "abc123",
|
||||
"message": "Initial commit",
|
||||
"author": "octocat",
|
||||
"date": "2024-01-01T00:00:00Z",
|
||||
"url": "https://github.com/owner/repo/commit/abc123",
|
||||
}
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def list_commits(
|
||||
credentials: GithubCredentials,
|
||||
repo_url: str,
|
||||
branch: str,
|
||||
per_page: int,
|
||||
page: int,
|
||||
) -> list[Output.CommitItem]:
|
||||
api = get_api(credentials)
|
||||
commits_url = repo_url + "/commits"
|
||||
params = {"sha": branch, "per_page": str(per_page), "page": str(page)}
|
||||
response = await api.get(commits_url, params=params)
|
||||
data = response.json()
|
||||
repo_path = github_repo_path(repo_url)
|
||||
return [
|
||||
GithubListCommitsBlock.Output.CommitItem(
|
||||
sha=c["sha"],
|
||||
message=c["commit"]["message"],
|
||||
author=(c["commit"].get("author") or {}).get("name", "Unknown"),
|
||||
date=(c["commit"].get("author") or {}).get("date", ""),
|
||||
url=f"https://github.com/{repo_path}/commit/{c['sha']}",
|
||||
)
|
||||
for c in data
|
||||
]
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
commits = await self.list_commits(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.branch,
|
||||
input_data.per_page,
|
||||
input_data.page,
|
||||
)
|
||||
yield "commits", commits
|
||||
for commit in commits:
|
||||
yield "commit", commit
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class FileOperation(StrEnum):
|
||||
"""File operations for GithubMultiFileCommitBlock.
|
||||
|
||||
UPSERT creates or overwrites a file (the Git Trees API does not distinguish
|
||||
between creation and update — the blob is placed at the given path regardless
|
||||
of whether a file already exists there).
|
||||
|
||||
DELETE removes a file from the tree.
|
||||
"""
|
||||
|
||||
UPSERT = "upsert"
|
||||
DELETE = "delete"
|
||||
|
||||
|
||||
class FileOperationInput(TypedDict):
|
||||
path: str
|
||||
content: str
|
||||
operation: FileOperation
|
||||
|
||||
|
||||
class GithubMultiFileCommitBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
branch: str = SchemaField(
|
||||
description="Branch to commit to",
|
||||
placeholder="feature-branch",
|
||||
)
|
||||
commit_message: str = SchemaField(
|
||||
description="Commit message",
|
||||
placeholder="Add new feature",
|
||||
)
|
||||
files: list[FileOperationInput] = SchemaField(
|
||||
description=(
|
||||
"List of file operations. Each item has: "
|
||||
"'path' (file path), 'content' (file content, ignored for delete), "
|
||||
"'operation' (upsert/delete)"
|
||||
),
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
sha: str = SchemaField(description="SHA of the new commit")
|
||||
url: str = SchemaField(description="URL of the new commit")
|
||||
error: str = SchemaField(description="Error message if the commit failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="389eee51-a95e-4230-9bed-92167a327802",
|
||||
description=(
|
||||
"This block creates a single commit with multiple file "
|
||||
"upsert/delete operations using the Git Trees API."
|
||||
),
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubMultiFileCommitBlock.Input,
|
||||
output_schema=GithubMultiFileCommitBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"branch": "feature",
|
||||
"commit_message": "Add files",
|
||||
"files": [
|
||||
{
|
||||
"path": "src/new.py",
|
||||
"content": "print('hello')",
|
||||
"operation": "upsert",
|
||||
},
|
||||
{
|
||||
"path": "src/old.py",
|
||||
"content": "",
|
||||
"operation": "delete",
|
||||
},
|
||||
],
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("sha", "newcommitsha"),
|
||||
("url", "https://github.com/owner/repo/commit/newcommitsha"),
|
||||
],
|
||||
test_mock={
|
||||
"multi_file_commit": lambda *args, **kwargs: (
|
||||
"newcommitsha",
|
||||
"https://github.com/owner/repo/commit/newcommitsha",
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def multi_file_commit(
|
||||
credentials: GithubCredentials,
|
||||
repo_url: str,
|
||||
branch: str,
|
||||
commit_message: str,
|
||||
files: list[FileOperationInput],
|
||||
) -> tuple[str, str]:
|
||||
api = get_api(credentials)
|
||||
safe_branch = quote(branch, safe="")
|
||||
|
||||
# 1. Get the latest commit SHA for the branch
|
||||
ref_url = repo_url + f"/git/refs/heads/{safe_branch}"
|
||||
response = await api.get(ref_url)
|
||||
ref_data = response.json()
|
||||
latest_commit_sha = ref_data["object"]["sha"]
|
||||
|
||||
# 2. Get the tree SHA of the latest commit
|
||||
commit_url = repo_url + f"/git/commits/{latest_commit_sha}"
|
||||
response = await api.get(commit_url)
|
||||
commit_data = response.json()
|
||||
base_tree_sha = commit_data["tree"]["sha"]
|
||||
|
||||
# 3. Build tree entries for each file operation (blobs created concurrently)
|
||||
async def _create_blob(content: str) -> str:
|
||||
blob_url = repo_url + "/git/blobs"
|
||||
blob_response = await api.post(
|
||||
blob_url,
|
||||
json={"content": content, "encoding": "utf-8"},
|
||||
)
|
||||
return blob_response.json()["sha"]
|
||||
|
||||
tree_entries: list[dict] = []
|
||||
upsert_files = []
|
||||
for file_op in files:
|
||||
path = file_op["path"]
|
||||
operation = FileOperation(file_op.get("operation", "upsert"))
|
||||
|
||||
if operation == FileOperation.DELETE:
|
||||
tree_entries.append(
|
||||
{
|
||||
"path": path,
|
||||
"mode": "100644",
|
||||
"type": "blob",
|
||||
"sha": None, # null SHA = delete
|
||||
}
|
||||
)
|
||||
else:
|
||||
upsert_files.append((path, file_op.get("content", "")))
|
||||
|
||||
# Create all blobs concurrently
|
||||
if upsert_files:
|
||||
blob_shas = await asyncio.gather(
|
||||
*[_create_blob(content) for _, content in upsert_files]
|
||||
)
|
||||
for (path, _), blob_sha in zip(upsert_files, blob_shas):
|
||||
tree_entries.append(
|
||||
{
|
||||
"path": path,
|
||||
"mode": "100644",
|
||||
"type": "blob",
|
||||
"sha": blob_sha,
|
||||
}
|
||||
)
|
||||
|
||||
# 4. Create a new tree
|
||||
tree_url = repo_url + "/git/trees"
|
||||
tree_response = await api.post(
|
||||
tree_url,
|
||||
json={"base_tree": base_tree_sha, "tree": tree_entries},
|
||||
)
|
||||
new_tree_sha = tree_response.json()["sha"]
|
||||
|
||||
# 5. Create a new commit
|
||||
new_commit_url = repo_url + "/git/commits"
|
||||
commit_response = await api.post(
|
||||
new_commit_url,
|
||||
json={
|
||||
"message": commit_message,
|
||||
"tree": new_tree_sha,
|
||||
"parents": [latest_commit_sha],
|
||||
},
|
||||
)
|
||||
new_commit_sha = commit_response.json()["sha"]
|
||||
|
||||
# 6. Update the branch reference
|
||||
try:
|
||||
await api.patch(
|
||||
ref_url,
|
||||
json={"sha": new_commit_sha},
|
||||
)
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"Commit {new_commit_sha} was created but failed to update "
|
||||
f"ref heads/{branch}: {e}. "
|
||||
f"You can recover by manually updating the branch to {new_commit_sha}."
|
||||
) from e
|
||||
|
||||
repo_path = github_repo_path(repo_url)
|
||||
commit_web_url = f"https://github.com/{repo_path}/commit/{new_commit_sha}"
|
||||
return new_commit_sha, commit_web_url
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
sha, url = await self.multi_file_commit(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.branch,
|
||||
input_data.commit_message,
|
||||
input_data.files,
|
||||
)
|
||||
yield "sha", sha
|
||||
yield "url", url
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
@@ -1,5 +1,4 @@
|
||||
import re
|
||||
from typing import Literal
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
@@ -21,8 +20,6 @@ from ._auth import (
|
||||
GithubCredentialsInput,
|
||||
)
|
||||
|
||||
MergeMethod = Literal["merge", "squash", "rebase"]
|
||||
|
||||
|
||||
class GithubListPullRequestsBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
@@ -561,109 +558,12 @@ class GithubListPRReviewersBlock(Block):
|
||||
yield "reviewer", reviewer
|
||||
|
||||
|
||||
class GithubMergePullRequestBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
pr_url: str = SchemaField(
|
||||
description="URL of the GitHub pull request",
|
||||
placeholder="https://github.com/owner/repo/pull/1",
|
||||
)
|
||||
merge_method: MergeMethod = SchemaField(
|
||||
description="Merge method to use: merge, squash, or rebase",
|
||||
default="merge",
|
||||
)
|
||||
commit_title: str = SchemaField(
|
||||
description="Title for the merge commit (optional, used for merge and squash)",
|
||||
default="",
|
||||
)
|
||||
commit_message: str = SchemaField(
|
||||
description="Message for the merge commit (optional, used for merge and squash)",
|
||||
default="",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
sha: str = SchemaField(description="SHA of the merge commit")
|
||||
merged: bool = SchemaField(description="Whether the PR was merged")
|
||||
message: str = SchemaField(description="Merge status message")
|
||||
error: str = SchemaField(description="Error message if the merge failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="77456c22-33d8-4fd4-9eef-50b46a35bb48",
|
||||
description="This block merges a pull request using merge, squash, or rebase.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubMergePullRequestBlock.Input,
|
||||
output_schema=GithubMergePullRequestBlock.Output,
|
||||
test_input={
|
||||
"pr_url": "https://github.com/owner/repo/pull/1",
|
||||
"merge_method": "squash",
|
||||
"commit_title": "",
|
||||
"commit_message": "",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("sha", "abc123"),
|
||||
("merged", True),
|
||||
("message", "Pull Request successfully merged"),
|
||||
],
|
||||
test_mock={
|
||||
"merge_pr": lambda *args, **kwargs: (
|
||||
"abc123",
|
||||
True,
|
||||
"Pull Request successfully merged",
|
||||
)
|
||||
},
|
||||
is_sensitive_action=True,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def merge_pr(
|
||||
credentials: GithubCredentials,
|
||||
pr_url: str,
|
||||
merge_method: MergeMethod,
|
||||
commit_title: str,
|
||||
commit_message: str,
|
||||
) -> tuple[str, bool, str]:
|
||||
api = get_api(credentials)
|
||||
merge_url = prepare_pr_api_url(pr_url=pr_url, path="merge")
|
||||
data: dict[str, str] = {"merge_method": merge_method}
|
||||
if commit_title:
|
||||
data["commit_title"] = commit_title
|
||||
if commit_message:
|
||||
data["commit_message"] = commit_message
|
||||
response = await api.put(merge_url, json=data)
|
||||
result = response.json()
|
||||
return result["sha"], result["merged"], result["message"]
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
sha, merged, message = await self.merge_pr(
|
||||
credentials,
|
||||
input_data.pr_url,
|
||||
input_data.merge_method,
|
||||
input_data.commit_title,
|
||||
input_data.commit_message,
|
||||
)
|
||||
yield "sha", sha
|
||||
yield "merged", merged
|
||||
yield "message", message
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
def prepare_pr_api_url(pr_url: str, path: str) -> str:
|
||||
# Pattern to capture the base repository URL and the pull request number
|
||||
pattern = r"^(?:(https?)://)?([^/]+/[^/]+/[^/]+)/pull/(\d+)"
|
||||
pattern = r"^(?:https?://)?([^/]+/[^/]+/[^/]+)/pull/(\d+)"
|
||||
match = re.match(pattern, pr_url)
|
||||
if not match:
|
||||
return pr_url
|
||||
|
||||
scheme, base_url, pr_number = match.groups()
|
||||
return f"{scheme or 'https'}://{base_url}/pulls/{pr_number}/{path}"
|
||||
base_url, pr_number = match.groups()
|
||||
return f"{base_url}/pulls/{pr_number}/{path}"
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import base64
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from backend.blocks._base import (
|
||||
@@ -17,7 +19,6 @@ from ._auth import (
|
||||
GithubCredentialsField,
|
||||
GithubCredentialsInput,
|
||||
)
|
||||
from ._utils import github_repo_path
|
||||
|
||||
|
||||
class GithubListTagsBlock(Block):
|
||||
@@ -88,7 +89,7 @@ class GithubListTagsBlock(Block):
|
||||
tags_url = repo_url + "/tags"
|
||||
response = await api.get(tags_url)
|
||||
data = response.json()
|
||||
repo_path = github_repo_path(repo_url)
|
||||
repo_path = repo_url.replace("https://github.com/", "")
|
||||
tags: list[GithubListTagsBlock.Output.TagItem] = [
|
||||
{
|
||||
"name": tag["name"],
|
||||
@@ -114,6 +115,101 @@ class GithubListTagsBlock(Block):
|
||||
yield "tag", tag
|
||||
|
||||
|
||||
class GithubListBranchesBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class BranchItem(TypedDict):
|
||||
name: str
|
||||
url: str
|
||||
|
||||
branch: BranchItem = SchemaField(
|
||||
title="Branch",
|
||||
description="Branches with their name and file tree browser URL",
|
||||
)
|
||||
branches: list[BranchItem] = SchemaField(
|
||||
description="List of branches with their name and file tree browser URL"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="74243e49-2bec-4916-8bf4-db43d44aead5",
|
||||
description="This block lists all branches for a specified GitHub repository.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubListBranchesBlock.Input,
|
||||
output_schema=GithubListBranchesBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
(
|
||||
"branches",
|
||||
[
|
||||
{
|
||||
"name": "main",
|
||||
"url": "https://github.com/owner/repo/tree/main",
|
||||
}
|
||||
],
|
||||
),
|
||||
(
|
||||
"branch",
|
||||
{
|
||||
"name": "main",
|
||||
"url": "https://github.com/owner/repo/tree/main",
|
||||
},
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"list_branches": lambda *args, **kwargs: [
|
||||
{
|
||||
"name": "main",
|
||||
"url": "https://github.com/owner/repo/tree/main",
|
||||
}
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def list_branches(
|
||||
credentials: GithubCredentials, repo_url: str
|
||||
) -> list[Output.BranchItem]:
|
||||
api = get_api(credentials)
|
||||
branches_url = repo_url + "/branches"
|
||||
response = await api.get(branches_url)
|
||||
data = response.json()
|
||||
repo_path = repo_url.replace("https://github.com/", "")
|
||||
branches: list[GithubListBranchesBlock.Output.BranchItem] = [
|
||||
{
|
||||
"name": branch["name"],
|
||||
"url": f"https://github.com/{repo_path}/tree/{branch['name']}",
|
||||
}
|
||||
for branch in data
|
||||
]
|
||||
return branches
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
branches = await self.list_branches(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
)
|
||||
yield "branches", branches
|
||||
for branch in branches:
|
||||
yield "branch", branch
|
||||
|
||||
|
||||
class GithubListDiscussionsBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
@@ -187,7 +283,7 @@ class GithubListDiscussionsBlock(Block):
|
||||
) -> list[Output.DiscussionItem]:
|
||||
api = get_api(credentials)
|
||||
# GitHub GraphQL API endpoint is different; we'll use api.post with custom URL
|
||||
repo_path = github_repo_path(repo_url)
|
||||
repo_path = repo_url.replace("https://github.com/", "")
|
||||
owner, repo = repo_path.split("/")
|
||||
query = """
|
||||
query($owner: String!, $repo: String!, $num: Int!) {
|
||||
@@ -320,6 +416,564 @@ class GithubListReleasesBlock(Block):
|
||||
yield "release", release
|
||||
|
||||
|
||||
class GithubReadFileBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
file_path: str = SchemaField(
|
||||
description="Path to the file in the repository",
|
||||
placeholder="path/to/file",
|
||||
)
|
||||
branch: str = SchemaField(
|
||||
description="Branch to read from",
|
||||
placeholder="branch_name",
|
||||
default="master",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
text_content: str = SchemaField(
|
||||
description="Content of the file (decoded as UTF-8 text)"
|
||||
)
|
||||
raw_content: str = SchemaField(
|
||||
description="Raw base64-encoded content of the file"
|
||||
)
|
||||
size: int = SchemaField(description="The size of the file (in bytes)")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="87ce6c27-5752-4bbc-8e26-6da40a3dcfd3",
|
||||
description="This block reads the content of a specified file from a GitHub repository.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubReadFileBlock.Input,
|
||||
output_schema=GithubReadFileBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"file_path": "path/to/file",
|
||||
"branch": "master",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("raw_content", "RmlsZSBjb250ZW50"),
|
||||
("text_content", "File content"),
|
||||
("size", 13),
|
||||
],
|
||||
test_mock={"read_file": lambda *args, **kwargs: ("RmlsZSBjb250ZW50", 13)},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def read_file(
|
||||
credentials: GithubCredentials, repo_url: str, file_path: str, branch: str
|
||||
) -> tuple[str, int]:
|
||||
api = get_api(credentials)
|
||||
content_url = repo_url + f"/contents/{file_path}?ref={branch}"
|
||||
response = await api.get(content_url)
|
||||
data = response.json()
|
||||
|
||||
if isinstance(data, list):
|
||||
# Multiple entries of different types exist at this path
|
||||
if not (file := next((f for f in data if f["type"] == "file"), None)):
|
||||
raise TypeError("Not a file")
|
||||
data = file
|
||||
|
||||
if data["type"] != "file":
|
||||
raise TypeError("Not a file")
|
||||
|
||||
return data["content"], data["size"]
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
content, size = await self.read_file(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.file_path,
|
||||
input_data.branch,
|
||||
)
|
||||
yield "raw_content", content
|
||||
yield "text_content", base64.b64decode(content).decode("utf-8")
|
||||
yield "size", size
|
||||
|
||||
|
||||
class GithubReadFolderBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
folder_path: str = SchemaField(
|
||||
description="Path to the folder in the repository",
|
||||
placeholder="path/to/folder",
|
||||
)
|
||||
branch: str = SchemaField(
|
||||
description="Branch name to read from (defaults to master)",
|
||||
placeholder="branch_name",
|
||||
default="master",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class DirEntry(TypedDict):
|
||||
name: str
|
||||
path: str
|
||||
|
||||
class FileEntry(TypedDict):
|
||||
name: str
|
||||
path: str
|
||||
size: int
|
||||
|
||||
file: FileEntry = SchemaField(description="Files in the folder")
|
||||
dir: DirEntry = SchemaField(description="Directories in the folder")
|
||||
error: str = SchemaField(
|
||||
description="Error message if reading the folder failed"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="1355f863-2db3-4d75-9fba-f91e8a8ca400",
|
||||
description="This block reads the content of a specified folder from a GitHub repository.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubReadFolderBlock.Input,
|
||||
output_schema=GithubReadFolderBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"folder_path": "path/to/folder",
|
||||
"branch": "master",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
(
|
||||
"file",
|
||||
{
|
||||
"name": "file1.txt",
|
||||
"path": "path/to/folder/file1.txt",
|
||||
"size": 1337,
|
||||
},
|
||||
),
|
||||
("dir", {"name": "dir2", "path": "path/to/folder/dir2"}),
|
||||
],
|
||||
test_mock={
|
||||
"read_folder": lambda *args, **kwargs: (
|
||||
[
|
||||
{
|
||||
"name": "file1.txt",
|
||||
"path": "path/to/folder/file1.txt",
|
||||
"size": 1337,
|
||||
}
|
||||
],
|
||||
[{"name": "dir2", "path": "path/to/folder/dir2"}],
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def read_folder(
|
||||
credentials: GithubCredentials, repo_url: str, folder_path: str, branch: str
|
||||
) -> tuple[list[Output.FileEntry], list[Output.DirEntry]]:
|
||||
api = get_api(credentials)
|
||||
contents_url = repo_url + f"/contents/{folder_path}?ref={branch}"
|
||||
response = await api.get(contents_url)
|
||||
data = response.json()
|
||||
|
||||
if not isinstance(data, list):
|
||||
raise TypeError("Not a folder")
|
||||
|
||||
files: list[GithubReadFolderBlock.Output.FileEntry] = [
|
||||
GithubReadFolderBlock.Output.FileEntry(
|
||||
name=entry["name"],
|
||||
path=entry["path"],
|
||||
size=entry["size"],
|
||||
)
|
||||
for entry in data
|
||||
if entry["type"] == "file"
|
||||
]
|
||||
|
||||
dirs: list[GithubReadFolderBlock.Output.DirEntry] = [
|
||||
GithubReadFolderBlock.Output.DirEntry(
|
||||
name=entry["name"],
|
||||
path=entry["path"],
|
||||
)
|
||||
for entry in data
|
||||
if entry["type"] == "dir"
|
||||
]
|
||||
|
||||
return files, dirs
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
files, dirs = await self.read_folder(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.folder_path.lstrip("/"),
|
||||
input_data.branch,
|
||||
)
|
||||
for file in files:
|
||||
yield "file", file
|
||||
for dir in dirs:
|
||||
yield "dir", dir
|
||||
|
||||
|
||||
class GithubMakeBranchBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
new_branch: str = SchemaField(
|
||||
description="Name of the new branch",
|
||||
placeholder="new_branch_name",
|
||||
)
|
||||
source_branch: str = SchemaField(
|
||||
description="Name of the source branch",
|
||||
placeholder="source_branch_name",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
status: str = SchemaField(description="Status of the branch creation operation")
|
||||
error: str = SchemaField(
|
||||
description="Error message if the branch creation failed"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="944cc076-95e7-4d1b-b6b6-b15d8ee5448d",
|
||||
description="This block creates a new branch from a specified source branch.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubMakeBranchBlock.Input,
|
||||
output_schema=GithubMakeBranchBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"new_branch": "new_branch_name",
|
||||
"source_branch": "source_branch_name",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[("status", "Branch created successfully")],
|
||||
test_mock={
|
||||
"create_branch": lambda *args, **kwargs: "Branch created successfully"
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def create_branch(
|
||||
credentials: GithubCredentials,
|
||||
repo_url: str,
|
||||
new_branch: str,
|
||||
source_branch: str,
|
||||
) -> str:
|
||||
api = get_api(credentials)
|
||||
ref_url = repo_url + f"/git/refs/heads/{source_branch}"
|
||||
response = await api.get(ref_url)
|
||||
data = response.json()
|
||||
sha = data["object"]["sha"]
|
||||
|
||||
# Create the new branch
|
||||
new_ref_url = repo_url + "/git/refs"
|
||||
data = {
|
||||
"ref": f"refs/heads/{new_branch}",
|
||||
"sha": sha,
|
||||
}
|
||||
response = await api.post(new_ref_url, json=data)
|
||||
return "Branch created successfully"
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
status = await self.create_branch(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.new_branch,
|
||||
input_data.source_branch,
|
||||
)
|
||||
yield "status", status
|
||||
|
||||
|
||||
class GithubDeleteBranchBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
branch: str = SchemaField(
|
||||
description="Name of the branch to delete",
|
||||
placeholder="branch_name",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
status: str = SchemaField(description="Status of the branch deletion operation")
|
||||
error: str = SchemaField(
|
||||
description="Error message if the branch deletion failed"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="0d4130f7-e0ab-4d55-adc3-0a40225e80f4",
|
||||
description="This block deletes a specified branch.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubDeleteBranchBlock.Input,
|
||||
output_schema=GithubDeleteBranchBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"branch": "branch_name",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[("status", "Branch deleted successfully")],
|
||||
test_mock={
|
||||
"delete_branch": lambda *args, **kwargs: "Branch deleted successfully"
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def delete_branch(
|
||||
credentials: GithubCredentials, repo_url: str, branch: str
|
||||
) -> str:
|
||||
api = get_api(credentials)
|
||||
ref_url = repo_url + f"/git/refs/heads/{branch}"
|
||||
await api.delete(ref_url)
|
||||
return "Branch deleted successfully"
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
status = await self.delete_branch(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.branch,
|
||||
)
|
||||
yield "status", status
|
||||
|
||||
|
||||
class GithubCreateFileBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
file_path: str = SchemaField(
|
||||
description="Path where the file should be created",
|
||||
placeholder="path/to/file.txt",
|
||||
)
|
||||
content: str = SchemaField(
|
||||
description="Content to write to the file",
|
||||
placeholder="File content here",
|
||||
)
|
||||
branch: str = SchemaField(
|
||||
description="Branch where the file should be created",
|
||||
default="main",
|
||||
)
|
||||
commit_message: str = SchemaField(
|
||||
description="Message for the commit",
|
||||
default="Create new file",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
url: str = SchemaField(description="URL of the created file")
|
||||
sha: str = SchemaField(description="SHA of the commit")
|
||||
error: str = SchemaField(
|
||||
description="Error message if the file creation failed"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="8fd132ac-b917-428a-8159-d62893e8a3fe",
|
||||
description="This block creates a new file in a GitHub repository.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubCreateFileBlock.Input,
|
||||
output_schema=GithubCreateFileBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"file_path": "test/file.txt",
|
||||
"content": "Test content",
|
||||
"branch": "main",
|
||||
"commit_message": "Create test file",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("url", "https://github.com/owner/repo/blob/main/test/file.txt"),
|
||||
("sha", "abc123"),
|
||||
],
|
||||
test_mock={
|
||||
"create_file": lambda *args, **kwargs: (
|
||||
"https://github.com/owner/repo/blob/main/test/file.txt",
|
||||
"abc123",
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def create_file(
|
||||
credentials: GithubCredentials,
|
||||
repo_url: str,
|
||||
file_path: str,
|
||||
content: str,
|
||||
branch: str,
|
||||
commit_message: str,
|
||||
) -> tuple[str, str]:
|
||||
api = get_api(credentials)
|
||||
contents_url = repo_url + f"/contents/{file_path}"
|
||||
content_base64 = base64.b64encode(content.encode()).decode()
|
||||
data = {
|
||||
"message": commit_message,
|
||||
"content": content_base64,
|
||||
"branch": branch,
|
||||
}
|
||||
response = await api.put(contents_url, json=data)
|
||||
data = response.json()
|
||||
return data["content"]["html_url"], data["commit"]["sha"]
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
url, sha = await self.create_file(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.file_path,
|
||||
input_data.content,
|
||||
input_data.branch,
|
||||
input_data.commit_message,
|
||||
)
|
||||
yield "url", url
|
||||
yield "sha", sha
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class GithubUpdateFileBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
file_path: str = SchemaField(
|
||||
description="Path to the file to update",
|
||||
placeholder="path/to/file.txt",
|
||||
)
|
||||
content: str = SchemaField(
|
||||
description="New content for the file",
|
||||
placeholder="Updated content here",
|
||||
)
|
||||
branch: str = SchemaField(
|
||||
description="Branch containing the file",
|
||||
default="main",
|
||||
)
|
||||
commit_message: str = SchemaField(
|
||||
description="Message for the commit",
|
||||
default="Update file",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
url: str = SchemaField(description="URL of the updated file")
|
||||
sha: str = SchemaField(description="SHA of the commit")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="30be12a4-57cb-4aa4-baf5-fcc68d136076",
|
||||
description="This block updates an existing file in a GitHub repository.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubUpdateFileBlock.Input,
|
||||
output_schema=GithubUpdateFileBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"file_path": "test/file.txt",
|
||||
"content": "Updated content",
|
||||
"branch": "main",
|
||||
"commit_message": "Update test file",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("url", "https://github.com/owner/repo/blob/main/test/file.txt"),
|
||||
("sha", "def456"),
|
||||
],
|
||||
test_mock={
|
||||
"update_file": lambda *args, **kwargs: (
|
||||
"https://github.com/owner/repo/blob/main/test/file.txt",
|
||||
"def456",
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def update_file(
|
||||
credentials: GithubCredentials,
|
||||
repo_url: str,
|
||||
file_path: str,
|
||||
content: str,
|
||||
branch: str,
|
||||
commit_message: str,
|
||||
) -> tuple[str, str]:
|
||||
api = get_api(credentials)
|
||||
contents_url = repo_url + f"/contents/{file_path}"
|
||||
params = {"ref": branch}
|
||||
response = await api.get(contents_url, params=params)
|
||||
data = response.json()
|
||||
|
||||
# Convert new content to base64
|
||||
content_base64 = base64.b64encode(content.encode()).decode()
|
||||
data = {
|
||||
"message": commit_message,
|
||||
"content": content_base64,
|
||||
"sha": data["sha"],
|
||||
"branch": branch,
|
||||
}
|
||||
response = await api.put(contents_url, json=data)
|
||||
data = response.json()
|
||||
return data["content"]["html_url"], data["commit"]["sha"]
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
url, sha = await self.update_file(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.file_path,
|
||||
input_data.content,
|
||||
input_data.branch,
|
||||
input_data.commit_message,
|
||||
)
|
||||
yield "url", url
|
||||
yield "sha", sha
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class GithubCreateRepositoryBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
@@ -449,7 +1103,7 @@ class GithubListStargazersBlock(Block):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="e96d01ec-b55e-4a99-8ce8-c8776dce850b", # Generated unique UUID
|
||||
id="a4b9c2d1-e5f6-4g7h-8i9j-0k1l2m3n4o5p", # Generated unique UUID
|
||||
description="This block lists all users who have starred a specified GitHub repository.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubListStargazersBlock.Input,
|
||||
@@ -518,230 +1172,3 @@ class GithubListStargazersBlock(Block):
|
||||
yield "stargazers", stargazers
|
||||
for stargazer in stargazers:
|
||||
yield "stargazer", stargazer
|
||||
|
||||
|
||||
class GithubGetRepositoryInfoBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
name: str = SchemaField(description="Repository name")
|
||||
full_name: str = SchemaField(description="Full repository name (owner/repo)")
|
||||
description: str = SchemaField(description="Repository description")
|
||||
default_branch: str = SchemaField(description="Default branch name (e.g. main)")
|
||||
private: bool = SchemaField(description="Whether the repository is private")
|
||||
html_url: str = SchemaField(description="Web URL of the repository")
|
||||
clone_url: str = SchemaField(description="Git clone URL")
|
||||
stars: int = SchemaField(description="Number of stars")
|
||||
forks: int = SchemaField(description="Number of forks")
|
||||
open_issues: int = SchemaField(description="Number of open issues")
|
||||
error: str = SchemaField(
|
||||
description="Error message if fetching repo info failed"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="59d4f241-968a-4040-95da-348ac5c5ce27",
|
||||
description="This block retrieves metadata about a GitHub repository.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubGetRepositoryInfoBlock.Input,
|
||||
output_schema=GithubGetRepositoryInfoBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("name", "repo"),
|
||||
("full_name", "owner/repo"),
|
||||
("description", "A test repo"),
|
||||
("default_branch", "main"),
|
||||
("private", False),
|
||||
("html_url", "https://github.com/owner/repo"),
|
||||
("clone_url", "https://github.com/owner/repo.git"),
|
||||
("stars", 42),
|
||||
("forks", 5),
|
||||
("open_issues", 3),
|
||||
],
|
||||
test_mock={
|
||||
"get_repo_info": lambda *args, **kwargs: {
|
||||
"name": "repo",
|
||||
"full_name": "owner/repo",
|
||||
"description": "A test repo",
|
||||
"default_branch": "main",
|
||||
"private": False,
|
||||
"html_url": "https://github.com/owner/repo",
|
||||
"clone_url": "https://github.com/owner/repo.git",
|
||||
"stargazers_count": 42,
|
||||
"forks_count": 5,
|
||||
"open_issues_count": 3,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def get_repo_info(credentials: GithubCredentials, repo_url: str) -> dict:
|
||||
api = get_api(credentials)
|
||||
response = await api.get(repo_url)
|
||||
return response.json()
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
data = await self.get_repo_info(credentials, input_data.repo_url)
|
||||
yield "name", data["name"]
|
||||
yield "full_name", data["full_name"]
|
||||
yield "description", data.get("description", "") or ""
|
||||
yield "default_branch", data["default_branch"]
|
||||
yield "private", data["private"]
|
||||
yield "html_url", data["html_url"]
|
||||
yield "clone_url", data["clone_url"]
|
||||
yield "stars", data["stargazers_count"]
|
||||
yield "forks", data["forks_count"]
|
||||
yield "open_issues", data["open_issues_count"]
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class GithubForkRepositoryBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository to fork",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
organization: str = SchemaField(
|
||||
description="Organization to fork into (leave empty to fork to your account)",
|
||||
default="",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
url: str = SchemaField(description="URL of the forked repository")
|
||||
clone_url: str = SchemaField(description="Git clone URL of the fork")
|
||||
full_name: str = SchemaField(description="Full name of the fork (owner/repo)")
|
||||
error: str = SchemaField(description="Error message if the fork failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="a439f2f4-835f-4dae-ba7b-0205ffa70be6",
|
||||
description="This block forks a GitHub repository to your account or an organization.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubForkRepositoryBlock.Input,
|
||||
output_schema=GithubForkRepositoryBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"organization": "",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("url", "https://github.com/myuser/repo"),
|
||||
("clone_url", "https://github.com/myuser/repo.git"),
|
||||
("full_name", "myuser/repo"),
|
||||
],
|
||||
test_mock={
|
||||
"fork_repo": lambda *args, **kwargs: (
|
||||
"https://github.com/myuser/repo",
|
||||
"https://github.com/myuser/repo.git",
|
||||
"myuser/repo",
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def fork_repo(
|
||||
credentials: GithubCredentials,
|
||||
repo_url: str,
|
||||
organization: str,
|
||||
) -> tuple[str, str, str]:
|
||||
api = get_api(credentials)
|
||||
forks_url = repo_url + "/forks"
|
||||
data: dict[str, str] = {}
|
||||
if organization:
|
||||
data["organization"] = organization
|
||||
response = await api.post(forks_url, json=data)
|
||||
result = response.json()
|
||||
return result["html_url"], result["clone_url"], result["full_name"]
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
url, clone_url, full_name = await self.fork_repo(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.organization,
|
||||
)
|
||||
yield "url", url
|
||||
yield "clone_url", clone_url
|
||||
yield "full_name", full_name
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class GithubStarRepositoryBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository to star",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
status: str = SchemaField(description="Status of the star operation")
|
||||
error: str = SchemaField(description="Error message if starring failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="bd700764-53e3-44dd-a969-d1854088458f",
|
||||
description="This block stars a GitHub repository.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubStarRepositoryBlock.Input,
|
||||
output_schema=GithubStarRepositoryBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[("status", "Repository starred successfully")],
|
||||
test_mock={
|
||||
"star_repo": lambda *args, **kwargs: "Repository starred successfully"
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def star_repo(credentials: GithubCredentials, repo_url: str) -> str:
|
||||
api = get_api(credentials, convert_urls=False)
|
||||
repo_path = github_repo_path(repo_url)
|
||||
owner, repo = repo_path.split("/")
|
||||
await api.put(
|
||||
f"https://api.github.com/user/starred/{owner}/{repo}",
|
||||
headers={"Content-Length": "0"},
|
||||
)
|
||||
return "Repository starred successfully"
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
status = await self.star_repo(credentials, input_data.repo_url)
|
||||
yield "status", status
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
@@ -1,452 +0,0 @@
|
||||
from urllib.parse import quote
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from backend.blocks._base import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
from ._api import get_api
|
||||
from ._auth import (
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
GithubCredentials,
|
||||
GithubCredentialsField,
|
||||
GithubCredentialsInput,
|
||||
)
|
||||
from ._utils import github_repo_path
|
||||
|
||||
|
||||
class GithubListBranchesBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
per_page: int = SchemaField(
|
||||
description="Number of branches to return per page (max 100)",
|
||||
default=30,
|
||||
ge=1,
|
||||
le=100,
|
||||
)
|
||||
page: int = SchemaField(
|
||||
description="Page number for pagination",
|
||||
default=1,
|
||||
ge=1,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class BranchItem(TypedDict):
|
||||
name: str
|
||||
url: str
|
||||
|
||||
branch: BranchItem = SchemaField(
|
||||
title="Branch",
|
||||
description="Branches with their name and file tree browser URL",
|
||||
)
|
||||
branches: list[BranchItem] = SchemaField(
|
||||
description="List of branches with their name and file tree browser URL"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if listing branches failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="74243e49-2bec-4916-8bf4-db43d44aead5",
|
||||
description="This block lists all branches for a specified GitHub repository.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubListBranchesBlock.Input,
|
||||
output_schema=GithubListBranchesBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"per_page": 30,
|
||||
"page": 1,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
(
|
||||
"branches",
|
||||
[
|
||||
{
|
||||
"name": "main",
|
||||
"url": "https://github.com/owner/repo/tree/main",
|
||||
}
|
||||
],
|
||||
),
|
||||
(
|
||||
"branch",
|
||||
{
|
||||
"name": "main",
|
||||
"url": "https://github.com/owner/repo/tree/main",
|
||||
},
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"list_branches": lambda *args, **kwargs: [
|
||||
{
|
||||
"name": "main",
|
||||
"url": "https://github.com/owner/repo/tree/main",
|
||||
}
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def list_branches(
|
||||
credentials: GithubCredentials, repo_url: str, per_page: int, page: int
|
||||
) -> list[Output.BranchItem]:
|
||||
api = get_api(credentials)
|
||||
branches_url = repo_url + "/branches"
|
||||
response = await api.get(
|
||||
branches_url, params={"per_page": str(per_page), "page": str(page)}
|
||||
)
|
||||
data = response.json()
|
||||
repo_path = github_repo_path(repo_url)
|
||||
branches: list[GithubListBranchesBlock.Output.BranchItem] = [
|
||||
{
|
||||
"name": branch["name"],
|
||||
"url": f"https://github.com/{repo_path}/tree/{branch['name']}",
|
||||
}
|
||||
for branch in data
|
||||
]
|
||||
return branches
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
branches = await self.list_branches(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.per_page,
|
||||
input_data.page,
|
||||
)
|
||||
yield "branches", branches
|
||||
for branch in branches:
|
||||
yield "branch", branch
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class GithubMakeBranchBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
new_branch: str = SchemaField(
|
||||
description="Name of the new branch",
|
||||
placeholder="new_branch_name",
|
||||
)
|
||||
source_branch: str = SchemaField(
|
||||
description="Name of the source branch",
|
||||
placeholder="source_branch_name",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
status: str = SchemaField(description="Status of the branch creation operation")
|
||||
error: str = SchemaField(
|
||||
description="Error message if the branch creation failed"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="944cc076-95e7-4d1b-b6b6-b15d8ee5448d",
|
||||
description="This block creates a new branch from a specified source branch.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubMakeBranchBlock.Input,
|
||||
output_schema=GithubMakeBranchBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"new_branch": "new_branch_name",
|
||||
"source_branch": "source_branch_name",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[("status", "Branch created successfully")],
|
||||
test_mock={
|
||||
"create_branch": lambda *args, **kwargs: "Branch created successfully"
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def create_branch(
|
||||
credentials: GithubCredentials,
|
||||
repo_url: str,
|
||||
new_branch: str,
|
||||
source_branch: str,
|
||||
) -> str:
|
||||
api = get_api(credentials)
|
||||
ref_url = repo_url + f"/git/refs/heads/{quote(source_branch, safe='')}"
|
||||
response = await api.get(ref_url)
|
||||
data = response.json()
|
||||
sha = data["object"]["sha"]
|
||||
|
||||
# Create the new branch
|
||||
new_ref_url = repo_url + "/git/refs"
|
||||
data = {
|
||||
"ref": f"refs/heads/{new_branch}",
|
||||
"sha": sha,
|
||||
}
|
||||
response = await api.post(new_ref_url, json=data)
|
||||
return "Branch created successfully"
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
status = await self.create_branch(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.new_branch,
|
||||
input_data.source_branch,
|
||||
)
|
||||
yield "status", status
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class GithubDeleteBranchBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
branch: str = SchemaField(
|
||||
description="Name of the branch to delete",
|
||||
placeholder="branch_name",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
status: str = SchemaField(description="Status of the branch deletion operation")
|
||||
error: str = SchemaField(
|
||||
description="Error message if the branch deletion failed"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="0d4130f7-e0ab-4d55-adc3-0a40225e80f4",
|
||||
description="This block deletes a specified branch.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubDeleteBranchBlock.Input,
|
||||
output_schema=GithubDeleteBranchBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"branch": "branch_name",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[("status", "Branch deleted successfully")],
|
||||
test_mock={
|
||||
"delete_branch": lambda *args, **kwargs: "Branch deleted successfully"
|
||||
},
|
||||
is_sensitive_action=True,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def delete_branch(
|
||||
credentials: GithubCredentials, repo_url: str, branch: str
|
||||
) -> str:
|
||||
api = get_api(credentials)
|
||||
ref_url = repo_url + f"/git/refs/heads/{quote(branch, safe='')}"
|
||||
await api.delete(ref_url)
|
||||
return "Branch deleted successfully"
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
status = await self.delete_branch(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.branch,
|
||||
)
|
||||
yield "status", status
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class GithubCompareBranchesBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
base: str = SchemaField(
|
||||
description="Base branch or commit SHA",
|
||||
placeholder="main",
|
||||
)
|
||||
head: str = SchemaField(
|
||||
description="Head branch or commit SHA to compare against base",
|
||||
placeholder="feature-branch",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class FileChange(TypedDict):
|
||||
filename: str
|
||||
status: str
|
||||
additions: int
|
||||
deletions: int
|
||||
patch: str
|
||||
|
||||
status: str = SchemaField(
|
||||
description="Comparison status: ahead, behind, diverged, or identical"
|
||||
)
|
||||
ahead_by: int = SchemaField(
|
||||
description="Number of commits head is ahead of base"
|
||||
)
|
||||
behind_by: int = SchemaField(
|
||||
description="Number of commits head is behind base"
|
||||
)
|
||||
total_commits: int = SchemaField(
|
||||
description="Total number of commits in the comparison"
|
||||
)
|
||||
diff: str = SchemaField(description="Unified diff of all file changes")
|
||||
file: FileChange = SchemaField(
|
||||
title="Changed File", description="A changed file with its diff"
|
||||
)
|
||||
files: list[FileChange] = SchemaField(
|
||||
description="List of changed files with their diffs"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if comparison failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="2e4faa8c-6086-4546-ba77-172d1d560186",
|
||||
description="This block compares two branches or commits in a GitHub repository.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubCompareBranchesBlock.Input,
|
||||
output_schema=GithubCompareBranchesBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"base": "main",
|
||||
"head": "feature",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("status", "ahead"),
|
||||
("ahead_by", 2),
|
||||
("behind_by", 0),
|
||||
("total_commits", 2),
|
||||
("diff", "+++ b/file.py\n+new line"),
|
||||
(
|
||||
"files",
|
||||
[
|
||||
{
|
||||
"filename": "file.py",
|
||||
"status": "modified",
|
||||
"additions": 1,
|
||||
"deletions": 0,
|
||||
"patch": "+new line",
|
||||
}
|
||||
],
|
||||
),
|
||||
(
|
||||
"file",
|
||||
{
|
||||
"filename": "file.py",
|
||||
"status": "modified",
|
||||
"additions": 1,
|
||||
"deletions": 0,
|
||||
"patch": "+new line",
|
||||
},
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"compare_branches": lambda *args, **kwargs: {
|
||||
"status": "ahead",
|
||||
"ahead_by": 2,
|
||||
"behind_by": 0,
|
||||
"total_commits": 2,
|
||||
"files": [
|
||||
{
|
||||
"filename": "file.py",
|
||||
"status": "modified",
|
||||
"additions": 1,
|
||||
"deletions": 0,
|
||||
"patch": "+new line",
|
||||
}
|
||||
],
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def compare_branches(
|
||||
credentials: GithubCredentials,
|
||||
repo_url: str,
|
||||
base: str,
|
||||
head: str,
|
||||
) -> dict:
|
||||
api = get_api(credentials)
|
||||
safe_base = quote(base, safe="")
|
||||
safe_head = quote(head, safe="")
|
||||
compare_url = repo_url + f"/compare/{safe_base}...{safe_head}"
|
||||
response = await api.get(compare_url)
|
||||
return response.json()
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
data = await self.compare_branches(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.base,
|
||||
input_data.head,
|
||||
)
|
||||
yield "status", data["status"]
|
||||
yield "ahead_by", data["ahead_by"]
|
||||
yield "behind_by", data["behind_by"]
|
||||
yield "total_commits", data["total_commits"]
|
||||
|
||||
files: list[GithubCompareBranchesBlock.Output.FileChange] = [
|
||||
GithubCompareBranchesBlock.Output.FileChange(
|
||||
filename=f["filename"],
|
||||
status=f["status"],
|
||||
additions=f["additions"],
|
||||
deletions=f["deletions"],
|
||||
patch=f.get("patch", ""),
|
||||
)
|
||||
for f in data.get("files", [])
|
||||
]
|
||||
|
||||
# Build unified diff
|
||||
diff_parts = []
|
||||
for f in data.get("files", []):
|
||||
patch = f.get("patch", "")
|
||||
if patch:
|
||||
diff_parts.append(f"+++ b/{f['filename']}\n{patch}")
|
||||
yield "diff", "\n".join(diff_parts)
|
||||
|
||||
yield "files", files
|
||||
for file in files:
|
||||
yield "file", file
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
@@ -1,720 +0,0 @@
|
||||
import base64
|
||||
from urllib.parse import quote
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from backend.blocks._base import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
from ._api import get_api
|
||||
from ._auth import (
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
GithubCredentials,
|
||||
GithubCredentialsField,
|
||||
GithubCredentialsInput,
|
||||
)
|
||||
|
||||
|
||||
class GithubReadFileBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
file_path: str = SchemaField(
|
||||
description="Path to the file in the repository",
|
||||
placeholder="path/to/file",
|
||||
)
|
||||
branch: str = SchemaField(
|
||||
description="Branch to read from",
|
||||
placeholder="branch_name",
|
||||
default="main",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
text_content: str = SchemaField(
|
||||
description="Content of the file (decoded as UTF-8 text)"
|
||||
)
|
||||
raw_content: str = SchemaField(
|
||||
description="Raw base64-encoded content of the file"
|
||||
)
|
||||
size: int = SchemaField(description="The size of the file (in bytes)")
|
||||
error: str = SchemaField(description="Error message if reading the file failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="87ce6c27-5752-4bbc-8e26-6da40a3dcfd3",
|
||||
description="This block reads the content of a specified file from a GitHub repository.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubReadFileBlock.Input,
|
||||
output_schema=GithubReadFileBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"file_path": "path/to/file",
|
||||
"branch": "main",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("raw_content", "RmlsZSBjb250ZW50"),
|
||||
("text_content", "File content"),
|
||||
("size", 13),
|
||||
],
|
||||
test_mock={"read_file": lambda *args, **kwargs: ("RmlsZSBjb250ZW50", 13)},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def read_file(
|
||||
credentials: GithubCredentials, repo_url: str, file_path: str, branch: str
|
||||
) -> tuple[str, int]:
|
||||
api = get_api(credentials)
|
||||
content_url = (
|
||||
repo_url
|
||||
+ f"/contents/{quote(file_path, safe='')}?ref={quote(branch, safe='')}"
|
||||
)
|
||||
response = await api.get(content_url)
|
||||
data = response.json()
|
||||
|
||||
if isinstance(data, list):
|
||||
# Multiple entries of different types exist at this path
|
||||
if not (file := next((f for f in data if f["type"] == "file"), None)):
|
||||
raise TypeError("Not a file")
|
||||
data = file
|
||||
|
||||
if data["type"] != "file":
|
||||
raise TypeError("Not a file")
|
||||
|
||||
return data["content"], data["size"]
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
content, size = await self.read_file(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.file_path,
|
||||
input_data.branch,
|
||||
)
|
||||
yield "raw_content", content
|
||||
yield "text_content", base64.b64decode(content).decode("utf-8")
|
||||
yield "size", size
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class GithubReadFolderBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
folder_path: str = SchemaField(
|
||||
description="Path to the folder in the repository",
|
||||
placeholder="path/to/folder",
|
||||
)
|
||||
branch: str = SchemaField(
|
||||
description="Branch name to read from (defaults to main)",
|
||||
placeholder="branch_name",
|
||||
default="main",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class DirEntry(TypedDict):
|
||||
name: str
|
||||
path: str
|
||||
|
||||
class FileEntry(TypedDict):
|
||||
name: str
|
||||
path: str
|
||||
size: int
|
||||
|
||||
file: FileEntry = SchemaField(description="Files in the folder")
|
||||
dir: DirEntry = SchemaField(description="Directories in the folder")
|
||||
error: str = SchemaField(
|
||||
description="Error message if reading the folder failed"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="1355f863-2db3-4d75-9fba-f91e8a8ca400",
|
||||
description="This block reads the content of a specified folder from a GitHub repository.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubReadFolderBlock.Input,
|
||||
output_schema=GithubReadFolderBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"folder_path": "path/to/folder",
|
||||
"branch": "main",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
(
|
||||
"file",
|
||||
{
|
||||
"name": "file1.txt",
|
||||
"path": "path/to/folder/file1.txt",
|
||||
"size": 1337,
|
||||
},
|
||||
),
|
||||
("dir", {"name": "dir2", "path": "path/to/folder/dir2"}),
|
||||
],
|
||||
test_mock={
|
||||
"read_folder": lambda *args, **kwargs: (
|
||||
[
|
||||
{
|
||||
"name": "file1.txt",
|
||||
"path": "path/to/folder/file1.txt",
|
||||
"size": 1337,
|
||||
}
|
||||
],
|
||||
[{"name": "dir2", "path": "path/to/folder/dir2"}],
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def read_folder(
|
||||
credentials: GithubCredentials, repo_url: str, folder_path: str, branch: str
|
||||
) -> tuple[list[Output.FileEntry], list[Output.DirEntry]]:
|
||||
api = get_api(credentials)
|
||||
contents_url = (
|
||||
repo_url
|
||||
+ f"/contents/{quote(folder_path, safe='/')}?ref={quote(branch, safe='')}"
|
||||
)
|
||||
response = await api.get(contents_url)
|
||||
data = response.json()
|
||||
|
||||
if not isinstance(data, list):
|
||||
raise TypeError("Not a folder")
|
||||
|
||||
files: list[GithubReadFolderBlock.Output.FileEntry] = [
|
||||
GithubReadFolderBlock.Output.FileEntry(
|
||||
name=entry["name"],
|
||||
path=entry["path"],
|
||||
size=entry["size"],
|
||||
)
|
||||
for entry in data
|
||||
if entry["type"] == "file"
|
||||
]
|
||||
|
||||
dirs: list[GithubReadFolderBlock.Output.DirEntry] = [
|
||||
GithubReadFolderBlock.Output.DirEntry(
|
||||
name=entry["name"],
|
||||
path=entry["path"],
|
||||
)
|
||||
for entry in data
|
||||
if entry["type"] == "dir"
|
||||
]
|
||||
|
||||
return files, dirs
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
files, dirs = await self.read_folder(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.folder_path.lstrip("/"),
|
||||
input_data.branch,
|
||||
)
|
||||
for file in files:
|
||||
yield "file", file
|
||||
for dir in dirs:
|
||||
yield "dir", dir
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class GithubCreateFileBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
file_path: str = SchemaField(
|
||||
description="Path where the file should be created",
|
||||
placeholder="path/to/file.txt",
|
||||
)
|
||||
content: str = SchemaField(
|
||||
description="Content to write to the file",
|
||||
placeholder="File content here",
|
||||
)
|
||||
branch: str = SchemaField(
|
||||
description="Branch where the file should be created",
|
||||
default="main",
|
||||
)
|
||||
commit_message: str = SchemaField(
|
||||
description="Message for the commit",
|
||||
default="Create new file",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
url: str = SchemaField(description="URL of the created file")
|
||||
sha: str = SchemaField(description="SHA of the commit")
|
||||
error: str = SchemaField(
|
||||
description="Error message if the file creation failed"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="8fd132ac-b917-428a-8159-d62893e8a3fe",
|
||||
description="This block creates a new file in a GitHub repository.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubCreateFileBlock.Input,
|
||||
output_schema=GithubCreateFileBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"file_path": "test/file.txt",
|
||||
"content": "Test content",
|
||||
"branch": "main",
|
||||
"commit_message": "Create test file",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("url", "https://github.com/owner/repo/blob/main/test/file.txt"),
|
||||
("sha", "abc123"),
|
||||
],
|
||||
test_mock={
|
||||
"create_file": lambda *args, **kwargs: (
|
||||
"https://github.com/owner/repo/blob/main/test/file.txt",
|
||||
"abc123",
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def create_file(
|
||||
credentials: GithubCredentials,
|
||||
repo_url: str,
|
||||
file_path: str,
|
||||
content: str,
|
||||
branch: str,
|
||||
commit_message: str,
|
||||
) -> tuple[str, str]:
|
||||
api = get_api(credentials)
|
||||
contents_url = repo_url + f"/contents/{quote(file_path, safe='/')}"
|
||||
content_base64 = base64.b64encode(content.encode()).decode()
|
||||
data = {
|
||||
"message": commit_message,
|
||||
"content": content_base64,
|
||||
"branch": branch,
|
||||
}
|
||||
response = await api.put(contents_url, json=data)
|
||||
data = response.json()
|
||||
return data["content"]["html_url"], data["commit"]["sha"]
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
url, sha = await self.create_file(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.file_path,
|
||||
input_data.content,
|
||||
input_data.branch,
|
||||
input_data.commit_message,
|
||||
)
|
||||
yield "url", url
|
||||
yield "sha", sha
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class GithubUpdateFileBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
file_path: str = SchemaField(
|
||||
description="Path to the file to update",
|
||||
placeholder="path/to/file.txt",
|
||||
)
|
||||
content: str = SchemaField(
|
||||
description="New content for the file",
|
||||
placeholder="Updated content here",
|
||||
)
|
||||
branch: str = SchemaField(
|
||||
description="Branch containing the file",
|
||||
default="main",
|
||||
)
|
||||
commit_message: str = SchemaField(
|
||||
description="Message for the commit",
|
||||
default="Update file",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
url: str = SchemaField(description="URL of the updated file")
|
||||
sha: str = SchemaField(description="SHA of the commit")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="30be12a4-57cb-4aa4-baf5-fcc68d136076",
|
||||
description="This block updates an existing file in a GitHub repository.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubUpdateFileBlock.Input,
|
||||
output_schema=GithubUpdateFileBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"file_path": "test/file.txt",
|
||||
"content": "Updated content",
|
||||
"branch": "main",
|
||||
"commit_message": "Update test file",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("url", "https://github.com/owner/repo/blob/main/test/file.txt"),
|
||||
("sha", "def456"),
|
||||
],
|
||||
test_mock={
|
||||
"update_file": lambda *args, **kwargs: (
|
||||
"https://github.com/owner/repo/blob/main/test/file.txt",
|
||||
"def456",
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def update_file(
|
||||
credentials: GithubCredentials,
|
||||
repo_url: str,
|
||||
file_path: str,
|
||||
content: str,
|
||||
branch: str,
|
||||
commit_message: str,
|
||||
) -> tuple[str, str]:
|
||||
api = get_api(credentials)
|
||||
contents_url = repo_url + f"/contents/{quote(file_path, safe='/')}"
|
||||
params = {"ref": branch}
|
||||
response = await api.get(contents_url, params=params)
|
||||
data = response.json()
|
||||
|
||||
# Convert new content to base64
|
||||
content_base64 = base64.b64encode(content.encode()).decode()
|
||||
data = {
|
||||
"message": commit_message,
|
||||
"content": content_base64,
|
||||
"sha": data["sha"],
|
||||
"branch": branch,
|
||||
}
|
||||
response = await api.put(contents_url, json=data)
|
||||
data = response.json()
|
||||
return data["content"]["html_url"], data["commit"]["sha"]
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
url, sha = await self.update_file(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.file_path,
|
||||
input_data.content,
|
||||
input_data.branch,
|
||||
input_data.commit_message,
|
||||
)
|
||||
yield "url", url
|
||||
yield "sha", sha
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class GithubSearchCodeBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
query: str = SchemaField(
|
||||
description="Search query (GitHub code search syntax)",
|
||||
placeholder="className language:python",
|
||||
)
|
||||
repo: str = SchemaField(
|
||||
description="Restrict search to a repository (owner/repo format, optional)",
|
||||
default="",
|
||||
placeholder="owner/repo",
|
||||
)
|
||||
per_page: int = SchemaField(
|
||||
description="Number of results to return (max 100)",
|
||||
default=30,
|
||||
ge=1,
|
||||
le=100,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class SearchResult(TypedDict):
|
||||
name: str
|
||||
path: str
|
||||
repository: str
|
||||
url: str
|
||||
score: float
|
||||
|
||||
result: SearchResult = SchemaField(
|
||||
title="Result", description="A code search result"
|
||||
)
|
||||
results: list[SearchResult] = SchemaField(
|
||||
description="List of code search results"
|
||||
)
|
||||
total_count: int = SchemaField(description="Total number of matching results")
|
||||
error: str = SchemaField(description="Error message if search failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="47f94891-a2b1-4f1c-b5f2-573c043f721e",
|
||||
description="This block searches for code in GitHub repositories.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubSearchCodeBlock.Input,
|
||||
output_schema=GithubSearchCodeBlock.Output,
|
||||
test_input={
|
||||
"query": "addClass",
|
||||
"repo": "owner/repo",
|
||||
"per_page": 30,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("total_count", 1),
|
||||
(
|
||||
"results",
|
||||
[
|
||||
{
|
||||
"name": "file.py",
|
||||
"path": "src/file.py",
|
||||
"repository": "owner/repo",
|
||||
"url": "https://github.com/owner/repo/blob/main/src/file.py",
|
||||
"score": 1.0,
|
||||
}
|
||||
],
|
||||
),
|
||||
(
|
||||
"result",
|
||||
{
|
||||
"name": "file.py",
|
||||
"path": "src/file.py",
|
||||
"repository": "owner/repo",
|
||||
"url": "https://github.com/owner/repo/blob/main/src/file.py",
|
||||
"score": 1.0,
|
||||
},
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"search_code": lambda *args, **kwargs: (
|
||||
1,
|
||||
[
|
||||
{
|
||||
"name": "file.py",
|
||||
"path": "src/file.py",
|
||||
"repository": "owner/repo",
|
||||
"url": "https://github.com/owner/repo/blob/main/src/file.py",
|
||||
"score": 1.0,
|
||||
}
|
||||
],
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def search_code(
|
||||
credentials: GithubCredentials,
|
||||
query: str,
|
||||
repo: str,
|
||||
per_page: int,
|
||||
) -> tuple[int, list[Output.SearchResult]]:
|
||||
api = get_api(credentials, convert_urls=False)
|
||||
full_query = f"{query} repo:{repo}" if repo else query
|
||||
params = {"q": full_query, "per_page": str(per_page)}
|
||||
response = await api.get("https://api.github.com/search/code", params=params)
|
||||
data = response.json()
|
||||
results: list[GithubSearchCodeBlock.Output.SearchResult] = [
|
||||
GithubSearchCodeBlock.Output.SearchResult(
|
||||
name=item["name"],
|
||||
path=item["path"],
|
||||
repository=item["repository"]["full_name"],
|
||||
url=item["html_url"],
|
||||
score=item["score"],
|
||||
)
|
||||
for item in data["items"]
|
||||
]
|
||||
return data["total_count"], results
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
total_count, results = await self.search_code(
|
||||
credentials,
|
||||
input_data.query,
|
||||
input_data.repo,
|
||||
input_data.per_page,
|
||||
)
|
||||
yield "total_count", total_count
|
||||
yield "results", results
|
||||
for result in results:
|
||||
yield "result", result
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class GithubGetRepositoryTreeBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
branch: str = SchemaField(
|
||||
description="Branch name to get the tree from",
|
||||
default="main",
|
||||
)
|
||||
recursive: bool = SchemaField(
|
||||
description="Whether to recursively list the entire tree",
|
||||
default=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class TreeEntry(TypedDict):
|
||||
path: str
|
||||
type: str
|
||||
size: int
|
||||
sha: str
|
||||
|
||||
entry: TreeEntry = SchemaField(
|
||||
title="Tree Entry", description="A file or directory in the tree"
|
||||
)
|
||||
entries: list[TreeEntry] = SchemaField(
|
||||
description="List of all files and directories in the tree"
|
||||
)
|
||||
truncated: bool = SchemaField(
|
||||
description="Whether the tree was truncated due to size"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if getting tree failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="89c5c0ec-172e-4001-a32c-bdfe4d0c9e81",
|
||||
description="This block lists the entire file tree of a GitHub repository recursively.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubGetRepositoryTreeBlock.Input,
|
||||
output_schema=GithubGetRepositoryTreeBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"branch": "main",
|
||||
"recursive": True,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("truncated", False),
|
||||
(
|
||||
"entries",
|
||||
[
|
||||
{
|
||||
"path": "src/main.py",
|
||||
"type": "blob",
|
||||
"size": 1234,
|
||||
"sha": "abc123",
|
||||
}
|
||||
],
|
||||
),
|
||||
(
|
||||
"entry",
|
||||
{
|
||||
"path": "src/main.py",
|
||||
"type": "blob",
|
||||
"size": 1234,
|
||||
"sha": "abc123",
|
||||
},
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"get_tree": lambda *args, **kwargs: (
|
||||
False,
|
||||
[
|
||||
{
|
||||
"path": "src/main.py",
|
||||
"type": "blob",
|
||||
"size": 1234,
|
||||
"sha": "abc123",
|
||||
}
|
||||
],
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def get_tree(
|
||||
credentials: GithubCredentials,
|
||||
repo_url: str,
|
||||
branch: str,
|
||||
recursive: bool,
|
||||
) -> tuple[bool, list[Output.TreeEntry]]:
|
||||
api = get_api(credentials)
|
||||
tree_url = repo_url + f"/git/trees/{quote(branch, safe='')}"
|
||||
params = {"recursive": "1"} if recursive else {}
|
||||
response = await api.get(tree_url, params=params)
|
||||
data = response.json()
|
||||
entries: list[GithubGetRepositoryTreeBlock.Output.TreeEntry] = [
|
||||
GithubGetRepositoryTreeBlock.Output.TreeEntry(
|
||||
path=item["path"],
|
||||
type=item["type"],
|
||||
size=item.get("size", 0),
|
||||
sha=item["sha"],
|
||||
)
|
||||
for item in data["tree"]
|
||||
]
|
||||
return data.get("truncated", False), entries
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
truncated, entries = await self.get_tree(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.branch,
|
||||
input_data.recursive,
|
||||
)
|
||||
yield "truncated", truncated
|
||||
yield "entries", entries
|
||||
for entry in entries:
|
||||
yield "entry", entry
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
@@ -1,120 +0,0 @@
|
||||
import inspect
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.blocks.github._auth import TEST_CREDENTIALS, TEST_CREDENTIALS_INPUT
|
||||
from backend.blocks.github.commits import FileOperation, GithubMultiFileCommitBlock
|
||||
from backend.blocks.github.pull_requests import (
|
||||
GithubMergePullRequestBlock,
|
||||
prepare_pr_api_url,
|
||||
)
|
||||
from backend.util.exceptions import BlockExecutionError
|
||||
|
||||
# ── prepare_pr_api_url tests ──
|
||||
|
||||
|
||||
class TestPreparePrApiUrl:
|
||||
def test_https_scheme_preserved(self):
|
||||
result = prepare_pr_api_url("https://github.com/owner/repo/pull/42", "merge")
|
||||
assert result == "https://github.com/owner/repo/pulls/42/merge"
|
||||
|
||||
def test_http_scheme_preserved(self):
|
||||
result = prepare_pr_api_url("http://github.com/owner/repo/pull/1", "files")
|
||||
assert result == "http://github.com/owner/repo/pulls/1/files"
|
||||
|
||||
def test_no_scheme_defaults_to_https(self):
|
||||
result = prepare_pr_api_url("github.com/owner/repo/pull/5", "merge")
|
||||
assert result == "https://github.com/owner/repo/pulls/5/merge"
|
||||
|
||||
def test_reviewers_path(self):
|
||||
result = prepare_pr_api_url(
|
||||
"https://github.com/owner/repo/pull/99", "requested_reviewers"
|
||||
)
|
||||
assert result == "https://github.com/owner/repo/pulls/99/requested_reviewers"
|
||||
|
||||
def test_invalid_url_returned_as_is(self):
|
||||
url = "https://example.com/not-a-pr"
|
||||
assert prepare_pr_api_url(url, "merge") == url
|
||||
|
||||
def test_empty_string(self):
|
||||
assert prepare_pr_api_url("", "merge") == ""
|
||||
|
||||
|
||||
# ── Error-path block tests ──
|
||||
# When a block's run() yields ("error", msg), _execute() converts it to a
|
||||
# BlockExecutionError. We call block.execute() directly (not execute_block_test,
|
||||
# which returns early on empty test_output).
|
||||
|
||||
|
||||
def _mock_block(block, mocks: dict):
|
||||
"""Apply mocks to a block's static methods, wrapping sync mocks as async."""
|
||||
for name, mock_fn in mocks.items():
|
||||
original = getattr(block, name)
|
||||
if inspect.iscoroutinefunction(original):
|
||||
|
||||
async def async_mock(*args, _fn=mock_fn, **kwargs):
|
||||
return _fn(*args, **kwargs)
|
||||
|
||||
setattr(block, name, async_mock)
|
||||
else:
|
||||
setattr(block, name, mock_fn)
|
||||
|
||||
|
||||
def _raise(exc: Exception):
|
||||
"""Helper that returns a callable which raises the given exception."""
|
||||
|
||||
def _raiser(*args, **kwargs):
|
||||
raise exc
|
||||
|
||||
return _raiser
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_merge_pr_error_path():
|
||||
block = GithubMergePullRequestBlock()
|
||||
_mock_block(block, {"merge_pr": _raise(RuntimeError("PR not mergeable"))})
|
||||
input_data = {
|
||||
"pr_url": "https://github.com/owner/repo/pull/1",
|
||||
"merge_method": "squash",
|
||||
"commit_title": "",
|
||||
"commit_message": "",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
}
|
||||
with pytest.raises(BlockExecutionError, match="PR not mergeable"):
|
||||
async for _ in block.execute(input_data, credentials=TEST_CREDENTIALS):
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multi_file_commit_error_path():
|
||||
block = GithubMultiFileCommitBlock()
|
||||
_mock_block(block, {"multi_file_commit": _raise(RuntimeError("ref update failed"))})
|
||||
input_data = {
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"branch": "feature",
|
||||
"commit_message": "test",
|
||||
"files": [{"path": "a.py", "content": "x", "operation": "upsert"}],
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
}
|
||||
with pytest.raises(BlockExecutionError, match="ref update failed"):
|
||||
async for _ in block.execute(input_data, credentials=TEST_CREDENTIALS):
|
||||
pass
|
||||
|
||||
|
||||
# ── FileOperation enum tests ──
|
||||
|
||||
|
||||
class TestFileOperation:
|
||||
def test_upsert_value(self):
|
||||
assert FileOperation.UPSERT == "upsert"
|
||||
|
||||
def test_delete_value(self):
|
||||
assert FileOperation.DELETE == "delete"
|
||||
|
||||
def test_invalid_value_raises(self):
|
||||
with pytest.raises(ValueError):
|
||||
FileOperation("create")
|
||||
|
||||
def test_invalid_value_raises_typo(self):
|
||||
with pytest.raises(ValueError):
|
||||
FileOperation("upser")
|
||||
@@ -241,8 +241,8 @@ class GmailBase(Block, ABC):
|
||||
h.ignore_links = False
|
||||
h.ignore_images = True
|
||||
return h.handle(html_content)
|
||||
except Exception:
|
||||
# Keep extraction resilient if html2text is unavailable or fails.
|
||||
except ImportError:
|
||||
# Fallback: return raw HTML if html2text is not available
|
||||
return html_content
|
||||
|
||||
# Handle content stored as attachment
|
||||
|
||||
@@ -67,7 +67,6 @@ class HITLReviewHelper:
|
||||
graph_version: int,
|
||||
block_name: str = "Block",
|
||||
editable: bool = False,
|
||||
is_graph_execution: bool = True,
|
||||
) -> Optional[ReviewResult]:
|
||||
"""
|
||||
Handle a review request for a block that requires human review.
|
||||
@@ -144,11 +143,10 @@ class HITLReviewHelper:
|
||||
logger.info(
|
||||
f"Block {block_name} pausing execution for node {node_exec_id} - awaiting human review"
|
||||
)
|
||||
if is_graph_execution:
|
||||
await HITLReviewHelper.update_node_execution_status(
|
||||
exec_id=node_exec_id,
|
||||
status=ExecutionStatus.REVIEW,
|
||||
)
|
||||
await HITLReviewHelper.update_node_execution_status(
|
||||
exec_id=node_exec_id,
|
||||
status=ExecutionStatus.REVIEW,
|
||||
)
|
||||
return None # Signal that execution should pause
|
||||
|
||||
# Mark review as processed if not already done
|
||||
@@ -170,7 +168,6 @@ class HITLReviewHelper:
|
||||
graph_version: int,
|
||||
block_name: str = "Block",
|
||||
editable: bool = False,
|
||||
is_graph_execution: bool = True,
|
||||
) -> Optional[ReviewDecision]:
|
||||
"""
|
||||
Handle a review request and return the decision in a single call.
|
||||
@@ -200,7 +197,6 @@ class HITLReviewHelper:
|
||||
graph_version=graph_version,
|
||||
block_name=block_name,
|
||||
editable=editable,
|
||||
is_graph_execution=is_graph_execution,
|
||||
)
|
||||
|
||||
if review_result is None:
|
||||
|
||||
@@ -140,31 +140,19 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
|
||||
# OpenRouter models
|
||||
OPENAI_GPT_OSS_120B = "openai/gpt-oss-120b"
|
||||
OPENAI_GPT_OSS_20B = "openai/gpt-oss-20b"
|
||||
GEMINI_2_5_PRO_PREVIEW = "google/gemini-2.5-pro-preview-03-25"
|
||||
GEMINI_2_5_PRO = "google/gemini-2.5-pro"
|
||||
GEMINI_3_1_PRO_PREVIEW = "google/gemini-3.1-pro-preview"
|
||||
GEMINI_3_FLASH_PREVIEW = "google/gemini-3-flash-preview"
|
||||
GEMINI_2_5_PRO = "google/gemini-2.5-pro-preview-03-25"
|
||||
GEMINI_3_PRO_PREVIEW = "google/gemini-3-pro-preview"
|
||||
GEMINI_2_5_FLASH = "google/gemini-2.5-flash"
|
||||
GEMINI_2_0_FLASH = "google/gemini-2.0-flash-001"
|
||||
GEMINI_3_1_FLASH_LITE_PREVIEW = "google/gemini-3.1-flash-lite-preview"
|
||||
GEMINI_2_5_FLASH_LITE_PREVIEW = "google/gemini-2.5-flash-lite-preview-06-17"
|
||||
GEMINI_2_0_FLASH_LITE = "google/gemini-2.0-flash-lite-001"
|
||||
MISTRAL_NEMO = "mistralai/mistral-nemo"
|
||||
MISTRAL_LARGE_3 = "mistralai/mistral-large-2512"
|
||||
MISTRAL_MEDIUM_3_1 = "mistralai/mistral-medium-3.1"
|
||||
MISTRAL_SMALL_3_2 = "mistralai/mistral-small-3.2-24b-instruct"
|
||||
CODESTRAL = "mistralai/codestral-2508"
|
||||
COHERE_COMMAND_R_08_2024 = "cohere/command-r-08-2024"
|
||||
COHERE_COMMAND_R_PLUS_08_2024 = "cohere/command-r-plus-08-2024"
|
||||
COHERE_COMMAND_A_03_2025 = "cohere/command-a-03-2025"
|
||||
COHERE_COMMAND_A_TRANSLATE_08_2025 = "cohere/command-a-translate-08-2025"
|
||||
COHERE_COMMAND_A_REASONING_08_2025 = "cohere/command-a-reasoning-08-2025"
|
||||
COHERE_COMMAND_A_VISION_07_2025 = "cohere/command-a-vision-07-2025"
|
||||
DEEPSEEK_CHAT = "deepseek/deepseek-chat" # Actually: DeepSeek V3
|
||||
DEEPSEEK_R1_0528 = "deepseek/deepseek-r1-0528"
|
||||
PERPLEXITY_SONAR = "perplexity/sonar"
|
||||
PERPLEXITY_SONAR_PRO = "perplexity/sonar-pro"
|
||||
PERPLEXITY_SONAR_REASONING_PRO = "perplexity/sonar-reasoning-pro"
|
||||
PERPLEXITY_SONAR_DEEP_RESEARCH = "perplexity/sonar-deep-research"
|
||||
NOUSRESEARCH_HERMES_3_LLAMA_3_1_405B = "nousresearch/hermes-3-llama-3.1-405b"
|
||||
NOUSRESEARCH_HERMES_3_LLAMA_3_1_70B = "nousresearch/hermes-3-llama-3.1-70b"
|
||||
@@ -172,11 +160,9 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
|
||||
AMAZON_NOVA_MICRO_V1 = "amazon/nova-micro-v1"
|
||||
AMAZON_NOVA_PRO_V1 = "amazon/nova-pro-v1"
|
||||
MICROSOFT_WIZARDLM_2_8X22B = "microsoft/wizardlm-2-8x22b"
|
||||
MICROSOFT_PHI_4 = "microsoft/phi-4"
|
||||
GRYPHE_MYTHOMAX_L2_13B = "gryphe/mythomax-l2-13b"
|
||||
META_LLAMA_4_SCOUT = "meta-llama/llama-4-scout"
|
||||
META_LLAMA_4_MAVERICK = "meta-llama/llama-4-maverick"
|
||||
GROK_3 = "x-ai/grok-3"
|
||||
GROK_4 = "x-ai/grok-4"
|
||||
GROK_4_FAST = "x-ai/grok-4-fast"
|
||||
GROK_4_1_FAST = "x-ai/grok-4.1-fast"
|
||||
@@ -354,41 +340,17 @@ MODEL_METADATA = {
|
||||
"ollama", 32768, None, "Dolphin Mistral Latest", "Ollama", "Mistral AI", 1
|
||||
),
|
||||
# https://openrouter.ai/models
|
||||
LlmModel.GEMINI_2_5_PRO_PREVIEW: ModelMetadata(
|
||||
LlmModel.GEMINI_2_5_PRO: ModelMetadata(
|
||||
"open_router",
|
||||
1048576,
|
||||
65536,
|
||||
1050000,
|
||||
8192,
|
||||
"Gemini 2.5 Pro Preview 03.25",
|
||||
"OpenRouter",
|
||||
"Google",
|
||||
2,
|
||||
),
|
||||
LlmModel.GEMINI_2_5_PRO: ModelMetadata(
|
||||
"open_router",
|
||||
1048576,
|
||||
65536,
|
||||
"Gemini 2.5 Pro",
|
||||
"OpenRouter",
|
||||
"Google",
|
||||
2,
|
||||
),
|
||||
LlmModel.GEMINI_3_1_PRO_PREVIEW: ModelMetadata(
|
||||
"open_router",
|
||||
1048576,
|
||||
65536,
|
||||
"Gemini 3.1 Pro Preview",
|
||||
"OpenRouter",
|
||||
"Google",
|
||||
2,
|
||||
),
|
||||
LlmModel.GEMINI_3_FLASH_PREVIEW: ModelMetadata(
|
||||
"open_router",
|
||||
1048576,
|
||||
65536,
|
||||
"Gemini 3 Flash Preview",
|
||||
"OpenRouter",
|
||||
"Google",
|
||||
1,
|
||||
LlmModel.GEMINI_3_PRO_PREVIEW: ModelMetadata(
|
||||
"open_router", 1048576, 65535, "Gemini 3 Pro Preview", "OpenRouter", "Google", 2
|
||||
),
|
||||
LlmModel.GEMINI_2_5_FLASH: ModelMetadata(
|
||||
"open_router", 1048576, 65535, "Gemini 2.5 Flash", "OpenRouter", "Google", 1
|
||||
@@ -396,15 +358,6 @@ MODEL_METADATA = {
|
||||
LlmModel.GEMINI_2_0_FLASH: ModelMetadata(
|
||||
"open_router", 1048576, 8192, "Gemini 2.0 Flash 001", "OpenRouter", "Google", 1
|
||||
),
|
||||
LlmModel.GEMINI_3_1_FLASH_LITE_PREVIEW: ModelMetadata(
|
||||
"open_router",
|
||||
1048576,
|
||||
65536,
|
||||
"Gemini 3.1 Flash Lite Preview",
|
||||
"OpenRouter",
|
||||
"Google",
|
||||
1,
|
||||
),
|
||||
LlmModel.GEMINI_2_5_FLASH_LITE_PREVIEW: ModelMetadata(
|
||||
"open_router",
|
||||
1048576,
|
||||
@@ -426,78 +379,12 @@ MODEL_METADATA = {
|
||||
LlmModel.MISTRAL_NEMO: ModelMetadata(
|
||||
"open_router", 128000, 4096, "Mistral Nemo", "OpenRouter", "Mistral AI", 1
|
||||
),
|
||||
LlmModel.MISTRAL_LARGE_3: ModelMetadata(
|
||||
"open_router",
|
||||
262144,
|
||||
None,
|
||||
"Mistral Large 3 2512",
|
||||
"OpenRouter",
|
||||
"Mistral AI",
|
||||
2,
|
||||
),
|
||||
LlmModel.MISTRAL_MEDIUM_3_1: ModelMetadata(
|
||||
"open_router",
|
||||
131072,
|
||||
None,
|
||||
"Mistral Medium 3.1",
|
||||
"OpenRouter",
|
||||
"Mistral AI",
|
||||
2,
|
||||
),
|
||||
LlmModel.MISTRAL_SMALL_3_2: ModelMetadata(
|
||||
"open_router",
|
||||
131072,
|
||||
131072,
|
||||
"Mistral Small 3.2 24B",
|
||||
"OpenRouter",
|
||||
"Mistral AI",
|
||||
1,
|
||||
),
|
||||
LlmModel.CODESTRAL: ModelMetadata(
|
||||
"open_router",
|
||||
256000,
|
||||
None,
|
||||
"Codestral 2508",
|
||||
"OpenRouter",
|
||||
"Mistral AI",
|
||||
1,
|
||||
),
|
||||
LlmModel.COHERE_COMMAND_R_08_2024: ModelMetadata(
|
||||
"open_router", 128000, 4096, "Command R 08.2024", "OpenRouter", "Cohere", 1
|
||||
),
|
||||
LlmModel.COHERE_COMMAND_R_PLUS_08_2024: ModelMetadata(
|
||||
"open_router", 128000, 4096, "Command R Plus 08.2024", "OpenRouter", "Cohere", 2
|
||||
),
|
||||
LlmModel.COHERE_COMMAND_A_03_2025: ModelMetadata(
|
||||
"open_router", 256000, 8192, "Command A 03.2025", "OpenRouter", "Cohere", 2
|
||||
),
|
||||
LlmModel.COHERE_COMMAND_A_TRANSLATE_08_2025: ModelMetadata(
|
||||
"open_router",
|
||||
128000,
|
||||
8192,
|
||||
"Command A Translate 08.2025",
|
||||
"OpenRouter",
|
||||
"Cohere",
|
||||
2,
|
||||
),
|
||||
LlmModel.COHERE_COMMAND_A_REASONING_08_2025: ModelMetadata(
|
||||
"open_router",
|
||||
256000,
|
||||
32768,
|
||||
"Command A Reasoning 08.2025",
|
||||
"OpenRouter",
|
||||
"Cohere",
|
||||
3,
|
||||
),
|
||||
LlmModel.COHERE_COMMAND_A_VISION_07_2025: ModelMetadata(
|
||||
"open_router",
|
||||
128000,
|
||||
8192,
|
||||
"Command A Vision 07.2025",
|
||||
"OpenRouter",
|
||||
"Cohere",
|
||||
2,
|
||||
),
|
||||
LlmModel.DEEPSEEK_CHAT: ModelMetadata(
|
||||
"open_router", 64000, 2048, "DeepSeek Chat", "OpenRouter", "DeepSeek", 1
|
||||
),
|
||||
@@ -510,15 +397,6 @@ MODEL_METADATA = {
|
||||
LlmModel.PERPLEXITY_SONAR_PRO: ModelMetadata(
|
||||
"open_router", 200000, 8000, "Sonar Pro", "OpenRouter", "Perplexity", 2
|
||||
),
|
||||
LlmModel.PERPLEXITY_SONAR_REASONING_PRO: ModelMetadata(
|
||||
"open_router",
|
||||
128000,
|
||||
8000,
|
||||
"Sonar Reasoning Pro",
|
||||
"OpenRouter",
|
||||
"Perplexity",
|
||||
2,
|
||||
),
|
||||
LlmModel.PERPLEXITY_SONAR_DEEP_RESEARCH: ModelMetadata(
|
||||
"open_router",
|
||||
128000,
|
||||
@@ -564,9 +442,6 @@ MODEL_METADATA = {
|
||||
LlmModel.MICROSOFT_WIZARDLM_2_8X22B: ModelMetadata(
|
||||
"open_router", 65536, 4096, "WizardLM 2 8x22B", "OpenRouter", "Microsoft", 1
|
||||
),
|
||||
LlmModel.MICROSOFT_PHI_4: ModelMetadata(
|
||||
"open_router", 16384, 16384, "Phi-4", "OpenRouter", "Microsoft", 1
|
||||
),
|
||||
LlmModel.GRYPHE_MYTHOMAX_L2_13B: ModelMetadata(
|
||||
"open_router", 4096, 4096, "MythoMax L2 13B", "OpenRouter", "Gryphe", 1
|
||||
),
|
||||
@@ -576,15 +451,6 @@ MODEL_METADATA = {
|
||||
LlmModel.META_LLAMA_4_MAVERICK: ModelMetadata(
|
||||
"open_router", 1048576, 1000000, "Llama 4 Maverick", "OpenRouter", "Meta", 1
|
||||
),
|
||||
LlmModel.GROK_3: ModelMetadata(
|
||||
"open_router",
|
||||
131072,
|
||||
131072,
|
||||
"Grok 3",
|
||||
"OpenRouter",
|
||||
"xAI",
|
||||
2,
|
||||
),
|
||||
LlmModel.GROK_4: ModelMetadata(
|
||||
"open_router", 256000, 256000, "Grok 4", "OpenRouter", "xAI", 3
|
||||
),
|
||||
|
||||
@@ -4,7 +4,7 @@ from enum import Enum
|
||||
from typing import Any, Literal
|
||||
|
||||
import openai
|
||||
from pydantic import SecretStr, field_validator
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.blocks._base import (
|
||||
Block,
|
||||
@@ -13,7 +13,6 @@ from backend.blocks._base import (
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.block import BlockInput
|
||||
from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
CredentialsField,
|
||||
@@ -36,20 +35,6 @@ class PerplexityModel(str, Enum):
|
||||
SONAR_DEEP_RESEARCH = "perplexity/sonar-deep-research"
|
||||
|
||||
|
||||
def _sanitize_perplexity_model(value: Any) -> PerplexityModel:
|
||||
"""Return a valid PerplexityModel, falling back to SONAR for invalid values."""
|
||||
if isinstance(value, PerplexityModel):
|
||||
return value
|
||||
try:
|
||||
return PerplexityModel(value)
|
||||
except ValueError:
|
||||
logger.warning(
|
||||
f"Invalid PerplexityModel '{value}', "
|
||||
f"falling back to {PerplexityModel.SONAR.value}"
|
||||
)
|
||||
return PerplexityModel.SONAR
|
||||
|
||||
|
||||
PerplexityCredentials = CredentialsMetaInput[
|
||||
Literal[ProviderName.OPEN_ROUTER], Literal["api_key"]
|
||||
]
|
||||
@@ -88,25 +73,6 @@ class PerplexityBlock(Block):
|
||||
advanced=False,
|
||||
)
|
||||
credentials: PerplexityCredentials = PerplexityCredentialsField()
|
||||
|
||||
@field_validator("model", mode="before")
|
||||
@classmethod
|
||||
def fallback_invalid_model(cls, v: Any) -> PerplexityModel:
|
||||
"""Fall back to SONAR if the model value is not a valid
|
||||
PerplexityModel (e.g. an OpenAI model ID set by the agent
|
||||
generator)."""
|
||||
return _sanitize_perplexity_model(v)
|
||||
|
||||
@classmethod
|
||||
def validate_data(cls, data: BlockInput) -> str | None:
|
||||
"""Sanitize the model field before JSON schema validation so that
|
||||
invalid values are replaced with the default instead of raising a
|
||||
BlockInputError."""
|
||||
model_value = data.get("model")
|
||||
if model_value is not None:
|
||||
data["model"] = _sanitize_perplexity_model(model_value).value
|
||||
return super().validate_data(data)
|
||||
|
||||
system_prompt: str = SchemaField(
|
||||
title="System Prompt",
|
||||
default="",
|
||||
|
||||
@@ -2232,7 +2232,6 @@ class DeleteRedditPostBlock(Block):
|
||||
("post_id", "abc123"),
|
||||
],
|
||||
test_mock={"delete_post": lambda creds, post_id: True},
|
||||
is_sensitive_action=True,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -2291,7 +2290,6 @@ class DeleteRedditCommentBlock(Block):
|
||||
("comment_id", "xyz789"),
|
||||
],
|
||||
test_mock={"delete_comment": lambda creds, comment_id: True},
|
||||
is_sensitive_action=True,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -72,7 +72,6 @@ class Slant3DCreateOrderBlock(Slant3DBlockBase):
|
||||
"_make_request": lambda *args, **kwargs: {"orderId": "314144241"},
|
||||
"_convert_to_color": lambda *args, **kwargs: "black",
|
||||
},
|
||||
is_sensitive_action=True,
|
||||
)
|
||||
|
||||
async def run(
|
||||
|
||||
@@ -1,81 +0,0 @@
|
||||
"""Unit tests for PerplexityBlock model fallback behavior."""
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.blocks.perplexity import (
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
PerplexityBlock,
|
||||
PerplexityModel,
|
||||
)
|
||||
|
||||
|
||||
def _make_input(**overrides) -> dict:
|
||||
defaults = {
|
||||
"prompt": "test query",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
}
|
||||
defaults.update(overrides)
|
||||
return defaults
|
||||
|
||||
|
||||
class TestPerplexityModelFallback:
|
||||
"""Tests for fallback_invalid_model field_validator."""
|
||||
|
||||
def test_invalid_model_falls_back_to_sonar(self):
|
||||
inp = PerplexityBlock.Input(**_make_input(model="gpt-5.2-2025-12-11"))
|
||||
assert inp.model == PerplexityModel.SONAR
|
||||
|
||||
def test_another_invalid_model_falls_back_to_sonar(self):
|
||||
inp = PerplexityBlock.Input(**_make_input(model="gpt-4o"))
|
||||
assert inp.model == PerplexityModel.SONAR
|
||||
|
||||
def test_valid_model_string_is_kept(self):
|
||||
inp = PerplexityBlock.Input(**_make_input(model="perplexity/sonar-pro"))
|
||||
assert inp.model == PerplexityModel.SONAR_PRO
|
||||
|
||||
def test_valid_enum_value_is_kept(self):
|
||||
inp = PerplexityBlock.Input(
|
||||
**_make_input(model=PerplexityModel.SONAR_DEEP_RESEARCH)
|
||||
)
|
||||
assert inp.model == PerplexityModel.SONAR_DEEP_RESEARCH
|
||||
|
||||
def test_default_model_when_omitted(self):
|
||||
inp = PerplexityBlock.Input(**_make_input())
|
||||
assert inp.model == PerplexityModel.SONAR
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model_value",
|
||||
[
|
||||
"perplexity/sonar",
|
||||
"perplexity/sonar-pro",
|
||||
"perplexity/sonar-deep-research",
|
||||
],
|
||||
)
|
||||
def test_all_valid_models_accepted(self, model_value: str):
|
||||
inp = PerplexityBlock.Input(**_make_input(model=model_value))
|
||||
assert inp.model.value == model_value
|
||||
|
||||
|
||||
class TestPerplexityValidateData:
|
||||
"""Tests for validate_data which runs during block execution (before
|
||||
Pydantic instantiation). Invalid models must be sanitized here so
|
||||
JSON schema validation does not reject them."""
|
||||
|
||||
def test_invalid_model_sanitized_before_schema_validation(self):
|
||||
data = _make_input(model="gpt-5.2-2025-12-11")
|
||||
error = PerplexityBlock.Input.validate_data(data)
|
||||
assert error is None
|
||||
assert data["model"] == PerplexityModel.SONAR.value
|
||||
|
||||
def test_valid_model_unchanged_by_validate_data(self):
|
||||
data = _make_input(model="perplexity/sonar-pro")
|
||||
error = PerplexityBlock.Input.validate_data(data)
|
||||
assert error is None
|
||||
assert data["model"] == "perplexity/sonar-pro"
|
||||
|
||||
def test_missing_model_uses_default(self):
|
||||
data = _make_input() # no model key
|
||||
error = PerplexityBlock.Input.validate_data(data)
|
||||
assert error is None
|
||||
inp = PerplexityBlock.Input(**data)
|
||||
assert inp.model == PerplexityModel.SONAR
|
||||
@@ -6,32 +6,6 @@
|
||||
COPILOT_ERROR_PREFIX = "[__COPILOT_ERROR_f7a1__]" # Renders as ErrorCard
|
||||
COPILOT_SYSTEM_PREFIX = "[__COPILOT_SYSTEM_e3b0__]" # Renders as system info message
|
||||
|
||||
# Prefix for all synthetic IDs generated by CoPilot block execution.
|
||||
# Used to distinguish CoPilot-generated records from real graph execution records
|
||||
# in PendingHumanReview and other tables.
|
||||
COPILOT_SYNTHETIC_ID_PREFIX = "copilot-"
|
||||
|
||||
# Sub-prefixes for session-scoped and node-scoped synthetic IDs.
|
||||
COPILOT_SESSION_PREFIX = f"{COPILOT_SYNTHETIC_ID_PREFIX}session-"
|
||||
COPILOT_NODE_PREFIX = f"{COPILOT_SYNTHETIC_ID_PREFIX}node-"
|
||||
|
||||
# Separator used in synthetic node_exec_id to encode node_id.
|
||||
# Format: "{node_id}:{random_hex}" — extract node_id via rsplit(":", 1)[0]
|
||||
COPILOT_NODE_EXEC_ID_SEPARATOR = ":"
|
||||
|
||||
# Compaction notice messages shown to users.
|
||||
COMPACTION_DONE_MSG = "Earlier messages were summarized to fit within context limits."
|
||||
COMPACTION_TOOL_NAME = "context_compaction"
|
||||
|
||||
|
||||
def is_copilot_synthetic_id(id_value: str) -> bool:
|
||||
"""Check if an ID is a CoPilot synthetic ID (not from a real graph execution)."""
|
||||
return id_value.startswith(COPILOT_SYNTHETIC_ID_PREFIX)
|
||||
|
||||
|
||||
def parse_node_id_from_exec_id(node_exec_id: str) -> str:
|
||||
"""Extract node_id from a synthetic node_exec_id.
|
||||
|
||||
Format: "{node_id}:{random_hex}" → returns "{node_id}".
|
||||
"""
|
||||
return node_exec_id.rsplit(COPILOT_NODE_EXEC_ID_SEPARATOR, 1)[0]
|
||||
|
||||
@@ -26,17 +26,3 @@ For other services, search the MCP registry at https://registry.modelcontextprot
|
||||
|
||||
If the server requires credentials, a `SetupRequirementsResponse` is returned with an OAuth
|
||||
login prompt. Once the user completes the flow and confirms, retry the same call immediately.
|
||||
|
||||
### Communication style
|
||||
|
||||
Avoid technical jargon like "MCP server", "OAuth", or "credentials" when talking to the user.
|
||||
Use plain, friendly language instead:
|
||||
|
||||
| Instead of… | Say… |
|
||||
|---|---|
|
||||
| "Let me connect to Sentry's MCP server and discover what tools are available." | "I can connect to Sentry and help identify important issues." |
|
||||
| "Let me connect to Sentry's MCP server now." | "Next, I'll connect to Sentry." |
|
||||
| "The MCP server at mcp.sentry.dev requires authentication. Please connect your credentials to continue." | "To continue, sign in to Sentry and approve access." |
|
||||
| "Sentry's MCP server needs OAuth authentication. You should see a prompt to connect your Sentry account…" | "You should see a prompt to sign in to Sentry. Once connected, I can help surface critical issues right away." |
|
||||
|
||||
Use **"connect to [Service]"** or **"sign in to [Service]"** — never "MCP server", "OAuth", or "credentials".
|
||||
|
||||
@@ -23,11 +23,6 @@ from typing import Any, Literal
|
||||
|
||||
import orjson
|
||||
|
||||
from backend.api.model import CopilotCompletionPayload
|
||||
from backend.data.notification_bus import (
|
||||
AsyncRedisNotificationEventBus,
|
||||
NotificationEvent,
|
||||
)
|
||||
from backend.data.redis_client import get_redis_async
|
||||
|
||||
from .config import ChatConfig
|
||||
@@ -43,7 +38,6 @@ from .response_model import (
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
config = ChatConfig()
|
||||
_notification_bus = AsyncRedisNotificationEventBus()
|
||||
|
||||
# Track background tasks for this pod (just the asyncio.Task reference, not subscribers)
|
||||
_local_sessions: dict[str, asyncio.Task] = {}
|
||||
@@ -751,29 +745,6 @@ async def mark_session_completed(
|
||||
|
||||
# Clean up local session reference if exists
|
||||
_local_sessions.pop(session_id, None)
|
||||
|
||||
# Publish copilot completion notification via WebSocket
|
||||
if meta:
|
||||
parsed = _parse_session_meta(meta, session_id)
|
||||
if parsed.user_id:
|
||||
try:
|
||||
await _notification_bus.publish(
|
||||
NotificationEvent(
|
||||
user_id=parsed.user_id,
|
||||
payload=CopilotCompletionPayload(
|
||||
type="copilot_completion",
|
||||
event="session_completed",
|
||||
session_id=session_id,
|
||||
status=status,
|
||||
),
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to publish copilot completion notification "
|
||||
f"for session {session_id}: {e}"
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
|
||||
@@ -12,7 +12,6 @@ from .agent_browser import BrowserActTool, BrowserNavigateTool, BrowserScreensho
|
||||
from .agent_output import AgentOutputTool
|
||||
from .base import BaseTool
|
||||
from .bash_exec import BashExecTool
|
||||
from .continue_run_block import ContinueRunBlockTool
|
||||
from .create_agent import CreateAgentTool
|
||||
from .customize_agent import CustomizeAgentTool
|
||||
from .edit_agent import EditAgentTool
|
||||
@@ -69,7 +68,6 @@ TOOL_REGISTRY: dict[str, BaseTool] = {
|
||||
"move_agents_to_folder": MoveAgentsToFolderTool(),
|
||||
"run_agent": RunAgentTool(),
|
||||
"run_block": RunBlockTool(),
|
||||
"continue_run_block": ContinueRunBlockTool(),
|
||||
"run_mcp_tool": RunMCPToolTool(),
|
||||
"get_mcp_guide": GetMCPGuideTool(),
|
||||
"view_agent_output": AgentOutputTool(),
|
||||
|
||||
@@ -829,12 +829,8 @@ class AgentFixer:
|
||||
|
||||
For nodes whose block has category "AI", this function ensures that the
|
||||
input_default has a "model" parameter set to one of the allowed models.
|
||||
If missing or set to an unsupported value, it is replaced with the
|
||||
appropriate default.
|
||||
|
||||
Blocks that define their own ``enum`` constraint on the ``model`` field
|
||||
in their inputSchema (e.g. PerplexityBlock) are validated against that
|
||||
enum instead of the generic allowed set.
|
||||
If missing or set to an unsupported value, it is replaced with
|
||||
default_model.
|
||||
|
||||
Args:
|
||||
agent: The agent dictionary to fix
|
||||
@@ -844,7 +840,7 @@ class AgentFixer:
|
||||
Returns:
|
||||
The fixed agent dictionary
|
||||
"""
|
||||
generic_allowed_models = {"gpt-4o", "claude-opus-4-6"}
|
||||
allowed_models = {"gpt-4o", "claude-opus-4-6"}
|
||||
|
||||
# Create a mapping of block_id to block for quick lookup
|
||||
block_map = {block.get("id"): block for block in blocks}
|
||||
@@ -872,36 +868,20 @@ class AgentFixer:
|
||||
input_default = node.get("input_default", {})
|
||||
current_model = input_default.get("model")
|
||||
|
||||
# Determine allowed models and default from the block's schema.
|
||||
# Blocks with a block-specific enum on the model field (e.g.
|
||||
# PerplexityBlock) use their own enum values; others use the
|
||||
# generic set.
|
||||
model_schema = (
|
||||
block.get("inputSchema", {}).get("properties", {}).get("model", {})
|
||||
)
|
||||
block_model_enum = model_schema.get("enum")
|
||||
|
||||
if block_model_enum:
|
||||
allowed_models = set(block_model_enum)
|
||||
fallback_model = model_schema.get("default", block_model_enum[0])
|
||||
else:
|
||||
allowed_models = generic_allowed_models
|
||||
fallback_model = default_model
|
||||
|
||||
if current_model not in allowed_models:
|
||||
block_name = block.get("name", "Unknown AI Block")
|
||||
if current_model is None:
|
||||
self.add_fix_log(
|
||||
f"Added model parameter '{fallback_model}' to AI "
|
||||
f"Added model parameter '{default_model}' to AI "
|
||||
f"block node {node_id} ({block_name})"
|
||||
)
|
||||
else:
|
||||
self.add_fix_log(
|
||||
f"Replaced unsupported model '{current_model}' "
|
||||
f"with '{fallback_model}' on AI block node "
|
||||
f"with '{default_model}' on AI block node "
|
||||
f"{node_id} ({block_name})"
|
||||
)
|
||||
input_default["model"] = fallback_model
|
||||
input_default["model"] = default_model
|
||||
node["input_default"] = input_default
|
||||
fixed_count += 1
|
||||
|
||||
|
||||
@@ -475,111 +475,6 @@ class TestFixAiModelParameter:
|
||||
|
||||
assert result["nodes"][0]["input_default"]["model"] == "claude-opus-4-6"
|
||||
|
||||
def test_block_specific_enum_uses_block_default(self):
|
||||
"""Blocks with their own model enum (e.g. PerplexityBlock) should use
|
||||
the block's allowed models and default, not the generic ones."""
|
||||
fixer = AgentFixer()
|
||||
block_id = generate_uuid()
|
||||
node = _make_node(
|
||||
node_id="n1",
|
||||
block_id=block_id,
|
||||
input_default={"model": "gpt-5.2-2025-12-11"},
|
||||
)
|
||||
agent = _make_agent(nodes=[node])
|
||||
|
||||
blocks = [
|
||||
{
|
||||
"id": block_id,
|
||||
"name": "PerplexityBlock",
|
||||
"categories": [{"category": "AI"}],
|
||||
"inputSchema": {
|
||||
"properties": {
|
||||
"model": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"perplexity/sonar",
|
||||
"perplexity/sonar-pro",
|
||||
"perplexity/sonar-deep-research",
|
||||
],
|
||||
"default": "perplexity/sonar",
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
result = fixer.fix_ai_model_parameter(agent, blocks)
|
||||
|
||||
assert result["nodes"][0]["input_default"]["model"] == "perplexity/sonar"
|
||||
|
||||
def test_block_specific_enum_valid_model_unchanged(self):
|
||||
"""A valid block-specific model should not be replaced."""
|
||||
fixer = AgentFixer()
|
||||
block_id = generate_uuid()
|
||||
node = _make_node(
|
||||
node_id="n1",
|
||||
block_id=block_id,
|
||||
input_default={"model": "perplexity/sonar-pro"},
|
||||
)
|
||||
agent = _make_agent(nodes=[node])
|
||||
|
||||
blocks = [
|
||||
{
|
||||
"id": block_id,
|
||||
"name": "PerplexityBlock",
|
||||
"categories": [{"category": "AI"}],
|
||||
"inputSchema": {
|
||||
"properties": {
|
||||
"model": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"perplexity/sonar",
|
||||
"perplexity/sonar-pro",
|
||||
"perplexity/sonar-deep-research",
|
||||
],
|
||||
"default": "perplexity/sonar",
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
result = fixer.fix_ai_model_parameter(agent, blocks)
|
||||
|
||||
assert result["nodes"][0]["input_default"]["model"] == "perplexity/sonar-pro"
|
||||
|
||||
def test_block_specific_enum_missing_model_gets_block_default(self):
|
||||
"""Missing model on a block with enum should use the block's default."""
|
||||
fixer = AgentFixer()
|
||||
block_id = generate_uuid()
|
||||
node = _make_node(node_id="n1", block_id=block_id, input_default={})
|
||||
agent = _make_agent(nodes=[node])
|
||||
|
||||
blocks = [
|
||||
{
|
||||
"id": block_id,
|
||||
"name": "PerplexityBlock",
|
||||
"categories": [{"category": "AI"}],
|
||||
"inputSchema": {
|
||||
"properties": {
|
||||
"model": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"perplexity/sonar",
|
||||
"perplexity/sonar-pro",
|
||||
"perplexity/sonar-deep-research",
|
||||
],
|
||||
"default": "perplexity/sonar",
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
result = fixer.fix_ai_model_parameter(agent, blocks)
|
||||
|
||||
assert result["nodes"][0]["input_default"]["model"] == "perplexity/sonar"
|
||||
|
||||
|
||||
class TestFixAgentExecutorBlocks:
|
||||
"""Tests for fix_agent_executor_blocks."""
|
||||
|
||||
@@ -1,157 +0,0 @@
|
||||
"""Tool for continuing block execution after human review approval."""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from prisma.enums import ReviewStatus
|
||||
|
||||
from backend.blocks import get_block
|
||||
from backend.copilot.constants import (
|
||||
COPILOT_NODE_PREFIX,
|
||||
COPILOT_SESSION_PREFIX,
|
||||
parse_node_id_from_exec_id,
|
||||
)
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.data.db_accessors import review_db
|
||||
|
||||
from .base import BaseTool
|
||||
from .helpers import execute_block, resolve_block_credentials
|
||||
from .models import ErrorResponse, ToolResponseBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ContinueRunBlockTool(BaseTool):
|
||||
"""Tool for continuing a block execution after human review approval."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "continue_run_block"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Continue executing a block after human review approval. "
|
||||
"Use this after a run_block call returned review_required. "
|
||||
"Pass the review_id from the review_required response. "
|
||||
"The block will execute with the original pre-approved input data."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"review_id": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"The review_id from a previous review_required response. "
|
||||
"This resumes execution with the pre-approved input data."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": ["review_id"],
|
||||
}
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return True
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
review_id = (
|
||||
kwargs.get("review_id", "").strip() if kwargs.get("review_id") else ""
|
||||
)
|
||||
session_id = session.session_id
|
||||
|
||||
if not review_id:
|
||||
return ErrorResponse(
|
||||
message="Please provide a review_id", session_id=session_id
|
||||
)
|
||||
|
||||
if not user_id:
|
||||
return ErrorResponse(
|
||||
message="Authentication required", session_id=session_id
|
||||
)
|
||||
|
||||
# Look up and validate the review record via adapter
|
||||
reviews = await review_db().get_reviews_by_node_exec_ids([review_id], user_id)
|
||||
review = reviews.get(review_id)
|
||||
|
||||
if not review:
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
f"Review '{review_id}' not found or already executed. "
|
||||
"It may have been consumed by a previous continue_run_block call."
|
||||
),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Validate the review belongs to this session
|
||||
expected_graph_exec_id = f"{COPILOT_SESSION_PREFIX}{session_id}"
|
||||
if review.graph_exec_id != expected_graph_exec_id:
|
||||
return ErrorResponse(
|
||||
message="Review does not belong to this session.",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if review.status == ReviewStatus.WAITING:
|
||||
return ErrorResponse(
|
||||
message="Review has not been approved yet. "
|
||||
"Please wait for the user to approve the review first.",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if review.status == ReviewStatus.REJECTED:
|
||||
return ErrorResponse(
|
||||
message="Review was rejected. The block will not execute.",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Extract block_id from review_id: copilot-node-{block_id}:{random_hex}
|
||||
block_id = parse_node_id_from_exec_id(review_id).removeprefix(
|
||||
COPILOT_NODE_PREFIX
|
||||
)
|
||||
block = get_block(block_id)
|
||||
if not block:
|
||||
return ErrorResponse(
|
||||
message=f"Block '{block_id}' not found", session_id=session_id
|
||||
)
|
||||
|
||||
input_data: dict[str, Any] = (
|
||||
review.payload if isinstance(review.payload, dict) else {}
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Continuing block {block.name} ({block_id}) for user {user_id} "
|
||||
f"with review_id={review_id}"
|
||||
)
|
||||
|
||||
matched_creds, missing_creds = await resolve_block_credentials(
|
||||
user_id, block, input_data
|
||||
)
|
||||
if missing_creds:
|
||||
return ErrorResponse(
|
||||
message=f"Block '{block.name}' requires credentials that are not configured.",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
result = await execute_block(
|
||||
block=block,
|
||||
block_id=block_id,
|
||||
input_data=input_data,
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
node_exec_id=review_id,
|
||||
matched_credentials=matched_creds,
|
||||
)
|
||||
|
||||
# Delete review record after successful execution (one-time use)
|
||||
if result.type != "error":
|
||||
await review_db().delete_review_by_node_exec_id(review_id, user_id)
|
||||
|
||||
return result
|
||||
@@ -1,186 +0,0 @@
|
||||
"""Tests for ContinueRunBlockTool."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from prisma.enums import ReviewStatus
|
||||
|
||||
from ._test_data import make_session
|
||||
from .continue_run_block import ContinueRunBlockTool
|
||||
from .models import BlockOutputResponse, ErrorResponse
|
||||
|
||||
_TEST_USER_ID = "test-user-continue"
|
||||
|
||||
|
||||
def _make_review_model(
|
||||
node_exec_id: str,
|
||||
status: ReviewStatus = ReviewStatus.APPROVED,
|
||||
payload: dict | None = None,
|
||||
graph_exec_id: str = "",
|
||||
):
|
||||
"""Create a mock PendingHumanReviewModel."""
|
||||
mock = MagicMock()
|
||||
mock.node_exec_id = node_exec_id
|
||||
mock.status = status
|
||||
mock.payload = payload or {"text": "hello"}
|
||||
mock.graph_exec_id = graph_exec_id
|
||||
return mock
|
||||
|
||||
|
||||
class TestContinueRunBlock:
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_missing_review_id_returns_error(self):
|
||||
tool = ContinueRunBlockTool()
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
|
||||
response = await tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
review_id="",
|
||||
)
|
||||
|
||||
assert isinstance(response, ErrorResponse)
|
||||
assert "review_id" in response.message
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_review_not_found_returns_error(self):
|
||||
tool = ContinueRunBlockTool()
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.get_reviews_by_node_exec_ids = AsyncMock(return_value={})
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.continue_run_block.review_db",
|
||||
return_value=mock_db,
|
||||
):
|
||||
response = await tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
review_id="copilot-node-some-block:abc12345",
|
||||
)
|
||||
|
||||
assert isinstance(response, ErrorResponse)
|
||||
assert "not found" in response.message
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_waiting_review_returns_error(self):
|
||||
tool = ContinueRunBlockTool()
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
review_id = "copilot-node-some-block:abc12345"
|
||||
graph_exec_id = f"copilot-session-{session.session_id}"
|
||||
review = _make_review_model(
|
||||
review_id, status=ReviewStatus.WAITING, graph_exec_id=graph_exec_id
|
||||
)
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.get_reviews_by_node_exec_ids = AsyncMock(
|
||||
return_value={review_id: review}
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.continue_run_block.review_db",
|
||||
return_value=mock_db,
|
||||
):
|
||||
response = await tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
review_id=review_id,
|
||||
)
|
||||
|
||||
assert isinstance(response, ErrorResponse)
|
||||
assert "not been approved" in response.message
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_rejected_review_returns_error(self):
|
||||
tool = ContinueRunBlockTool()
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
review_id = "copilot-node-some-block:abc12345"
|
||||
graph_exec_id = f"copilot-session-{session.session_id}"
|
||||
review = _make_review_model(
|
||||
review_id, status=ReviewStatus.REJECTED, graph_exec_id=graph_exec_id
|
||||
)
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.get_reviews_by_node_exec_ids = AsyncMock(
|
||||
return_value={review_id: review}
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.continue_run_block.review_db",
|
||||
return_value=mock_db,
|
||||
):
|
||||
response = await tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
review_id=review_id,
|
||||
)
|
||||
|
||||
assert isinstance(response, ErrorResponse)
|
||||
assert "rejected" in response.message.lower()
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_approved_review_executes_block(self):
|
||||
tool = ContinueRunBlockTool()
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
review_id = "copilot-node-delete-branch-id:abc12345"
|
||||
graph_exec_id = f"copilot-session-{session.session_id}"
|
||||
input_data = {"repo_url": "https://github.com/test/repo", "branch": "main"}
|
||||
review = _make_review_model(
|
||||
review_id,
|
||||
status=ReviewStatus.APPROVED,
|
||||
payload=input_data,
|
||||
graph_exec_id=graph_exec_id,
|
||||
)
|
||||
|
||||
mock_block = MagicMock()
|
||||
mock_block.name = "Delete Branch"
|
||||
|
||||
async def mock_execute(data, **kwargs):
|
||||
yield "result", "Branch deleted"
|
||||
|
||||
mock_block.execute = mock_execute
|
||||
mock_block.input_schema.get_credentials_fields_info.return_value = []
|
||||
|
||||
mock_workspace_db = MagicMock()
|
||||
mock_workspace_db.get_or_create_workspace = AsyncMock(
|
||||
return_value=MagicMock(id="test-workspace-id")
|
||||
)
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.get_reviews_by_node_exec_ids = AsyncMock(
|
||||
return_value={review_id: review}
|
||||
)
|
||||
mock_db.delete_review_by_node_exec_id = AsyncMock(return_value=1)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.tools.continue_run_block.review_db",
|
||||
return_value=mock_db,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.continue_run_block.get_block",
|
||||
return_value=mock_block,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.helpers.workspace_db",
|
||||
return_value=mock_workspace_db,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.helpers.match_credentials_to_requirements",
|
||||
return_value=({}, []),
|
||||
),
|
||||
):
|
||||
response = await tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
review_id=review_id,
|
||||
)
|
||||
|
||||
assert isinstance(response, BlockOutputResponse)
|
||||
assert response.success is True
|
||||
assert response.block_name == "Delete Branch"
|
||||
# Verify review was deleted (one-time use)
|
||||
mock_db.delete_review_by_node_exec_id.assert_called_once_with(
|
||||
review_id, _TEST_USER_ID
|
||||
)
|
||||
@@ -1,24 +1,7 @@
|
||||
"""Shared helpers for chat tools."""
|
||||
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from typing import Any
|
||||
|
||||
from pydantic_core import PydanticUndefined
|
||||
|
||||
from backend.blocks._base import AnyBlockSchema
|
||||
from backend.copilot.constants import COPILOT_NODE_PREFIX, COPILOT_SESSION_PREFIX
|
||||
from backend.data.db_accessors import workspace_db
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.data.model import CredentialsFieldInfo, CredentialsMetaInput
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.util.exceptions import BlockError
|
||||
|
||||
from .models import BlockOutputResponse, ErrorResponse, ToolResponseBase
|
||||
from .utils import match_credentials_to_requirements
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_inputs_from_schema(
|
||||
input_schema: dict[str, Any],
|
||||
@@ -44,159 +27,3 @@ def get_inputs_from_schema(
|
||||
for name, schema in properties.items()
|
||||
if name not in exclude
|
||||
]
|
||||
|
||||
|
||||
async def execute_block(
|
||||
*,
|
||||
block: AnyBlockSchema,
|
||||
block_id: str,
|
||||
input_data: dict[str, Any],
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
node_exec_id: str,
|
||||
matched_credentials: dict[str, CredentialsMetaInput],
|
||||
sensitive_action_safe_mode: bool = False,
|
||||
) -> ToolResponseBase:
|
||||
"""Execute a block with full context setup, credential injection, and error handling.
|
||||
|
||||
This is the shared execution path used by both ``run_block`` (after review
|
||||
check) and ``continue_run_block`` (after approval).
|
||||
|
||||
Returns:
|
||||
BlockOutputResponse on success, ErrorResponse on failure.
|
||||
"""
|
||||
try:
|
||||
workspace = await workspace_db().get_or_create_workspace(user_id)
|
||||
|
||||
synthetic_graph_id = f"{COPILOT_SESSION_PREFIX}{session_id}"
|
||||
synthetic_node_id = f"{COPILOT_NODE_PREFIX}{block_id}"
|
||||
|
||||
execution_context = ExecutionContext(
|
||||
user_id=user_id,
|
||||
graph_id=synthetic_graph_id,
|
||||
graph_exec_id=synthetic_graph_id,
|
||||
graph_version=1,
|
||||
node_id=synthetic_node_id,
|
||||
node_exec_id=node_exec_id,
|
||||
workspace_id=workspace.id,
|
||||
session_id=session_id,
|
||||
sensitive_action_safe_mode=sensitive_action_safe_mode,
|
||||
)
|
||||
|
||||
exec_kwargs: dict[str, Any] = {
|
||||
"user_id": user_id,
|
||||
"execution_context": execution_context,
|
||||
"workspace_id": workspace.id,
|
||||
"graph_exec_id": synthetic_graph_id,
|
||||
"node_exec_id": node_exec_id,
|
||||
"node_id": synthetic_node_id,
|
||||
"graph_version": 1,
|
||||
"graph_id": synthetic_graph_id,
|
||||
}
|
||||
|
||||
# Inject credentials
|
||||
creds_manager = IntegrationCredentialsManager()
|
||||
for field_name, cred_meta in matched_credentials.items():
|
||||
if field_name not in input_data:
|
||||
input_data[field_name] = cred_meta.model_dump()
|
||||
|
||||
actual_credentials = await creds_manager.get(
|
||||
user_id, cred_meta.id, lock=False
|
||||
)
|
||||
if actual_credentials:
|
||||
exec_kwargs[field_name] = actual_credentials
|
||||
else:
|
||||
return ErrorResponse(
|
||||
message=f"Failed to retrieve credentials for {field_name}",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Execute the block and collect outputs
|
||||
outputs: dict[str, list[Any]] = defaultdict(list)
|
||||
async for output_name, output_data in block.execute(
|
||||
input_data,
|
||||
**exec_kwargs,
|
||||
):
|
||||
outputs[output_name].append(output_data)
|
||||
|
||||
return BlockOutputResponse(
|
||||
message=f"Block '{block.name}' executed successfully",
|
||||
block_id=block_id,
|
||||
block_name=block.name,
|
||||
outputs=dict(outputs),
|
||||
success=True,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
except BlockError as e:
|
||||
logger.warning(f"Block execution failed: {e}")
|
||||
return ErrorResponse(
|
||||
message=f"Block execution failed: {e}",
|
||||
error=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error executing block: {e}", exc_info=True)
|
||||
return ErrorResponse(
|
||||
message=f"Failed to execute block: {str(e)}",
|
||||
error=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
|
||||
async def resolve_block_credentials(
|
||||
user_id: str,
|
||||
block: AnyBlockSchema,
|
||||
input_data: dict[str, Any] | None = None,
|
||||
) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]:
|
||||
"""Resolve credentials for a block by matching user's available credentials.
|
||||
|
||||
Handles discriminated credentials (e.g. provider selection based on model).
|
||||
|
||||
Returns:
|
||||
(matched_credentials, missing_credentials)
|
||||
"""
|
||||
input_data = input_data or {}
|
||||
requirements = _resolve_discriminated_credentials(block, input_data)
|
||||
|
||||
if not requirements:
|
||||
return {}, []
|
||||
|
||||
return await match_credentials_to_requirements(user_id, requirements)
|
||||
|
||||
|
||||
def _resolve_discriminated_credentials(
|
||||
block: AnyBlockSchema,
|
||||
input_data: dict[str, Any],
|
||||
) -> dict[str, CredentialsFieldInfo]:
|
||||
"""Resolve credential requirements, applying discriminator logic where needed."""
|
||||
credentials_fields_info = block.input_schema.get_credentials_fields_info()
|
||||
if not credentials_fields_info:
|
||||
return {}
|
||||
|
||||
resolved: dict[str, CredentialsFieldInfo] = {}
|
||||
|
||||
for field_name, field_info in credentials_fields_info.items():
|
||||
effective_field_info = field_info
|
||||
|
||||
if field_info.discriminator and field_info.discriminator_mapping:
|
||||
discriminator_value = input_data.get(field_info.discriminator)
|
||||
if discriminator_value is None:
|
||||
field = block.input_schema.model_fields.get(field_info.discriminator)
|
||||
if field and field.default is not PydanticUndefined:
|
||||
discriminator_value = field.default
|
||||
|
||||
if (
|
||||
discriminator_value
|
||||
and discriminator_value in field_info.discriminator_mapping
|
||||
):
|
||||
effective_field_info = field_info.discriminate(discriminator_value)
|
||||
effective_field_info.discriminator_values.add(discriminator_value)
|
||||
logger.debug(
|
||||
f"Discriminated provider for {field_name}: "
|
||||
f"{discriminator_value} -> {effective_field_info.provider}"
|
||||
)
|
||||
|
||||
resolved[field_name] = effective_field_info
|
||||
|
||||
return resolved
|
||||
|
||||
@@ -39,7 +39,6 @@ class ResponseType(str, Enum):
|
||||
BLOCK_LIST = "block_list"
|
||||
BLOCK_DETAILS = "block_details"
|
||||
BLOCK_OUTPUT = "block_output"
|
||||
REVIEW_REQUIRED = "review_required"
|
||||
|
||||
# MCP
|
||||
MCP_GUIDE = "mcp_guide"
|
||||
@@ -459,21 +458,6 @@ class BlockOutputResponse(ToolResponseBase):
|
||||
success: bool = True
|
||||
|
||||
|
||||
class ReviewRequiredResponse(ToolResponseBase):
|
||||
"""Response when a block requires human review before execution."""
|
||||
|
||||
type: ResponseType = ResponseType.REVIEW_REQUIRED
|
||||
block_id: str
|
||||
block_name: str
|
||||
review_id: str = Field(description="The review ID for tracking approval status")
|
||||
graph_exec_id: str = Field(
|
||||
description="The graph execution ID for fetching review status"
|
||||
)
|
||||
input_data: dict[str, Any] = Field(
|
||||
description="The input data that requires review"
|
||||
)
|
||||
|
||||
|
||||
class WebFetchResponse(ToolResponseBase):
|
||||
"""Response for web_fetch tool."""
|
||||
|
||||
|
||||
@@ -534,9 +534,7 @@ class RunAgentTool(BaseTool):
|
||||
return ExecutionStartedResponse(
|
||||
message=(
|
||||
f"Agent '{library_agent.name}' is awaiting human review. "
|
||||
f"The user can approve or reject inline. After approval, "
|
||||
f"the execution resumes automatically. Use view_agent_output "
|
||||
f"with execution_id='{execution.id}' to check the result."
|
||||
f"Check at {library_agent_link}."
|
||||
),
|
||||
session_id=session_id,
|
||||
execution_id=execution.id,
|
||||
|
||||
@@ -2,34 +2,38 @@
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from typing import Any
|
||||
|
||||
from pydantic_core import PydanticUndefined
|
||||
|
||||
from backend.blocks import BlockType, get_block
|
||||
from backend.blocks._base import AnyBlockSchema
|
||||
from backend.copilot.constants import (
|
||||
COPILOT_NODE_EXEC_ID_SEPARATOR,
|
||||
COPILOT_NODE_PREFIX,
|
||||
COPILOT_SESSION_PREFIX,
|
||||
)
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.data.db_accessors import review_db
|
||||
from backend.data.db_accessors import workspace_db
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.data.model import CredentialsFieldInfo, CredentialsMetaInput
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.util.exceptions import BlockError
|
||||
|
||||
from .base import BaseTool
|
||||
from .find_block import COPILOT_EXCLUDED_BLOCK_IDS, COPILOT_EXCLUDED_BLOCK_TYPES
|
||||
from .helpers import execute_block, get_inputs_from_schema, resolve_block_credentials
|
||||
from .helpers import get_inputs_from_schema
|
||||
from .models import (
|
||||
BlockDetails,
|
||||
BlockDetailsResponse,
|
||||
BlockOutputResponse,
|
||||
ErrorResponse,
|
||||
InputValidationErrorResponse,
|
||||
ReviewRequiredResponse,
|
||||
SetupInfo,
|
||||
SetupRequirementsResponse,
|
||||
ToolResponseBase,
|
||||
UserReadiness,
|
||||
)
|
||||
from .utils import build_missing_credentials_from_field_info
|
||||
from .utils import (
|
||||
build_missing_credentials_from_field_info,
|
||||
match_credentials_to_requirements,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -48,9 +52,7 @@ class RunBlockTool(BaseTool):
|
||||
"IMPORTANT: You MUST call find_block first to get the block's 'id' - "
|
||||
"do NOT guess or make up block IDs. "
|
||||
"On first attempt (without input_data), returns detailed schema showing "
|
||||
"required inputs and outputs. Then call again with proper input_data to execute. "
|
||||
"If a block requires human review, use continue_run_block with the "
|
||||
"review_id after the user approves."
|
||||
"required inputs and outputs. Then call again with proper input_data to execute."
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -164,10 +166,11 @@ class RunBlockTool(BaseTool):
|
||||
|
||||
logger.info(f"Executing block {block.name} ({block_id}) for user {user_id}")
|
||||
|
||||
creds_manager = IntegrationCredentialsManager()
|
||||
(
|
||||
matched_credentials,
|
||||
missing_credentials,
|
||||
) = await resolve_block_credentials(user_id, block, input_data)
|
||||
) = await self._resolve_block_credentials(user_id, block, input_data)
|
||||
|
||||
# Get block schemas for details/validation
|
||||
try:
|
||||
@@ -276,97 +279,169 @@ class RunBlockTool(BaseTool):
|
||||
user_authenticated=True,
|
||||
)
|
||||
|
||||
# Generate synthetic IDs for CoPilot context.
|
||||
# Encode node_id in node_exec_id so it can be extracted later
|
||||
# (e.g. for auto-approve, where we need node_id but have no NodeExecution row).
|
||||
synthetic_graph_id = f"{COPILOT_SESSION_PREFIX}{session.session_id}"
|
||||
synthetic_node_id = f"{COPILOT_NODE_PREFIX}{block_id}"
|
||||
try:
|
||||
# Get or create user's workspace for CoPilot file operations
|
||||
workspace = await workspace_db().get_or_create_workspace(user_id)
|
||||
|
||||
# Check for an existing WAITING review for this block with the same input.
|
||||
# If the LLM retries run_block with identical input, we reuse the existing
|
||||
# review instead of creating duplicates. Different inputs = new execution.
|
||||
existing_reviews = await review_db().get_pending_reviews_for_execution(
|
||||
synthetic_graph_id, user_id
|
||||
)
|
||||
existing_review = next(
|
||||
(
|
||||
r
|
||||
for r in existing_reviews
|
||||
if r.node_id == synthetic_node_id
|
||||
and r.status.value == "WAITING"
|
||||
and r.payload == input_data
|
||||
),
|
||||
None,
|
||||
)
|
||||
if existing_review:
|
||||
return ReviewRequiredResponse(
|
||||
message=(
|
||||
f"Block '{block.name}' requires human review. "
|
||||
f"After the user approves, call continue_run_block with "
|
||||
f"review_id='{existing_review.node_exec_id}' to execute."
|
||||
),
|
||||
session_id=session_id,
|
||||
block_id=block_id,
|
||||
block_name=block.name,
|
||||
review_id=existing_review.node_exec_id,
|
||||
graph_exec_id=synthetic_graph_id,
|
||||
input_data=input_data,
|
||||
# Generate synthetic IDs for CoPilot context
|
||||
# Each chat session is treated as its own agent with one continuous run
|
||||
# This means:
|
||||
# - graph_id (agent) = session (memories scoped to session when limit_to_agent=True)
|
||||
# - graph_exec_id (run) = session (memories scoped to session when limit_to_run=True)
|
||||
# - node_exec_id = unique per block execution
|
||||
synthetic_graph_id = f"copilot-session-{session.session_id}"
|
||||
synthetic_graph_exec_id = f"copilot-session-{session.session_id}"
|
||||
synthetic_node_id = f"copilot-node-{block_id}"
|
||||
synthetic_node_exec_id = (
|
||||
f"copilot-{session.session_id}-{uuid.uuid4().hex[:8]}"
|
||||
)
|
||||
|
||||
synthetic_node_exec_id = (
|
||||
f"{synthetic_node_id}{COPILOT_NODE_EXEC_ID_SEPARATOR}"
|
||||
f"{uuid.uuid4().hex[:8]}"
|
||||
)
|
||||
|
||||
# Check for HITL review before execution.
|
||||
# This creates the review record in the DB for CoPilot flows.
|
||||
review_context = ExecutionContext(
|
||||
user_id=user_id,
|
||||
graph_id=synthetic_graph_id,
|
||||
graph_exec_id=synthetic_graph_id,
|
||||
graph_version=1,
|
||||
node_id=synthetic_node_id,
|
||||
node_exec_id=synthetic_node_exec_id,
|
||||
sensitive_action_safe_mode=True,
|
||||
)
|
||||
should_pause, input_data = await block.is_block_exec_need_review(
|
||||
input_data,
|
||||
user_id=user_id,
|
||||
node_id=synthetic_node_id,
|
||||
node_exec_id=synthetic_node_exec_id,
|
||||
graph_exec_id=synthetic_graph_id,
|
||||
graph_id=synthetic_graph_id,
|
||||
graph_version=1,
|
||||
execution_context=review_context,
|
||||
is_graph_execution=False,
|
||||
)
|
||||
if should_pause:
|
||||
return ReviewRequiredResponse(
|
||||
message=(
|
||||
f"Block '{block.name}' requires human review. "
|
||||
f"After the user approves, call continue_run_block with "
|
||||
f"review_id='{synthetic_node_exec_id}' to execute."
|
||||
),
|
||||
session_id=session_id,
|
||||
block_id=block_id,
|
||||
block_name=block.name,
|
||||
review_id=synthetic_node_exec_id,
|
||||
graph_exec_id=synthetic_graph_id,
|
||||
input_data=input_data,
|
||||
# Create unified execution context with all required fields
|
||||
execution_context = ExecutionContext(
|
||||
# Execution identity
|
||||
user_id=user_id,
|
||||
graph_id=synthetic_graph_id,
|
||||
graph_exec_id=synthetic_graph_exec_id,
|
||||
graph_version=1, # Versions are 1-indexed
|
||||
node_id=synthetic_node_id,
|
||||
node_exec_id=synthetic_node_exec_id,
|
||||
# Workspace with session scoping
|
||||
workspace_id=workspace.id,
|
||||
session_id=session.session_id,
|
||||
)
|
||||
|
||||
return await execute_block(
|
||||
block=block,
|
||||
block_id=block_id,
|
||||
input_data=input_data,
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
node_exec_id=synthetic_node_exec_id,
|
||||
matched_credentials=matched_credentials,
|
||||
)
|
||||
# Prepare kwargs for block execution
|
||||
# Keep individual kwargs for backwards compatibility with existing blocks
|
||||
exec_kwargs: dict[str, Any] = {
|
||||
"user_id": user_id,
|
||||
"execution_context": execution_context,
|
||||
# Legacy: individual kwargs for blocks not yet using execution_context
|
||||
"workspace_id": workspace.id,
|
||||
"graph_exec_id": synthetic_graph_exec_id,
|
||||
"node_exec_id": synthetic_node_exec_id,
|
||||
"node_id": synthetic_node_id,
|
||||
"graph_version": 1, # Versions are 1-indexed
|
||||
"graph_id": synthetic_graph_id,
|
||||
}
|
||||
|
||||
for field_name, cred_meta in matched_credentials.items():
|
||||
# Inject metadata into input_data (for validation)
|
||||
if field_name not in input_data:
|
||||
input_data[field_name] = cred_meta.model_dump()
|
||||
|
||||
# Fetch actual credentials and pass as kwargs (for execution)
|
||||
actual_credentials = await creds_manager.get(
|
||||
user_id, cred_meta.id, lock=False
|
||||
)
|
||||
if actual_credentials:
|
||||
exec_kwargs[field_name] = actual_credentials
|
||||
else:
|
||||
return ErrorResponse(
|
||||
message=f"Failed to retrieve credentials for {field_name}",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Execute the block and collect outputs
|
||||
outputs: dict[str, list[Any]] = defaultdict(list)
|
||||
async for output_name, output_data in block.execute(
|
||||
input_data,
|
||||
**exec_kwargs,
|
||||
):
|
||||
outputs[output_name].append(output_data)
|
||||
|
||||
return BlockOutputResponse(
|
||||
message=f"Block '{block.name}' executed successfully",
|
||||
block_id=block_id,
|
||||
block_name=block.name,
|
||||
outputs=dict(outputs),
|
||||
success=True,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
except BlockError as e:
|
||||
logger.warning(f"Block execution failed: {e}")
|
||||
return ErrorResponse(
|
||||
message=f"Block execution failed: {e}",
|
||||
error=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error executing block: {e}", exc_info=True)
|
||||
return ErrorResponse(
|
||||
message=f"Failed to execute block: {str(e)}",
|
||||
error=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
async def _resolve_block_credentials(
|
||||
self,
|
||||
user_id: str,
|
||||
block: AnyBlockSchema,
|
||||
input_data: dict[str, Any] | None = None,
|
||||
) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]:
|
||||
"""
|
||||
Resolve credentials for a block by matching user's available credentials.
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
block: Block to resolve credentials for
|
||||
input_data: Input data for the block (used to determine provider via discriminator)
|
||||
|
||||
Returns:
|
||||
tuple of (matched_credentials, missing_credentials) - matched credentials
|
||||
are used for block execution, missing ones indicate setup requirements.
|
||||
"""
|
||||
input_data = input_data or {}
|
||||
requirements = self._resolve_discriminated_credentials(block, input_data)
|
||||
|
||||
if not requirements:
|
||||
return {}, []
|
||||
|
||||
return await match_credentials_to_requirements(user_id, requirements)
|
||||
|
||||
def _get_inputs_list(self, block: AnyBlockSchema) -> list[dict[str, Any]]:
|
||||
"""Extract non-credential inputs from block schema."""
|
||||
schema = block.input_schema.jsonschema()
|
||||
credentials_fields = set(block.input_schema.get_credentials_fields().keys())
|
||||
return get_inputs_from_schema(schema, exclude_fields=credentials_fields)
|
||||
|
||||
def _resolve_discriminated_credentials(
|
||||
self,
|
||||
block: AnyBlockSchema,
|
||||
input_data: dict[str, Any],
|
||||
) -> dict[str, CredentialsFieldInfo]:
|
||||
"""Resolve credential requirements, applying discriminator logic where needed."""
|
||||
credentials_fields_info = block.input_schema.get_credentials_fields_info()
|
||||
if not credentials_fields_info:
|
||||
return {}
|
||||
|
||||
resolved: dict[str, CredentialsFieldInfo] = {}
|
||||
|
||||
for field_name, field_info in credentials_fields_info.items():
|
||||
effective_field_info = field_info
|
||||
|
||||
if field_info.discriminator and field_info.discriminator_mapping:
|
||||
discriminator_value = input_data.get(field_info.discriminator)
|
||||
if discriminator_value is None:
|
||||
field = block.input_schema.model_fields.get(
|
||||
field_info.discriminator
|
||||
)
|
||||
if field and field.default is not PydanticUndefined:
|
||||
discriminator_value = field.default
|
||||
|
||||
if (
|
||||
discriminator_value
|
||||
and discriminator_value in field_info.discriminator_mapping
|
||||
):
|
||||
effective_field_info = field_info.discriminate(discriminator_value)
|
||||
# For host-scoped credentials, add the discriminator value
|
||||
# (e.g., URL) so _credential_is_for_host can match it
|
||||
effective_field_info.discriminator_values.add(discriminator_value)
|
||||
logger.debug(
|
||||
f"Discriminated provider for {field_name}: "
|
||||
f"{discriminator_value} -> {effective_field_info.provider}"
|
||||
)
|
||||
|
||||
resolved[field_name] = effective_field_info
|
||||
|
||||
return resolved
|
||||
|
||||
@@ -12,7 +12,6 @@ from .models import (
|
||||
BlockOutputResponse,
|
||||
ErrorResponse,
|
||||
InputValidationErrorResponse,
|
||||
ReviewRequiredResponse,
|
||||
)
|
||||
from .run_block import RunBlockTool
|
||||
|
||||
@@ -28,16 +27,9 @@ def make_mock_block(
|
||||
mock.name = name
|
||||
mock.block_type = block_type
|
||||
mock.disabled = disabled
|
||||
mock.is_sensitive_action = False
|
||||
mock.input_schema = MagicMock()
|
||||
mock.input_schema.jsonschema.return_value = {"properties": {}, "required": []}
|
||||
mock.input_schema.get_credentials_fields_info.return_value = {}
|
||||
mock.input_schema.get_credentials_fields.return_value = {}
|
||||
|
||||
async def _no_review(input_data, **kwargs):
|
||||
return False, input_data
|
||||
|
||||
mock.is_block_exec_need_review = _no_review
|
||||
mock.input_schema.get_credentials_fields_info.return_value = []
|
||||
return mock
|
||||
|
||||
|
||||
@@ -54,7 +46,6 @@ def make_mock_block_with_schema(
|
||||
mock.name = name
|
||||
mock.block_type = BlockType.STANDARD
|
||||
mock.disabled = False
|
||||
mock.is_sensitive_action = False
|
||||
mock.description = f"Test block: {name}"
|
||||
|
||||
input_schema = {
|
||||
@@ -72,12 +63,6 @@ def make_mock_block_with_schema(
|
||||
mock.output_schema = MagicMock()
|
||||
mock.output_schema.jsonschema.return_value = output_schema
|
||||
|
||||
# Default: no review needed, pass through input_data unchanged
|
||||
async def _no_review(input_data, **kwargs):
|
||||
return False, input_data
|
||||
|
||||
mock.is_block_exec_need_review = _no_review
|
||||
|
||||
return mock
|
||||
|
||||
|
||||
@@ -141,15 +126,9 @@ class TestRunBlockFiltering:
|
||||
"standard-id", "HTTP Request", BlockType.STANDARD
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.tools.run_block.get_block",
|
||||
return_value=standard_block,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.helpers.match_credentials_to_requirements",
|
||||
return_value=({}, []),
|
||||
),
|
||||
with patch(
|
||||
"backend.copilot.tools.run_block.get_block",
|
||||
return_value=standard_block,
|
||||
):
|
||||
tool = RunBlockTool()
|
||||
response = await tool._execute(
|
||||
@@ -175,7 +154,12 @@ class TestRunBlockInputValidation:
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_unknown_input_fields_are_rejected(self):
|
||||
"""run_block rejects unknown input fields instead of silently ignoring them."""
|
||||
"""run_block rejects unknown input fields instead of silently ignoring them.
|
||||
|
||||
Scenario: The AI Text Generator block has a field called 'model' (for LLM model
|
||||
selection), but the LLM calling the tool guesses wrong and sends 'LLM_Model'
|
||||
instead. The block should reject the request and return the valid schema.
|
||||
"""
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
|
||||
mock_block = make_mock_block_with_schema(
|
||||
@@ -198,31 +182,27 @@ class TestRunBlockInputValidation:
|
||||
output_properties={"response": {"type": "string"}},
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.tools.run_block.get_block",
|
||||
return_value=mock_block,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.helpers.match_credentials_to_requirements",
|
||||
return_value=({}, []),
|
||||
),
|
||||
with patch(
|
||||
"backend.copilot.tools.run_block.get_block",
|
||||
return_value=mock_block,
|
||||
):
|
||||
tool = RunBlockTool()
|
||||
|
||||
# Provide 'prompt' (correct) but 'LLM_Model' instead of 'model' (wrong key)
|
||||
response = await tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
block_id="ai-text-gen-id",
|
||||
input_data={
|
||||
"prompt": "Write a haiku about coding",
|
||||
"LLM_Model": "claude-opus-4-6",
|
||||
"LLM_Model": "claude-opus-4-6", # WRONG KEY - should be 'model'
|
||||
},
|
||||
)
|
||||
|
||||
assert isinstance(response, InputValidationErrorResponse)
|
||||
assert "LLM_Model" in response.unrecognized_fields
|
||||
assert "Block was not executed" in response.message
|
||||
assert "inputs" in response.model_dump()
|
||||
assert "inputs" in response.model_dump() # valid schema included
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_multiple_wrong_keys_are_all_reported(self):
|
||||
@@ -241,26 +221,21 @@ class TestRunBlockInputValidation:
|
||||
required_fields=["prompt"],
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.tools.run_block.get_block",
|
||||
return_value=mock_block,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.helpers.match_credentials_to_requirements",
|
||||
return_value=({}, []),
|
||||
),
|
||||
with patch(
|
||||
"backend.copilot.tools.run_block.get_block",
|
||||
return_value=mock_block,
|
||||
):
|
||||
tool = RunBlockTool()
|
||||
|
||||
response = await tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
block_id="ai-text-gen-id",
|
||||
input_data={
|
||||
"prompt": "Hello",
|
||||
"llm_model": "claude-opus-4-6",
|
||||
"system_prompt": "Be helpful",
|
||||
"retries": 5,
|
||||
"prompt": "Hello", # correct
|
||||
"llm_model": "claude-opus-4-6", # WRONG - should be 'model'
|
||||
"system_prompt": "Be helpful", # WRONG - should be 'sys_prompt'
|
||||
"retries": 5, # WRONG - should be 'retry'
|
||||
},
|
||||
)
|
||||
|
||||
@@ -287,26 +262,23 @@ class TestRunBlockInputValidation:
|
||||
required_fields=["prompt"],
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.tools.run_block.get_block",
|
||||
return_value=mock_block,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.helpers.match_credentials_to_requirements",
|
||||
return_value=({}, []),
|
||||
),
|
||||
with patch(
|
||||
"backend.copilot.tools.run_block.get_block",
|
||||
return_value=mock_block,
|
||||
):
|
||||
tool = RunBlockTool()
|
||||
|
||||
# 'prompt' is missing AND 'LLM_Model' is an unknown field
|
||||
response = await tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
block_id="ai-text-gen-id",
|
||||
input_data={
|
||||
"LLM_Model": "claude-opus-4-6",
|
||||
"LLM_Model": "claude-opus-4-6", # wrong key, and 'prompt' is missing
|
||||
},
|
||||
)
|
||||
|
||||
# Unknown fields are caught first
|
||||
assert isinstance(response, InputValidationErrorResponse)
|
||||
assert "LLM_Model" in response.unrecognized_fields
|
||||
|
||||
@@ -341,11 +313,7 @@ class TestRunBlockInputValidation:
|
||||
return_value=mock_block,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.helpers.match_credentials_to_requirements",
|
||||
return_value=({}, []),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.helpers.workspace_db",
|
||||
"backend.copilot.tools.run_block.workspace_db",
|
||||
return_value=mock_workspace_db,
|
||||
),
|
||||
):
|
||||
@@ -357,7 +325,7 @@ class TestRunBlockInputValidation:
|
||||
block_id="ai-text-gen-id",
|
||||
input_data={
|
||||
"prompt": "Write a haiku",
|
||||
"model": "gpt-4o-mini",
|
||||
"model": "gpt-4o-mini", # correct field name
|
||||
},
|
||||
)
|
||||
|
||||
@@ -379,191 +347,20 @@ class TestRunBlockInputValidation:
|
||||
required_fields=["prompt"],
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.tools.run_block.get_block",
|
||||
return_value=mock_block,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.helpers.match_credentials_to_requirements",
|
||||
return_value=({}, []),
|
||||
),
|
||||
with patch(
|
||||
"backend.copilot.tools.run_block.get_block",
|
||||
return_value=mock_block,
|
||||
):
|
||||
tool = RunBlockTool()
|
||||
|
||||
# Only provide valid optional field, missing required 'prompt'
|
||||
response = await tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
block_id="ai-text-gen-id",
|
||||
input_data={
|
||||
"model": "gpt-4o-mini",
|
||||
"model": "gpt-4o-mini", # valid but optional
|
||||
},
|
||||
)
|
||||
|
||||
assert isinstance(response, BlockDetailsResponse)
|
||||
|
||||
|
||||
class TestRunBlockSensitiveAction:
|
||||
"""Tests for sensitive action HITL review in RunBlockTool.
|
||||
|
||||
run_block calls is_block_exec_need_review() explicitly before execution.
|
||||
When review is needed (should_pause=True), ReviewRequiredResponse is returned.
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_sensitive_block_paused_returns_review_required(self):
|
||||
"""When is_block_exec_need_review returns should_pause=True, ReviewRequiredResponse is returned."""
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
|
||||
input_data = {
|
||||
"repo_url": "https://github.com/test/repo",
|
||||
"branch": "feature-branch",
|
||||
}
|
||||
mock_block = make_mock_block_with_schema(
|
||||
block_id="delete-branch-id",
|
||||
name="Delete Branch",
|
||||
input_properties={
|
||||
"repo_url": {"type": "string"},
|
||||
"branch": {"type": "string"},
|
||||
},
|
||||
required_fields=["repo_url", "branch"],
|
||||
)
|
||||
mock_block.is_sensitive_action = True
|
||||
mock_block.is_block_exec_need_review = AsyncMock(
|
||||
return_value=(True, input_data)
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.tools.run_block.get_block",
|
||||
return_value=mock_block,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.helpers.match_credentials_to_requirements",
|
||||
return_value=({}, []),
|
||||
),
|
||||
):
|
||||
tool = RunBlockTool()
|
||||
response = await tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
block_id="delete-branch-id",
|
||||
input_data=input_data,
|
||||
)
|
||||
|
||||
assert isinstance(response, ReviewRequiredResponse)
|
||||
assert "requires human review" in response.message
|
||||
assert "continue_run_block" in response.message
|
||||
assert response.block_name == "Delete Branch"
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_sensitive_block_executes_after_approval(self):
|
||||
"""After approval (should_pause=False), sensitive blocks execute and return outputs."""
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
|
||||
input_data = {
|
||||
"repo_url": "https://github.com/test/repo",
|
||||
"branch": "feature-branch",
|
||||
}
|
||||
mock_block = make_mock_block_with_schema(
|
||||
block_id="delete-branch-id",
|
||||
name="Delete Branch",
|
||||
input_properties={
|
||||
"repo_url": {"type": "string"},
|
||||
"branch": {"type": "string"},
|
||||
},
|
||||
required_fields=["repo_url", "branch"],
|
||||
)
|
||||
mock_block.is_sensitive_action = True
|
||||
mock_block.is_block_exec_need_review = AsyncMock(
|
||||
return_value=(False, input_data)
|
||||
)
|
||||
|
||||
async def mock_execute(input_data, **kwargs):
|
||||
yield "result", "Branch deleted successfully"
|
||||
|
||||
mock_block.execute = mock_execute
|
||||
|
||||
mock_workspace_db = MagicMock()
|
||||
mock_workspace_db.get_or_create_workspace = AsyncMock(
|
||||
return_value=MagicMock(id="test-workspace-id")
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.tools.run_block.get_block",
|
||||
return_value=mock_block,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.helpers.match_credentials_to_requirements",
|
||||
return_value=({}, []),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.helpers.workspace_db",
|
||||
return_value=mock_workspace_db,
|
||||
),
|
||||
):
|
||||
tool = RunBlockTool()
|
||||
response = await tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
block_id="delete-branch-id",
|
||||
input_data=input_data,
|
||||
)
|
||||
|
||||
assert isinstance(response, BlockOutputResponse)
|
||||
assert response.success is True
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_non_sensitive_block_executes_normally(self):
|
||||
"""Non-sensitive blocks skip review and execute directly."""
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
|
||||
input_data = {"url": "https://example.com"}
|
||||
mock_block = make_mock_block_with_schema(
|
||||
block_id="http-request-id",
|
||||
name="HTTP Request",
|
||||
input_properties={
|
||||
"url": {"type": "string"},
|
||||
},
|
||||
required_fields=["url"],
|
||||
)
|
||||
mock_block.is_sensitive_action = False
|
||||
mock_block.is_block_exec_need_review = AsyncMock(
|
||||
return_value=(False, input_data)
|
||||
)
|
||||
|
||||
async def mock_execute(input_data, **kwargs):
|
||||
yield "response", {"status": 200}
|
||||
|
||||
mock_block.execute = mock_execute
|
||||
|
||||
mock_workspace_db = MagicMock()
|
||||
mock_workspace_db.get_or_create_workspace = AsyncMock(
|
||||
return_value=MagicMock(id="test-workspace-id")
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.tools.run_block.get_block",
|
||||
return_value=mock_block,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.helpers.match_credentials_to_requirements",
|
||||
return_value=({}, []),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.helpers.workspace_db",
|
||||
return_value=mock_workspace_db,
|
||||
),
|
||||
):
|
||||
tool = RunBlockTool()
|
||||
response = await tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
block_id="http-request-id",
|
||||
input_data=input_data,
|
||||
)
|
||||
|
||||
assert isinstance(response, BlockOutputResponse)
|
||||
assert response.success is True
|
||||
|
||||
@@ -34,11 +34,6 @@ logger = logging.getLogger(__name__)
|
||||
_AUTH_STATUS_CODES = {401, 403}
|
||||
|
||||
|
||||
def _service_name(host: str) -> str:
|
||||
"""Strip the 'mcp.' prefix from an MCP hostname: 'mcp.sentry.dev' → 'sentry.dev'"""
|
||||
return host[4:] if host.startswith("mcp.") else host
|
||||
|
||||
|
||||
class RunMCPToolTool(BaseTool):
|
||||
"""
|
||||
Tool for discovering and executing tools on any MCP server.
|
||||
@@ -308,8 +303,8 @@ class RunMCPToolTool(BaseTool):
|
||||
)
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
f"Unable to connect to {_service_name(server_host(server_url))} "
|
||||
"— no credentials configured."
|
||||
f"The MCP server at {server_host(server_url)} requires authentication, "
|
||||
"but no credential configuration was found."
|
||||
),
|
||||
session_id=session_id,
|
||||
)
|
||||
@@ -317,13 +312,15 @@ class RunMCPToolTool(BaseTool):
|
||||
missing_creds_list = list(missing_creds_dict.values())
|
||||
|
||||
host = server_host(server_url)
|
||||
service = _service_name(host)
|
||||
return SetupRequirementsResponse(
|
||||
message=(f"To continue, sign in to {service} and approve access."),
|
||||
message=(
|
||||
f"The MCP server at {host} requires authentication. "
|
||||
"Please connect your credentials to continue."
|
||||
),
|
||||
session_id=session_id,
|
||||
setup_info=SetupInfo(
|
||||
agent_id=server_url,
|
||||
agent_name=service,
|
||||
agent_name=f"MCP: {host}",
|
||||
user_readiness=UserReadiness(
|
||||
has_all_credentials=False,
|
||||
missing_credentials=missing_creds_dict,
|
||||
|
||||
@@ -65,8 +65,9 @@ async def test_run_block_returns_details_when_no_input_provided():
|
||||
return_value=http_block,
|
||||
):
|
||||
# Mock credentials check to return no missing credentials
|
||||
with patch(
|
||||
"backend.copilot.tools.run_block.resolve_block_credentials",
|
||||
with patch.object(
|
||||
RunBlockTool,
|
||||
"_resolve_block_credentials",
|
||||
new_callable=AsyncMock,
|
||||
return_value=({}, []), # (matched_credentials, missing_credentials)
|
||||
):
|
||||
@@ -122,8 +123,9 @@ async def test_run_block_returns_details_when_only_credentials_provided():
|
||||
"backend.copilot.tools.run_block.get_block",
|
||||
return_value=mock,
|
||||
):
|
||||
with patch(
|
||||
"backend.copilot.tools.run_block.resolve_block_credentials",
|
||||
with patch.object(
|
||||
RunBlockTool,
|
||||
"_resolve_block_credentials",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(
|
||||
{
|
||||
|
||||
@@ -756,4 +756,4 @@ async def test_build_setup_requirements_returns_setup_response():
|
||||
)
|
||||
assert isinstance(result, SetupRequirementsResponse)
|
||||
assert result.setup_info.agent_id == _SERVER_URL
|
||||
assert "sign in" in result.message.lower()
|
||||
assert "authentication" in result.message.lower()
|
||||
|
||||
@@ -100,31 +100,19 @@ MODEL_COST: dict[LlmModel, int] = {
|
||||
LlmModel.OLLAMA_DOLPHIN: 1,
|
||||
LlmModel.OPENAI_GPT_OSS_120B: 1,
|
||||
LlmModel.OPENAI_GPT_OSS_20B: 1,
|
||||
LlmModel.GEMINI_2_5_PRO_PREVIEW: 4,
|
||||
LlmModel.GEMINI_2_5_PRO: 4,
|
||||
LlmModel.GEMINI_3_1_PRO_PREVIEW: 5,
|
||||
LlmModel.GEMINI_3_FLASH_PREVIEW: 2,
|
||||
LlmModel.GEMINI_3_PRO_PREVIEW: 5,
|
||||
LlmModel.GEMINI_2_5_FLASH: 1,
|
||||
LlmModel.GEMINI_2_0_FLASH: 1,
|
||||
LlmModel.GEMINI_3_1_FLASH_LITE_PREVIEW: 1,
|
||||
LlmModel.GEMINI_2_5_FLASH_LITE_PREVIEW: 1,
|
||||
LlmModel.GEMINI_2_0_FLASH_LITE: 1,
|
||||
LlmModel.MISTRAL_NEMO: 1,
|
||||
LlmModel.MISTRAL_LARGE_3: 2,
|
||||
LlmModel.MISTRAL_MEDIUM_3_1: 2,
|
||||
LlmModel.MISTRAL_SMALL_3_2: 1,
|
||||
LlmModel.CODESTRAL: 1,
|
||||
LlmModel.COHERE_COMMAND_R_08_2024: 1,
|
||||
LlmModel.COHERE_COMMAND_R_PLUS_08_2024: 3,
|
||||
LlmModel.COHERE_COMMAND_A_03_2025: 3,
|
||||
LlmModel.COHERE_COMMAND_A_TRANSLATE_08_2025: 3,
|
||||
LlmModel.COHERE_COMMAND_A_REASONING_08_2025: 6,
|
||||
LlmModel.COHERE_COMMAND_A_VISION_07_2025: 3,
|
||||
LlmModel.DEEPSEEK_CHAT: 2,
|
||||
LlmModel.DEEPSEEK_R1_0528: 1,
|
||||
LlmModel.PERPLEXITY_SONAR: 1,
|
||||
LlmModel.PERPLEXITY_SONAR_PRO: 5,
|
||||
LlmModel.PERPLEXITY_SONAR_REASONING_PRO: 5,
|
||||
LlmModel.PERPLEXITY_SONAR_DEEP_RESEARCH: 10,
|
||||
LlmModel.NOUSRESEARCH_HERMES_3_LLAMA_3_1_405B: 1,
|
||||
LlmModel.NOUSRESEARCH_HERMES_3_LLAMA_3_1_70B: 1,
|
||||
@@ -132,7 +120,6 @@ MODEL_COST: dict[LlmModel, int] = {
|
||||
LlmModel.AMAZON_NOVA_MICRO_V1: 1,
|
||||
LlmModel.AMAZON_NOVA_PRO_V1: 1,
|
||||
LlmModel.MICROSOFT_WIZARDLM_2_8X22B: 1,
|
||||
LlmModel.MICROSOFT_PHI_4: 1,
|
||||
LlmModel.GRYPHE_MYTHOMAX_L2_13B: 1,
|
||||
LlmModel.META_LLAMA_4_SCOUT: 1,
|
||||
LlmModel.META_LLAMA_4_MAVERICK: 1,
|
||||
@@ -140,7 +127,6 @@ MODEL_COST: dict[LlmModel, int] = {
|
||||
LlmModel.LLAMA_API_LLAMA4_MAVERICK: 1,
|
||||
LlmModel.LLAMA_API_LLAMA3_3_8B: 1,
|
||||
LlmModel.LLAMA_API_LLAMA3_3_70B: 1,
|
||||
LlmModel.GROK_3: 3,
|
||||
LlmModel.GROK_4: 9,
|
||||
LlmModel.GROK_4_FAST: 1,
|
||||
LlmModel.GROK_4_1_FAST: 1,
|
||||
|
||||
@@ -116,16 +116,3 @@ def workspace_db():
|
||||
workspace_db = get_database_manager_async_client()
|
||||
|
||||
return workspace_db
|
||||
|
||||
|
||||
def review_db():
|
||||
if db.is_connected():
|
||||
from backend.data import human_review as _review_db
|
||||
|
||||
review_db = _review_db
|
||||
else:
|
||||
from backend.util.clients import get_database_manager_async_client
|
||||
|
||||
review_db = get_database_manager_async_client()
|
||||
|
||||
return review_db
|
||||
|
||||
@@ -79,10 +79,7 @@ from backend.data.graph import (
|
||||
from backend.data.human_review import (
|
||||
cancel_pending_reviews_for_execution,
|
||||
check_approval,
|
||||
delete_review_by_node_exec_id,
|
||||
get_or_create_human_review,
|
||||
get_pending_reviews_for_execution,
|
||||
get_reviews_by_node_exec_ids,
|
||||
has_pending_reviews_for_graph_exec,
|
||||
update_review_processed_status,
|
||||
)
|
||||
@@ -249,10 +246,7 @@ class DatabaseManager(AppService):
|
||||
# ============ Human In The Loop ============ #
|
||||
cancel_pending_reviews_for_execution = _(cancel_pending_reviews_for_execution)
|
||||
check_approval = _(check_approval)
|
||||
delete_review_by_node_exec_id = _(delete_review_by_node_exec_id)
|
||||
get_or_create_human_review = _(get_or_create_human_review)
|
||||
get_pending_reviews_for_execution = _(get_pending_reviews_for_execution)
|
||||
get_reviews_by_node_exec_ids = _(get_reviews_by_node_exec_ids)
|
||||
has_pending_reviews_for_graph_exec = _(has_pending_reviews_for_graph_exec)
|
||||
update_review_processed_status = _(update_review_processed_status)
|
||||
|
||||
@@ -439,10 +433,7 @@ class DatabaseManagerAsyncClient(AppServiceClient):
|
||||
# ============ Human In The Loop ============ #
|
||||
cancel_pending_reviews_for_execution = d.cancel_pending_reviews_for_execution
|
||||
check_approval = d.check_approval
|
||||
delete_review_by_node_exec_id = d.delete_review_by_node_exec_id
|
||||
get_or_create_human_review = d.get_or_create_human_review
|
||||
get_pending_reviews_for_execution = d.get_pending_reviews_for_execution
|
||||
get_reviews_by_node_exec_ids = d.get_reviews_by_node_exec_ids
|
||||
update_review_processed_status = d.update_review_processed_status
|
||||
|
||||
# ============ User Comms ============ #
|
||||
|
||||
@@ -17,10 +17,6 @@ from backend.api.features.executions.review.model import (
|
||||
PendingHumanReviewModel,
|
||||
SafeJsonData,
|
||||
)
|
||||
from backend.copilot.constants import (
|
||||
is_copilot_synthetic_id,
|
||||
parse_node_id_from_exec_id,
|
||||
)
|
||||
from backend.data.execution import get_graph_execution_meta
|
||||
from backend.util.json import SafeJson
|
||||
|
||||
@@ -127,13 +123,11 @@ async def create_auto_approval_record(
|
||||
Raises:
|
||||
ValueError: If the graph execution doesn't belong to the user
|
||||
"""
|
||||
# Validate ownership: if a graph execution record exists, it must belong
|
||||
# to this user. Non-graph executions (e.g. CoPilot) won't have a record.
|
||||
if not is_copilot_synthetic_id(
|
||||
graph_exec_id
|
||||
) and not await get_graph_execution_meta(
|
||||
# Validate that the graph execution belongs to this user (defense in depth)
|
||||
graph_exec = await get_graph_execution_meta(
|
||||
user_id=user_id, execution_id=graph_exec_id
|
||||
):
|
||||
)
|
||||
if not graph_exec:
|
||||
raise ValueError(
|
||||
f"Graph execution {graph_exec_id} not found or doesn't belong to user {user_id}"
|
||||
)
|
||||
@@ -271,7 +265,7 @@ async def get_pending_review_by_node_exec_id(
|
||||
|
||||
async def get_reviews_by_node_exec_ids(
|
||||
node_exec_ids: list[str], user_id: str
|
||||
) -> dict[str, PendingHumanReviewModel]:
|
||||
) -> dict[str, "PendingHumanReviewModel"]:
|
||||
"""
|
||||
Get multiple reviews by their node execution IDs regardless of status.
|
||||
|
||||
@@ -298,26 +292,21 @@ async def get_reviews_by_node_exec_ids(
|
||||
if not reviews:
|
||||
return {}
|
||||
|
||||
# Split into synthetic (CoPilot) and real IDs for different resolution paths
|
||||
synthetic_ids = {
|
||||
r.nodeExecId for r in reviews if is_copilot_synthetic_id(r.nodeExecId)
|
||||
}
|
||||
real_ids = [r.nodeExecId for r in reviews if r.nodeExecId not in synthetic_ids]
|
||||
# Batch fetch all node executions to avoid N+1 queries
|
||||
node_exec_ids_to_fetch = [review.nodeExecId for review in reviews]
|
||||
node_execs = await AgentNodeExecution.prisma().find_many(
|
||||
where={"id": {"in": node_exec_ids_to_fetch}},
|
||||
include={"Node": True},
|
||||
)
|
||||
|
||||
# Batch fetch real node executions to avoid N+1 queries
|
||||
node_exec_id_to_node_id: dict[str, str] = {}
|
||||
if real_ids:
|
||||
node_execs = await AgentNodeExecution.prisma().find_many(
|
||||
where={"id": {"in": real_ids}},
|
||||
)
|
||||
node_exec_id_to_node_id = {ne.id: ne.agentNodeId for ne in node_execs}
|
||||
# Create mapping from node_exec_id to node_id
|
||||
node_exec_id_to_node_id = {
|
||||
node_exec.id: node_exec.agentNodeId for node_exec in node_execs
|
||||
}
|
||||
|
||||
result = {}
|
||||
for review in reviews:
|
||||
if review.nodeExecId in synthetic_ids:
|
||||
node_id = parse_node_id_from_exec_id(review.nodeExecId)
|
||||
else:
|
||||
node_id = node_exec_id_to_node_id.get(review.nodeExecId, review.nodeExecId)
|
||||
node_id = node_exec_id_to_node_id.get(review.nodeExecId, review.nodeExecId)
|
||||
result[review.nodeExecId] = PendingHumanReviewModel.from_db(
|
||||
review, node_id=node_id
|
||||
)
|
||||
@@ -342,19 +331,6 @@ async def has_pending_reviews_for_graph_exec(graph_exec_id: str) -> bool:
|
||||
return count > 0
|
||||
|
||||
|
||||
async def _resolve_node_id(node_exec_id: str, get_node_execution) -> str:
|
||||
"""Resolve node_id from a node_exec_id.
|
||||
|
||||
For CoPilot synthetic IDs (e.g. copilot-node-block-id:abc12345),
|
||||
extract the node_id portion (copilot-node-block-id).
|
||||
For real graph executions, look up the NodeExecution record.
|
||||
"""
|
||||
if is_copilot_synthetic_id(node_exec_id):
|
||||
return parse_node_id_from_exec_id(node_exec_id)
|
||||
node_exec = await get_node_execution(node_exec_id)
|
||||
return node_exec.node_id if node_exec else node_exec_id
|
||||
|
||||
|
||||
async def get_pending_reviews_for_user(
|
||||
user_id: str, page: int = 1, page_size: int = 25
|
||||
) -> list["PendingHumanReviewModel"]:
|
||||
@@ -385,7 +361,8 @@ async def get_pending_reviews_for_user(
|
||||
# Fetch node_id for each review from NodeExecution
|
||||
result = []
|
||||
for review in reviews:
|
||||
node_id = await _resolve_node_id(review.nodeExecId, get_node_execution)
|
||||
node_exec = await get_node_execution(review.nodeExecId)
|
||||
node_id = node_exec.node_id if node_exec else review.nodeExecId
|
||||
result.append(PendingHumanReviewModel.from_db(review, node_id=node_id))
|
||||
|
||||
return result
|
||||
@@ -393,7 +370,7 @@ async def get_pending_reviews_for_user(
|
||||
|
||||
async def get_pending_reviews_for_execution(
|
||||
graph_exec_id: str, user_id: str
|
||||
) -> list[PendingHumanReviewModel]:
|
||||
) -> list["PendingHumanReviewModel"]:
|
||||
"""
|
||||
Get all pending reviews for a specific graph execution.
|
||||
|
||||
@@ -419,7 +396,8 @@ async def get_pending_reviews_for_execution(
|
||||
# Fetch node_id for each review from NodeExecution
|
||||
result = []
|
||||
for review in reviews:
|
||||
node_id = await _resolve_node_id(review.nodeExecId, get_node_execution)
|
||||
node_exec = await get_node_execution(review.nodeExecId)
|
||||
node_id = node_exec.node_id if node_exec else review.nodeExecId
|
||||
result.append(PendingHumanReviewModel.from_db(review, node_id=node_id))
|
||||
|
||||
return result
|
||||
@@ -531,12 +509,8 @@ async def process_all_reviews_for_execution(
|
||||
|
||||
result = {}
|
||||
for review in all_result_reviews:
|
||||
if is_copilot_synthetic_id(review.nodeExecId):
|
||||
# CoPilot synthetic node_exec_ids encode node_id as "{node_id}:{random}"
|
||||
node_id = parse_node_id_from_exec_id(review.nodeExecId)
|
||||
else:
|
||||
node_exec = await get_node_execution(review.nodeExecId)
|
||||
node_id = node_exec.node_id if node_exec else review.nodeExecId
|
||||
node_exec = await get_node_execution(review.nodeExecId)
|
||||
node_id = node_exec.node_id if node_exec else review.nodeExecId
|
||||
result[review.nodeExecId] = PendingHumanReviewModel.from_db(
|
||||
review, node_id=node_id
|
||||
)
|
||||
@@ -590,21 +564,3 @@ async def cancel_pending_reviews_for_execution(graph_exec_id: str, user_id: str)
|
||||
},
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
async def delete_review_by_node_exec_id(node_exec_id: str, user_id: str) -> int:
|
||||
"""Delete a review record by node execution ID after it has been consumed.
|
||||
|
||||
Used by CoPilot's continue_run_block to clean up one-time-use review records
|
||||
after successful execution.
|
||||
|
||||
Args:
|
||||
node_exec_id: The node execution ID of the review to delete
|
||||
user_id: User ID for authorization
|
||||
|
||||
Returns:
|
||||
Number of records deleted
|
||||
"""
|
||||
return await PendingHumanReview.prisma().delete_many(
|
||||
where={"nodeExecId": node_exec_id, "userId": user_id}
|
||||
)
|
||||
|
||||
@@ -0,0 +1,31 @@
|
||||
"""LLM Registry - Dynamic model management system."""
|
||||
|
||||
from .model import ModelMetadata
|
||||
from .registry import (
|
||||
RegistryModel,
|
||||
RegistryModelCost,
|
||||
RegistryModelCreator,
|
||||
get_all_model_slugs_for_validation,
|
||||
get_all_models,
|
||||
get_default_model_slug,
|
||||
get_enabled_models,
|
||||
get_model,
|
||||
get_schema_options,
|
||||
refresh_llm_registry,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Models
|
||||
"ModelMetadata",
|
||||
"RegistryModel",
|
||||
"RegistryModelCost",
|
||||
"RegistryModelCreator",
|
||||
# Functions
|
||||
"refresh_llm_registry",
|
||||
"get_model",
|
||||
"get_all_models",
|
||||
"get_enabled_models",
|
||||
"get_schema_options",
|
||||
"get_default_model_slug",
|
||||
"get_all_model_slugs_for_validation",
|
||||
]
|
||||
@@ -0,0 +1,9 @@
|
||||
"""Type definitions for LLM model metadata.
|
||||
|
||||
Re-exports ModelMetadata from blocks.llm to avoid type collision.
|
||||
In PR #5 (block integration), this will become the canonical location.
|
||||
"""
|
||||
|
||||
from backend.blocks.llm import ModelMetadata
|
||||
|
||||
__all__ = ["ModelMetadata"]
|
||||
240
autogpt_platform/backend/backend/data/llm_registry/registry.py
Normal file
240
autogpt_platform/backend/backend/data/llm_registry/registry.py
Normal file
@@ -0,0 +1,240 @@
|
||||
"""Core LLM registry implementation for managing models dynamically."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
import prisma.models
|
||||
|
||||
from backend.data.llm_registry.model import ModelMetadata
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RegistryModelCost:
|
||||
"""Cost configuration for an LLM model."""
|
||||
|
||||
unit: str # "RUN" or "TOKENS"
|
||||
credit_cost: int
|
||||
credential_provider: str
|
||||
credential_id: str | None
|
||||
credential_type: str | None
|
||||
currency: str | None
|
||||
metadata: dict[str, Any]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RegistryModelCreator:
|
||||
"""Creator information for an LLM model."""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
display_name: str
|
||||
description: str | None
|
||||
website_url: str | None
|
||||
logo_url: str | None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RegistryModel:
|
||||
"""Represents a model in the LLM registry."""
|
||||
|
||||
slug: str
|
||||
display_name: str
|
||||
description: str | None
|
||||
metadata: ModelMetadata
|
||||
capabilities: dict[str, Any]
|
||||
extra_metadata: dict[str, Any]
|
||||
provider_display_name: str
|
||||
is_enabled: bool
|
||||
is_recommended: bool = False
|
||||
costs: tuple[RegistryModelCost, ...] = field(default_factory=tuple)
|
||||
creator: RegistryModelCreator | None = None
|
||||
|
||||
|
||||
# In-memory cache (will be replaced with Redis in PR #6)
|
||||
_dynamic_models: dict[str, RegistryModel] = {}
|
||||
_schema_options: list[dict[str, str]] = []
|
||||
_lock = asyncio.Lock()
|
||||
|
||||
|
||||
async def refresh_llm_registry() -> None:
|
||||
"""
|
||||
Refresh the LLM registry from the database.
|
||||
|
||||
Fetches all models with their costs, providers, and creators,
|
||||
then updates the in-memory cache.
|
||||
"""
|
||||
async with _lock:
|
||||
try:
|
||||
records = await prisma.models.LlmModel.prisma().find_many(
|
||||
include={
|
||||
"Provider": True,
|
||||
"Costs": True,
|
||||
"Creator": True,
|
||||
}
|
||||
)
|
||||
logger.info(f"Fetched {len(records)} LLM models from database")
|
||||
|
||||
# Build model instances
|
||||
new_models: dict[str, RegistryModel] = {}
|
||||
for record in records:
|
||||
# Parse costs
|
||||
costs = tuple(
|
||||
RegistryModelCost(
|
||||
unit=str(cost.unit), # Convert enum to string
|
||||
credit_cost=cost.creditCost,
|
||||
credential_provider=cost.credentialProvider,
|
||||
credential_id=cost.credentialId,
|
||||
credential_type=cost.credentialType,
|
||||
currency=cost.currency,
|
||||
metadata=dict(cost.metadata or {}),
|
||||
)
|
||||
for cost in (record.Costs or [])
|
||||
)
|
||||
|
||||
# Parse creator
|
||||
creator = None
|
||||
if record.Creator:
|
||||
creator = RegistryModelCreator(
|
||||
id=record.Creator.id,
|
||||
name=record.Creator.name,
|
||||
display_name=record.Creator.displayName,
|
||||
description=record.Creator.description,
|
||||
website_url=record.Creator.websiteUrl,
|
||||
logo_url=record.Creator.logoUrl,
|
||||
)
|
||||
|
||||
# Parse capabilities
|
||||
capabilities = dict(record.capabilities or {})
|
||||
|
||||
# Build metadata from record
|
||||
# Warn if Provider relation is missing (indicates data corruption)
|
||||
if not record.Provider:
|
||||
logger.warning(
|
||||
f"LlmModel {record.slug} has no Provider despite NOT NULL FK - "
|
||||
f"falling back to providerId {record.providerId}"
|
||||
)
|
||||
provider_name = (
|
||||
record.Provider.name if record.Provider else record.providerId
|
||||
)
|
||||
provider_display = (
|
||||
record.Provider.displayName
|
||||
if record.Provider
|
||||
else record.providerId
|
||||
)
|
||||
|
||||
# Extract creator name (fallback to "Unknown" if no creator)
|
||||
creator_name = (
|
||||
record.Creator.displayName if record.Creator else "Unknown"
|
||||
)
|
||||
|
||||
# Price tier defaults to 1 if not set
|
||||
price_tier = record.priceTier if record.priceTier in (1, 2, 3) else 1
|
||||
|
||||
metadata = ModelMetadata(
|
||||
provider=provider_name,
|
||||
context_window=record.contextWindow,
|
||||
max_output_tokens=(
|
||||
record.maxOutputTokens
|
||||
if record.maxOutputTokens is not None
|
||||
else record.contextWindow
|
||||
),
|
||||
display_name=record.displayName,
|
||||
provider_name=provider_display,
|
||||
creator_name=creator_name,
|
||||
price_tier=price_tier,
|
||||
)
|
||||
|
||||
# Create model instance
|
||||
model = RegistryModel(
|
||||
slug=record.slug,
|
||||
display_name=record.displayName,
|
||||
description=record.description,
|
||||
metadata=metadata,
|
||||
capabilities=capabilities,
|
||||
extra_metadata=dict(record.metadata or {}),
|
||||
provider_display_name=provider_display,
|
||||
is_enabled=record.isEnabled,
|
||||
is_recommended=record.isRecommended,
|
||||
costs=costs,
|
||||
creator=creator,
|
||||
)
|
||||
new_models[record.slug] = model
|
||||
|
||||
# Atomic swap
|
||||
global _dynamic_models, _schema_options
|
||||
_dynamic_models = new_models
|
||||
_schema_options = _build_schema_options()
|
||||
|
||||
logger.info(
|
||||
f"LLM registry refreshed: {len(_dynamic_models)} models, "
|
||||
f"{len(_schema_options)} schema options"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to refresh LLM registry: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
def _build_schema_options() -> list[dict[str, str]]:
|
||||
"""Build schema options for model selection dropdown. Only includes enabled models."""
|
||||
return [
|
||||
{
|
||||
"label": model.display_name,
|
||||
"value": model.slug,
|
||||
"group": model.metadata.provider,
|
||||
"description": model.description or "",
|
||||
}
|
||||
for model in sorted(
|
||||
_dynamic_models.values(), key=lambda m: m.display_name.lower()
|
||||
)
|
||||
if model.is_enabled
|
||||
]
|
||||
|
||||
|
||||
def get_model(slug: str) -> RegistryModel | None:
|
||||
"""Get a model by slug from the registry."""
|
||||
return _dynamic_models.get(slug)
|
||||
|
||||
|
||||
def get_all_models() -> list[RegistryModel]:
|
||||
"""Get all models from the registry (including disabled)."""
|
||||
return list(_dynamic_models.values())
|
||||
|
||||
|
||||
def get_enabled_models() -> list[RegistryModel]:
|
||||
"""Get only enabled models from the registry."""
|
||||
return [model for model in _dynamic_models.values() if model.is_enabled]
|
||||
|
||||
|
||||
def get_schema_options() -> list[dict[str, str]]:
|
||||
"""Get schema options for model selection dropdown (enabled models only)."""
|
||||
return _schema_options
|
||||
|
||||
|
||||
def get_default_model_slug() -> str | None:
|
||||
"""Get the default model slug (first recommended, or first enabled)."""
|
||||
# Sort once and use next() to short-circuit on first match
|
||||
models = sorted(_dynamic_models.values(), key=lambda m: m.display_name)
|
||||
|
||||
# Prefer recommended models
|
||||
recommended = next(
|
||||
(m.slug for m in models if m.is_recommended and m.is_enabled), None
|
||||
)
|
||||
if recommended:
|
||||
return recommended
|
||||
|
||||
# Fallback to first enabled model
|
||||
return next((m.slug for m in models if m.is_enabled), None)
|
||||
|
||||
|
||||
def get_all_model_slugs_for_validation() -> list[str]:
|
||||
"""
|
||||
Get all model slugs for validation (enables migrate_llm_models to work).
|
||||
Returns slugs for enabled models only.
|
||||
"""
|
||||
return [model.slug for model in _dynamic_models.values() if model.is_enabled]
|
||||
@@ -8,8 +8,6 @@ from backend.api.model import NotificationPayload
|
||||
from backend.data.event_bus import AsyncRedisEventBus
|
||||
from backend.util.settings import Settings
|
||||
|
||||
_settings = Settings()
|
||||
|
||||
|
||||
class NotificationEvent(BaseModel):
|
||||
"""Generic notification event destined for websocket delivery."""
|
||||
@@ -28,7 +26,7 @@ class AsyncRedisNotificationEventBus(AsyncRedisEventBus[NotificationEvent]):
|
||||
|
||||
@property
|
||||
def event_bus_name(self) -> str:
|
||||
return _settings.config.notification_event_bus_name
|
||||
return Settings().config.notification_event_bus_name
|
||||
|
||||
async def publish(self, event: NotificationEvent) -> None:
|
||||
await self.publish_event(event, event.user_id)
|
||||
|
||||
@@ -0,0 +1,5 @@
|
||||
"""LLM registry public API."""
|
||||
|
||||
from .routes import router
|
||||
|
||||
__all__ = ["router"]
|
||||
67
autogpt_platform/backend/backend/server/v2/llm/model.py
Normal file
67
autogpt_platform/backend/backend/server/v2/llm/model.py
Normal file
@@ -0,0 +1,67 @@
|
||||
"""Pydantic models for LLM registry public API."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import pydantic
|
||||
|
||||
|
||||
class LlmModelCost(pydantic.BaseModel):
|
||||
"""Cost configuration for an LLM model."""
|
||||
|
||||
unit: str # "RUN" or "TOKENS"
|
||||
credit_cost: int = pydantic.Field(ge=0)
|
||||
credential_provider: str
|
||||
credential_id: str | None = None
|
||||
credential_type: str | None = None
|
||||
currency: str | None = None
|
||||
metadata: dict[str, Any] = pydantic.Field(default_factory=dict)
|
||||
|
||||
|
||||
class LlmModelCreator(pydantic.BaseModel):
|
||||
"""Represents the organization that created/trained the model."""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
display_name: str
|
||||
description: str | None = None
|
||||
website_url: str | None = None
|
||||
logo_url: str | None = None
|
||||
|
||||
|
||||
class LlmModel(pydantic.BaseModel):
|
||||
"""Public-facing LLM model information."""
|
||||
|
||||
slug: str
|
||||
display_name: str
|
||||
description: str | None = None
|
||||
provider_name: str
|
||||
creator: LlmModelCreator | None = None
|
||||
context_window: int
|
||||
max_output_tokens: int | None = None
|
||||
price_tier: int # 1=cheapest, 2=medium, 3=expensive
|
||||
is_recommended: bool = False
|
||||
capabilities: dict[str, Any] = pydantic.Field(default_factory=dict)
|
||||
costs: list[LlmModelCost] = pydantic.Field(default_factory=list)
|
||||
|
||||
|
||||
class LlmProvider(pydantic.BaseModel):
|
||||
"""Provider with its enabled models."""
|
||||
|
||||
name: str
|
||||
display_name: str
|
||||
models: list[LlmModel] = pydantic.Field(default_factory=list)
|
||||
|
||||
|
||||
class LlmModelsResponse(pydantic.BaseModel):
|
||||
"""Response for GET /llm/models."""
|
||||
|
||||
models: list[LlmModel]
|
||||
total: int
|
||||
|
||||
|
||||
class LlmProvidersResponse(pydantic.BaseModel):
|
||||
"""Response for GET /llm/providers."""
|
||||
|
||||
providers: list[LlmProvider]
|
||||
141
autogpt_platform/backend/backend/server/v2/llm/routes.py
Normal file
141
autogpt_platform/backend/backend/server/v2/llm/routes.py
Normal file
@@ -0,0 +1,141 @@
|
||||
"""Public read-only API for LLM registry."""
|
||||
|
||||
import autogpt_libs.auth
|
||||
import fastapi
|
||||
|
||||
from backend.data.llm_registry import (
|
||||
RegistryModelCreator,
|
||||
get_all_models,
|
||||
get_enabled_models,
|
||||
)
|
||||
from backend.server.v2.llm import model as llm_model
|
||||
|
||||
router = fastapi.APIRouter(
|
||||
prefix="/llm",
|
||||
tags=["llm"],
|
||||
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
|
||||
)
|
||||
|
||||
|
||||
def _map_creator(
|
||||
creator: RegistryModelCreator | None,
|
||||
) -> llm_model.LlmModelCreator | None:
|
||||
"""Convert registry creator to API model."""
|
||||
if not creator:
|
||||
return None
|
||||
return llm_model.LlmModelCreator(
|
||||
id=creator.id,
|
||||
name=creator.name,
|
||||
display_name=creator.display_name,
|
||||
description=creator.description,
|
||||
website_url=creator.website_url,
|
||||
logo_url=creator.logo_url,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/models", response_model=llm_model.LlmModelsResponse)
|
||||
async def list_models(
|
||||
enabled_only: bool = fastapi.Query(
|
||||
default=True, description="Only return enabled models"
|
||||
),
|
||||
):
|
||||
"""
|
||||
List all LLM models available to users.
|
||||
|
||||
Returns models from the in-memory registry cache.
|
||||
Use enabled_only=true to filter to only enabled models (default).
|
||||
"""
|
||||
# Get models from in-memory registry
|
||||
registry_models = get_enabled_models() if enabled_only else get_all_models()
|
||||
|
||||
# Map to API response models
|
||||
models = [
|
||||
llm_model.LlmModel(
|
||||
slug=model.slug,
|
||||
display_name=model.display_name,
|
||||
description=model.description,
|
||||
provider_name=model.provider_display_name,
|
||||
creator=_map_creator(model.creator),
|
||||
context_window=model.metadata.context_window,
|
||||
max_output_tokens=model.metadata.max_output_tokens,
|
||||
price_tier=model.metadata.price_tier,
|
||||
is_recommended=model.is_recommended,
|
||||
capabilities=model.capabilities,
|
||||
costs=[
|
||||
llm_model.LlmModelCost(
|
||||
unit=cost.unit,
|
||||
credit_cost=cost.credit_cost,
|
||||
credential_provider=cost.credential_provider,
|
||||
credential_id=cost.credential_id,
|
||||
credential_type=cost.credential_type,
|
||||
currency=cost.currency,
|
||||
metadata=cost.metadata,
|
||||
)
|
||||
for cost in model.costs
|
||||
],
|
||||
)
|
||||
for model in registry_models
|
||||
]
|
||||
|
||||
return llm_model.LlmModelsResponse(models=models, total=len(models))
|
||||
|
||||
|
||||
@router.get("/providers", response_model=llm_model.LlmProvidersResponse)
|
||||
async def list_providers():
|
||||
"""
|
||||
List all LLM providers with their enabled models.
|
||||
|
||||
Groups enabled models by provider from the in-memory registry.
|
||||
"""
|
||||
# Get all enabled models and group by provider
|
||||
registry_models = get_enabled_models()
|
||||
|
||||
# Group models by provider
|
||||
provider_map: dict[str, list] = {}
|
||||
for model in registry_models:
|
||||
provider_key = model.metadata.provider
|
||||
if provider_key not in provider_map:
|
||||
provider_map[provider_key] = []
|
||||
provider_map[provider_key].append(model)
|
||||
|
||||
# Build provider responses
|
||||
providers = []
|
||||
for provider_key, models in sorted(provider_map.items()):
|
||||
# Use the first model's provider display name
|
||||
display_name = models[0].provider_display_name if models else provider_key
|
||||
|
||||
providers.append(
|
||||
llm_model.LlmProvider(
|
||||
name=provider_key,
|
||||
display_name=display_name,
|
||||
models=[
|
||||
llm_model.LlmModel(
|
||||
slug=model.slug,
|
||||
display_name=model.display_name,
|
||||
description=model.description,
|
||||
provider_name=model.provider_display_name,
|
||||
creator=_map_creator(model.creator),
|
||||
context_window=model.metadata.context_window,
|
||||
max_output_tokens=model.metadata.max_output_tokens,
|
||||
price_tier=model.metadata.price_tier,
|
||||
is_recommended=model.is_recommended,
|
||||
capabilities=model.capabilities,
|
||||
costs=[
|
||||
llm_model.LlmModelCost(
|
||||
unit=cost.unit,
|
||||
credit_cost=cost.credit_cost,
|
||||
credential_provider=cost.credential_provider,
|
||||
credential_id=cost.credential_id,
|
||||
credential_type=cost.credential_type,
|
||||
currency=cost.currency,
|
||||
metadata=cost.metadata,
|
||||
)
|
||||
for cost in model.costs
|
||||
],
|
||||
)
|
||||
for model in sorted(models, key=lambda m: m.display_name)
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
return llm_model.LlmProvidersResponse(providers=providers)
|
||||
@@ -1,22 +0,0 @@
|
||||
-- Migrate Gemini 3 Pro Preview to Gemini 3.1 Pro Preview
|
||||
-- This updates all AgentNode blocks that use the deprecated Gemini 3 Pro Preview model
|
||||
-- Google is shutting down google/gemini-3-pro-preview on March 9, 2026
|
||||
|
||||
-- Update AgentNode constant inputs
|
||||
UPDATE "AgentNode"
|
||||
SET "constantInput" = JSONB_SET(
|
||||
"constantInput"::jsonb,
|
||||
'{model}',
|
||||
'"google/gemini-3.1-pro-preview"'::jsonb
|
||||
)
|
||||
WHERE "constantInput"::jsonb->>'model' = 'google/gemini-3-pro-preview';
|
||||
|
||||
-- Update AgentPreset input overrides (stored in AgentNodeExecutionInputOutput)
|
||||
UPDATE "AgentNodeExecutionInputOutput"
|
||||
SET "data" = JSONB_SET(
|
||||
"data"::jsonb,
|
||||
'{model}',
|
||||
'"google/gemini-3.1-pro-preview"'::jsonb
|
||||
)
|
||||
WHERE "agentPresetId" IS NOT NULL
|
||||
AND "data"::jsonb->>'model' = 'google/gemini-3-pro-preview';
|
||||
@@ -1,7 +0,0 @@
|
||||
-- Remove GraphExecution foreign key from PendingHumanReview
|
||||
-- The graphExecId column remains for querying, but we remove the FK constraint
|
||||
-- to AgentGraphExecution since PendingHumanReview records can now be created
|
||||
-- with synthetic graph_exec_ids (e.g., CoPilot direct block execution uses
|
||||
-- "copilot-session-{session_id}" as graph_exec_id).
|
||||
|
||||
ALTER TABLE "PendingHumanReview" DROP CONSTRAINT IF EXISTS "PendingHumanReview_graphExecId_fkey";
|
||||
@@ -0,0 +1,148 @@
|
||||
-- CreateEnum
|
||||
CREATE TYPE "LlmCostUnit" AS ENUM ('RUN', 'TOKENS');
|
||||
|
||||
-- CreateTable
|
||||
CREATE TABLE "LlmProvider" (
|
||||
"id" TEXT NOT NULL,
|
||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updatedAt" TIMESTAMP(3) NOT NULL,
|
||||
"name" TEXT NOT NULL,
|
||||
"displayName" TEXT NOT NULL,
|
||||
"description" TEXT,
|
||||
"defaultCredentialProvider" TEXT,
|
||||
"defaultCredentialId" TEXT,
|
||||
"defaultCredentialType" TEXT,
|
||||
"metadata" JSONB NOT NULL DEFAULT '{}',
|
||||
|
||||
CONSTRAINT "LlmProvider_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- CreateTable
|
||||
CREATE TABLE "LlmModelCreator" (
|
||||
"id" TEXT NOT NULL,
|
||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updatedAt" TIMESTAMP(3) NOT NULL,
|
||||
"name" TEXT NOT NULL,
|
||||
"displayName" TEXT NOT NULL,
|
||||
"description" TEXT,
|
||||
"websiteUrl" TEXT,
|
||||
"logoUrl" TEXT,
|
||||
"metadata" JSONB NOT NULL DEFAULT '{}',
|
||||
|
||||
CONSTRAINT "LlmModelCreator_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- CreateTable
|
||||
CREATE TABLE "LlmModel" (
|
||||
"id" TEXT NOT NULL,
|
||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updatedAt" TIMESTAMP(3) NOT NULL,
|
||||
"slug" TEXT NOT NULL,
|
||||
"displayName" TEXT NOT NULL,
|
||||
"description" TEXT,
|
||||
"providerId" TEXT NOT NULL,
|
||||
"creatorId" TEXT,
|
||||
"contextWindow" INTEGER NOT NULL,
|
||||
"maxOutputTokens" INTEGER,
|
||||
"priceTier" INTEGER NOT NULL DEFAULT 1,
|
||||
"isEnabled" BOOLEAN NOT NULL DEFAULT true,
|
||||
"isRecommended" BOOLEAN NOT NULL DEFAULT false,
|
||||
"supportsTools" BOOLEAN NOT NULL DEFAULT false,
|
||||
"supportsJsonOutput" BOOLEAN NOT NULL DEFAULT false,
|
||||
"supportsReasoning" BOOLEAN NOT NULL DEFAULT false,
|
||||
"supportsParallelToolCalls" BOOLEAN NOT NULL DEFAULT false,
|
||||
"capabilities" JSONB NOT NULL DEFAULT '{}',
|
||||
"metadata" JSONB NOT NULL DEFAULT '{}',
|
||||
|
||||
CONSTRAINT "LlmModel_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- CreateTable
|
||||
CREATE TABLE "LlmModelCost" (
|
||||
"id" TEXT NOT NULL,
|
||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updatedAt" TIMESTAMP(3) NOT NULL,
|
||||
"unit" "LlmCostUnit" NOT NULL DEFAULT 'RUN',
|
||||
"creditCost" INTEGER NOT NULL,
|
||||
"credentialProvider" TEXT NOT NULL,
|
||||
"credentialId" TEXT,
|
||||
"credentialType" TEXT,
|
||||
"currency" TEXT,
|
||||
"metadata" JSONB NOT NULL DEFAULT '{}',
|
||||
"llmModelId" TEXT NOT NULL,
|
||||
|
||||
CONSTRAINT "LlmModelCost_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- CreateTable
|
||||
CREATE TABLE "LlmModelMigration" (
|
||||
"id" TEXT NOT NULL,
|
||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updatedAt" TIMESTAMP(3) NOT NULL,
|
||||
"sourceModelSlug" TEXT NOT NULL,
|
||||
"targetModelSlug" TEXT NOT NULL,
|
||||
"reason" TEXT,
|
||||
"migratedNodeIds" JSONB NOT NULL DEFAULT '[]',
|
||||
"nodeCount" INTEGER NOT NULL,
|
||||
"customCreditCost" INTEGER,
|
||||
"isReverted" BOOLEAN NOT NULL DEFAULT false,
|
||||
"revertedAt" TIMESTAMP(3),
|
||||
|
||||
CONSTRAINT "LlmModelMigration_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- CreateIndex
|
||||
CREATE UNIQUE INDEX "LlmProvider_name_key" ON "LlmProvider"("name");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE UNIQUE INDEX "LlmModelCreator_name_key" ON "LlmModelCreator"("name");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE UNIQUE INDEX "LlmModel_slug_key" ON "LlmModel"("slug");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "LlmModel_providerId_isEnabled_idx" ON "LlmModel"("providerId", "isEnabled");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "LlmModel_creatorId_idx" ON "LlmModel"("creatorId");
|
||||
|
||||
-- CreateIndex (partial unique for default costs - no specific credential)
|
||||
CREATE UNIQUE INDEX "LlmModelCost_default_cost_key" ON "LlmModelCost"("llmModelId", "credentialProvider", "unit") WHERE "credentialId" IS NULL;
|
||||
|
||||
-- CreateIndex (partial unique for credential-specific costs)
|
||||
CREATE UNIQUE INDEX "LlmModelCost_credential_cost_key" ON "LlmModelCost"("llmModelId", "credentialProvider", "credentialId", "unit") WHERE "credentialId" IS NOT NULL;
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "LlmModelMigration_targetModelSlug_idx" ON "LlmModelMigration"("targetModelSlug");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "LlmModelMigration_sourceModelSlug_isReverted_idx" ON "LlmModelMigration"("sourceModelSlug", "isReverted");
|
||||
|
||||
-- CreateIndex (partial unique to prevent multiple active migrations per source)
|
||||
CREATE UNIQUE INDEX "LlmModelMigration_active_source_key" ON "LlmModelMigration"("sourceModelSlug") WHERE "isReverted" = false;
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "LlmModel" ADD CONSTRAINT "LlmModel_providerId_fkey" FOREIGN KEY ("providerId") REFERENCES "LlmProvider"("id") ON DELETE RESTRICT ON UPDATE CASCADE;
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "LlmModel" ADD CONSTRAINT "LlmModel_creatorId_fkey" FOREIGN KEY ("creatorId") REFERENCES "LlmModelCreator"("id") ON DELETE SET NULL ON UPDATE CASCADE;
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "LlmModelCost" ADD CONSTRAINT "LlmModelCost_llmModelId_fkey" FOREIGN KEY ("llmModelId") REFERENCES "LlmModel"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "LlmModelMigration" ADD CONSTRAINT "LlmModelMigration_sourceModelSlug_fkey" FOREIGN KEY ("sourceModelSlug") REFERENCES "LlmModel"("slug") ON DELETE RESTRICT ON UPDATE CASCADE;
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "LlmModelMigration" ADD CONSTRAINT "LlmModelMigration_targetModelSlug_fkey" FOREIGN KEY ("targetModelSlug") REFERENCES "LlmModel"("slug") ON DELETE RESTRICT ON UPDATE CASCADE;
|
||||
|
||||
-- AddCheckConstraints (enforce data integrity)
|
||||
ALTER TABLE "LlmModel"
|
||||
ADD CONSTRAINT "LlmModel_priceTier_check" CHECK ("priceTier" BETWEEN 1 AND 3);
|
||||
|
||||
ALTER TABLE "LlmModelCost"
|
||||
ADD CONSTRAINT "LlmModelCost_creditCost_check" CHECK ("creditCost" >= 0);
|
||||
|
||||
ALTER TABLE "LlmModelMigration"
|
||||
ADD CONSTRAINT "LlmModelMigration_nodeCount_check" CHECK ("nodeCount" >= 0),
|
||||
ADD CONSTRAINT "LlmModelMigration_customCreditCost_check" CHECK ("customCreditCost" IS NULL OR "customCreditCost" >= 0);
|
||||
@@ -1,40 +0,0 @@
|
||||
-- Fix PerplexityBlock nodes that have invalid model values (e.g. gpt-4o,
|
||||
-- gpt-5.2-2025-12-11) set by the agent generator. Defaults them to the
|
||||
-- standard "perplexity/sonar" model.
|
||||
--
|
||||
-- PerplexityBlock ID: c8a5f2e9-8b3d-4a7e-9f6c-1d5e3c9b7a4f
|
||||
-- Valid models: perplexity/sonar, perplexity/sonar-pro, perplexity/sonar-deep-research
|
||||
|
||||
UPDATE "AgentNode"
|
||||
SET "constantInput" = JSONB_SET(
|
||||
"constantInput"::jsonb,
|
||||
'{model}',
|
||||
'"perplexity/sonar"'::jsonb
|
||||
)
|
||||
WHERE "agentBlockId" = 'c8a5f2e9-8b3d-4a7e-9f6c-1d5e3c9b7a4f'
|
||||
AND "constantInput"::jsonb ? 'model'
|
||||
AND "constantInput"::jsonb->>'model' NOT IN (
|
||||
'perplexity/sonar',
|
||||
'perplexity/sonar-pro',
|
||||
'perplexity/sonar-deep-research'
|
||||
);
|
||||
|
||||
-- Update AgentPreset input overrides (stored in AgentNodeExecutionInputOutput).
|
||||
-- The table links to AgentNode through AgentNodeExecution, not directly.
|
||||
UPDATE "AgentNodeExecutionInputOutput" io
|
||||
SET "data" = JSONB_SET(
|
||||
io."data"::jsonb,
|
||||
'{model}',
|
||||
'"perplexity/sonar"'::jsonb
|
||||
)
|
||||
FROM "AgentNodeExecution" exe
|
||||
JOIN "AgentNode" n ON n."id" = exe."agentNodeId"
|
||||
WHERE io."agentPresetId" IS NOT NULL
|
||||
AND (io."referencedByInputExecId" = exe."id" OR io."referencedByOutputExecId" = exe."id")
|
||||
AND n."agentBlockId" = 'c8a5f2e9-8b3d-4a7e-9f6c-1d5e3c9b7a4f'
|
||||
AND io."data"::jsonb ? 'model'
|
||||
AND io."data"::jsonb->>'model' NOT IN (
|
||||
'perplexity/sonar',
|
||||
'perplexity/sonar-pro',
|
||||
'perplexity/sonar-deep-research'
|
||||
);
|
||||
@@ -566,6 +566,8 @@ model AgentGraphExecution {
|
||||
shareToken String? @unique
|
||||
sharedAt DateTime?
|
||||
|
||||
PendingHumanReviews PendingHumanReview[]
|
||||
|
||||
@@index([agentGraphId, agentGraphVersion])
|
||||
@@index([userId, isDeleted, createdAt])
|
||||
@@index([createdAt])
|
||||
@@ -662,7 +664,8 @@ model PendingHumanReview {
|
||||
updatedAt DateTime? @updatedAt
|
||||
reviewedAt DateTime?
|
||||
|
||||
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
||||
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
||||
GraphExecution AgentGraphExecution @relation(fields: [graphExecId], references: [id], onDelete: Cascade)
|
||||
|
||||
@@unique([nodeExecId]) // One pending review per node execution
|
||||
@@index([userId, status])
|
||||
@@ -1301,3 +1304,164 @@ model OAuthRefreshToken {
|
||||
@@index([userId, applicationId])
|
||||
@@index([expiresAt]) // For cleanup
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// LLM Registry Models
|
||||
// ============================================================================
|
||||
|
||||
enum LlmCostUnit {
|
||||
RUN
|
||||
TOKENS
|
||||
}
|
||||
|
||||
model LlmProvider {
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @updatedAt
|
||||
|
||||
name String @unique
|
||||
displayName String
|
||||
description String?
|
||||
|
||||
defaultCredentialProvider String?
|
||||
defaultCredentialId String?
|
||||
defaultCredentialType String?
|
||||
|
||||
metadata Json @default("{}")
|
||||
|
||||
Models LlmModel[]
|
||||
|
||||
}
|
||||
|
||||
model LlmModel {
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @updatedAt
|
||||
|
||||
slug String @unique
|
||||
displayName String
|
||||
description String?
|
||||
|
||||
providerId String
|
||||
Provider LlmProvider @relation(fields: [providerId], references: [id], onDelete: Restrict)
|
||||
|
||||
// Creator is the organization that created/trained the model (e.g., OpenAI, Meta)
|
||||
// This is distinct from the provider who hosts/serves the model (e.g., OpenRouter)
|
||||
creatorId String?
|
||||
Creator LlmModelCreator? @relation(fields: [creatorId], references: [id], onDelete: SetNull)
|
||||
|
||||
contextWindow Int
|
||||
maxOutputTokens Int?
|
||||
priceTier Int @default(1) // 1=cheapest, 2=medium, 3=expensive (DB constraint: 1-3)
|
||||
isEnabled Boolean @default(true)
|
||||
isRecommended Boolean @default(false)
|
||||
|
||||
// Model-specific capabilities
|
||||
// These vary per model even within the same provider (e.g., Hugging Face)
|
||||
// Default to false for safety - partially-seeded rows should not be assumed capable
|
||||
supportsTools Boolean @default(false)
|
||||
supportsJsonOutput Boolean @default(false)
|
||||
supportsReasoning Boolean @default(false)
|
||||
supportsParallelToolCalls Boolean @default(false)
|
||||
|
||||
capabilities Json @default("{}")
|
||||
metadata Json @default("{}")
|
||||
|
||||
Costs LlmModelCost[]
|
||||
SourceMigrations LlmModelMigration[] @relation("SourceMigrations")
|
||||
TargetMigrations LlmModelMigration[] @relation("TargetMigrations")
|
||||
|
||||
@@index([providerId, isEnabled])
|
||||
@@index([creatorId])
|
||||
// Note: slug already has @unique which creates an implicit index
|
||||
}
|
||||
|
||||
model LlmModelCost {
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @updatedAt
|
||||
unit LlmCostUnit @default(RUN)
|
||||
|
||||
creditCost Int // DB constraint: >= 0
|
||||
|
||||
// Provider identifier (e.g., "openai", "anthropic", "openrouter")
|
||||
// Used to determine which credential system provides the API key.
|
||||
// Allows different pricing for:
|
||||
// - Default provider costs (WHERE credentialId IS NULL)
|
||||
// - User's own API key costs (WHERE credentialId IS NOT NULL)
|
||||
credentialProvider String
|
||||
credentialId String?
|
||||
credentialType String?
|
||||
currency String?
|
||||
|
||||
metadata Json @default("{}")
|
||||
|
||||
llmModelId String
|
||||
Model LlmModel @relation(fields: [llmModelId], references: [id], onDelete: Cascade)
|
||||
|
||||
// Note: Unique constraints are implemented as partial indexes in migration SQL:
|
||||
// - One for default costs (WHERE credentialId IS NULL)
|
||||
// - One for credential-specific costs (WHERE credentialId IS NOT NULL)
|
||||
// This allows both provider-level defaults and credential-specific overrides
|
||||
}
|
||||
|
||||
model LlmModelCreator {
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @updatedAt
|
||||
|
||||
name String @unique // e.g., "openai", "anthropic", "meta"
|
||||
displayName String // e.g., "OpenAI", "Anthropic", "Meta"
|
||||
description String?
|
||||
websiteUrl String? // Link to creator's website
|
||||
logoUrl String? // URL to creator's logo
|
||||
|
||||
metadata Json @default("{}")
|
||||
|
||||
Models LlmModel[]
|
||||
|
||||
}
|
||||
|
||||
model LlmModelMigration {
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @updatedAt
|
||||
|
||||
sourceModelSlug String // The original model that was disabled
|
||||
targetModelSlug String // The model workflows were migrated to
|
||||
reason String? // Why the migration happened (e.g., "Provider outage")
|
||||
|
||||
// FK constraints ensure slugs reference valid models
|
||||
SourceModel LlmModel @relation("SourceMigrations", fields: [sourceModelSlug], references: [slug], onDelete: Restrict)
|
||||
TargetModel LlmModel @relation("TargetMigrations", fields: [targetModelSlug], references: [slug], onDelete: Restrict)
|
||||
|
||||
// Track affected nodes as JSON array of node IDs
|
||||
// Format: ["node-uuid-1", "node-uuid-2", ...]
|
||||
migratedNodeIds Json @default("[]")
|
||||
nodeCount Int // Number of nodes migrated (DB constraint: >= 0)
|
||||
|
||||
// Custom pricing override for migrated workflows during the migration period.
|
||||
// Use case: When migrating users from an expensive model (e.g., GPT-4) to a cheaper
|
||||
// one (e.g., GPT-3.5), you may want to temporarily maintain the original pricing
|
||||
// to avoid billing surprises, or offer a discount during the transition.
|
||||
//
|
||||
// IMPORTANT: This field is intended for integration with the billing system.
|
||||
// When billing calculates costs for nodes affected by this migration, it should
|
||||
// check if customCreditCost is set and use it instead of the target model's cost.
|
||||
// If null, the target model's normal cost applies.
|
||||
//
|
||||
// TODO: Integrate with billing system to apply this override during cost calculation.
|
||||
// LIMITATION: This is a simple Int and doesn't distinguish RUN vs TOKENS pricing.
|
||||
// For token-priced models, this may be ambiguous. Consider migrating to a relation
|
||||
// with LlmModelCost or a dedicated override model in a follow-up PR.
|
||||
customCreditCost Int? // DB constraint: >= 0 when not null
|
||||
|
||||
// Revert tracking
|
||||
isReverted Boolean @default(false)
|
||||
revertedAt DateTime?
|
||||
|
||||
// Note: Partial unique index in migration SQL prevents multiple active migrations per source:
|
||||
// UNIQUE (sourceModelSlug) WHERE isReverted = false
|
||||
@@index([targetModelSlug])
|
||||
@@index([sourceModelSlug, isReverted]) // Composite index for active migration queries
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
"id": "test-agent-1",
|
||||
"graph_id": "test-agent-1",
|
||||
"graph_version": 1,
|
||||
"owner_user_id": "3e53486c-cf57-477e-ba2a-cb02dc828e1a",
|
||||
"image_url": null,
|
||||
"creator_name": "Test Creator",
|
||||
"creator_image_url": "",
|
||||
@@ -50,6 +51,7 @@
|
||||
"id": "test-agent-2",
|
||||
"graph_id": "test-agent-2",
|
||||
"graph_version": 1,
|
||||
"owner_user_id": "3e53486c-cf57-477e-ba2a-cb02dc828e1a",
|
||||
"image_url": null,
|
||||
"creator_name": "Test Creator",
|
||||
"creator_image_url": "",
|
||||
|
||||
@@ -84,27 +84,6 @@ class TestGmailReadBlock:
|
||||
assert "Hello World" in result
|
||||
assert "This is HTML content" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_html_fallback_when_html2text_conversion_fails(self):
|
||||
"""Fallback to raw HTML when html2text converter raises unexpectedly."""
|
||||
html_text = "<html><body><p>Broken <b>HTML</p></body></html>"
|
||||
|
||||
msg = {
|
||||
"id": "test_msg_html_error",
|
||||
"payload": {
|
||||
"mimeType": "text/html",
|
||||
"body": {"data": self._encode_base64(html_text)},
|
||||
},
|
||||
}
|
||||
|
||||
with patch("html2text.HTML2Text") as mock_html2text:
|
||||
mock_converter = Mock()
|
||||
mock_converter.handle.side_effect = ValueError("conversion failed")
|
||||
mock_html2text.return_value = mock_converter
|
||||
|
||||
result = await self.gmail_block._get_email_body(msg, self.mock_service)
|
||||
assert result == html_text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_html_fallback_when_html2text_unavailable(self):
|
||||
"""Test fallback to raw HTML when html2text is not available."""
|
||||
|
||||
@@ -75,7 +75,7 @@ export const getSecondCalculatorNode = () => {
|
||||
export const getFormContainerSelector = (blockId: string): string | null => {
|
||||
const node = getNodeByBlockId(blockId);
|
||||
if (node) {
|
||||
return `[data-id="form-creator-container-${node.id}-node"]`;
|
||||
return `[data-id="form-creator-container-${node.id}"]`;
|
||||
}
|
||||
return null;
|
||||
};
|
||||
|
||||
@@ -7,7 +7,6 @@
|
||||
*
|
||||
* Typography (body, small, action, info, tip, warning) uses Tailwind utilities directly in steps.ts
|
||||
*/
|
||||
import "shepherd.js/dist/css/shepherd.css";
|
||||
import "./tutorial.css";
|
||||
|
||||
export const injectTutorialStyles = () => {
|
||||
|
||||
@@ -1,14 +1,3 @@
|
||||
.new-builder-tutorial-disable {
|
||||
opacity: 0.3 !important;
|
||||
pointer-events: none !important;
|
||||
filter: grayscale(100%) !important;
|
||||
}
|
||||
|
||||
.new-builder-tutorial-highlight {
|
||||
position: relative;
|
||||
z-index: 10;
|
||||
}
|
||||
|
||||
.new-builder-tutorial-highlight * {
|
||||
opacity: 1 !important;
|
||||
filter: none !important;
|
||||
|
||||
@@ -15,8 +15,6 @@ import { ChatSidebar } from "./components/ChatSidebar/ChatSidebar";
|
||||
import { DeleteChatDialog } from "./components/DeleteChatDialog/DeleteChatDialog";
|
||||
import { MobileDrawer } from "./components/MobileDrawer/MobileDrawer";
|
||||
import { MobileHeader } from "./components/MobileHeader/MobileHeader";
|
||||
import { NotificationBanner } from "./components/NotificationBanner/NotificationBanner";
|
||||
import { NotificationDialog } from "./components/NotificationDialog/NotificationDialog";
|
||||
import { ScaleLoader } from "./components/ScaleLoader/ScaleLoader";
|
||||
import { useCopilotPage } from "./useCopilotPage";
|
||||
|
||||
@@ -119,7 +117,6 @@ export function CopilotPage() {
|
||||
onDrop={handleDrop}
|
||||
>
|
||||
{isMobile && <MobileHeader onOpenDrawer={handleOpenDrawer} />}
|
||||
<NotificationBanner />
|
||||
{/* Drop overlay */}
|
||||
<div
|
||||
className={cn(
|
||||
@@ -204,7 +201,6 @@ export function CopilotPage() {
|
||||
onCancel={handleCancelDelete}
|
||||
/>
|
||||
)}
|
||||
<NotificationDialog />
|
||||
</SidebarProvider>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import { useCopilotUIStore } from "@/app/(platform)/copilot/store";
|
||||
import { ChangeEvent, FormEvent, useEffect, useState } from "react";
|
||||
|
||||
interface Args {
|
||||
@@ -17,16 +16,6 @@ export function useChatInput({
|
||||
}: Args) {
|
||||
const [value, setValue] = useState("");
|
||||
const [isSending, setIsSending] = useState(false);
|
||||
const { initialPrompt, setInitialPrompt } = useCopilotUIStore();
|
||||
|
||||
useEffect(
|
||||
function consumeInitialPrompt() {
|
||||
if (!initialPrompt) return;
|
||||
setValue((prev) => (prev.length === 0 ? initialPrompt : prev));
|
||||
setInitialPrompt(null);
|
||||
},
|
||||
[initialPrompt, setInitialPrompt],
|
||||
);
|
||||
|
||||
useEffect(
|
||||
function focusOnMount() {
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import { useMemo } from "react";
|
||||
import {
|
||||
Conversation,
|
||||
ConversationContent,
|
||||
@@ -9,7 +8,6 @@ import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner
|
||||
import { FileUIPart, UIDataTypes, UIMessage, UITools } from "ai";
|
||||
import { TOOL_PART_PREFIX } from "../JobStatsBar/constants";
|
||||
import { TurnStatsBar } from "../JobStatsBar/TurnStatsBar";
|
||||
import { CopilotPendingReviews } from "../CopilotPendingReviews/CopilotPendingReviews";
|
||||
import {
|
||||
buildRenderSegments,
|
||||
getTurnMessages,
|
||||
@@ -53,50 +51,6 @@ function renderSegments(
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract graph_exec_id from tool outputs that need review.
|
||||
* Handles both:
|
||||
* - run_block ReviewRequiredResponse (has graph_exec_id directly)
|
||||
* - run_agent ExecutionStartedResponse with status "REVIEW" (has execution_id)
|
||||
*/
|
||||
function extractGraphExecId(
|
||||
messages: UIMessage<unknown, UIDataTypes, UITools>[],
|
||||
): string | null {
|
||||
// Scan backwards — the most recent review output has the ID
|
||||
for (let i = messages.length - 1; i >= 0; i--) {
|
||||
const msg = messages[i];
|
||||
for (const part of msg.parts) {
|
||||
if ("output" in part && part.output) {
|
||||
const out =
|
||||
typeof part.output === "string"
|
||||
? (() => {
|
||||
try {
|
||||
return JSON.parse(part.output);
|
||||
} catch {
|
||||
return null;
|
||||
}
|
||||
})()
|
||||
: part.output;
|
||||
if (out && typeof out === "object") {
|
||||
// run_block: ReviewRequiredResponse has graph_exec_id
|
||||
if ("graph_exec_id" in out) {
|
||||
return (out as { graph_exec_id: string }).graph_exec_id;
|
||||
}
|
||||
// run_agent: ExecutionStartedResponse with status "REVIEW"
|
||||
if (
|
||||
"execution_id" in out &&
|
||||
"status" in out &&
|
||||
(out as { status: string }).status === "REVIEW"
|
||||
) {
|
||||
return (out as { execution_id: string }).execution_id;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
export function ChatMessagesContainer({
|
||||
messages,
|
||||
status,
|
||||
@@ -106,7 +60,6 @@ export function ChatMessagesContainer({
|
||||
sessionID,
|
||||
}: Props) {
|
||||
const lastMessage = messages[messages.length - 1];
|
||||
const graphExecId = useMemo(() => extractGraphExecId(messages), [messages]);
|
||||
|
||||
const hasInflight = (() => {
|
||||
if (lastMessage?.role !== "assistant") return false;
|
||||
@@ -252,7 +205,6 @@ export function ChatMessagesContainer({
|
||||
</MessageContent>
|
||||
</Message>
|
||||
)}
|
||||
{graphExecId && <CopilotPendingReviews graphExecId={graphExecId} />}
|
||||
{error && (
|
||||
<details className="rounded-lg bg-red-50 p-4 text-sm text-red-700">
|
||||
<summary className="cursor-pointer font-medium">
|
||||
|
||||
@@ -130,7 +130,6 @@ export function MessagePartRenderer({ part, messageID, partIndex }: Props) {
|
||||
case "tool-get_doc_page":
|
||||
return <SearchDocsTool key={key} part={part as ToolUIPart} />;
|
||||
case "tool-run_block":
|
||||
case "tool-continue_run_block":
|
||||
return <RunBlockTool key={key} part={part as ToolUIPart} />;
|
||||
case "tool-run_mcp_tool":
|
||||
return <RunMCPToolComponent key={key} part={part as ToolUIPart} />;
|
||||
|
||||
@@ -19,7 +19,6 @@ const CUSTOM_TOOL_TYPES = new Set([
|
||||
"tool-search_docs",
|
||||
"tool-get_doc_page",
|
||||
"tool-run_block",
|
||||
"tool-continue_run_block",
|
||||
"tool-run_mcp_tool",
|
||||
"tool-run_agent",
|
||||
"tool-schedule_agent",
|
||||
@@ -34,7 +33,6 @@ const INTERACTIVE_RESPONSE_TYPES: ReadonlySet<string> = new Set([
|
||||
ResponseType.setup_requirements,
|
||||
ResponseType.agent_details,
|
||||
ResponseType.block_details,
|
||||
ResponseType.review_required,
|
||||
ResponseType.need_login,
|
||||
ResponseType.input_validation_error,
|
||||
ResponseType.agent_builder_clarification_needed,
|
||||
|
||||
@@ -23,36 +23,24 @@ import {
|
||||
useSidebar,
|
||||
} from "@/components/ui/sidebar";
|
||||
import { cn } from "@/lib/utils";
|
||||
import {
|
||||
CheckCircle,
|
||||
DotsThree,
|
||||
PlusCircleIcon,
|
||||
PlusIcon,
|
||||
} from "@phosphor-icons/react";
|
||||
import { DotsThree, PlusCircleIcon, PlusIcon } from "@phosphor-icons/react";
|
||||
import { useQueryClient } from "@tanstack/react-query";
|
||||
import { AnimatePresence, motion } from "framer-motion";
|
||||
import { parseAsString, useQueryState } from "nuqs";
|
||||
import { useEffect, useRef, useState } from "react";
|
||||
import { useCopilotUIStore } from "../../store";
|
||||
import { NotificationToggle } from "./components/NotificationToggle/NotificationToggle";
|
||||
import { DeleteChatDialog } from "../DeleteChatDialog/DeleteChatDialog";
|
||||
import { PulseLoader } from "../PulseLoader/PulseLoader";
|
||||
|
||||
export function ChatSidebar() {
|
||||
const { state } = useSidebar();
|
||||
const isCollapsed = state === "collapsed";
|
||||
const [sessionId, setSessionId] = useQueryState("sessionId", parseAsString);
|
||||
const {
|
||||
sessionToDelete,
|
||||
setSessionToDelete,
|
||||
completedSessionIDs,
|
||||
clearCompletedSession,
|
||||
} = useCopilotUIStore();
|
||||
const { sessionToDelete, setSessionToDelete } = useCopilotUIStore();
|
||||
|
||||
const queryClient = useQueryClient();
|
||||
|
||||
const { data: sessionsResponse, isLoading: isLoadingSessions } =
|
||||
useGetV2ListSessions({ limit: 50 }, { query: { refetchInterval: 10_000 } });
|
||||
useGetV2ListSessions({ limit: 50 });
|
||||
|
||||
const { mutate: deleteSession, isPending: isDeleting } =
|
||||
useDeleteV2DeleteSession({
|
||||
@@ -111,22 +99,6 @@ export function ChatSidebar() {
|
||||
}
|
||||
}, [editingSessionId]);
|
||||
|
||||
// Refetch session list when active session changes
|
||||
useEffect(() => {
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: getGetV2ListSessionsQueryKey(),
|
||||
});
|
||||
}, [sessionId, queryClient]);
|
||||
|
||||
// Clear completed indicator when navigating to a session (works for all paths)
|
||||
useEffect(() => {
|
||||
if (!sessionId || !completedSessionIDs.has(sessionId)) return;
|
||||
clearCompletedSession(sessionId);
|
||||
const remaining = completedSessionIDs.size - 1;
|
||||
document.title =
|
||||
remaining > 0 ? `(${remaining}) Otto is ready - AutoGPT` : "AutoGPT";
|
||||
}, [sessionId, completedSessionIDs, clearCompletedSession]);
|
||||
|
||||
const sessions =
|
||||
sessionsResponse?.status === 200 ? sessionsResponse.data.sessions : [];
|
||||
|
||||
@@ -256,11 +228,8 @@ export function ChatSidebar() {
|
||||
<Text variant="h3" size="body-medium">
|
||||
Your chats
|
||||
</Text>
|
||||
<div className="relative left-5 flex items-center gap-1">
|
||||
<NotificationToggle />
|
||||
<div className="relative left-1">
|
||||
<SidebarTrigger />
|
||||
</div>
|
||||
<div className="relative left-6">
|
||||
<SidebarTrigger />
|
||||
</div>
|
||||
</div>
|
||||
{sessionId ? (
|
||||
@@ -336,8 +305,8 @@ export function ChatSidebar() {
|
||||
onClick={() => handleSelectSession(session.id)}
|
||||
className="w-full px-3 py-2.5 pr-10 text-left"
|
||||
>
|
||||
<div className="flex min-w-0 max-w-full items-center gap-2">
|
||||
<div className="min-w-0 flex-1">
|
||||
<div className="flex min-w-0 max-w-full flex-col overflow-hidden">
|
||||
<div className="min-w-0 max-w-full">
|
||||
<Text
|
||||
variant="body"
|
||||
className={cn(
|
||||
@@ -360,22 +329,10 @@ export function ChatSidebar() {
|
||||
</motion.span>
|
||||
</AnimatePresence>
|
||||
</Text>
|
||||
<Text variant="small" className="text-neutral-400">
|
||||
{formatDate(session.updated_at)}
|
||||
</Text>
|
||||
</div>
|
||||
{session.is_processing &&
|
||||
session.id !== sessionId &&
|
||||
!completedSessionIDs.has(session.id) && (
|
||||
<PulseLoader size={16} className="shrink-0" />
|
||||
)}
|
||||
{completedSessionIDs.has(session.id) &&
|
||||
session.id !== sessionId && (
|
||||
<CheckCircle
|
||||
className="h-4 w-4 shrink-0 text-green-500"
|
||||
weight="fill"
|
||||
/>
|
||||
)}
|
||||
<Text variant="small" className="text-neutral-400">
|
||||
{formatDate(session.updated_at)}
|
||||
</Text>
|
||||
</div>
|
||||
</button>
|
||||
)}
|
||||
|
||||
@@ -1,92 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { Switch } from "@/components/atoms/Switch/Switch";
|
||||
import {
|
||||
Popover,
|
||||
PopoverContent,
|
||||
PopoverTrigger,
|
||||
} from "@/components/molecules/Popover/Popover";
|
||||
import { toast } from "@/components/molecules/Toast/use-toast";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { Bell, BellRinging, BellSlash } from "@phosphor-icons/react";
|
||||
import { useCopilotUIStore } from "../../../../store";
|
||||
|
||||
export function NotificationToggle() {
|
||||
const {
|
||||
isNotificationsEnabled,
|
||||
setNotificationsEnabled,
|
||||
isSoundEnabled,
|
||||
toggleSound,
|
||||
} = useCopilotUIStore();
|
||||
|
||||
async function handleToggleNotifications() {
|
||||
if (isNotificationsEnabled) {
|
||||
setNotificationsEnabled(false);
|
||||
return;
|
||||
}
|
||||
if (typeof Notification === "undefined") {
|
||||
toast({
|
||||
title: "Notifications not supported",
|
||||
description: "Your browser does not support notifications.",
|
||||
variant: "destructive",
|
||||
});
|
||||
return;
|
||||
}
|
||||
const permission = await Notification.requestPermission();
|
||||
if (permission === "granted") {
|
||||
setNotificationsEnabled(true);
|
||||
} else {
|
||||
toast({
|
||||
title: "Notifications blocked",
|
||||
description:
|
||||
"Please allow notifications in your browser settings to enable this feature.",
|
||||
variant: "destructive",
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<Popover>
|
||||
<PopoverTrigger asChild>
|
||||
<button
|
||||
className="rounded p-1 text-black transition-colors hover:bg-zinc-50"
|
||||
aria-label="Notification settings"
|
||||
>
|
||||
{!isNotificationsEnabled ? (
|
||||
<BellSlash className="!size-5" />
|
||||
) : isSoundEnabled ? (
|
||||
<BellRinging className="!size-5" />
|
||||
) : (
|
||||
<Bell className="!size-5" />
|
||||
)}
|
||||
</button>
|
||||
</PopoverTrigger>
|
||||
<PopoverContent align="start" className="w-56 p-3">
|
||||
<div className="flex flex-col gap-3">
|
||||
<label className="flex items-center justify-between">
|
||||
<span className="text-sm text-zinc-700">Notifications</span>
|
||||
<Switch
|
||||
checked={isNotificationsEnabled}
|
||||
onCheckedChange={handleToggleNotifications}
|
||||
/>
|
||||
</label>
|
||||
<label className="flex items-center justify-between">
|
||||
<span
|
||||
className={cn(
|
||||
"text-sm text-zinc-700",
|
||||
!isNotificationsEnabled && "opacity-50",
|
||||
)}
|
||||
>
|
||||
Sound
|
||||
</span>
|
||||
<Switch
|
||||
checked={isSoundEnabled && isNotificationsEnabled}
|
||||
onCheckedChange={toggleSound}
|
||||
disabled={!isNotificationsEnabled}
|
||||
/>
|
||||
</label>
|
||||
</div>
|
||||
</PopoverContent>
|
||||
</Popover>
|
||||
);
|
||||
}
|
||||
@@ -1,61 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { useCallback } from "react";
|
||||
import { PendingReviewsList } from "@/components/organisms/PendingReviewsList/PendingReviewsList";
|
||||
import { useCopilotChatActions } from "../CopilotChatActionsProvider/useCopilotChatActions";
|
||||
import { usePendingReviewsForExecution } from "@/hooks/usePendingReviews";
|
||||
import { okData } from "@/app/api/helpers";
|
||||
|
||||
interface Props {
|
||||
graphExecId: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Renders a single consolidated PendingReviewsList for all pending copilot
|
||||
* reviews in a session — mirrors the non-copilot review page behavior.
|
||||
* Works for both run_block (synthetic copilot-session-*) and run_agent (real graph exec) reviews.
|
||||
*/
|
||||
export function CopilotPendingReviews({ graphExecId }: Props) {
|
||||
const { onSend } = useCopilotChatActions();
|
||||
const { pendingReviews, refetch } = usePendingReviewsForExecution(
|
||||
graphExecId,
|
||||
{ enabled: !!graphExecId, refetchInterval: 2000 },
|
||||
);
|
||||
|
||||
// Graph executions auto-resume after approval; block reviews need continue_run_block.
|
||||
const isGraphExecution = !graphExecId.startsWith("copilot-session-");
|
||||
|
||||
const handleReviewComplete = useCallback(async () => {
|
||||
// Brief delay for the server to propagate the approval
|
||||
await new Promise((resolve) => setTimeout(resolve, 500));
|
||||
const result = await refetch();
|
||||
const remaining = okData(result.data) || [];
|
||||
|
||||
if (remaining.length > 0) return;
|
||||
|
||||
if (isGraphExecution) {
|
||||
onSend(
|
||||
`All pending reviews have been processed. ` +
|
||||
`The agent execution will resume automatically for approved reviews. ` +
|
||||
`Use view_agent_output with execution_id="${graphExecId}" to check the result.`,
|
||||
);
|
||||
} else {
|
||||
onSend(
|
||||
`All pending reviews have been processed. ` +
|
||||
`For any approved reviews, call continue_run_block with the corresponding review_id to execute them. ` +
|
||||
`For rejected reviews, no further action is needed.`,
|
||||
);
|
||||
}
|
||||
}, [refetch, onSend, isGraphExecution, graphExecId]);
|
||||
|
||||
if (pendingReviews.length === 0) return null;
|
||||
|
||||
return (
|
||||
<div className="py-2">
|
||||
<PendingReviewsList
|
||||
reviews={pendingReviews}
|
||||
onReviewComplete={handleReviewComplete}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -3,17 +3,8 @@ import { Button } from "@/components/atoms/Button/Button";
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import { scrollbarStyles } from "@/components/styles/scrollbars";
|
||||
import { cn } from "@/lib/utils";
|
||||
import {
|
||||
CheckCircle,
|
||||
PlusIcon,
|
||||
SpeakerHigh,
|
||||
SpeakerSlash,
|
||||
SpinnerGapIcon,
|
||||
X,
|
||||
} from "@phosphor-icons/react";
|
||||
import { PlusIcon, SpinnerGapIcon, X } from "@phosphor-icons/react";
|
||||
import { Drawer } from "vaul";
|
||||
import { useCopilotUIStore } from "../../store";
|
||||
import { PulseLoader } from "../PulseLoader/PulseLoader";
|
||||
|
||||
interface Props {
|
||||
isOpen: boolean;
|
||||
@@ -61,13 +52,6 @@ export function MobileDrawer({
|
||||
onClose,
|
||||
onOpenChange,
|
||||
}: Props) {
|
||||
const {
|
||||
completedSessionIDs,
|
||||
clearCompletedSession,
|
||||
isSoundEnabled,
|
||||
toggleSound,
|
||||
} = useCopilotUIStore();
|
||||
|
||||
return (
|
||||
<Drawer.Root open={isOpen} onOpenChange={onOpenChange} direction="left">
|
||||
<Drawer.Portal>
|
||||
@@ -78,31 +62,14 @@ export function MobileDrawer({
|
||||
<Drawer.Title className="text-lg font-semibold text-zinc-800">
|
||||
Your chats
|
||||
</Drawer.Title>
|
||||
<div className="flex items-center gap-1">
|
||||
<button
|
||||
onClick={toggleSound}
|
||||
className="rounded p-1.5 text-zinc-400 transition-colors hover:text-zinc-600"
|
||||
aria-label={
|
||||
isSoundEnabled
|
||||
? "Disable notification sound"
|
||||
: "Enable notification sound"
|
||||
}
|
||||
>
|
||||
{isSoundEnabled ? (
|
||||
<SpeakerHigh className="h-4 w-4" />
|
||||
) : (
|
||||
<SpeakerSlash className="h-4 w-4" />
|
||||
)}
|
||||
</button>
|
||||
<Button
|
||||
variant="icon"
|
||||
size="icon"
|
||||
aria-label="Close sessions"
|
||||
onClick={onClose}
|
||||
>
|
||||
<X width="1rem" height="1rem" />
|
||||
</Button>
|
||||
</div>
|
||||
<Button
|
||||
variant="icon"
|
||||
size="icon"
|
||||
aria-label="Close sessions"
|
||||
onClick={onClose}
|
||||
>
|
||||
<X width="1rem" height="1rem" />
|
||||
</Button>
|
||||
</div>
|
||||
{currentSessionId ? (
|
||||
<div className="mt-2">
|
||||
@@ -136,12 +103,7 @@ export function MobileDrawer({
|
||||
sessions.map((session) => (
|
||||
<button
|
||||
key={session.id}
|
||||
onClick={() => {
|
||||
onSelectSession(session.id);
|
||||
if (completedSessionIDs.has(session.id)) {
|
||||
clearCompletedSession(session.id);
|
||||
}
|
||||
}}
|
||||
onClick={() => onSelectSession(session.id)}
|
||||
className={cn(
|
||||
"w-full rounded-lg px-3 py-2.5 text-left transition-colors",
|
||||
session.id === currentSessionId
|
||||
@@ -150,7 +112,7 @@ export function MobileDrawer({
|
||||
)}
|
||||
>
|
||||
<div className="flex min-w-0 max-w-full flex-col overflow-hidden">
|
||||
<div className="flex min-w-0 max-w-full items-center gap-1.5">
|
||||
<div className="min-w-0 max-w-full">
|
||||
<Text
|
||||
variant="body"
|
||||
className={cn(
|
||||
@@ -162,18 +124,6 @@ export function MobileDrawer({
|
||||
>
|
||||
{session.title || "Untitled chat"}
|
||||
</Text>
|
||||
{session.is_processing &&
|
||||
!completedSessionIDs.has(session.id) &&
|
||||
session.id !== currentSessionId && (
|
||||
<PulseLoader size={8} className="shrink-0" />
|
||||
)}
|
||||
{completedSessionIDs.has(session.id) &&
|
||||
session.id !== currentSessionId && (
|
||||
<CheckCircle
|
||||
className="h-4 w-4 shrink-0 text-green-500"
|
||||
weight="fill"
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
<Text variant="small" className="text-neutral-400">
|
||||
{formatDate(session.updated_at)}
|
||||
|
||||
@@ -1,74 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import { Key, storage } from "@/services/storage/local-storage";
|
||||
import { BellRinging, X } from "@phosphor-icons/react";
|
||||
import { useEffect, useState } from "react";
|
||||
import { useCopilotUIStore } from "../../store";
|
||||
|
||||
export function NotificationBanner() {
|
||||
const { setNotificationsEnabled, isNotificationsEnabled } =
|
||||
useCopilotUIStore();
|
||||
|
||||
const [dismissed, setDismissed] = useState(
|
||||
() => storage.get(Key.COPILOT_NOTIFICATION_BANNER_DISMISSED) === "true",
|
||||
);
|
||||
const [permission, setPermission] = useState(() =>
|
||||
typeof Notification !== "undefined" ? Notification.permission : "denied",
|
||||
);
|
||||
|
||||
// Re-read dismissed flag when notifications are toggled off (e.g. clearCopilotLocalData)
|
||||
useEffect(() => {
|
||||
if (!isNotificationsEnabled) {
|
||||
setDismissed(
|
||||
storage.get(Key.COPILOT_NOTIFICATION_BANNER_DISMISSED) === "true",
|
||||
);
|
||||
}
|
||||
}, [isNotificationsEnabled]);
|
||||
|
||||
// Don't show if notifications aren't supported, already decided, dismissed, or already enabled
|
||||
if (
|
||||
typeof Notification === "undefined" ||
|
||||
permission !== "default" ||
|
||||
dismissed ||
|
||||
isNotificationsEnabled
|
||||
) {
|
||||
return null;
|
||||
}
|
||||
|
||||
function handleEnable() {
|
||||
Notification.requestPermission().then((result) => {
|
||||
setPermission(result);
|
||||
if (result === "granted") {
|
||||
setNotificationsEnabled(true);
|
||||
handleDismiss();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
function handleDismiss() {
|
||||
storage.set(Key.COPILOT_NOTIFICATION_BANNER_DISMISSED, "true");
|
||||
setDismissed(true);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="flex items-center gap-3 border-b border-amber-200 bg-amber-50 px-4 py-2.5">
|
||||
<BellRinging className="h-5 w-5 shrink-0 text-amber-600" weight="fill" />
|
||||
<Text variant="body" className="flex-1 text-sm text-amber-800">
|
||||
Enable browser notifications to know when Otto finishes working, even
|
||||
when you switch tabs.
|
||||
</Text>
|
||||
<Button variant="primary" size="small" onClick={handleEnable}>
|
||||
Enable
|
||||
</Button>
|
||||
<button
|
||||
onClick={handleDismiss}
|
||||
className="rounded p-1 text-amber-400 transition-colors hover:text-amber-600"
|
||||
aria-label="Dismiss"
|
||||
>
|
||||
<X className="h-4 w-4" />
|
||||
</button>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -1,95 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import { Dialog } from "@/components/molecules/Dialog/Dialog";
|
||||
import { Key, storage } from "@/services/storage/local-storage";
|
||||
import { BellRinging } from "@phosphor-icons/react";
|
||||
import { useEffect, useState } from "react";
|
||||
import { useCopilotUIStore } from "../../store";
|
||||
|
||||
export function NotificationDialog() {
|
||||
const {
|
||||
showNotificationDialog,
|
||||
setShowNotificationDialog,
|
||||
setNotificationsEnabled,
|
||||
isNotificationsEnabled,
|
||||
} = useCopilotUIStore();
|
||||
|
||||
const [dismissed, setDismissed] = useState(
|
||||
() => storage.get(Key.COPILOT_NOTIFICATION_DIALOG_DISMISSED) === "true",
|
||||
);
|
||||
const [permission, setPermission] = useState(() =>
|
||||
typeof Notification !== "undefined" ? Notification.permission : "denied",
|
||||
);
|
||||
|
||||
// Re-read dismissed flag when notifications are toggled off (e.g. clearCopilotLocalData)
|
||||
useEffect(() => {
|
||||
if (!isNotificationsEnabled) {
|
||||
setDismissed(
|
||||
storage.get(Key.COPILOT_NOTIFICATION_DIALOG_DISMISSED) === "true",
|
||||
);
|
||||
}
|
||||
}, [isNotificationsEnabled]);
|
||||
|
||||
const shouldShowAuto =
|
||||
typeof Notification !== "undefined" &&
|
||||
permission === "default" &&
|
||||
!dismissed;
|
||||
|
||||
const isOpen = showNotificationDialog || shouldShowAuto;
|
||||
|
||||
function handleEnable() {
|
||||
if (typeof Notification === "undefined") {
|
||||
handleDismiss();
|
||||
return;
|
||||
}
|
||||
Notification.requestPermission().then((result) => {
|
||||
setPermission(result);
|
||||
if (result === "granted") {
|
||||
setNotificationsEnabled(true);
|
||||
handleDismiss();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
function handleDismiss() {
|
||||
storage.set(Key.COPILOT_NOTIFICATION_DIALOG_DISMISSED, "true");
|
||||
setDismissed(true);
|
||||
setShowNotificationDialog(false);
|
||||
}
|
||||
|
||||
return (
|
||||
<Dialog
|
||||
title="Stay in the loop"
|
||||
styling={{ maxWidth: "28rem", minWidth: "auto" }}
|
||||
controlled={{
|
||||
isOpen,
|
||||
set: async (open) => {
|
||||
if (!open) handleDismiss();
|
||||
},
|
||||
}}
|
||||
onClose={handleDismiss}
|
||||
>
|
||||
<Dialog.Content>
|
||||
<div className="flex flex-col items-center gap-4 py-2">
|
||||
<div className="flex h-12 w-12 items-center justify-center rounded-full bg-violet-100">
|
||||
<BellRinging className="h-6 w-6 text-violet-600" weight="fill" />
|
||||
</div>
|
||||
<Text variant="body" className="text-center text-neutral-600">
|
||||
Otto can notify you when a response is ready, even if you switch
|
||||
tabs or close this page. Enable notifications so you never miss one.
|
||||
</Text>
|
||||
</div>
|
||||
<Dialog.Footer>
|
||||
<Button variant="secondary" onClick={handleDismiss}>
|
||||
Not now
|
||||
</Button>
|
||||
<Button variant="primary" onClick={handleEnable}>
|
||||
Enable notifications
|
||||
</Button>
|
||||
</Dialog.Footer>
|
||||
</Dialog.Content>
|
||||
</Dialog>
|
||||
);
|
||||
}
|
||||
@@ -15,8 +15,6 @@
|
||||
position: absolute;
|
||||
left: 0;
|
||||
top: 0;
|
||||
transform: scale(0);
|
||||
opacity: 0;
|
||||
animation: ripple 2s linear infinite;
|
||||
}
|
||||
|
||||
@@ -27,10 +25,7 @@
|
||||
@keyframes ripple {
|
||||
0% {
|
||||
transform: scale(0);
|
||||
opacity: 0.6;
|
||||
}
|
||||
50% {
|
||||
opacity: 0.3;
|
||||
opacity: 1;
|
||||
}
|
||||
100% {
|
||||
transform: scale(1);
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import { Key, storage } from "@/services/storage/local-storage";
|
||||
import { create } from "zustand";
|
||||
|
||||
export interface DeleteTarget {
|
||||
@@ -7,89 +6,17 @@ export interface DeleteTarget {
|
||||
}
|
||||
|
||||
interface CopilotUIState {
|
||||
/** Prompt extracted from URL hash (e.g. /copilot#prompt=...) for input prefill. */
|
||||
initialPrompt: string | null;
|
||||
setInitialPrompt: (prompt: string | null) => void;
|
||||
|
||||
sessionToDelete: DeleteTarget | null;
|
||||
setSessionToDelete: (target: DeleteTarget | null) => void;
|
||||
|
||||
isDrawerOpen: boolean;
|
||||
setDrawerOpen: (open: boolean) => void;
|
||||
|
||||
completedSessionIDs: Set<string>;
|
||||
addCompletedSession: (id: string) => void;
|
||||
clearCompletedSession: (id: string) => void;
|
||||
clearAllCompletedSessions: () => void;
|
||||
|
||||
isNotificationsEnabled: boolean;
|
||||
setNotificationsEnabled: (enabled: boolean) => void;
|
||||
|
||||
isSoundEnabled: boolean;
|
||||
toggleSound: () => void;
|
||||
|
||||
showNotificationDialog: boolean;
|
||||
setShowNotificationDialog: (show: boolean) => void;
|
||||
|
||||
clearCopilotLocalData: () => void;
|
||||
}
|
||||
|
||||
export const useCopilotUIStore = create<CopilotUIState>((set) => ({
|
||||
initialPrompt: null,
|
||||
setInitialPrompt: (prompt) => set({ initialPrompt: prompt }),
|
||||
|
||||
sessionToDelete: null,
|
||||
setSessionToDelete: (target) => set({ sessionToDelete: target }),
|
||||
|
||||
isDrawerOpen: false,
|
||||
setDrawerOpen: (open) => set({ isDrawerOpen: open }),
|
||||
|
||||
completedSessionIDs: new Set<string>(),
|
||||
addCompletedSession: (id) =>
|
||||
set((state) => {
|
||||
const next = new Set(state.completedSessionIDs);
|
||||
next.add(id);
|
||||
return { completedSessionIDs: next };
|
||||
}),
|
||||
clearCompletedSession: (id) =>
|
||||
set((state) => {
|
||||
const next = new Set(state.completedSessionIDs);
|
||||
next.delete(id);
|
||||
return { completedSessionIDs: next };
|
||||
}),
|
||||
clearAllCompletedSessions: () =>
|
||||
set({ completedSessionIDs: new Set<string>() }),
|
||||
|
||||
isNotificationsEnabled:
|
||||
storage.get(Key.COPILOT_NOTIFICATIONS_ENABLED) === "true" &&
|
||||
typeof Notification !== "undefined" &&
|
||||
Notification.permission === "granted",
|
||||
setNotificationsEnabled: (enabled) => {
|
||||
storage.set(Key.COPILOT_NOTIFICATIONS_ENABLED, String(enabled));
|
||||
set({ isNotificationsEnabled: enabled });
|
||||
},
|
||||
|
||||
isSoundEnabled: storage.get(Key.COPILOT_SOUND_ENABLED) !== "false",
|
||||
toggleSound: () =>
|
||||
set((state) => {
|
||||
const next = !state.isSoundEnabled;
|
||||
storage.set(Key.COPILOT_SOUND_ENABLED, String(next));
|
||||
return { isSoundEnabled: next };
|
||||
}),
|
||||
|
||||
showNotificationDialog: false,
|
||||
setShowNotificationDialog: (show) => set({ showNotificationDialog: show }),
|
||||
|
||||
clearCopilotLocalData: () => {
|
||||
storage.clean(Key.COPILOT_NOTIFICATIONS_ENABLED);
|
||||
storage.clean(Key.COPILOT_SOUND_ENABLED);
|
||||
storage.clean(Key.COPILOT_NOTIFICATION_BANNER_DISMISSED);
|
||||
storage.clean(Key.COPILOT_NOTIFICATION_DIALOG_DISMISSED);
|
||||
set({
|
||||
completedSessionIDs: new Set<string>(),
|
||||
isNotificationsEnabled: false,
|
||||
isSoundEnabled: true,
|
||||
});
|
||||
document.title = "AutoGPT";
|
||||
},
|
||||
}));
|
||||
|
||||
@@ -15,7 +15,6 @@ import {
|
||||
isRunBlockBlockOutput,
|
||||
isRunBlockDetailsOutput,
|
||||
isRunBlockErrorOutput,
|
||||
isRunBlockReviewRequiredOutput,
|
||||
isRunBlockSetupRequirementsOutput,
|
||||
ToolIcon,
|
||||
} from "./helpers";
|
||||
@@ -55,15 +54,10 @@ export function RunBlockTool({ part }: Props) {
|
||||
part.state === "output-available" &&
|
||||
!!output &&
|
||||
!setupRequirementsOutput &&
|
||||
!isRunBlockReviewRequiredOutput(output) &&
|
||||
(isRunBlockBlockOutput(output) ||
|
||||
isRunBlockDetailsOutput(output) ||
|
||||
isRunBlockErrorOutput(output));
|
||||
|
||||
// Review UI is rendered at the chat level by CopilotPendingReviews,
|
||||
// not inside each tool card. This matches the non-copilot flow where
|
||||
// a single PendingReviewsList shows all reviews grouped together.
|
||||
|
||||
return (
|
||||
<div className="py-2">
|
||||
<div className="flex items-center gap-2 text-sm text-muted-foreground">
|
||||
|
||||
@@ -26,18 +26,6 @@ export interface BlockDetailsResponse {
|
||||
user_authenticated: boolean;
|
||||
}
|
||||
|
||||
/** Response when a block requires human review before execution. */
|
||||
export interface ReviewRequiredResponse {
|
||||
type: typeof ResponseType.review_required;
|
||||
message: string;
|
||||
session_id?: string | null;
|
||||
block_id: string;
|
||||
block_name: string;
|
||||
review_id: string;
|
||||
graph_exec_id: string;
|
||||
input_data: Record<string, unknown>;
|
||||
}
|
||||
|
||||
export interface RunBlockInput {
|
||||
block_id?: string;
|
||||
block_name?: string;
|
||||
@@ -48,14 +36,12 @@ export type RunBlockToolOutput =
|
||||
| SetupRequirementsResponse
|
||||
| BlockDetailsResponse
|
||||
| BlockOutputResponse
|
||||
| ReviewRequiredResponse
|
||||
| ErrorResponse;
|
||||
|
||||
const RUN_BLOCK_OUTPUT_TYPES = new Set<string>([
|
||||
ResponseType.setup_requirements,
|
||||
ResponseType.block_details,
|
||||
ResponseType.block_output,
|
||||
ResponseType.review_required,
|
||||
ResponseType.error,
|
||||
]);
|
||||
|
||||
@@ -80,19 +66,7 @@ export function isRunBlockDetailsOutput(
|
||||
export function isRunBlockBlockOutput(
|
||||
output: RunBlockToolOutput,
|
||||
): output is BlockOutputResponse {
|
||||
return (
|
||||
output.type === ResponseType.block_output ||
|
||||
("block_id" in output && !("review_id" in output))
|
||||
);
|
||||
}
|
||||
|
||||
export function isRunBlockReviewRequiredOutput(
|
||||
output: RunBlockToolOutput,
|
||||
): output is ReviewRequiredResponse {
|
||||
return (
|
||||
output.type === ResponseType.review_required ||
|
||||
("review_id" in output && "block_name" in output && "input_data" in output)
|
||||
);
|
||||
return output.type === ResponseType.block_output || "block_id" in output;
|
||||
}
|
||||
|
||||
export function isRunBlockErrorOutput(
|
||||
@@ -117,7 +91,6 @@ function parseOutput(output: unknown): RunBlockToolOutput | null {
|
||||
if (typeof type === "string" && RUN_BLOCK_OUTPUT_TYPES.has(type)) {
|
||||
return output as RunBlockToolOutput;
|
||||
}
|
||||
if ("review_id" in output) return output as ReviewRequiredResponse;
|
||||
if ("block_id" in output) return output as BlockOutputResponse;
|
||||
if ("block" in output) return output as BlockDetailsResponse;
|
||||
if ("setup_info" in output) return output as SetupRequirementsResponse;
|
||||
@@ -162,9 +135,6 @@ export function getAnimationText(part: {
|
||||
if (isRunBlockSetupRequirementsOutput(output)) {
|
||||
return `Setup needed for "${output.setup_info.agent_name}"`;
|
||||
}
|
||||
if (isRunBlockReviewRequiredOutput(output)) {
|
||||
return `Review needed for "${output.block_name}"`;
|
||||
}
|
||||
return "Error running block";
|
||||
}
|
||||
case "output-error":
|
||||
@@ -257,14 +227,6 @@ export function getAccordionMeta(output: RunBlockToolOutput): {
|
||||
};
|
||||
}
|
||||
|
||||
if (isRunBlockReviewRequiredOutput(output)) {
|
||||
return {
|
||||
icon,
|
||||
title: output.block_name,
|
||||
description: "Sensitive action — awaiting review",
|
||||
};
|
||||
}
|
||||
|
||||
return {
|
||||
icon: (
|
||||
<WarningDiamondIcon size={32} weight="light" className="text-red-500" />
|
||||
|
||||
@@ -235,8 +235,8 @@ describe("getAnimationText", () => {
|
||||
state: "input-streaming",
|
||||
...BASE,
|
||||
});
|
||||
expect(text).toContain("Connecting");
|
||||
expect(text).toContain("example.com");
|
||||
expect(text).toContain("Discovering");
|
||||
expect(text).toContain("mcp.example.com");
|
||||
});
|
||||
|
||||
it("shows tool call text when tool_name is set", () => {
|
||||
@@ -245,7 +245,7 @@ describe("getAnimationText", () => {
|
||||
input: { server_url: "https://mcp.example.com/mcp", tool_name: "fetch" },
|
||||
});
|
||||
expect(text).toContain("fetch");
|
||||
expect(text).toContain("example.com");
|
||||
expect(text).toContain("mcp.example.com");
|
||||
});
|
||||
|
||||
it("includes query argument preview when tool_arguments has a query key", () => {
|
||||
@@ -282,7 +282,7 @@ describe("getAnimationText", () => {
|
||||
tool_arguments: {},
|
||||
},
|
||||
});
|
||||
expect(text).toBe("Calling list_users on example.com");
|
||||
expect(text).toBe("Calling list_users on mcp.example.com");
|
||||
});
|
||||
|
||||
it("truncates long argument previews to 60 chars with ellipsis", () => {
|
||||
@@ -327,8 +327,8 @@ describe("getAnimationText", () => {
|
||||
output: DISCOVERY,
|
||||
input: { server_url: "https://mcp.example.com/mcp" },
|
||||
});
|
||||
expect(text).toContain("Connected");
|
||||
expect(text).toContain("example.com");
|
||||
expect(text).toContain("Discovered");
|
||||
expect(text).toContain("1");
|
||||
});
|
||||
|
||||
it("shows setup label on output-available for setup requirements", () => {
|
||||
|
||||
@@ -12,6 +12,8 @@ import { CredentialsProvidersContext } from "@/providers/agent-credentials/crede
|
||||
import { useContext, useEffect, useRef, useState } from "react";
|
||||
import { useCopilotChatActions } from "../../../../components/CopilotChatActionsProvider/useCopilotChatActions";
|
||||
import { ContentMessage } from "../../../../components/ToolAccordion/AccordionContent";
|
||||
import { serverHost } from "../../helpers";
|
||||
|
||||
interface Props {
|
||||
output: SetupRequirementsResponse;
|
||||
/**
|
||||
@@ -36,8 +38,7 @@ export function MCPSetupCard({ output, retryInstruction }: Props) {
|
||||
|
||||
// setup_info.agent_id is set to the server_url in the backend
|
||||
const serverUrl = output.setup_info.agent_id;
|
||||
// agent_name is computed by the backend as the display name for the service
|
||||
const service = output.setup_info.agent_name;
|
||||
const host = serverHost(serverUrl);
|
||||
|
||||
const [loading, setLoading] = useState(false);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
@@ -94,7 +95,10 @@ export function MCPSetupCard({ output, retryInstruction }: Props) {
|
||||
}
|
||||
|
||||
setConnected(true);
|
||||
onSend(retryInstruction ?? "I've connected. Please retry.");
|
||||
onSend(
|
||||
retryInstruction ??
|
||||
"I've connected the MCP server credentials. Please retry.",
|
||||
);
|
||||
} catch (e: unknown) {
|
||||
const err = e as Record<string, unknown>;
|
||||
if (err?.status === 400) {
|
||||
@@ -133,7 +137,10 @@ export function MCPSetupCard({ output, retryInstruction }: Props) {
|
||||
if (!(res.status >= 200 && res.status < 300))
|
||||
throw new Error("Failed to store token");
|
||||
setConnected(true);
|
||||
onSend(retryInstruction ?? "I've connected. Please retry.");
|
||||
onSend(
|
||||
retryInstruction ??
|
||||
"I've connected the MCP server credentials. Please retry.",
|
||||
);
|
||||
} catch (e: unknown) {
|
||||
const err = e as Record<string, unknown>;
|
||||
setError(
|
||||
@@ -148,7 +155,7 @@ export function MCPSetupCard({ output, retryInstruction }: Props) {
|
||||
if (connected) {
|
||||
return (
|
||||
<div className="mt-2 rounded-lg border border-green-200 bg-green-50 px-3 py-2 text-sm text-green-700">
|
||||
Connected to {service}!
|
||||
Connected to {host}!
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -164,7 +171,7 @@ export function MCPSetupCard({ output, retryInstruction }: Props) {
|
||||
onClick={handleConnect}
|
||||
disabled={loading}
|
||||
>
|
||||
{loading ? "Connecting…" : `Connect ${service}`}
|
||||
{loading ? "Connecting…" : `Connect to ${host}`}
|
||||
</Button>
|
||||
|
||||
{error && (
|
||||
@@ -177,7 +184,7 @@ export function MCPSetupCard({ output, retryInstruction }: Props) {
|
||||
<div className="mt-3 flex gap-2">
|
||||
<input
|
||||
type="password"
|
||||
aria-label={`API token for ${service}`}
|
||||
aria-label={`API token for ${host}`}
|
||||
placeholder="Paste API token"
|
||||
value={manualToken}
|
||||
onChange={(e) => setManualToken(e.target.value)}
|
||||
|
||||
@@ -32,11 +32,11 @@ vi.mock("@/app/api/__generated__/endpoints/mcp/mcp", () => ({
|
||||
function makeSetupOutput(serverUrl = "https://mcp.example.com/mcp") {
|
||||
return {
|
||||
type: "setup_requirements" as const,
|
||||
message: "To continue, sign in to example.com and approve access.",
|
||||
message: "The MCP server at mcp.example.com requires authentication.",
|
||||
session_id: "test-session",
|
||||
setup_info: {
|
||||
agent_id: serverUrl,
|
||||
agent_name: "example.com",
|
||||
agent_name: "MCP: mcp.example.com",
|
||||
user_readiness: {
|
||||
has_all_credentials: false,
|
||||
missing_credentials: {},
|
||||
@@ -58,9 +58,9 @@ describe("MCPSetupCard", () => {
|
||||
|
||||
it("renders setup message and connect button", () => {
|
||||
render(<MCPSetupCard output={makeSetupOutput()} />);
|
||||
expect(screen.getByText(/sign in to example\.com/i)).toBeDefined();
|
||||
expect(screen.getByText(/requires authentication/)).toBeDefined();
|
||||
expect(
|
||||
screen.getByRole("button", { name: /connect example\.com/i }),
|
||||
screen.getByRole("button", { name: /connect to mcp.example.com/i }),
|
||||
).toBeDefined();
|
||||
});
|
||||
|
||||
@@ -76,7 +76,7 @@ describe("MCPSetupCard", () => {
|
||||
|
||||
render(<MCPSetupCard output={makeSetupOutput()} />);
|
||||
fireEvent.click(
|
||||
screen.getByRole("button", { name: /connect example\.com/i }),
|
||||
screen.getByRole("button", { name: /connect to mcp.example.com/i }),
|
||||
);
|
||||
|
||||
await waitFor(() => {
|
||||
@@ -100,7 +100,7 @@ describe("MCPSetupCard", () => {
|
||||
|
||||
render(<MCPSetupCard output={makeSetupOutput()} />);
|
||||
fireEvent.click(
|
||||
screen.getByRole("button", { name: /connect example\.com/i }),
|
||||
screen.getByRole("button", { name: /connect to mcp.example.com/i }),
|
||||
);
|
||||
|
||||
await waitFor(() => {
|
||||
@@ -127,7 +127,7 @@ describe("MCPSetupCard", () => {
|
||||
fireEvent.click(screen.getByRole("button", { name: /use token/i }));
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText(/connected to example\.com/i)).toBeDefined();
|
||||
expect(screen.getByText(/connected to mcp.example.com/i)).toBeDefined();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -125,11 +125,6 @@ export function serverHost(url: string): string {
|
||||
}
|
||||
}
|
||||
|
||||
/** Strip the 'mcp.' prefix from an MCP hostname: 'mcp.sentry.dev' → 'sentry.dev' */
|
||||
export function serviceNameFromHost(host: string): string {
|
||||
return host.startsWith("mcp.") ? host.slice(4) : host;
|
||||
}
|
||||
|
||||
/**
|
||||
* Return a short preview of the most meaningful argument value, e.g. `"my query"`.
|
||||
* Checks common "query" key names first, then falls back to the first string value.
|
||||
@@ -179,30 +174,28 @@ export function getAnimationText(part: {
|
||||
const host = input?.server_url ? serverHost(input.server_url) : "";
|
||||
const toolName = input?.tool_name?.trim();
|
||||
|
||||
const service = host ? serviceNameFromHost(host) : "";
|
||||
|
||||
switch (part.state) {
|
||||
case "input-streaming":
|
||||
case "input-available": {
|
||||
if (!toolName) return `Connecting to ${service || "integration"}…`;
|
||||
if (!toolName) return `Discovering MCP tools${host ? ` on ${host}` : ""}`;
|
||||
const argPreview = getArgPreview(input?.tool_arguments);
|
||||
return `Calling ${toolName}${argPreview ? `(${argPreview})` : ""}${service ? ` on ${service}` : ""}`;
|
||||
return `Calling ${toolName}${argPreview ? `(${argPreview})` : ""}${host ? ` on ${host}` : ""}`;
|
||||
}
|
||||
case "output-available": {
|
||||
const output = getRunMCPToolOutput(part);
|
||||
if (!output) return "Connecting…";
|
||||
if (!output) return "Connecting to MCP server";
|
||||
if (isSetupRequirementsOutput(output))
|
||||
return `Connect ${output.setup_info.agent_name}`;
|
||||
return `Connect to ${output.setup_info.agent_name}`;
|
||||
if (isMCPToolOutput(output))
|
||||
return `Ran ${output.tool_name}${service ? ` on ${service}` : ""}`;
|
||||
return `Ran ${output.tool_name}${host ? ` on ${host}` : ""}`;
|
||||
if (isDiscoveryOutput(output))
|
||||
return `Connected to ${serviceNameFromHost(serverHost(output.server_url))}`;
|
||||
return "Connection error";
|
||||
return `Discovered ${output.tools.length} tool(s) on ${serverHost(output.server_url)}`;
|
||||
return "MCP error";
|
||||
}
|
||||
case "output-error":
|
||||
return "Connection error";
|
||||
return "MCP error";
|
||||
default:
|
||||
return "Connecting…";
|
||||
return "Connecting to MCP server";
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,118 +0,0 @@
|
||||
import { useBackendAPI } from "@/lib/autogpt-server-api/context";
|
||||
import type { WebSocketNotification } from "@/lib/autogpt-server-api/types";
|
||||
import { useEffect, useRef } from "react";
|
||||
import { useCopilotUIStore } from "./store";
|
||||
|
||||
const ORIGINAL_TITLE = "AutoGPT";
|
||||
const NOTIFICATION_SOUND_PATH = "/sounds/notification.mp3";
|
||||
|
||||
/**
|
||||
* Listens for copilot completion notifications via WebSocket.
|
||||
* Updates the Zustand store, plays a sound, and updates document.title.
|
||||
*/
|
||||
export function useCopilotNotifications(activeSessionID: string | null) {
|
||||
const api = useBackendAPI();
|
||||
const audioRef = useRef<HTMLAudioElement | null>(null);
|
||||
const activeSessionRef = useRef(activeSessionID);
|
||||
activeSessionRef.current = activeSessionID;
|
||||
const windowFocusedRef = useRef(true);
|
||||
|
||||
// Pre-load audio element
|
||||
useEffect(() => {
|
||||
if (typeof window === "undefined") return;
|
||||
const audio = new Audio(NOTIFICATION_SOUND_PATH);
|
||||
audio.volume = 0.5;
|
||||
audioRef.current = audio;
|
||||
}, []);
|
||||
|
||||
// Listen for WebSocket notifications
|
||||
useEffect(() => {
|
||||
function handleNotification(notification: WebSocketNotification) {
|
||||
if (notification.type !== "copilot_completion") return;
|
||||
if (notification.event !== "session_completed") return;
|
||||
|
||||
const sessionID = (notification as Record<string, unknown>).session_id;
|
||||
if (typeof sessionID !== "string") return;
|
||||
|
||||
const state = useCopilotUIStore.getState();
|
||||
|
||||
const isActiveSession = sessionID === activeSessionRef.current;
|
||||
const isUserAway =
|
||||
document.visibilityState === "hidden" || !windowFocusedRef.current;
|
||||
|
||||
// Skip if viewing the active session and it's in focus
|
||||
if (isActiveSession && !isUserAway) return;
|
||||
|
||||
// Skip if we already notified for this session (e.g. WS replay)
|
||||
if (state.completedSessionIDs.has(sessionID)) return;
|
||||
|
||||
// Always update UI state (checkmark + title) regardless of notification setting
|
||||
state.addCompletedSession(sessionID);
|
||||
const count = useCopilotUIStore.getState().completedSessionIDs.size;
|
||||
document.title = `(${count}) Otto is ready - ${ORIGINAL_TITLE}`;
|
||||
|
||||
// Sound and browser notifications are gated by the user setting
|
||||
if (!state.isNotificationsEnabled) return;
|
||||
|
||||
if (state.isSoundEnabled && audioRef.current) {
|
||||
audioRef.current.currentTime = 0;
|
||||
audioRef.current.play().catch(() => {});
|
||||
}
|
||||
|
||||
// Send browser notification when user is away
|
||||
if (
|
||||
typeof Notification !== "undefined" &&
|
||||
Notification.permission === "granted" &&
|
||||
isUserAway
|
||||
) {
|
||||
const n = new Notification("Otto is ready", {
|
||||
body: "A response is waiting for you.",
|
||||
icon: "/favicon.ico",
|
||||
});
|
||||
n.onclick = () => {
|
||||
window.focus();
|
||||
const url = new URL(window.location.href);
|
||||
url.searchParams.set("sessionId", sessionID);
|
||||
window.history.pushState({}, "", url.toString());
|
||||
window.dispatchEvent(new PopStateEvent("popstate"));
|
||||
n.close();
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
const detach = api.onWebSocketMessage("notification", handleNotification);
|
||||
return () => {
|
||||
detach();
|
||||
};
|
||||
}, [api]);
|
||||
|
||||
// Track window focus for browser notifications when app is in background
|
||||
useEffect(() => {
|
||||
function handleFocus() {
|
||||
windowFocusedRef.current = true;
|
||||
if (useCopilotUIStore.getState().completedSessionIDs.size === 0) {
|
||||
document.title = ORIGINAL_TITLE;
|
||||
}
|
||||
}
|
||||
function handleBlur() {
|
||||
windowFocusedRef.current = false;
|
||||
}
|
||||
function handleVisibilityChange() {
|
||||
if (
|
||||
document.visibilityState === "visible" &&
|
||||
useCopilotUIStore.getState().completedSessionIDs.size === 0
|
||||
) {
|
||||
document.title = ORIGINAL_TITLE;
|
||||
}
|
||||
}
|
||||
|
||||
window.addEventListener("focus", handleFocus);
|
||||
window.addEventListener("blur", handleBlur);
|
||||
document.addEventListener("visibilitychange", handleVisibilityChange);
|
||||
return () => {
|
||||
window.removeEventListener("focus", handleFocus);
|
||||
window.removeEventListener("blur", handleBlur);
|
||||
document.removeEventListener("visibilitychange", handleVisibilityChange);
|
||||
};
|
||||
}, []);
|
||||
}
|
||||
@@ -13,48 +13,11 @@ import type { FileUIPart } from "ai";
|
||||
import { useEffect, useRef, useState } from "react";
|
||||
import { useCopilotUIStore } from "./store";
|
||||
import { useChatSession } from "./useChatSession";
|
||||
import { useCopilotNotifications } from "./useCopilotNotifications";
|
||||
import { useCopilotStream } from "./useCopilotStream";
|
||||
|
||||
const TITLE_POLL_INTERVAL_MS = 2_000;
|
||||
const TITLE_POLL_MAX_ATTEMPTS = 5;
|
||||
|
||||
/**
|
||||
* Extract a prompt from the URL hash fragment.
|
||||
* Supports: /copilot#prompt=URL-encoded-text
|
||||
* Optionally auto-submits if ?autosubmit=true is in the query string.
|
||||
* Returns null if no prompt is present.
|
||||
*/
|
||||
function extractPromptFromUrl(): {
|
||||
prompt: string;
|
||||
autosubmit: boolean;
|
||||
} | null {
|
||||
if (typeof window === "undefined") return null;
|
||||
|
||||
const hash = window.location.hash;
|
||||
if (!hash) return null;
|
||||
|
||||
const hashParams = new URLSearchParams(hash.slice(1));
|
||||
const prompt = hashParams.get("prompt");
|
||||
|
||||
if (!prompt || !prompt.trim()) return null;
|
||||
|
||||
const searchParams = new URLSearchParams(window.location.search);
|
||||
const autosubmit = searchParams.get("autosubmit") === "true";
|
||||
|
||||
// Clean up hash + autosubmit param only (preserve other query params)
|
||||
const cleanURL = new URL(window.location.href);
|
||||
cleanURL.hash = "";
|
||||
cleanURL.searchParams.delete("autosubmit");
|
||||
window.history.replaceState(
|
||||
null,
|
||||
"",
|
||||
`${cleanURL.pathname}${cleanURL.search}`,
|
||||
);
|
||||
|
||||
return { prompt: prompt.trim(), autosubmit };
|
||||
}
|
||||
|
||||
interface UploadedFile {
|
||||
file_id: string;
|
||||
name: string;
|
||||
@@ -97,8 +60,6 @@ export function useCopilotPage() {
|
||||
refetchSession,
|
||||
});
|
||||
|
||||
useCopilotNotifications(sessionId);
|
||||
|
||||
// --- Delete session ---
|
||||
const { mutate: deleteSessionMutation, isPending: isDeleting } =
|
||||
useDeleteV2DeleteSession({
|
||||
@@ -163,28 +124,6 @@ export function useCopilotPage() {
|
||||
}
|
||||
}, [sessionId, pendingMessage, sendMessage]);
|
||||
|
||||
// --- Extract prompt from URL hash on mount (e.g. /copilot#prompt=Hello) ---
|
||||
const { setInitialPrompt } = useCopilotUIStore();
|
||||
const hasProcessedUrlPrompt = useRef(false);
|
||||
useEffect(() => {
|
||||
if (hasProcessedUrlPrompt.current) return;
|
||||
|
||||
const urlPrompt = extractPromptFromUrl();
|
||||
if (!urlPrompt) return;
|
||||
|
||||
hasProcessedUrlPrompt.current = true;
|
||||
|
||||
if (urlPrompt.autosubmit) {
|
||||
setPendingMessage(urlPrompt.prompt);
|
||||
void createSession().catch(() => {
|
||||
setPendingMessage(null);
|
||||
setInitialPrompt(urlPrompt.prompt);
|
||||
});
|
||||
} else {
|
||||
setInitialPrompt(urlPrompt.prompt);
|
||||
}
|
||||
}, [createSession, setInitialPrompt]);
|
||||
|
||||
async function uploadFiles(
|
||||
files: File[],
|
||||
sid: string,
|
||||
|
||||
@@ -44,7 +44,6 @@ export function NewAgentLibraryView() {
|
||||
handleSelectRun,
|
||||
handleCountsChange,
|
||||
handleClearSelectedRun,
|
||||
handleScheduleDeleted,
|
||||
onRunInitiated,
|
||||
onTriggerSetup,
|
||||
onScheduleCreated,
|
||||
@@ -197,7 +196,6 @@ export function NewAgentLibraryView() {
|
||||
selectedRunId={activeItem ?? undefined}
|
||||
onSelectRun={handleSelectRun}
|
||||
onClearSelectedRun={handleClearSelectedRun}
|
||||
onScheduleDeleted={handleScheduleDeleted}
|
||||
onTabChange={setActiveTab}
|
||||
onCountsChange={handleCountsChange}
|
||||
/>
|
||||
@@ -208,7 +206,7 @@ export function NewAgentLibraryView() {
|
||||
<SelectedScheduleView
|
||||
agent={agent}
|
||||
scheduleId={activeItem}
|
||||
onScheduleDeleted={handleScheduleDeleted}
|
||||
onClearSelectedRun={handleClearSelectedRun}
|
||||
banner={renderMarketplaceUpdateBanner()}
|
||||
/>
|
||||
) : activeTab === "templates" ? (
|
||||
|
||||
@@ -260,7 +260,7 @@ export function CronScheduler({
|
||||
)}
|
||||
|
||||
{frequency !== "hourly" &&
|
||||
!(frequency === "custom" && customInterval.unit !== "days") && (
|
||||
!(frequency === "custom" && customInterval.unit === "hours") && (
|
||||
<TimeAt
|
||||
value={selectedTime}
|
||||
onChange={setSelectedTime}
|
||||
|
||||
@@ -19,14 +19,14 @@ import { useSelectedScheduleView } from "./useSelectedScheduleView";
|
||||
interface Props {
|
||||
agent: LibraryAgent;
|
||||
scheduleId: string;
|
||||
onScheduleDeleted?: (deletedScheduleId: string) => void;
|
||||
onClearSelectedRun?: () => void;
|
||||
banner?: React.ReactNode;
|
||||
}
|
||||
|
||||
export function SelectedScheduleView({
|
||||
agent,
|
||||
scheduleId,
|
||||
onScheduleDeleted,
|
||||
onClearSelectedRun,
|
||||
banner,
|
||||
}: Props) {
|
||||
const { schedule, isLoading, error } = useSelectedScheduleView(
|
||||
@@ -89,7 +89,7 @@ export function SelectedScheduleView({
|
||||
<SelectedScheduleActions
|
||||
agent={agent}
|
||||
scheduleId={schedule.id}
|
||||
onDeleted={() => onScheduleDeleted?.(schedule.id)}
|
||||
onDeleted={onClearSelectedRun}
|
||||
/>
|
||||
</div>
|
||||
) : null}
|
||||
@@ -168,7 +168,7 @@ export function SelectedScheduleView({
|
||||
<SelectedScheduleActions
|
||||
agent={agent}
|
||||
scheduleId={schedule.id}
|
||||
onDeleted={() => onScheduleDeleted?.(schedule.id)}
|
||||
onDeleted={onClearSelectedRun}
|
||||
/>
|
||||
</div>
|
||||
) : null}
|
||||
|
||||
@@ -28,15 +28,13 @@ export function useSelectedScheduleActions({
|
||||
mutation: {
|
||||
onSuccess: () => {
|
||||
toast({ title: "Schedule deleted" });
|
||||
setShowDeleteDialog(false);
|
||||
|
||||
onDeleted?.();
|
||||
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: getGetV1ListExecutionSchedulesForAGraphQueryOptions(
|
||||
agent.graph_id,
|
||||
).queryKey,
|
||||
});
|
||||
setShowDeleteDialog(false);
|
||||
onDeleted?.();
|
||||
},
|
||||
onError: (error: unknown) =>
|
||||
toast({
|
||||
|
||||
@@ -28,7 +28,6 @@ interface Props {
|
||||
tab?: "runs" | "scheduled" | "templates" | "triggers",
|
||||
) => void;
|
||||
onClearSelectedRun?: () => void;
|
||||
onScheduleDeleted?: (deletedScheduleId: string) => void;
|
||||
onTabChange?: (tab: "runs" | "scheduled" | "templates" | "triggers") => void;
|
||||
onCountsChange?: (info: {
|
||||
runsCount: number;
|
||||
@@ -44,7 +43,6 @@ export function SidebarRunsList({
|
||||
selectedRunId,
|
||||
onSelectRun,
|
||||
onClearSelectedRun,
|
||||
onScheduleDeleted,
|
||||
onTabChange,
|
||||
onCountsChange,
|
||||
}: Props) {
|
||||
@@ -185,7 +183,6 @@ export function SidebarRunsList({
|
||||
agent={agent}
|
||||
selected={selectedRunId === s.id}
|
||||
onClick={() => onSelectRun(s.id, "scheduled")}
|
||||
onDeleted={() => onScheduleDeleted?.(s.id)}
|
||||
/>
|
||||
</div>
|
||||
))
|
||||
|
||||
@@ -39,15 +39,15 @@ export function ScheduleActionsDropdown({ agent, schedule, onDeleted }: Props) {
|
||||
await deleteSchedule({ scheduleId: schedule.id });
|
||||
|
||||
toast({ title: "Schedule deleted" });
|
||||
setShowDeleteDialog(false);
|
||||
|
||||
onDeleted?.();
|
||||
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: getGetV1ListExecutionSchedulesForAGraphQueryOptions(
|
||||
agent.graph_id,
|
||||
).queryKey,
|
||||
});
|
||||
|
||||
setShowDeleteDialog(false);
|
||||
onDeleted?.();
|
||||
} catch (error: unknown) {
|
||||
toast({
|
||||
title: "Failed to delete schedule",
|
||||
|
||||
@@ -47,7 +47,7 @@ export function useMarketplaceUpdate({ agent }: UseMarketplaceUpdateProps) {
|
||||
{
|
||||
query: {
|
||||
// Only fetch if user is the creator
|
||||
enabled: !!(user?.id && agent?.can_access_graph),
|
||||
enabled: !!(user?.id && agent?.owner_user_id === user.id),
|
||||
},
|
||||
},
|
||||
);
|
||||
@@ -90,7 +90,7 @@ export function useMarketplaceUpdate({ agent }: UseMarketplaceUpdateProps) {
|
||||
};
|
||||
}
|
||||
|
||||
const isUserCreator = !!(agent?.can_access_graph && user?.id);
|
||||
const isUserCreator = agent?.owner_user_id === user?.id;
|
||||
|
||||
const submissionsResponse = okData(submissionsData) as any;
|
||||
const agentSubmissions =
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import { useGetV2GetLibraryAgent } from "@/app/api/__generated__/endpoints/library/library";
|
||||
import { useGetV2GetASpecificPreset } from "@/app/api/__generated__/endpoints/presets/presets";
|
||||
import { useGetV1ListExecutionSchedulesForAGraph } from "@/app/api/__generated__/endpoints/schedules/schedules";
|
||||
import { GraphExecutionJobInfo } from "@/app/api/__generated__/models/graphExecutionJobInfo";
|
||||
import { GraphExecutionMeta } from "@/app/api/__generated__/models/graphExecutionMeta";
|
||||
import { LibraryAgentPreset } from "@/app/api/__generated__/models/libraryAgentPreset";
|
||||
@@ -123,37 +122,6 @@ export function useNewAgentLibraryView() {
|
||||
});
|
||||
}
|
||||
|
||||
const { data: schedules } = useGetV1ListExecutionSchedulesForAGraph(
|
||||
agent?.graph_id || "",
|
||||
{
|
||||
query: {
|
||||
enabled: !!agent?.graph_id,
|
||||
select: okData,
|
||||
},
|
||||
},
|
||||
);
|
||||
|
||||
function handleScheduleDeleted(deletedScheduleId: string) {
|
||||
if (activeItem !== deletedScheduleId) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (!schedules) {
|
||||
handleClearSelectedRun();
|
||||
return;
|
||||
}
|
||||
|
||||
const remainingSchedules = schedules.filter(
|
||||
(s) => s.id !== deletedScheduleId,
|
||||
);
|
||||
|
||||
if (remainingSchedules.length > 0) {
|
||||
handleSelectRun(remainingSchedules[0].id, "scheduled");
|
||||
} else {
|
||||
handleClearSelectedRun();
|
||||
}
|
||||
}
|
||||
|
||||
function handleSetActiveTab(
|
||||
tab: "runs" | "scheduled" | "templates" | "triggers",
|
||||
) {
|
||||
@@ -237,7 +205,6 @@ export function useNewAgentLibraryView() {
|
||||
activeTab,
|
||||
setActiveTab: handleSetActiveTab,
|
||||
handleClearSelectedRun,
|
||||
handleScheduleDeleted,
|
||||
handleCountsChange,
|
||||
handleSelectRun,
|
||||
handleSelectSettings,
|
||||
|
||||
@@ -42,14 +42,6 @@ function ResetPasswordContent() {
|
||||
|
||||
if (isExpiredOrUsed) {
|
||||
setShowExpiredMessage(true);
|
||||
// Also show a toast with the Supabase error detail for debugging
|
||||
if (errorDescription) {
|
||||
toast({
|
||||
title: "Link Expired",
|
||||
description: errorDescription,
|
||||
variant: "destructive",
|
||||
});
|
||||
}
|
||||
} else {
|
||||
// Show toast for other errors
|
||||
const errorMessage =
|
||||
|
||||
60
autogpt_platform/frontend/src/app/api/__generated__/models/responseType.ts
generated
Normal file
60
autogpt_platform/frontend/src/app/api/__generated__/models/responseType.ts
generated
Normal file
@@ -0,0 +1,60 @@
|
||||
/**
|
||||
* Generated by orval v7.13.0 🍺
|
||||
* Do not edit manually.
|
||||
* AutoGPT Agent Server
|
||||
* This server is used to execute agents that are created by the AutoGPT system.
|
||||
* OpenAPI spec version: 0.1
|
||||
*/
|
||||
|
||||
/**
|
||||
* Types of tool responses.
|
||||
*/
|
||||
export type ResponseType = typeof ResponseType[keyof typeof ResponseType];
|
||||
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/no-redeclare
|
||||
export const ResponseType = {
|
||||
error: 'error',
|
||||
no_results: 'no_results',
|
||||
need_login: 'need_login',
|
||||
agents_found: 'agents_found',
|
||||
agent_details: 'agent_details',
|
||||
setup_requirements: 'setup_requirements',
|
||||
input_validation_error: 'input_validation_error',
|
||||
execution_started: 'execution_started',
|
||||
agent_output: 'agent_output',
|
||||
understanding_updated: 'understanding_updated',
|
||||
suggested_goal: 'suggested_goal',
|
||||
agent_builder_guide: 'agent_builder_guide',
|
||||
agent_builder_preview: 'agent_builder_preview',
|
||||
agent_builder_saved: 'agent_builder_saved',
|
||||
agent_builder_clarification_needed: 'agent_builder_clarification_needed',
|
||||
agent_builder_validation_result: 'agent_builder_validation_result',
|
||||
agent_builder_fix_result: 'agent_builder_fix_result',
|
||||
block_list: 'block_list',
|
||||
block_details: 'block_details',
|
||||
block_output: 'block_output',
|
||||
mcp_guide: 'mcp_guide',
|
||||
mcp_tools_discovered: 'mcp_tools_discovered',
|
||||
mcp_tool_output: 'mcp_tool_output',
|
||||
doc_search_results: 'doc_search_results',
|
||||
doc_page: 'doc_page',
|
||||
workspace_file_list: 'workspace_file_list',
|
||||
workspace_file_content: 'workspace_file_content',
|
||||
workspace_file_metadata: 'workspace_file_metadata',
|
||||
workspace_file_written: 'workspace_file_written',
|
||||
workspace_file_deleted: 'workspace_file_deleted',
|
||||
folder_created: 'folder_created',
|
||||
folder_list: 'folder_list',
|
||||
folder_updated: 'folder_updated',
|
||||
folder_moved: 'folder_moved',
|
||||
folder_deleted: 'folder_deleted',
|
||||
agents_moved_to_folder: 'agents_moved_to_folder',
|
||||
browser_navigate: 'browser_navigate',
|
||||
browser_act: 'browser_act',
|
||||
browser_screenshot: 'browser_screenshot',
|
||||
bash_exec: 'bash_exec',
|
||||
web_fetch: 'web_fetch',
|
||||
feature_request_search: 'feature_request_search',
|
||||
feature_request_created: 'feature_request_created',
|
||||
} as const;
|
||||
@@ -9,25 +9,6 @@ export async function GET(request: NextRequest) {
|
||||
process.env.NEXT_PUBLIC_FRONTEND_BASE_URL || "http://localhost:3000";
|
||||
|
||||
if (!code) {
|
||||
// Supabase may redirect here with error params instead of a code
|
||||
// (e.g. when the OTP token is expired or already used)
|
||||
const error = searchParams.get("error");
|
||||
const errorCode = searchParams.get("error_code");
|
||||
const errorDescription = searchParams.get("error_description");
|
||||
|
||||
if (error || errorCode || errorDescription) {
|
||||
// Forward raw Supabase error params to the reset-password page,
|
||||
// which already handles classification (expired vs other errors)
|
||||
const params = new URLSearchParams();
|
||||
if (error) params.set("error", error);
|
||||
if (errorCode) params.set("error_code", errorCode);
|
||||
if (errorDescription) params.set("error_description", errorDescription);
|
||||
|
||||
return NextResponse.redirect(
|
||||
`${origin}/reset-password?${params.toString()}`,
|
||||
);
|
||||
}
|
||||
|
||||
return NextResponse.redirect(
|
||||
`${origin}/reset-password?error=${encodeURIComponent("Missing verification code")}`,
|
||||
);
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -57,25 +57,27 @@ async function handleWorkspaceDownload(
|
||||
);
|
||||
}
|
||||
|
||||
// Fully buffer the response before forwarding. Passing response.body as a
|
||||
// ReadableStream causes silent truncation in Next.js / Vercel — the last
|
||||
// ~10 KB of larger files are dropped, corrupting PNGs and truncating CSVs.
|
||||
const buffer = await response.arrayBuffer();
|
||||
|
||||
// Get the content type from the backend response
|
||||
const contentType =
|
||||
response.headers.get("Content-Type") || "application/octet-stream";
|
||||
const contentDisposition = response.headers.get("Content-Disposition");
|
||||
|
||||
// Stream the response body
|
||||
const responseHeaders: Record<string, string> = {
|
||||
"Content-Type": contentType,
|
||||
"Content-Length": String(buffer.byteLength),
|
||||
};
|
||||
|
||||
if (contentDisposition) {
|
||||
responseHeaders["Content-Disposition"] = contentDisposition;
|
||||
}
|
||||
|
||||
return new NextResponse(buffer, {
|
||||
const contentLength = response.headers.get("Content-Length");
|
||||
if (contentLength) {
|
||||
responseHeaders["Content-Length"] = contentLength;
|
||||
}
|
||||
|
||||
// Stream the response body directly instead of buffering in memory
|
||||
return new NextResponse(response.body, {
|
||||
status: 200,
|
||||
headers: responseHeaders,
|
||||
});
|
||||
|
||||
@@ -13,77 +13,26 @@ const PREFERRED_VOICES = [
|
||||
"Samantha",
|
||||
"Karen",
|
||||
"Daniel",
|
||||
"Moira",
|
||||
"Tessa",
|
||||
// Chrome / Android
|
||||
"Google UK English Female",
|
||||
"Google UK English Male",
|
||||
"Google US English",
|
||||
// Edge / Windows OneCore
|
||||
// Edge / Windows
|
||||
"Microsoft Zira",
|
||||
"Microsoft David",
|
||||
"Microsoft Jenny",
|
||||
"Microsoft Aria",
|
||||
"Microsoft Guy",
|
||||
];
|
||||
|
||||
/**
|
||||
* Name fragments that indicate low-quality / robotic synthesis engines.
|
||||
* Matching is case-insensitive on `voice.name`.
|
||||
*/
|
||||
const ROBOTIC_VOICE_INDICATORS = [
|
||||
"espeak",
|
||||
"festival",
|
||||
"mbrola",
|
||||
"flite",
|
||||
"pico",
|
||||
];
|
||||
|
||||
/** Returns true when a voice is likely low-quality / robotic. */
|
||||
function isLikelyRobotic(voice: SpeechSynthesisVoice): boolean {
|
||||
const lower = voice.name.toLowerCase();
|
||||
return ROBOTIC_VOICE_INDICATORS.some((ind) => lower.includes(ind));
|
||||
}
|
||||
|
||||
/**
|
||||
* Selects the best available voice using a 4-tier fallback:
|
||||
* 1. Known high-quality voices by name (PREFERRED_VOICES)
|
||||
* 2. Non-robotic remote/cloud voices
|
||||
* 3. Non-robotic local English voices
|
||||
* 4. Any remaining voice
|
||||
*/
|
||||
function pickBestVoice(
|
||||
voices: SpeechSynthesisVoice[],
|
||||
): SpeechSynthesisVoice | undefined {
|
||||
// 1. Try preferred voices first (known high-quality)
|
||||
function pickBestVoice(): SpeechSynthesisVoice | undefined {
|
||||
const voices = window.speechSynthesis.getVoices();
|
||||
for (const name of PREFERRED_VOICES) {
|
||||
const matches = voices.filter((v) => v.name.includes(name));
|
||||
if (matches.length === 0) continue;
|
||||
|
||||
return (
|
||||
matches.find((v) => !isLikelyRobotic(v) && !v.localService) ||
|
||||
matches.find((v) => !isLikelyRobotic(v)) ||
|
||||
matches.find((v) => !v.localService) ||
|
||||
matches[0]
|
||||
);
|
||||
const match = voices.find((v) => v.name.includes(name));
|
||||
if (match) return match;
|
||||
}
|
||||
|
||||
// 2. Filter out known robotic / low-quality voices
|
||||
const nonRobotic = voices.filter((v) => !isLikelyRobotic(v));
|
||||
const candidates = nonRobotic.length > 0 ? nonRobotic : voices;
|
||||
|
||||
// 3. Prefer remote / cloud-backed voices (usually higher quality)
|
||||
const remote = candidates.filter((v) => !v.localService);
|
||||
if (remote.length > 0) {
|
||||
return remote.find((v) => v.lang.startsWith("en")) || remote[0];
|
||||
}
|
||||
|
||||
// 4. Best remaining local voice: English default → English → default → first
|
||||
// Fallback: prefer any voice flagged as default, or the first English voice
|
||||
return (
|
||||
candidates.find((v) => v.default && v.lang.startsWith("en")) ||
|
||||
candidates.find((v) => v.lang.startsWith("en")) ||
|
||||
candidates.find((v) => v.default) ||
|
||||
candidates[0]
|
||||
voices.find((v) => v.default) ||
|
||||
voices.find((v) => v.lang.startsWith("en")) ||
|
||||
voices[0]
|
||||
);
|
||||
}
|
||||
|
||||
@@ -91,31 +40,9 @@ export function useTextToSpeech(text: string) {
|
||||
const [status, setStatus] = useState<TTSStatus>("idle");
|
||||
const [isSupported, setIsSupported] = useState(false);
|
||||
const utteranceRef = useRef<SpeechSynthesisUtterance | null>(null);
|
||||
const cachedVoicesRef = useRef<SpeechSynthesisVoice[]>([]);
|
||||
|
||||
useEffect(() => {
|
||||
const supported = "speechSynthesis" in window;
|
||||
setIsSupported(supported);
|
||||
|
||||
if (supported) {
|
||||
// Populate cache immediately (may already be available)
|
||||
cachedVoicesRef.current = window.speechSynthesis.getVoices();
|
||||
|
||||
// Listen for voiceschanged to populate cache when voices load async
|
||||
function onVoicesChanged() {
|
||||
cachedVoicesRef.current = window.speechSynthesis.getVoices();
|
||||
}
|
||||
window.speechSynthesis.addEventListener("voiceschanged", onVoicesChanged);
|
||||
|
||||
return () => {
|
||||
window.speechSynthesis.removeEventListener(
|
||||
"voiceschanged",
|
||||
onVoicesChanged,
|
||||
);
|
||||
window.speechSynthesis.cancel();
|
||||
};
|
||||
}
|
||||
|
||||
setIsSupported("speechSynthesis" in window);
|
||||
return () => {
|
||||
window.speechSynthesis?.cancel();
|
||||
};
|
||||
@@ -142,7 +69,7 @@ export function useTextToSpeech(text: string) {
|
||||
|
||||
const utterance = new SpeechSynthesisUtterance(text);
|
||||
|
||||
const voice = pickBestVoice(cachedVoicesRef.current);
|
||||
const voice = pickBestVoice();
|
||||
if (voice) utterance.voice = voice;
|
||||
|
||||
utteranceRef.current = utterance;
|
||||
|
||||
@@ -9,7 +9,6 @@ import { cn } from "@/lib/utils";
|
||||
import { toDisplayName } from "@/providers/agent-credentials/helper";
|
||||
import { APIKeyCredentialsModal } from "./components/APIKeyCredentialsModal/APIKeyCredentialsModal";
|
||||
import { CredentialsFlatView } from "./components/CredentialsFlatView/CredentialsFlatView";
|
||||
import { CredentialTypeSelector } from "./components/CredentialTypeSelector/CredentialTypeSelector";
|
||||
import { HostScopedCredentialsModal } from "./components/HotScopedCredentialsModal/HotScopedCredentialsModal";
|
||||
import { OAuthFlowWaitingModal } from "./components/OAuthWaitingModal/OAuthWaitingModal";
|
||||
import { PasswordCredentialsModal } from "./components/PasswordCredentialsModal/PasswordCredentialsModal";
|
||||
@@ -71,25 +70,20 @@ export function CredentialsInput({
|
||||
supportsOAuth2,
|
||||
supportsUserPassword,
|
||||
supportsHostScoped,
|
||||
hasMultipleCredentialTypes,
|
||||
supportedTypes,
|
||||
userCredentials,
|
||||
systemCredentials,
|
||||
oAuthError,
|
||||
isAPICredentialsModalOpen,
|
||||
isUserPasswordCredentialsModalOpen,
|
||||
isHostScopedCredentialsModalOpen,
|
||||
isCredentialTypeSelectorOpen,
|
||||
isOAuth2FlowInProgress,
|
||||
oAuthPopupController,
|
||||
actionButtonText,
|
||||
setAPICredentialsModalOpen,
|
||||
setUserPasswordCredentialsModalOpen,
|
||||
setHostScopedCredentialsModalOpen,
|
||||
setCredentialTypeSelectorOpen,
|
||||
handleActionButtonClick,
|
||||
handleCredentialSelect,
|
||||
handleOAuthLogin,
|
||||
} = hookData;
|
||||
|
||||
const displayName = toDisplayName(provider);
|
||||
@@ -122,28 +116,7 @@ export function CredentialsInput({
|
||||
|
||||
{!readOnly && (
|
||||
<>
|
||||
{hasMultipleCredentialTypes && (
|
||||
<CredentialTypeSelector
|
||||
schema={schema}
|
||||
open={isCredentialTypeSelectorOpen}
|
||||
onClose={() => setCredentialTypeSelectorOpen(false)}
|
||||
provider={provider}
|
||||
providerName={providerName}
|
||||
supportedTypes={supportedTypes}
|
||||
onCredentialsCreate={(creds) => {
|
||||
onSelectCredential(creds);
|
||||
}}
|
||||
onOAuthLogin={handleOAuthLogin}
|
||||
onOpenPasswordModal={() =>
|
||||
setUserPasswordCredentialsModalOpen(true)
|
||||
}
|
||||
onOpenHostScopedModal={() =>
|
||||
setHostScopedCredentialsModalOpen(true)
|
||||
}
|
||||
siblingInputs={siblingInputs}
|
||||
/>
|
||||
)}
|
||||
{supportsApiKey && !hasMultipleCredentialTypes && (
|
||||
{supportsApiKey && (
|
||||
<APIKeyCredentialsModal
|
||||
schema={schema}
|
||||
open={isAPICredentialsModalOpen}
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import { IconKey } from "@/components/__legacy__/ui/icons";
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import {
|
||||
DropdownMenu,
|
||||
@@ -8,12 +9,9 @@ import {
|
||||
import { cn } from "@/lib/utils";
|
||||
import { CaretDownIcon, DotsThreeVertical } from "@phosphor-icons/react";
|
||||
import { useEffect, useRef, useState } from "react";
|
||||
import { CredentialsType } from "@/lib/autogpt-server-api/types";
|
||||
import {
|
||||
fallbackIcon,
|
||||
getCredentialDisplayName,
|
||||
getCredentialTypeIcon,
|
||||
getCredentialTypeLabel,
|
||||
MASKED_KEY_LENGTH,
|
||||
providerIcons,
|
||||
} from "../../helpers";
|
||||
@@ -49,16 +47,6 @@ export function CredentialRow({
|
||||
variant = "default",
|
||||
}: CredentialRowProps) {
|
||||
const ProviderIcon = providerIcons[provider] || fallbackIcon;
|
||||
const isRealCredentialType = [
|
||||
"api_key",
|
||||
"oauth2",
|
||||
"user_password",
|
||||
"host_scoped",
|
||||
].includes(credential.type);
|
||||
const credType = credential.type as CredentialsType;
|
||||
const TypeIcon = isRealCredentialType
|
||||
? getCredentialTypeIcon(credType, provider)
|
||||
: fallbackIcon;
|
||||
const isNodeVariant = variant === "node";
|
||||
const containerRef = useRef<HTMLDivElement>(null);
|
||||
const [showMaskedKey, setShowMaskedKey] = useState(true);
|
||||
@@ -104,29 +92,22 @@ export function CredentialRow({
|
||||
<div className="flex h-6 w-6 shrink-0 items-center justify-center rounded-full bg-gray-900">
|
||||
<ProviderIcon className="h-3 w-3 text-white" />
|
||||
</div>
|
||||
<TypeIcon className="h-5 w-5 shrink-0 text-zinc-800" />
|
||||
<IconKey className="h-5 w-5 shrink-0 text-zinc-800" />
|
||||
<div
|
||||
className={cn(
|
||||
"relative flex min-w-0 flex-1 flex-nowrap items-center gap-4",
|
||||
isNodeVariant && "overflow-hidden",
|
||||
)}
|
||||
>
|
||||
<div className="flex min-w-0 flex-1 items-center gap-2">
|
||||
<Text
|
||||
variant="body"
|
||||
className={cn(
|
||||
"min-w-0 shrink tracking-tight",
|
||||
isNodeVariant ? "truncate" : "line-clamp-1 text-ellipsis",
|
||||
)}
|
||||
>
|
||||
{getCredentialDisplayName(credential, displayName)}
|
||||
</Text>
|
||||
{isRealCredentialType && (
|
||||
<span className="shrink-0 rounded bg-zinc-100 px-1.5 py-0.5 text-[0.625rem] font-medium leading-tight text-zinc-500">
|
||||
{getCredentialTypeLabel(credType)}
|
||||
</span>
|
||||
<Text
|
||||
variant="body"
|
||||
className={cn(
|
||||
"min-w-0 flex-1 tracking-tight",
|
||||
isNodeVariant ? "truncate" : "line-clamp-1 text-ellipsis",
|
||||
)}
|
||||
</div>
|
||||
>
|
||||
{getCredentialDisplayName(credential, displayName)}
|
||||
</Text>
|
||||
{!(asSelectTrigger && isNodeVariant) && showMaskedKey && (
|
||||
<Text
|
||||
variant="large"
|
||||
|
||||
@@ -1,298 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import {
|
||||
Form,
|
||||
FormDescription,
|
||||
FormField,
|
||||
} from "@/components/__legacy__/ui/form";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { Input } from "@/components/atoms/Input/Input";
|
||||
import { Dialog } from "@/components/molecules/Dialog/Dialog";
|
||||
import {
|
||||
TabsLine,
|
||||
TabsLineContent,
|
||||
TabsLineList,
|
||||
TabsLineTrigger,
|
||||
} from "@/components/molecules/TabsLine/TabsLine";
|
||||
import { Skeleton } from "@/components/atoms/Skeleton/Skeleton";
|
||||
import {
|
||||
BlockIOCredentialsSubSchema,
|
||||
CredentialsMetaInput,
|
||||
CredentialsType,
|
||||
} from "@/lib/autogpt-server-api/types";
|
||||
import { useAPIKeyCredentialsModal } from "../APIKeyCredentialsModal/useAPIKeyCredentialsModal";
|
||||
import { getCredentialTypeIcon, getCredentialTypeLabel } from "../../helpers";
|
||||
|
||||
type Props = {
|
||||
schema: BlockIOCredentialsSubSchema;
|
||||
open: boolean;
|
||||
onClose: () => void;
|
||||
provider: string;
|
||||
providerName: string;
|
||||
supportedTypes: CredentialsType[];
|
||||
onCredentialsCreate: (creds: CredentialsMetaInput) => void;
|
||||
onOAuthLogin: () => void;
|
||||
onOpenPasswordModal: () => void;
|
||||
onOpenHostScopedModal: () => void;
|
||||
siblingInputs?: Record<string, unknown>;
|
||||
};
|
||||
|
||||
export function CredentialTypeSelector({
|
||||
schema,
|
||||
open,
|
||||
onClose,
|
||||
provider,
|
||||
providerName,
|
||||
supportedTypes,
|
||||
onCredentialsCreate,
|
||||
onOAuthLogin,
|
||||
onOpenPasswordModal,
|
||||
onOpenHostScopedModal,
|
||||
siblingInputs,
|
||||
}: Props) {
|
||||
const defaultTab = supportedTypes[0];
|
||||
|
||||
return (
|
||||
<Dialog
|
||||
title={`Add credential for ${providerName}`}
|
||||
controlled={{
|
||||
isOpen: open,
|
||||
set: (isOpen) => {
|
||||
if (!isOpen) onClose();
|
||||
},
|
||||
}}
|
||||
onClose={onClose}
|
||||
styling={{ maxWidth: "28rem" }}
|
||||
>
|
||||
<Dialog.Content>
|
||||
<TabsLine defaultValue={defaultTab}>
|
||||
<TabsLineList>
|
||||
{supportedTypes.map((type) => {
|
||||
const Icon = getCredentialTypeIcon(type, provider);
|
||||
return (
|
||||
<TabsLineTrigger
|
||||
key={type}
|
||||
value={type}
|
||||
className="inline-flex items-center gap-1.5"
|
||||
>
|
||||
<Icon size={16} />
|
||||
{getCredentialTypeLabel(type)}
|
||||
</TabsLineTrigger>
|
||||
);
|
||||
})}
|
||||
</TabsLineList>
|
||||
|
||||
{supportedTypes.includes("oauth2") && (
|
||||
<TabsLineContent value="oauth2">
|
||||
<OAuthTabContent
|
||||
providerName={providerName}
|
||||
onOAuthLogin={() => {
|
||||
onClose();
|
||||
onOAuthLogin();
|
||||
}}
|
||||
/>
|
||||
</TabsLineContent>
|
||||
)}
|
||||
|
||||
{supportedTypes.includes("api_key") && (
|
||||
<TabsLineContent value="api_key">
|
||||
<APIKeyTabContent
|
||||
schema={schema}
|
||||
siblingInputs={siblingInputs}
|
||||
onCredentialsCreate={(creds) => {
|
||||
onCredentialsCreate(creds);
|
||||
onClose();
|
||||
}}
|
||||
/>
|
||||
</TabsLineContent>
|
||||
)}
|
||||
|
||||
{supportedTypes.includes("user_password") && (
|
||||
<TabsLineContent value="user_password">
|
||||
<SimpleActionTab
|
||||
description="Add a username and password credential."
|
||||
buttonLabel="Enter credentials"
|
||||
onClick={() => {
|
||||
onClose();
|
||||
onOpenPasswordModal();
|
||||
}}
|
||||
/>
|
||||
</TabsLineContent>
|
||||
)}
|
||||
|
||||
{supportedTypes.includes("host_scoped") && (
|
||||
<TabsLineContent value="host_scoped">
|
||||
<SimpleActionTab
|
||||
description="Add host-scoped headers for authentication."
|
||||
buttonLabel="Add headers"
|
||||
onClick={() => {
|
||||
onClose();
|
||||
onOpenHostScopedModal();
|
||||
}}
|
||||
/>
|
||||
</TabsLineContent>
|
||||
)}
|
||||
</TabsLine>
|
||||
</Dialog.Content>
|
||||
</Dialog>
|
||||
);
|
||||
}
|
||||
|
||||
type OAuthTabContentProps = {
|
||||
providerName: string;
|
||||
onOAuthLogin: () => void;
|
||||
};
|
||||
|
||||
function OAuthTabContent({ providerName, onOAuthLogin }: OAuthTabContentProps) {
|
||||
return (
|
||||
<div className="space-y-4">
|
||||
<p className="text-sm text-zinc-600">
|
||||
Sign in with your {providerName} account using OAuth.
|
||||
</p>
|
||||
<Button
|
||||
variant="primary"
|
||||
size="small"
|
||||
onClick={onOAuthLogin}
|
||||
type="button"
|
||||
>
|
||||
Sign in with {providerName}
|
||||
</Button>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
type APIKeyTabContentProps = {
|
||||
schema: BlockIOCredentialsSubSchema;
|
||||
siblingInputs?: Record<string, unknown>;
|
||||
onCredentialsCreate: (creds: CredentialsMetaInput) => void;
|
||||
};
|
||||
|
||||
function APIKeyTabContent({
|
||||
schema,
|
||||
siblingInputs,
|
||||
onCredentialsCreate,
|
||||
}: APIKeyTabContentProps) {
|
||||
const {
|
||||
form,
|
||||
isLoading,
|
||||
isSubmitting,
|
||||
supportsApiKey,
|
||||
schemaDescription,
|
||||
onSubmit,
|
||||
} = useAPIKeyCredentialsModal({ schema, siblingInputs, onCredentialsCreate });
|
||||
|
||||
if (!supportsApiKey && !isLoading) {
|
||||
return null;
|
||||
}
|
||||
|
||||
if (isLoading) {
|
||||
return (
|
||||
<div className="space-y-4">
|
||||
<Skeleton className="h-4 w-3/4" />
|
||||
<div className="space-y-2">
|
||||
<Skeleton className="h-9 w-full" />
|
||||
<Skeleton className="h-9 w-full" />
|
||||
<Skeleton className="h-9 w-full" />
|
||||
<Skeleton className="h-9 w-68" />
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="space-y-4">
|
||||
{schemaDescription && (
|
||||
<p className="text-sm text-zinc-600">{schemaDescription}</p>
|
||||
)}
|
||||
|
||||
<Form {...form}>
|
||||
<form onSubmit={form.handleSubmit(onSubmit)} className="space-y-2">
|
||||
<FormField
|
||||
control={form.control}
|
||||
name="title"
|
||||
render={({ field }) => (
|
||||
<Input
|
||||
id="title"
|
||||
label="Name"
|
||||
type="text"
|
||||
placeholder="Enter a name for this API Key..."
|
||||
{...field}
|
||||
/>
|
||||
)}
|
||||
/>
|
||||
<FormField
|
||||
control={form.control}
|
||||
name="apiKey"
|
||||
render={({ field }) => (
|
||||
<Input
|
||||
id="apiKey"
|
||||
label="API Key"
|
||||
type="password"
|
||||
placeholder="Enter API Key..."
|
||||
hint={
|
||||
schema.credentials_scopes ? (
|
||||
<FormDescription>
|
||||
Required scope(s) for this block:{" "}
|
||||
{schema.credentials_scopes?.map((s, i, a) => (
|
||||
<span key={i}>
|
||||
<code className="text-xs font-bold">{s}</code>
|
||||
{i < a.length - 1 && ", "}
|
||||
</span>
|
||||
))}
|
||||
</FormDescription>
|
||||
) : null
|
||||
}
|
||||
{...field}
|
||||
/>
|
||||
)}
|
||||
/>
|
||||
<FormField
|
||||
control={form.control}
|
||||
name="expiresAt"
|
||||
render={({ field }) => (
|
||||
<Input
|
||||
id="expiresAt"
|
||||
label="Expiration Date"
|
||||
type="datetime-local"
|
||||
placeholder="Select expiration date..."
|
||||
value={field.value}
|
||||
onChange={(e) => field.onChange(e.target.value)}
|
||||
onBlur={field.onBlur}
|
||||
name={field.name}
|
||||
/>
|
||||
)}
|
||||
/>
|
||||
<Button
|
||||
type="submit"
|
||||
className="min-w-68"
|
||||
loading={isSubmitting}
|
||||
disabled={isSubmitting}
|
||||
>
|
||||
Add API Key
|
||||
</Button>
|
||||
</form>
|
||||
</Form>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
type SimpleActionTabProps = {
|
||||
description: string;
|
||||
buttonLabel: string;
|
||||
onClick: () => void;
|
||||
};
|
||||
|
||||
function SimpleActionTab({
|
||||
description,
|
||||
buttonLabel,
|
||||
onClick,
|
||||
}: SimpleActionTabProps) {
|
||||
return (
|
||||
<div className="space-y-4">
|
||||
<p className="text-sm text-zinc-600">{description}</p>
|
||||
<Button variant="primary" size="small" onClick={onClick} type="button">
|
||||
{buttonLabel}
|
||||
</Button>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -1,9 +1,5 @@
|
||||
import { CredentialsMetaInput } from "@/app/api/__generated__/models/credentialsMetaInput";
|
||||
import { CredentialsType } from "@/lib/autogpt-server-api/types";
|
||||
import {
|
||||
getCredentialDisplayName,
|
||||
getCredentialTypeLabel,
|
||||
} from "../../helpers";
|
||||
import { getCredentialDisplayName } from "../../helpers";
|
||||
import { CredentialRow } from "../CredentialRow/CredentialRow";
|
||||
|
||||
interface Props {
|
||||
@@ -90,8 +86,7 @@ export function CredentialsSelect({
|
||||
)}
|
||||
{credentials.map((credential) => (
|
||||
<option key={credential.id} value={credential.id}>
|
||||
{getCredentialDisplayName(credential, displayName)} (
|
||||
{getCredentialTypeLabel(credential.type as CredentialsType)})
|
||||
{getCredentialDisplayName(credential, displayName)}
|
||||
</option>
|
||||
))}
|
||||
</select>
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user