mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-02-06 04:45:10 -05:00
Compare commits
1 Commits
pwuts/secr
...
refactor/c
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7f7a7067ec |
3
autogpt_platform/backend/.gitignore
vendored
3
autogpt_platform/backend/.gitignore
vendored
@@ -19,6 +19,3 @@ load-tests/*.json
|
||||
load-tests/*.log
|
||||
load-tests/node_modules/*
|
||||
migrations/*/rollback*.sql
|
||||
|
||||
# Workspace files
|
||||
workspaces/
|
||||
|
||||
@@ -33,7 +33,7 @@ from backend.data.understanding import (
|
||||
get_business_understanding,
|
||||
)
|
||||
from backend.util.exceptions import NotFoundError
|
||||
from backend.util.settings import AppEnvironment, Settings
|
||||
from backend.util.settings import Settings
|
||||
|
||||
from . import db as chat_db
|
||||
from . import stream_registry
|
||||
@@ -222,18 +222,8 @@ async def _get_system_prompt_template(context: str) -> str:
|
||||
try:
|
||||
# cache_ttl_seconds=0 disables SDK caching to always get the latest prompt
|
||||
# Use asyncio.to_thread to avoid blocking the event loop
|
||||
# In non-production environments, fetch the latest prompt version
|
||||
# instead of the production-labeled version for easier testing
|
||||
label = (
|
||||
None
|
||||
if settings.config.app_env == AppEnvironment.PRODUCTION
|
||||
else "latest"
|
||||
)
|
||||
prompt = await asyncio.to_thread(
|
||||
langfuse.get_prompt,
|
||||
config.langfuse_prompt_name,
|
||||
label=label,
|
||||
cache_ttl_seconds=0,
|
||||
langfuse.get_prompt, config.langfuse_prompt_name, cache_ttl_seconds=0
|
||||
)
|
||||
return prompt.compile(users_information=context)
|
||||
except Exception as e:
|
||||
@@ -628,9 +618,6 @@ async def stream_chat_completion(
|
||||
total_tokens=chunk.totalTokens,
|
||||
)
|
||||
)
|
||||
elif isinstance(chunk, StreamHeartbeat):
|
||||
# Pass through heartbeat to keep SSE connection alive
|
||||
yield chunk
|
||||
else:
|
||||
logger.error(f"Unknown chunk type: {type(chunk)}", exc_info=True)
|
||||
|
||||
|
||||
@@ -7,7 +7,15 @@ from typing import Any, NotRequired, TypedDict
|
||||
|
||||
from backend.api.features.library import db as library_db
|
||||
from backend.api.features.store import db as store_db
|
||||
from backend.data.graph import Graph, Link, Node, get_graph, get_store_listed_graphs
|
||||
from backend.data.graph import (
|
||||
Graph,
|
||||
Link,
|
||||
Node,
|
||||
create_graph,
|
||||
get_graph,
|
||||
get_graph_all_versions,
|
||||
get_store_listed_graphs,
|
||||
)
|
||||
from backend.util.exceptions import DatabaseError, NotFoundError
|
||||
|
||||
from .service import (
|
||||
@@ -20,6 +28,8 @@ from .service import (
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
AGENT_EXECUTOR_BLOCK_ID = "e189baac-8c20-45a1-94a7-55177ea42565"
|
||||
|
||||
|
||||
class ExecutionSummary(TypedDict):
|
||||
"""Summary of a single execution for quality assessment."""
|
||||
@@ -659,6 +669,45 @@ def json_to_graph(agent_json: dict[str, Any]) -> Graph:
|
||||
)
|
||||
|
||||
|
||||
def _reassign_node_ids(graph: Graph) -> None:
|
||||
"""Reassign all node and link IDs to new UUIDs.
|
||||
|
||||
This is needed when creating a new version to avoid unique constraint violations.
|
||||
"""
|
||||
id_map = {node.id: str(uuid.uuid4()) for node in graph.nodes}
|
||||
|
||||
for node in graph.nodes:
|
||||
node.id = id_map[node.id]
|
||||
|
||||
for link in graph.links:
|
||||
link.id = str(uuid.uuid4())
|
||||
if link.source_id in id_map:
|
||||
link.source_id = id_map[link.source_id]
|
||||
if link.sink_id in id_map:
|
||||
link.sink_id = id_map[link.sink_id]
|
||||
|
||||
|
||||
def _populate_agent_executor_user_ids(agent_json: dict[str, Any], user_id: str) -> None:
|
||||
"""Populate user_id in AgentExecutorBlock nodes.
|
||||
|
||||
The external agent generator creates AgentExecutorBlock nodes with empty user_id.
|
||||
This function fills in the actual user_id so sub-agents run with correct permissions.
|
||||
|
||||
Args:
|
||||
agent_json: Agent JSON dict (modified in place)
|
||||
user_id: User ID to set
|
||||
"""
|
||||
for node in agent_json.get("nodes", []):
|
||||
if node.get("block_id") == AGENT_EXECUTOR_BLOCK_ID:
|
||||
input_default = node.get("input_default") or {}
|
||||
if not input_default.get("user_id"):
|
||||
input_default["user_id"] = user_id
|
||||
node["input_default"] = input_default
|
||||
logger.debug(
|
||||
f"Set user_id for AgentExecutorBlock node {node.get('id')}"
|
||||
)
|
||||
|
||||
|
||||
async def save_agent_to_library(
|
||||
agent_json: dict[str, Any], user_id: str, is_update: bool = False
|
||||
) -> tuple[Graph, Any]:
|
||||
@@ -672,10 +721,35 @@ async def save_agent_to_library(
|
||||
Returns:
|
||||
Tuple of (created Graph, LibraryAgent)
|
||||
"""
|
||||
# Populate user_id in AgentExecutorBlock nodes before conversion
|
||||
_populate_agent_executor_user_ids(agent_json, user_id)
|
||||
|
||||
graph = json_to_graph(agent_json)
|
||||
|
||||
if is_update:
|
||||
return await library_db.update_graph_in_library(graph, user_id)
|
||||
return await library_db.create_graph_in_library(graph, user_id)
|
||||
if graph.id:
|
||||
existing_versions = await get_graph_all_versions(graph.id, user_id)
|
||||
if existing_versions:
|
||||
latest_version = max(v.version for v in existing_versions)
|
||||
graph.version = latest_version + 1
|
||||
_reassign_node_ids(graph)
|
||||
logger.info(f"Updating agent {graph.id} to version {graph.version}")
|
||||
else:
|
||||
graph.id = str(uuid.uuid4())
|
||||
graph.version = 1
|
||||
_reassign_node_ids(graph)
|
||||
logger.info(f"Creating new agent with ID {graph.id}")
|
||||
|
||||
created_graph = await create_graph(graph, user_id)
|
||||
|
||||
library_agents = await library_db.create_library_agent(
|
||||
graph=created_graph,
|
||||
user_id=user_id,
|
||||
sensitive_action_safe_mode=True,
|
||||
create_library_agents_for_sub_graphs=False,
|
||||
)
|
||||
|
||||
return created_graph, library_agents[0]
|
||||
|
||||
|
||||
def graph_to_json(graph: Graph) -> dict[str, Any]:
|
||||
|
||||
@@ -206,9 +206,9 @@ async def search_agents(
|
||||
]
|
||||
)
|
||||
no_results_msg = (
|
||||
f"No agents found matching '{query}'. Let the user know they can try different keywords or browse the marketplace. Also let them know you can create a custom agent for them based on their needs."
|
||||
f"No agents found matching '{query}'. Try different keywords or browse the marketplace."
|
||||
if source == "marketplace"
|
||||
else f"No agents matching '{query}' found in your library. Let the user know you can create a custom agent for them based on their needs."
|
||||
else f"No agents matching '{query}' found in your library."
|
||||
)
|
||||
return NoResultsResponse(
|
||||
message=no_results_msg, session_id=session_id, suggestions=suggestions
|
||||
@@ -224,10 +224,10 @@ async def search_agents(
|
||||
message = (
|
||||
"Now you have found some options for the user to choose from. "
|
||||
"You can add a link to a recommended agent at: /marketplace/agent/agent_id "
|
||||
"Please ask the user if they would like to use any of these agents. Let the user know we can create a custom agent for them based on their needs."
|
||||
"Please ask the user if they would like to use any of these agents."
|
||||
if source == "marketplace"
|
||||
else "Found agents in the user's library. You can provide a link to view an agent at: "
|
||||
"/library/agents/{agent_id}. Use agent_output to get execution results, or run_agent to execute. Let the user know we can create a custom agent for them based on their needs."
|
||||
"/library/agents/{agent_id}. Use agent_output to get execution results, or run_agent to execute."
|
||||
)
|
||||
|
||||
return AgentsFoundResponse(
|
||||
|
||||
@@ -3,6 +3,8 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.store import db as store_db
|
||||
from backend.api.features.store.exceptions import AgentNotFoundError
|
||||
@@ -27,6 +29,23 @@ from .models import (
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CustomizeAgentInput(BaseModel):
|
||||
"""Input parameters for the customize_agent tool."""
|
||||
|
||||
agent_id: str = ""
|
||||
modifications: str = ""
|
||||
context: str = ""
|
||||
save: bool = True
|
||||
|
||||
@field_validator("agent_id", "modifications", "context", mode="before")
|
||||
@classmethod
|
||||
def strip_strings(cls, v: Any) -> str:
|
||||
"""Strip whitespace from string fields."""
|
||||
if isinstance(v, str):
|
||||
return v.strip()
|
||||
return v if v is not None else ""
|
||||
|
||||
|
||||
class CustomizeAgentTool(BaseTool):
|
||||
"""Tool for customizing marketplace/template agents using natural language."""
|
||||
|
||||
@@ -92,7 +111,7 @@ class CustomizeAgentTool(BaseTool):
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
**kwargs: Any,
|
||||
) -> ToolResponseBase:
|
||||
"""Execute the customize_agent tool.
|
||||
|
||||
@@ -102,20 +121,17 @@ class CustomizeAgentTool(BaseTool):
|
||||
3. Call customize_template with the modification request
|
||||
4. Preview or save based on the save parameter
|
||||
"""
|
||||
agent_id = kwargs.get("agent_id", "").strip()
|
||||
modifications = kwargs.get("modifications", "").strip()
|
||||
context = kwargs.get("context", "")
|
||||
save = kwargs.get("save", True)
|
||||
params = CustomizeAgentInput(**kwargs)
|
||||
session_id = session.session_id if session else None
|
||||
|
||||
if not agent_id:
|
||||
if not params.agent_id:
|
||||
return ErrorResponse(
|
||||
message="Please provide the marketplace agent ID (e.g., 'creator/agent-name').",
|
||||
error="missing_agent_id",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if not modifications:
|
||||
if not params.modifications:
|
||||
return ErrorResponse(
|
||||
message="Please describe how you want to customize this agent.",
|
||||
error="missing_modifications",
|
||||
@@ -123,11 +139,11 @@ class CustomizeAgentTool(BaseTool):
|
||||
)
|
||||
|
||||
# Parse agent_id in format "creator/slug"
|
||||
parts = [p.strip() for p in agent_id.split("/")]
|
||||
parts = params.agent_id.split("/")
|
||||
if len(parts) != 2 or not parts[0] or not parts[1]:
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
f"Invalid agent ID format: '{agent_id}'. "
|
||||
f"Invalid agent ID format: '{params.agent_id}'. "
|
||||
"Expected format is 'creator/agent-name' "
|
||||
"(e.g., 'autogpt/newsletter-writer')."
|
||||
),
|
||||
@@ -145,14 +161,14 @@ class CustomizeAgentTool(BaseTool):
|
||||
except AgentNotFoundError:
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
f"Could not find marketplace agent '{agent_id}'. "
|
||||
f"Could not find marketplace agent '{params.agent_id}'. "
|
||||
"Please check the agent ID and try again."
|
||||
),
|
||||
error="agent_not_found",
|
||||
session_id=session_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching marketplace agent {agent_id}: {e}")
|
||||
logger.error(f"Error fetching marketplace agent {params.agent_id}: {e}")
|
||||
return ErrorResponse(
|
||||
message="Failed to fetch the marketplace agent. Please try again.",
|
||||
error="fetch_error",
|
||||
@@ -162,7 +178,7 @@ class CustomizeAgentTool(BaseTool):
|
||||
if not agent_details.store_listing_version_id:
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
f"The agent '{agent_id}' does not have an available version. "
|
||||
f"The agent '{params.agent_id}' does not have an available version. "
|
||||
"Please try a different agent."
|
||||
),
|
||||
error="no_version_available",
|
||||
@@ -174,7 +190,7 @@ class CustomizeAgentTool(BaseTool):
|
||||
graph = await store_db.get_agent(agent_details.store_listing_version_id)
|
||||
template_agent = graph_to_json(graph)
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching agent graph for {agent_id}: {e}")
|
||||
logger.error(f"Error fetching agent graph for {params.agent_id}: {e}")
|
||||
return ErrorResponse(
|
||||
message="Failed to fetch the agent configuration. Please try again.",
|
||||
error="graph_fetch_error",
|
||||
@@ -185,8 +201,8 @@ class CustomizeAgentTool(BaseTool):
|
||||
try:
|
||||
result = await customize_template(
|
||||
template_agent=template_agent,
|
||||
modification_request=modifications,
|
||||
context=context,
|
||||
modification_request=params.modifications,
|
||||
context=params.context,
|
||||
)
|
||||
except AgentGeneratorNotConfiguredError:
|
||||
return ErrorResponse(
|
||||
@@ -198,7 +214,7 @@ class CustomizeAgentTool(BaseTool):
|
||||
session_id=session_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error calling customize_template for {agent_id}: {e}")
|
||||
logger.error(f"Error calling customize_template for {params.agent_id}: {e}")
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
"Failed to customize the agent due to a service error. "
|
||||
@@ -219,55 +235,25 @@ class CustomizeAgentTool(BaseTool):
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Handle error response
|
||||
if isinstance(result, dict) and result.get("type") == "error":
|
||||
error_msg = result.get("error", "Unknown error")
|
||||
error_type = result.get("error_type", "unknown")
|
||||
user_message = get_user_message_for_error(
|
||||
error_type,
|
||||
operation="customize the agent",
|
||||
llm_parse_message=(
|
||||
"The AI had trouble customizing the agent. "
|
||||
"Please try again or simplify your request."
|
||||
),
|
||||
validation_message=(
|
||||
"The customized agent failed validation. "
|
||||
"Please try rephrasing your request."
|
||||
),
|
||||
error_details=error_msg,
|
||||
)
|
||||
return ErrorResponse(
|
||||
message=user_message,
|
||||
error=f"customization_failed:{error_type}",
|
||||
session_id=session_id,
|
||||
)
|
||||
# Handle response using match/case for cleaner pattern matching
|
||||
return await self._handle_customization_result(
|
||||
result=result,
|
||||
params=params,
|
||||
agent_details=agent_details,
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Handle clarifying questions
|
||||
if isinstance(result, dict) and result.get("type") == "clarifying_questions":
|
||||
questions = result.get("questions") or []
|
||||
if not isinstance(questions, list):
|
||||
logger.error(
|
||||
f"Unexpected clarifying questions format: {type(questions)}"
|
||||
)
|
||||
questions = []
|
||||
return ClarificationNeededResponse(
|
||||
message=(
|
||||
"I need some more information to customize this agent. "
|
||||
"Please answer the following questions:"
|
||||
),
|
||||
questions=[
|
||||
ClarifyingQuestion(
|
||||
question=q.get("question", ""),
|
||||
keyword=q.get("keyword", ""),
|
||||
example=q.get("example"),
|
||||
)
|
||||
for q in questions
|
||||
if isinstance(q, dict)
|
||||
],
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Result should be the customized agent JSON
|
||||
async def _handle_customization_result(
|
||||
self,
|
||||
result: dict[str, Any],
|
||||
params: CustomizeAgentInput,
|
||||
agent_details: Any,
|
||||
user_id: str | None,
|
||||
session_id: str | None,
|
||||
) -> ToolResponseBase:
|
||||
"""Handle the result from customize_template using pattern matching."""
|
||||
# Ensure result is a dict
|
||||
if not isinstance(result, dict):
|
||||
logger.error(f"Unexpected customize_template response type: {type(result)}")
|
||||
return ErrorResponse(
|
||||
@@ -276,8 +262,77 @@ class CustomizeAgentTool(BaseTool):
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
customized_agent = result
|
||||
result_type = result.get("type")
|
||||
|
||||
match result_type:
|
||||
case "error":
|
||||
error_msg = result.get("error", "Unknown error")
|
||||
error_type = result.get("error_type", "unknown")
|
||||
user_message = get_user_message_for_error(
|
||||
error_type,
|
||||
operation="customize the agent",
|
||||
llm_parse_message=(
|
||||
"The AI had trouble customizing the agent. "
|
||||
"Please try again or simplify your request."
|
||||
),
|
||||
validation_message=(
|
||||
"The customized agent failed validation. "
|
||||
"Please try rephrasing your request."
|
||||
),
|
||||
error_details=error_msg,
|
||||
)
|
||||
return ErrorResponse(
|
||||
message=user_message,
|
||||
error=f"customization_failed:{error_type}",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
case "clarifying_questions":
|
||||
questions_data = result.get("questions") or []
|
||||
if not isinstance(questions_data, list):
|
||||
logger.error(
|
||||
f"Unexpected clarifying questions format: {type(questions_data)}"
|
||||
)
|
||||
questions_data = []
|
||||
|
||||
questions = [
|
||||
ClarifyingQuestion(
|
||||
question=q.get("question", "") if isinstance(q, dict) else "",
|
||||
keyword=q.get("keyword", "") if isinstance(q, dict) else "",
|
||||
example=q.get("example") if isinstance(q, dict) else None,
|
||||
)
|
||||
for q in questions_data
|
||||
if isinstance(q, dict)
|
||||
]
|
||||
|
||||
return ClarificationNeededResponse(
|
||||
message=(
|
||||
"I need some more information to customize this agent. "
|
||||
"Please answer the following questions:"
|
||||
),
|
||||
questions=questions,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
case _:
|
||||
# Default case: result is the customized agent JSON
|
||||
return await self._save_or_preview_agent(
|
||||
customized_agent=result,
|
||||
params=params,
|
||||
agent_details=agent_details,
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
async def _save_or_preview_agent(
|
||||
self,
|
||||
customized_agent: dict[str, Any],
|
||||
params: CustomizeAgentInput,
|
||||
agent_details: Any,
|
||||
user_id: str | None,
|
||||
session_id: str | None,
|
||||
) -> ToolResponseBase:
|
||||
"""Save or preview the customized agent based on params.save."""
|
||||
agent_name = customized_agent.get(
|
||||
"name", f"Customized {agent_details.agent_name}"
|
||||
)
|
||||
@@ -287,7 +342,7 @@ class CustomizeAgentTool(BaseTool):
|
||||
node_count = len(nodes) if isinstance(nodes, list) else 0
|
||||
link_count = len(links) if isinstance(links, list) else 0
|
||||
|
||||
if not save:
|
||||
if not params.save:
|
||||
return AgentPreviewResponse(
|
||||
message=(
|
||||
f"I've customized the agent '{agent_details.agent_name}'. "
|
||||
|
||||
@@ -8,12 +8,7 @@ from backend.api.features.library import model as library_model
|
||||
from backend.api.features.store import db as store_db
|
||||
from backend.data import graph as graph_db
|
||||
from backend.data.graph import GraphModel
|
||||
from backend.data.model import (
|
||||
CredentialsFieldInfo,
|
||||
CredentialsMetaInput,
|
||||
HostScopedCredentials,
|
||||
OAuth2Credentials,
|
||||
)
|
||||
from backend.data.model import Credentials, CredentialsFieldInfo, CredentialsMetaInput
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.util.exceptions import NotFoundError
|
||||
|
||||
@@ -128,7 +123,7 @@ def build_missing_credentials_from_graph(
|
||||
|
||||
return {
|
||||
field_key: _serialize_missing_credential(field_key, field_info)
|
||||
for field_key, (field_info, _, _) in aggregated_fields.items()
|
||||
for field_key, (field_info, _node_fields) in aggregated_fields.items()
|
||||
if field_key not in matched_keys
|
||||
}
|
||||
|
||||
@@ -269,8 +264,7 @@ async def match_user_credentials_to_graph(
|
||||
# provider is in the set of acceptable providers.
|
||||
for credential_field_name, (
|
||||
credential_requirements,
|
||||
_,
|
||||
_,
|
||||
_node_fields,
|
||||
) in aggregated_creds.items():
|
||||
# Find first matching credential by provider, type, and scopes
|
||||
matching_cred = next(
|
||||
@@ -279,14 +273,7 @@ async def match_user_credentials_to_graph(
|
||||
for cred in available_creds
|
||||
if cred.provider in credential_requirements.provider
|
||||
and cred.type in credential_requirements.supported_types
|
||||
and (
|
||||
cred.type != "oauth2"
|
||||
or _credential_has_required_scopes(cred, credential_requirements)
|
||||
)
|
||||
and (
|
||||
cred.type != "host_scoped"
|
||||
or _credential_is_for_host(cred, credential_requirements)
|
||||
)
|
||||
and _credential_has_required_scopes(cred, credential_requirements)
|
||||
),
|
||||
None,
|
||||
)
|
||||
@@ -331,10 +318,19 @@ async def match_user_credentials_to_graph(
|
||||
|
||||
|
||||
def _credential_has_required_scopes(
|
||||
credential: OAuth2Credentials,
|
||||
credential: Credentials,
|
||||
requirements: CredentialsFieldInfo,
|
||||
) -> bool:
|
||||
"""Check if an OAuth2 credential has all the scopes required by the input."""
|
||||
"""
|
||||
Check if a credential has all the scopes required by the block.
|
||||
|
||||
For OAuth2 credentials, verifies that the credential's scopes are a superset
|
||||
of the required scopes. For other credential types, returns True (no scope check).
|
||||
"""
|
||||
# Only OAuth2 credentials have scopes to check
|
||||
if credential.type != "oauth2":
|
||||
return True
|
||||
|
||||
# If no scopes are required, any credential matches
|
||||
if not requirements.required_scopes:
|
||||
return True
|
||||
@@ -343,22 +339,6 @@ def _credential_has_required_scopes(
|
||||
return set(credential.scopes).issuperset(requirements.required_scopes)
|
||||
|
||||
|
||||
def _credential_is_for_host(
|
||||
credential: HostScopedCredentials,
|
||||
requirements: CredentialsFieldInfo,
|
||||
) -> bool:
|
||||
"""Check if a host-scoped credential matches the host required by the input."""
|
||||
# We need to know the host to match host-scoped credentials to.
|
||||
# Graph.aggregate_credentials_inputs() adds the node's set URL value (if any)
|
||||
# to discriminator_values. No discriminator_values -> no host to match against.
|
||||
if not requirements.discriminator_values:
|
||||
return True
|
||||
|
||||
# Check that credential host matches required host.
|
||||
# Host-scoped credential inputs are grouped by host, so any item from the set works.
|
||||
return credential.matches_url(list(requirements.discriminator_values)[0])
|
||||
|
||||
|
||||
async def check_user_has_required_credentials(
|
||||
user_id: str,
|
||||
required_credentials: list[CredentialsMetaInput],
|
||||
|
||||
@@ -19,10 +19,7 @@ from backend.data.graph import GraphSettings
|
||||
from backend.data.includes import AGENT_PRESET_INCLUDE, library_agent_include
|
||||
from backend.data.model import CredentialsMetaInput
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.integrations.webhooks.graph_lifecycle_hooks import (
|
||||
on_graph_activate,
|
||||
on_graph_deactivate,
|
||||
)
|
||||
from backend.integrations.webhooks.graph_lifecycle_hooks import on_graph_activate
|
||||
from backend.util.clients import get_scheduler_client
|
||||
from backend.util.exceptions import DatabaseError, InvalidInputError, NotFoundError
|
||||
from backend.util.json import SafeJson
|
||||
@@ -540,92 +537,6 @@ async def update_agent_version_in_library(
|
||||
return library_model.LibraryAgent.from_db(lib)
|
||||
|
||||
|
||||
async def create_graph_in_library(
|
||||
graph: graph_db.Graph,
|
||||
user_id: str,
|
||||
) -> tuple[graph_db.GraphModel, library_model.LibraryAgent]:
|
||||
"""Create a new graph and add it to the user's library."""
|
||||
graph.version = 1
|
||||
graph_model = graph_db.make_graph_model(graph, user_id)
|
||||
graph_model.reassign_ids(user_id=user_id, reassign_graph_id=True)
|
||||
|
||||
created_graph = await graph_db.create_graph(graph_model, user_id)
|
||||
|
||||
library_agents = await create_library_agent(
|
||||
graph=created_graph,
|
||||
user_id=user_id,
|
||||
sensitive_action_safe_mode=True,
|
||||
create_library_agents_for_sub_graphs=False,
|
||||
)
|
||||
|
||||
if created_graph.is_active:
|
||||
created_graph = await on_graph_activate(created_graph, user_id=user_id)
|
||||
|
||||
return created_graph, library_agents[0]
|
||||
|
||||
|
||||
async def update_graph_in_library(
|
||||
graph: graph_db.Graph,
|
||||
user_id: str,
|
||||
) -> tuple[graph_db.GraphModel, library_model.LibraryAgent]:
|
||||
"""Create a new version of an existing graph and update the library entry."""
|
||||
existing_versions = await graph_db.get_graph_all_versions(graph.id, user_id)
|
||||
current_active_version = (
|
||||
next((v for v in existing_versions if v.is_active), None)
|
||||
if existing_versions
|
||||
else None
|
||||
)
|
||||
graph.version = (
|
||||
max(v.version for v in existing_versions) + 1 if existing_versions else 1
|
||||
)
|
||||
|
||||
graph_model = graph_db.make_graph_model(graph, user_id)
|
||||
graph_model.reassign_ids(user_id=user_id, reassign_graph_id=False)
|
||||
|
||||
created_graph = await graph_db.create_graph(graph_model, user_id)
|
||||
|
||||
library_agent = await get_library_agent_by_graph_id(user_id, created_graph.id)
|
||||
if not library_agent:
|
||||
raise NotFoundError(f"Library agent not found for graph {created_graph.id}")
|
||||
|
||||
library_agent = await update_library_agent_version_and_settings(
|
||||
user_id, created_graph
|
||||
)
|
||||
|
||||
if created_graph.is_active:
|
||||
created_graph = await on_graph_activate(created_graph, user_id=user_id)
|
||||
await graph_db.set_graph_active_version(
|
||||
graph_id=created_graph.id,
|
||||
version=created_graph.version,
|
||||
user_id=user_id,
|
||||
)
|
||||
if current_active_version:
|
||||
await on_graph_deactivate(current_active_version, user_id=user_id)
|
||||
|
||||
return created_graph, library_agent
|
||||
|
||||
|
||||
async def update_library_agent_version_and_settings(
|
||||
user_id: str, agent_graph: graph_db.GraphModel
|
||||
) -> library_model.LibraryAgent:
|
||||
"""Update library agent to point to new graph version and sync settings."""
|
||||
library = await update_agent_version_in_library(
|
||||
user_id, agent_graph.id, agent_graph.version
|
||||
)
|
||||
updated_settings = GraphSettings.from_graph(
|
||||
graph=agent_graph,
|
||||
hitl_safe_mode=library.settings.human_in_the_loop_safe_mode,
|
||||
sensitive_action_safe_mode=library.settings.sensitive_action_safe_mode,
|
||||
)
|
||||
if updated_settings != library.settings:
|
||||
library = await update_library_agent(
|
||||
library_agent_id=library.id,
|
||||
user_id=user_id,
|
||||
settings=updated_settings,
|
||||
)
|
||||
return library
|
||||
|
||||
|
||||
async def update_library_agent(
|
||||
library_agent_id: str,
|
||||
user_id: str,
|
||||
|
||||
@@ -101,6 +101,7 @@ from backend.util.timezone_utils import (
|
||||
from backend.util.virus_scanner import scan_content_safe
|
||||
|
||||
from .library import db as library_db
|
||||
from .library import model as library_model
|
||||
from .store.model import StoreAgentDetails
|
||||
|
||||
|
||||
@@ -822,16 +823,18 @@ async def update_graph(
|
||||
graph: graph_db.Graph,
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> graph_db.GraphModel:
|
||||
# Sanity check
|
||||
if graph.id and graph.id != graph_id:
|
||||
raise HTTPException(400, detail="Graph ID does not match ID in URI")
|
||||
|
||||
# Determine new version
|
||||
existing_versions = await graph_db.get_graph_all_versions(graph_id, user_id=user_id)
|
||||
if not existing_versions:
|
||||
raise HTTPException(404, detail=f"Graph #{graph_id} not found")
|
||||
latest_version_number = max(g.version for g in existing_versions)
|
||||
graph.version = latest_version_number + 1
|
||||
|
||||
graph.version = max(g.version for g in existing_versions) + 1
|
||||
current_active_version = next((v for v in existing_versions if v.is_active), None)
|
||||
|
||||
graph = graph_db.make_graph_model(graph, user_id)
|
||||
graph.reassign_ids(user_id=user_id, reassign_graph_id=False)
|
||||
graph.validate_graph(for_run=False)
|
||||
@@ -839,23 +842,27 @@ async def update_graph(
|
||||
new_graph_version = await graph_db.create_graph(graph, user_id=user_id)
|
||||
|
||||
if new_graph_version.is_active:
|
||||
await library_db.update_library_agent_version_and_settings(
|
||||
user_id, new_graph_version
|
||||
)
|
||||
# Keep the library agent up to date with the new active version
|
||||
await _update_library_agent_version_and_settings(user_id, new_graph_version)
|
||||
|
||||
# Handle activation of the new graph first to ensure continuity
|
||||
new_graph_version = await on_graph_activate(new_graph_version, user_id=user_id)
|
||||
# Ensure new version is the only active version
|
||||
await graph_db.set_graph_active_version(
|
||||
graph_id=graph_id, version=new_graph_version.version, user_id=user_id
|
||||
)
|
||||
if current_active_version:
|
||||
# Handle deactivation of the previously active version
|
||||
await on_graph_deactivate(current_active_version, user_id=user_id)
|
||||
|
||||
# Fetch new graph version *with sub-graphs* (needed for credentials input schema)
|
||||
new_graph_version_with_subgraphs = await graph_db.get_graph(
|
||||
graph_id,
|
||||
new_graph_version.version,
|
||||
user_id=user_id,
|
||||
include_subgraphs=True,
|
||||
)
|
||||
assert new_graph_version_with_subgraphs
|
||||
assert new_graph_version_with_subgraphs # make type checker happy
|
||||
return new_graph_version_with_subgraphs
|
||||
|
||||
|
||||
@@ -893,15 +900,33 @@ async def set_graph_active_version(
|
||||
)
|
||||
|
||||
# Keep the library agent up to date with the new active version
|
||||
await library_db.update_library_agent_version_and_settings(
|
||||
user_id, new_active_graph
|
||||
)
|
||||
await _update_library_agent_version_and_settings(user_id, new_active_graph)
|
||||
|
||||
if current_active_graph and current_active_graph.version != new_active_version:
|
||||
# Handle deactivation of the previously active version
|
||||
await on_graph_deactivate(current_active_graph, user_id=user_id)
|
||||
|
||||
|
||||
async def _update_library_agent_version_and_settings(
|
||||
user_id: str, agent_graph: graph_db.GraphModel
|
||||
) -> library_model.LibraryAgent:
|
||||
library = await library_db.update_agent_version_in_library(
|
||||
user_id, agent_graph.id, agent_graph.version
|
||||
)
|
||||
updated_settings = GraphSettings.from_graph(
|
||||
graph=agent_graph,
|
||||
hitl_safe_mode=library.settings.human_in_the_loop_safe_mode,
|
||||
sensitive_action_safe_mode=library.settings.sensitive_action_safe_mode,
|
||||
)
|
||||
if updated_settings != library.settings:
|
||||
library = await library_db.update_library_agent(
|
||||
library_agent_id=library.id,
|
||||
user_id=user_id,
|
||||
settings=updated_settings,
|
||||
)
|
||||
return library
|
||||
|
||||
|
||||
@v1_router.patch(
|
||||
path="/graphs/{graph_id}/settings",
|
||||
summary="Update graph settings",
|
||||
|
||||
@@ -1,77 +0,0 @@
|
||||
"""Text encoding block for converting special characters to escape sequences."""
|
||||
|
||||
import codecs
|
||||
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
|
||||
class TextEncoderBlock(Block):
|
||||
"""
|
||||
Encodes a string by converting special characters into escape sequences.
|
||||
|
||||
This block is the inverse of TextDecoderBlock. It takes text containing
|
||||
special characters (like newlines, tabs, etc.) and converts them into
|
||||
their escape sequence representations (e.g., newline becomes \\n).
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
"""Input schema for TextEncoderBlock."""
|
||||
|
||||
text: str = SchemaField(
|
||||
description="A string containing special characters to be encoded",
|
||||
placeholder="Your text with newlines and quotes to encode",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
"""Output schema for TextEncoderBlock."""
|
||||
|
||||
encoded_text: str = SchemaField(
|
||||
description="The encoded text with special characters converted to escape sequences"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if encoding fails")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="5185f32e-4b65-4ecf-8fbb-873f003f09d6",
|
||||
description="Encodes a string by converting special characters into escape sequences",
|
||||
categories={BlockCategory.TEXT},
|
||||
input_schema=TextEncoderBlock.Input,
|
||||
output_schema=TextEncoderBlock.Output,
|
||||
test_input={
|
||||
"text": """Hello
|
||||
World!
|
||||
This is a "quoted" string."""
|
||||
},
|
||||
test_output=[
|
||||
(
|
||||
"encoded_text",
|
||||
"""Hello\\nWorld!\\nThis is a "quoted" string.""",
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
"""
|
||||
Encode the input text by converting special characters to escape sequences.
|
||||
|
||||
Args:
|
||||
input_data: The input containing the text to encode.
|
||||
**kwargs: Additional keyword arguments (unused).
|
||||
|
||||
Yields:
|
||||
The encoded text with escape sequences, or an error message if encoding fails.
|
||||
"""
|
||||
try:
|
||||
encoded_text = codecs.encode(input_data.text, "unicode_escape").decode(
|
||||
"utf-8"
|
||||
)
|
||||
yield "encoded_text", encoded_text
|
||||
except Exception as e:
|
||||
yield "error", f"Encoding error: {str(e)}"
|
||||
@@ -162,16 +162,8 @@ class LinearClient:
|
||||
"searchTerm": team_name,
|
||||
}
|
||||
|
||||
result = await self.query(query, variables)
|
||||
nodes = result["teams"]["nodes"]
|
||||
|
||||
if not nodes:
|
||||
raise LinearAPIException(
|
||||
f"Team '{team_name}' not found. Check the team name or key and try again.",
|
||||
status_code=404,
|
||||
)
|
||||
|
||||
return nodes[0]["id"]
|
||||
team_id = await self.query(query, variables)
|
||||
return team_id["teams"]["nodes"][0]["id"]
|
||||
except LinearAPIException as e:
|
||||
raise e
|
||||
|
||||
@@ -248,44 +240,17 @@ class LinearClient:
|
||||
except LinearAPIException as e:
|
||||
raise e
|
||||
|
||||
async def try_search_issues(
|
||||
self,
|
||||
term: str,
|
||||
max_results: int = 10,
|
||||
team_id: str | None = None,
|
||||
) -> list[Issue]:
|
||||
async def try_search_issues(self, term: str) -> list[Issue]:
|
||||
try:
|
||||
query = """
|
||||
query SearchIssues(
|
||||
$term: String!,
|
||||
$first: Int,
|
||||
$teamId: String
|
||||
) {
|
||||
searchIssues(
|
||||
term: $term,
|
||||
first: $first,
|
||||
teamId: $teamId
|
||||
) {
|
||||
query SearchIssues($term: String!, $includeComments: Boolean!) {
|
||||
searchIssues(term: $term, includeComments: $includeComments) {
|
||||
nodes {
|
||||
id
|
||||
identifier
|
||||
title
|
||||
description
|
||||
priority
|
||||
createdAt
|
||||
state {
|
||||
id
|
||||
name
|
||||
type
|
||||
}
|
||||
project {
|
||||
id
|
||||
name
|
||||
}
|
||||
assignee {
|
||||
id
|
||||
name
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -293,8 +258,7 @@ class LinearClient:
|
||||
|
||||
variables: dict[str, Any] = {
|
||||
"term": term,
|
||||
"first": max_results,
|
||||
"teamId": team_id,
|
||||
"includeComments": True,
|
||||
}
|
||||
|
||||
issues = await self.query(query, variables)
|
||||
|
||||
@@ -17,7 +17,7 @@ from ._config import (
|
||||
LinearScope,
|
||||
linear,
|
||||
)
|
||||
from .models import CreateIssueResponse, Issue, State
|
||||
from .models import CreateIssueResponse, Issue
|
||||
|
||||
|
||||
class LinearCreateIssueBlock(Block):
|
||||
@@ -135,20 +135,9 @@ class LinearSearchIssuesBlock(Block):
|
||||
description="Linear credentials with read permissions",
|
||||
required_scopes={LinearScope.READ},
|
||||
)
|
||||
max_results: int = SchemaField(
|
||||
description="Maximum number of results to return",
|
||||
default=10,
|
||||
ge=1,
|
||||
le=100,
|
||||
)
|
||||
team_name: str | None = SchemaField(
|
||||
description="Optional team name to filter results (e.g., 'Internal', 'Open Source')",
|
||||
default=None,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
issues: list[Issue] = SchemaField(description="List of issues")
|
||||
error: str = SchemaField(description="Error message if the search failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -156,11 +145,8 @@ class LinearSearchIssuesBlock(Block):
|
||||
description="Searches for issues on Linear",
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
categories={BlockCategory.PRODUCTIVITY, BlockCategory.ISSUE_TRACKING},
|
||||
test_input={
|
||||
"term": "Test issue",
|
||||
"max_results": 10,
|
||||
"team_name": None,
|
||||
"credentials": TEST_CREDENTIALS_INPUT_OAUTH,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS_OAUTH,
|
||||
@@ -170,14 +156,10 @@ class LinearSearchIssuesBlock(Block):
|
||||
[
|
||||
Issue(
|
||||
id="abc123",
|
||||
identifier="TST-123",
|
||||
identifier="abc123",
|
||||
title="Test issue",
|
||||
description="Test description",
|
||||
priority=1,
|
||||
state=State(
|
||||
id="state1", name="In Progress", type="started"
|
||||
),
|
||||
createdAt="2026-01-15T10:00:00.000Z",
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -186,12 +168,10 @@ class LinearSearchIssuesBlock(Block):
|
||||
"search_issues": lambda *args, **kwargs: [
|
||||
Issue(
|
||||
id="abc123",
|
||||
identifier="TST-123",
|
||||
identifier="abc123",
|
||||
title="Test issue",
|
||||
description="Test description",
|
||||
priority=1,
|
||||
state=State(id="state1", name="In Progress", type="started"),
|
||||
createdAt="2026-01-15T10:00:00.000Z",
|
||||
)
|
||||
]
|
||||
},
|
||||
@@ -201,22 +181,10 @@ class LinearSearchIssuesBlock(Block):
|
||||
async def search_issues(
|
||||
credentials: OAuth2Credentials | APIKeyCredentials,
|
||||
term: str,
|
||||
max_results: int = 10,
|
||||
team_name: str | None = None,
|
||||
) -> list[Issue]:
|
||||
client = LinearClient(credentials=credentials)
|
||||
|
||||
# Resolve team name to ID if provided
|
||||
# Raises LinearAPIException with descriptive message if team not found
|
||||
team_id: str | None = None
|
||||
if team_name:
|
||||
team_id = await client.try_get_team_by_name(team_name=team_name)
|
||||
|
||||
return await client.try_search_issues(
|
||||
term=term,
|
||||
max_results=max_results,
|
||||
team_id=team_id,
|
||||
)
|
||||
response: list[Issue] = await client.try_search_issues(term=term)
|
||||
return response
|
||||
|
||||
async def run(
|
||||
self,
|
||||
@@ -228,10 +196,7 @@ class LinearSearchIssuesBlock(Block):
|
||||
"""Execute the issue search"""
|
||||
try:
|
||||
issues = await self.search_issues(
|
||||
credentials=credentials,
|
||||
term=input_data.term,
|
||||
max_results=input_data.max_results,
|
||||
team_name=input_data.team_name,
|
||||
credentials=credentials, term=input_data.term
|
||||
)
|
||||
yield "issues", issues
|
||||
except LinearAPIException as e:
|
||||
|
||||
@@ -36,21 +36,12 @@ class Project(BaseModel):
|
||||
content: str | None = None
|
||||
|
||||
|
||||
class State(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
type: str | None = (
|
||||
None # Workflow state type (e.g., "triage", "backlog", "started", "completed", "canceled")
|
||||
)
|
||||
|
||||
|
||||
class Issue(BaseModel):
|
||||
id: str
|
||||
identifier: str
|
||||
title: str
|
||||
description: str | None
|
||||
priority: int
|
||||
state: State | None = None
|
||||
project: Project | None = None
|
||||
createdAt: str | None = None
|
||||
comments: list[Comment] | None = None
|
||||
|
||||
@@ -1,77 +0,0 @@
|
||||
import pytest
|
||||
|
||||
from backend.blocks.encoder_block import TextEncoderBlock
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_text_encoder_basic():
|
||||
"""Test basic encoding of newlines and special characters."""
|
||||
block = TextEncoderBlock()
|
||||
result = []
|
||||
async for output in block.run(TextEncoderBlock.Input(text="Hello\nWorld")):
|
||||
result.append(output)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0][0] == "encoded_text"
|
||||
assert result[0][1] == "Hello\\nWorld"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_text_encoder_multiple_escapes():
|
||||
"""Test encoding of multiple escape sequences."""
|
||||
block = TextEncoderBlock()
|
||||
result = []
|
||||
async for output in block.run(
|
||||
TextEncoderBlock.Input(text="Line1\nLine2\tTabbed\rCarriage")
|
||||
):
|
||||
result.append(output)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0][0] == "encoded_text"
|
||||
assert "\\n" in result[0][1]
|
||||
assert "\\t" in result[0][1]
|
||||
assert "\\r" in result[0][1]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_text_encoder_unicode():
|
||||
"""Test that unicode characters are handled correctly."""
|
||||
block = TextEncoderBlock()
|
||||
result = []
|
||||
async for output in block.run(TextEncoderBlock.Input(text="Hello 世界\n")):
|
||||
result.append(output)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0][0] == "encoded_text"
|
||||
# Unicode characters should be escaped as \uXXXX sequences
|
||||
assert "\\n" in result[0][1]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_text_encoder_empty_string():
|
||||
"""Test encoding of an empty string."""
|
||||
block = TextEncoderBlock()
|
||||
result = []
|
||||
async for output in block.run(TextEncoderBlock.Input(text="")):
|
||||
result.append(output)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0][0] == "encoded_text"
|
||||
assert result[0][1] == ""
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_text_encoder_error_handling():
|
||||
"""Test that encoding errors are handled gracefully."""
|
||||
from unittest.mock import patch
|
||||
|
||||
block = TextEncoderBlock()
|
||||
result = []
|
||||
|
||||
with patch("codecs.encode", side_effect=Exception("Mocked encoding error")):
|
||||
async for output in block.run(TextEncoderBlock.Input(text="test")):
|
||||
result.append(output)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0][0] == "error"
|
||||
assert "Mocked encoding error" in result[0][1]
|
||||
@@ -165,13 +165,10 @@ class TranscribeYoutubeVideoBlock(Block):
|
||||
credentials: WebshareProxyCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
video_id = self.extract_video_id(input_data.youtube_url)
|
||||
transcript = self.get_transcript(video_id, credentials)
|
||||
transcript_text = self.format_transcript(transcript=transcript)
|
||||
video_id = self.extract_video_id(input_data.youtube_url)
|
||||
yield "video_id", video_id
|
||||
|
||||
# Only yield after all operations succeed
|
||||
yield "video_id", video_id
|
||||
yield "transcript", transcript_text
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
transcript = self.get_transcript(video_id, credentials)
|
||||
transcript_text = self.format_transcript(transcript=transcript)
|
||||
|
||||
yield "transcript", transcript_text
|
||||
|
||||
@@ -246,9 +246,7 @@ class BlockSchema(BaseModel):
|
||||
f"is not of type {CredentialsMetaInput.__name__}"
|
||||
)
|
||||
|
||||
CredentialsMetaInput.validate_credentials_field_schema(
|
||||
cls.get_field_schema(field_name), field_name
|
||||
)
|
||||
credentials_fields[field_name].validate_credentials_field_schema(cls)
|
||||
|
||||
elif field_name in credentials_fields:
|
||||
raise KeyError(
|
||||
|
||||
@@ -134,16 +134,6 @@ async def test_block_credit_reset(server: SpinTestServer):
|
||||
month1 = datetime.now(timezone.utc).replace(month=1, day=1)
|
||||
user_credit.time_now = lambda: month1
|
||||
|
||||
# IMPORTANT: Set updatedAt to December of previous year to ensure it's
|
||||
# in a different month than month1 (January). This fixes a timing bug
|
||||
# where if the test runs in early February, 35 days ago would be January,
|
||||
# matching the mocked month1 and preventing the refill from triggering.
|
||||
dec_previous_year = month1.replace(year=month1.year - 1, month=12, day=15)
|
||||
await UserBalance.prisma().update(
|
||||
where={"userId": DEFAULT_USER_ID},
|
||||
data={"updatedAt": dec_previous_year},
|
||||
)
|
||||
|
||||
# First call in month 1 should trigger refill
|
||||
balance = await user_credit.get_credits(DEFAULT_USER_ID)
|
||||
assert balance == REFILL_VALUE # Should get 1000 credits
|
||||
|
||||
@@ -20,7 +20,7 @@ from prisma.types import (
|
||||
AgentNodeLinkCreateInput,
|
||||
StoreListingVersionWhereInput,
|
||||
)
|
||||
from pydantic import BaseModel, BeforeValidator, Field
|
||||
from pydantic import BaseModel, BeforeValidator, Field, create_model
|
||||
from pydantic.fields import computed_field
|
||||
|
||||
from backend.blocks.agent import AgentExecutorBlock
|
||||
@@ -30,6 +30,7 @@ from backend.data.db import prisma as db
|
||||
from backend.data.dynamic_fields import is_tool_pin, sanitize_pin_name
|
||||
from backend.data.includes import MAX_GRAPH_VERSIONS_FETCH
|
||||
from backend.data.model import (
|
||||
CredentialsField,
|
||||
CredentialsFieldInfo,
|
||||
CredentialsMetaInput,
|
||||
is_credentials_field_name,
|
||||
@@ -44,6 +45,7 @@ from .block import (
|
||||
AnyBlockSchema,
|
||||
Block,
|
||||
BlockInput,
|
||||
BlockSchema,
|
||||
BlockType,
|
||||
EmptySchema,
|
||||
get_block,
|
||||
@@ -364,8 +366,39 @@ class Graph(BaseGraph):
|
||||
@computed_field
|
||||
@property
|
||||
def credentials_input_schema(self) -> dict[str, Any]:
|
||||
graph_credentials_inputs = self.aggregate_credentials_inputs()
|
||||
schema = self._credentials_input_schema.jsonschema()
|
||||
|
||||
# Determine which credential fields are required based on credentials_optional metadata
|
||||
graph_credentials_inputs = self.aggregate_credentials_inputs()
|
||||
required_fields = []
|
||||
|
||||
# Build a map of node_id -> node for quick lookup
|
||||
all_nodes = {node.id: node for node in self.nodes}
|
||||
for sub_graph in self.sub_graphs:
|
||||
for node in sub_graph.nodes:
|
||||
all_nodes[node.id] = node
|
||||
|
||||
for field_key, (
|
||||
_field_info,
|
||||
node_field_pairs,
|
||||
) in graph_credentials_inputs.items():
|
||||
# A field is required if ANY node using it has credentials_optional=False
|
||||
is_required = False
|
||||
for node_id, _field_name in node_field_pairs:
|
||||
node = all_nodes.get(node_id)
|
||||
if node and not node.credentials_optional:
|
||||
is_required = True
|
||||
break
|
||||
|
||||
if is_required:
|
||||
required_fields.append(field_key)
|
||||
|
||||
schema["required"] = required_fields
|
||||
return schema
|
||||
|
||||
@property
|
||||
def _credentials_input_schema(self) -> type[BlockSchema]:
|
||||
graph_credentials_inputs = self.aggregate_credentials_inputs()
|
||||
logger.debug(
|
||||
f"Combined credentials input fields for graph #{self.id} ({self.name}): "
|
||||
f"{graph_credentials_inputs}"
|
||||
@@ -373,8 +406,8 @@ class Graph(BaseGraph):
|
||||
|
||||
# Warn if same-provider credentials inputs can't be combined (= bad UX)
|
||||
graph_cred_fields = list(graph_credentials_inputs.values())
|
||||
for i, (field, keys, _) in enumerate(graph_cred_fields):
|
||||
for other_field, other_keys, _ in list(graph_cred_fields)[i + 1 :]:
|
||||
for i, (field, keys) in enumerate(graph_cred_fields):
|
||||
for other_field, other_keys in list(graph_cred_fields)[i + 1 :]:
|
||||
if field.provider != other_field.provider:
|
||||
continue
|
||||
if ProviderName.HTTP in field.provider:
|
||||
@@ -390,78 +423,31 @@ class Graph(BaseGraph):
|
||||
f"keys: {keys} <> {other_keys}."
|
||||
)
|
||||
|
||||
# Build JSON schema directly to avoid expensive create_model + validation overhead
|
||||
properties = {}
|
||||
required_fields = []
|
||||
|
||||
for agg_field_key, (
|
||||
field_info,
|
||||
_,
|
||||
is_required,
|
||||
) in graph_credentials_inputs.items():
|
||||
providers = list(field_info.provider)
|
||||
cred_types = list(field_info.supported_types)
|
||||
|
||||
field_schema: dict[str, Any] = {
|
||||
"credentials_provider": providers,
|
||||
"credentials_types": cred_types,
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"id": {"title": "Id", "type": "string"},
|
||||
"title": {
|
||||
"anyOf": [{"type": "string"}, {"type": "null"}],
|
||||
"default": None,
|
||||
"title": "Title",
|
||||
},
|
||||
"provider": {
|
||||
"title": "Provider",
|
||||
"type": "string",
|
||||
**(
|
||||
{"enum": providers}
|
||||
if len(providers) > 1
|
||||
else {"const": providers[0]}
|
||||
),
|
||||
},
|
||||
"type": {
|
||||
"title": "Type",
|
||||
"type": "string",
|
||||
**(
|
||||
{"enum": cred_types}
|
||||
if len(cred_types) > 1
|
||||
else {"const": cred_types[0]}
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": ["id", "provider", "type"],
|
||||
}
|
||||
|
||||
# Add other (optional) field info items
|
||||
field_schema.update(
|
||||
field_info.model_dump(
|
||||
by_alias=True,
|
||||
exclude_defaults=True,
|
||||
exclude={"provider", "supported_types"}, # already included above
|
||||
)
|
||||
fields: dict[str, tuple[type[CredentialsMetaInput], CredentialsMetaInput]] = {
|
||||
agg_field_key: (
|
||||
CredentialsMetaInput[
|
||||
Literal[tuple(field_info.provider)], # type: ignore
|
||||
Literal[tuple(field_info.supported_types)], # type: ignore
|
||||
],
|
||||
CredentialsField(
|
||||
required_scopes=set(field_info.required_scopes or []),
|
||||
discriminator=field_info.discriminator,
|
||||
discriminator_mapping=field_info.discriminator_mapping,
|
||||
discriminator_values=field_info.discriminator_values,
|
||||
),
|
||||
)
|
||||
|
||||
# Ensure field schema is well-formed
|
||||
CredentialsMetaInput.validate_credentials_field_schema(
|
||||
field_schema, agg_field_key
|
||||
)
|
||||
|
||||
properties[agg_field_key] = field_schema
|
||||
if is_required:
|
||||
required_fields.append(agg_field_key)
|
||||
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": properties,
|
||||
"required": required_fields,
|
||||
for agg_field_key, (field_info, _) in graph_credentials_inputs.items()
|
||||
}
|
||||
|
||||
return create_model(
|
||||
self.name.replace(" ", "") + "CredentialsInputSchema",
|
||||
__base__=BlockSchema,
|
||||
**fields, # type: ignore
|
||||
)
|
||||
|
||||
def aggregate_credentials_inputs(
|
||||
self,
|
||||
) -> dict[str, tuple[CredentialsFieldInfo, set[tuple[str, str]], bool]]:
|
||||
) -> dict[str, tuple[CredentialsFieldInfo, set[tuple[str, str]]]]:
|
||||
"""
|
||||
Returns:
|
||||
dict[aggregated_field_key, tuple(
|
||||
@@ -469,19 +455,13 @@ class Graph(BaseGraph):
|
||||
(now includes discriminator_values from matching nodes)
|
||||
set[(node_id, field_name)]: Node credentials fields that are
|
||||
compatible with this aggregated field spec
|
||||
bool: True if the field is required (any node has credentials_optional=False)
|
||||
)]
|
||||
"""
|
||||
# First collect all credential field data with input defaults
|
||||
# Track (field_info, (node_id, field_name), is_required) for each credential field
|
||||
node_credential_data: list[tuple[CredentialsFieldInfo, tuple[str, str]]] = []
|
||||
node_required_map: dict[str, bool] = {} # node_id -> is_required
|
||||
node_credential_data = []
|
||||
|
||||
for graph in [self] + self.sub_graphs:
|
||||
for node in graph.nodes:
|
||||
# Track if this node requires credentials (credentials_optional=False means required)
|
||||
node_required_map[node.id] = not node.credentials_optional
|
||||
|
||||
for (
|
||||
field_name,
|
||||
field_info,
|
||||
@@ -505,21 +485,7 @@ class Graph(BaseGraph):
|
||||
)
|
||||
|
||||
# Combine credential field info (this will merge discriminator_values automatically)
|
||||
combined = CredentialsFieldInfo.combine(*node_credential_data)
|
||||
|
||||
# Add is_required flag to each aggregated field
|
||||
# A field is required if ANY node using it has credentials_optional=False
|
||||
return {
|
||||
key: (
|
||||
field_info,
|
||||
node_field_pairs,
|
||||
any(
|
||||
node_required_map.get(node_id, True)
|
||||
for node_id, _ in node_field_pairs
|
||||
),
|
||||
)
|
||||
for key, (field_info, node_field_pairs) in combined.items()
|
||||
}
|
||||
return CredentialsFieldInfo.combine(*node_credential_data)
|
||||
|
||||
|
||||
class GraphModel(Graph):
|
||||
@@ -866,55 +832,16 @@ class GraphModel(Graph):
|
||||
)
|
||||
|
||||
|
||||
class GraphMeta(BaseModel):
|
||||
"""
|
||||
Graph metadata without nodes/links, used for list endpoints.
|
||||
|
||||
This is a flat, lightweight model (not inheriting from Graph) to avoid recomputing
|
||||
expensive computed fields. Values are copied from GraphModel.
|
||||
"""
|
||||
|
||||
id: str
|
||||
version: int = 1
|
||||
is_active: bool = True
|
||||
name: str
|
||||
description: str
|
||||
instructions: str | None = None
|
||||
recommended_schedule_cron: str | None = None
|
||||
forked_from_id: str | None = None
|
||||
forked_from_version: int | None = None
|
||||
class GraphMeta(Graph):
|
||||
user_id: str
|
||||
|
||||
input_schema: dict[str, Any]
|
||||
output_schema: dict[str, Any]
|
||||
credentials_input_schema: dict[str, Any]
|
||||
has_external_trigger: bool
|
||||
has_human_in_the_loop: bool
|
||||
has_sensitive_action: bool
|
||||
trigger_setup_info: Optional["GraphTriggerInfo"]
|
||||
# Easy work-around to prevent exposing nodes and links in the API response
|
||||
nodes: list[NodeModel] = Field(default=[], exclude=True) # type: ignore
|
||||
links: list[Link] = Field(default=[], exclude=True)
|
||||
|
||||
@staticmethod
|
||||
def from_graph(graph: "GraphModel") -> "GraphMeta":
|
||||
return GraphMeta(
|
||||
id=graph.id,
|
||||
version=graph.version,
|
||||
is_active=graph.is_active,
|
||||
name=graph.name,
|
||||
description=graph.description,
|
||||
instructions=graph.instructions,
|
||||
recommended_schedule_cron=graph.recommended_schedule_cron,
|
||||
forked_from_id=graph.forked_from_id,
|
||||
forked_from_version=graph.forked_from_version,
|
||||
user_id=graph.user_id,
|
||||
# Pre-computed values (were @computed_field on Graph)
|
||||
input_schema=graph.input_schema,
|
||||
output_schema=graph.output_schema,
|
||||
has_external_trigger=graph.has_external_trigger,
|
||||
has_human_in_the_loop=graph.has_human_in_the_loop,
|
||||
has_sensitive_action=graph.has_sensitive_action,
|
||||
trigger_setup_info=graph.trigger_setup_info,
|
||||
credentials_input_schema=graph.credentials_input_schema,
|
||||
)
|
||||
def from_graph(graph: GraphModel) -> "GraphMeta":
|
||||
return GraphMeta(**graph.model_dump())
|
||||
|
||||
|
||||
class GraphsPaginated(BaseModel):
|
||||
@@ -993,9 +920,9 @@ async def list_graphs_paginated(
|
||||
graph_models: list[GraphMeta] = []
|
||||
for graph in graphs:
|
||||
try:
|
||||
# GraphMeta.from_graph() accesses all computed fields on the GraphModel,
|
||||
# which validates that the graph is well formed (e.g. no unknown block_ids).
|
||||
graph_meta = GraphModel.from_db(graph).meta()
|
||||
# Trigger serialization to validate that the graph is well formed
|
||||
graph_meta.model_dump()
|
||||
graph_models.append(graph_meta)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing graph {graph.id}: {e}")
|
||||
|
||||
@@ -19,6 +19,7 @@ from typing import (
|
||||
cast,
|
||||
get_args,
|
||||
)
|
||||
from urllib.parse import urlparse
|
||||
from uuid import uuid4
|
||||
|
||||
from prisma.enums import CreditTransactionType, OnboardingStep
|
||||
@@ -41,7 +42,6 @@ from typing_extensions import TypedDict
|
||||
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.json import loads as json_loads
|
||||
from backend.util.request import parse_url
|
||||
from backend.util.settings import Secrets
|
||||
|
||||
# Type alias for any provider name (including custom ones)
|
||||
@@ -163,6 +163,7 @@ class User(BaseModel):
|
||||
if TYPE_CHECKING:
|
||||
from prisma.models import User as PrismaUser
|
||||
|
||||
from backend.data.block import BlockSchema
|
||||
|
||||
T = TypeVar("T")
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -396,25 +397,19 @@ class HostScopedCredentials(_BaseCredentials):
|
||||
def matches_url(self, url: str) -> bool:
|
||||
"""Check if this credential should be applied to the given URL."""
|
||||
|
||||
request_host, request_port = _extract_host_from_url(url)
|
||||
cred_scope_host, cred_scope_port = _extract_host_from_url(self.host)
|
||||
parsed_url = urlparse(url)
|
||||
# Extract hostname without port
|
||||
request_host = parsed_url.hostname
|
||||
if not request_host:
|
||||
return False
|
||||
|
||||
# If a port is specified in credential host, the request host port must match
|
||||
if cred_scope_port is not None and request_port != cred_scope_port:
|
||||
return False
|
||||
# Non-standard ports are only allowed if explicitly specified in credential host
|
||||
elif cred_scope_port is None and request_port not in (80, 443, None):
|
||||
return False
|
||||
|
||||
# Simple host matching
|
||||
if cred_scope_host == request_host:
|
||||
# Simple host matching - exact match or wildcard subdomain match
|
||||
if self.host == request_host:
|
||||
return True
|
||||
|
||||
# Support wildcard matching (e.g., "*.example.com" matches "api.example.com")
|
||||
if cred_scope_host.startswith("*."):
|
||||
domain = cred_scope_host[2:] # Remove "*."
|
||||
if self.host.startswith("*."):
|
||||
domain = self.host[2:] # Remove "*."
|
||||
return request_host.endswith(f".{domain}") or request_host == domain
|
||||
|
||||
return False
|
||||
@@ -507,13 +502,15 @@ class CredentialsMetaInput(BaseModel, Generic[CP, CT]):
|
||||
def allowed_cred_types(cls) -> tuple[CredentialsType, ...]:
|
||||
return get_args(cls.model_fields["type"].annotation)
|
||||
|
||||
@staticmethod
|
||||
def validate_credentials_field_schema(
|
||||
field_schema: dict[str, Any], field_name: str
|
||||
):
|
||||
@classmethod
|
||||
def validate_credentials_field_schema(cls, model: type["BlockSchema"]):
|
||||
"""Validates the schema of a credentials input field"""
|
||||
field_name = next(
|
||||
name for name, type in model.get_credentials_fields().items() if type is cls
|
||||
)
|
||||
field_schema = model.jsonschema()["properties"][field_name]
|
||||
try:
|
||||
field_info = CredentialsFieldInfo[CP, CT].model_validate(field_schema)
|
||||
schema_extra = CredentialsFieldInfo[CP, CT].model_validate(field_schema)
|
||||
except ValidationError as e:
|
||||
if "Field required [type=missing" not in str(e):
|
||||
raise
|
||||
@@ -523,11 +520,11 @@ class CredentialsMetaInput(BaseModel, Generic[CP, CT]):
|
||||
f"{field_schema}"
|
||||
) from e
|
||||
|
||||
providers = field_info.provider
|
||||
providers = cls.allowed_providers()
|
||||
if (
|
||||
providers is not None
|
||||
and len(providers) > 1
|
||||
and not field_info.discriminator
|
||||
and not schema_extra.discriminator
|
||||
):
|
||||
raise TypeError(
|
||||
f"Multi-provider CredentialsField '{field_name}' "
|
||||
@@ -554,13 +551,13 @@ class CredentialsMetaInput(BaseModel, Generic[CP, CT]):
|
||||
)
|
||||
|
||||
|
||||
def _extract_host_from_url(url: str) -> tuple[str, int | None]:
|
||||
"""Extract host and port from URL for grouping host-scoped credentials."""
|
||||
def _extract_host_from_url(url: str) -> str:
|
||||
"""Extract host from URL for grouping host-scoped credentials."""
|
||||
try:
|
||||
parsed = parse_url(url)
|
||||
return parsed.hostname or url, parsed.port
|
||||
parsed = urlparse(url)
|
||||
return parsed.hostname or url
|
||||
except Exception:
|
||||
return "", None
|
||||
return ""
|
||||
|
||||
|
||||
class CredentialsFieldInfo(BaseModel, Generic[CP, CT]):
|
||||
@@ -609,7 +606,7 @@ class CredentialsFieldInfo(BaseModel, Generic[CP, CT]):
|
||||
providers = frozenset(
|
||||
[cast(CP, "http")]
|
||||
+ [
|
||||
cast(CP, parse_url(str(value)).netloc)
|
||||
cast(CP, _extract_host_from_url(str(value)))
|
||||
for value in field.discriminator_values
|
||||
]
|
||||
)
|
||||
|
||||
@@ -79,23 +79,10 @@ class TestHostScopedCredentials:
|
||||
headers={"Authorization": SecretStr("Bearer token")},
|
||||
)
|
||||
|
||||
# Non-standard ports require explicit port in credential host
|
||||
assert not creds.matches_url("http://localhost:8080/api/v1")
|
||||
assert creds.matches_url("http://localhost:8080/api/v1")
|
||||
assert creds.matches_url("https://localhost:443/secure/endpoint")
|
||||
assert creds.matches_url("http://localhost/simple")
|
||||
|
||||
def test_matches_url_with_explicit_port(self):
|
||||
"""Test URL matching with explicit port in credential host."""
|
||||
creds = HostScopedCredentials(
|
||||
provider="custom",
|
||||
host="localhost:8080",
|
||||
headers={"Authorization": SecretStr("Bearer token")},
|
||||
)
|
||||
|
||||
assert creds.matches_url("http://localhost:8080/api/v1")
|
||||
assert not creds.matches_url("http://localhost:3000/api/v1")
|
||||
assert not creds.matches_url("http://localhost/simple")
|
||||
|
||||
def test_empty_headers_dict(self):
|
||||
"""Test HostScopedCredentials with empty headers."""
|
||||
creds = HostScopedCredentials(
|
||||
@@ -141,20 +128,8 @@ class TestHostScopedCredentials:
|
||||
("*.example.com", "https://sub.api.example.com/test", True),
|
||||
("*.example.com", "https://example.com/test", True),
|
||||
("*.example.com", "https://example.org/test", False),
|
||||
# Non-standard ports require explicit port in credential host
|
||||
("localhost", "http://localhost:3000/test", False),
|
||||
("localhost:3000", "http://localhost:3000/test", True),
|
||||
("localhost", "http://localhost:3000/test", True),
|
||||
("localhost", "http://127.0.0.1:3000/test", False),
|
||||
# IPv6 addresses (frontend stores with brackets via URL.hostname)
|
||||
("[::1]", "http://[::1]/test", True),
|
||||
("[::1]", "http://[::1]:80/test", True),
|
||||
("[::1]", "https://[::1]:443/test", True),
|
||||
("[::1]", "http://[::1]:8080/test", False), # Non-standard port
|
||||
("[::1]:8080", "http://[::1]:8080/test", True),
|
||||
("[::1]:8080", "http://[::1]:9090/test", False),
|
||||
("[2001:db8::1]", "http://[2001:db8::1]/path", True),
|
||||
("[2001:db8::1]", "https://[2001:db8::1]:443/path", True),
|
||||
("[2001:db8::1]", "http://[2001:db8::ff]/path", False),
|
||||
],
|
||||
)
|
||||
def test_url_matching_parametrized(self, host: str, test_url: str, expected: bool):
|
||||
|
||||
@@ -373,7 +373,7 @@ def make_node_credentials_input_map(
|
||||
# Get aggregated credentials fields for the graph
|
||||
graph_cred_inputs = graph.aggregate_credentials_inputs()
|
||||
|
||||
for graph_input_name, (_, compatible_node_fields, _) in graph_cred_inputs.items():
|
||||
for graph_input_name, (_, compatible_node_fields) in graph_cred_inputs.items():
|
||||
# Best-effort map: skip missing items
|
||||
if graph_input_name not in graph_credentials_input:
|
||||
continue
|
||||
|
||||
@@ -157,7 +157,12 @@ async def validate_url(
|
||||
is_trusted: Boolean indicating if the hostname is in trusted_origins
|
||||
ip_addresses: List of IP addresses for the host; empty if the host is trusted
|
||||
"""
|
||||
parsed = parse_url(url)
|
||||
# Canonicalize URL
|
||||
url = url.strip("/ ").replace("\\", "/")
|
||||
parsed = urlparse(url)
|
||||
if not parsed.scheme:
|
||||
url = f"http://{url}"
|
||||
parsed = urlparse(url)
|
||||
|
||||
# Check scheme
|
||||
if parsed.scheme not in ALLOWED_SCHEMES:
|
||||
@@ -215,17 +220,6 @@ async def validate_url(
|
||||
)
|
||||
|
||||
|
||||
def parse_url(url: str) -> URL:
|
||||
"""Canonicalizes and parses a URL string."""
|
||||
url = url.strip("/ ").replace("\\", "/")
|
||||
|
||||
# Ensure scheme is present for proper parsing
|
||||
if not re.match(r"[a-z0-9+.\-]+://", url):
|
||||
url = f"http://{url}"
|
||||
|
||||
return urlparse(url)
|
||||
|
||||
|
||||
def pin_url(url: URL, ip_addresses: Optional[list[str]] = None) -> URL:
|
||||
"""
|
||||
Pins a URL to a specific IP address to prevent DNS rebinding attacks.
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
"credentials_input_schema": {
|
||||
"properties": {},
|
||||
"required": [],
|
||||
"title": "TestGraphCredentialsInputSchema",
|
||||
"type": "object"
|
||||
},
|
||||
"description": "A test graph",
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
"credentials_input_schema": {
|
||||
"properties": {},
|
||||
"required": [],
|
||||
"title": "TestGraphCredentialsInputSchema",
|
||||
"type": "object"
|
||||
},
|
||||
"description": "A test graph",
|
||||
@@ -26,6 +27,7 @@
|
||||
"type": "object"
|
||||
},
|
||||
"recommended_schedule_cron": null,
|
||||
"sub_graphs": [],
|
||||
"trigger_setup_info": null,
|
||||
"user_id": "3e53486c-cf57-477e-ba2a-cb02dc828e1a",
|
||||
"version": 1
|
||||
|
||||
@@ -1,17 +1,6 @@
|
||||
import { OAuthPopupResultMessage } from "./types";
|
||||
import { NextResponse } from "next/server";
|
||||
|
||||
/**
|
||||
* Safely encode a value as JSON for embedding in a script tag.
|
||||
* Escapes characters that could break out of the script context to prevent XSS.
|
||||
*/
|
||||
function safeJsonStringify(value: unknown): string {
|
||||
return JSON.stringify(value)
|
||||
.replace(/</g, "\\u003c")
|
||||
.replace(/>/g, "\\u003e")
|
||||
.replace(/&/g, "\\u0026");
|
||||
}
|
||||
|
||||
// This route is intended to be used as the callback for integration OAuth flows,
|
||||
// controlled by the CredentialsInput component. The CredentialsInput opens the login
|
||||
// page in a pop-up window, which then redirects to this route to close the loop.
|
||||
@@ -34,13 +23,12 @@ export async function GET(request: Request) {
|
||||
console.debug("Sending message to opener:", message);
|
||||
|
||||
// Return a response with the message as JSON and a script to close the window
|
||||
// Use safeJsonStringify to prevent XSS by escaping <, >, and & characters
|
||||
return new NextResponse(
|
||||
`
|
||||
<html>
|
||||
<body>
|
||||
<script>
|
||||
window.opener.postMessage(${safeJsonStringify(message)});
|
||||
window.opener.postMessage(${JSON.stringify(message)});
|
||||
window.close();
|
||||
</script>
|
||||
</body>
|
||||
|
||||
@@ -26,20 +26,8 @@ export function buildCopilotChatUrl(prompt: string): string {
|
||||
|
||||
export function getQuickActions(): string[] {
|
||||
return [
|
||||
"I don't know where to start, just ask me stuff",
|
||||
"I do the same thing every week and it's killing me",
|
||||
"Help me find where I'm wasting my time",
|
||||
"Show me what I can automate",
|
||||
"Design a custom workflow",
|
||||
"Help me with content creation",
|
||||
];
|
||||
}
|
||||
|
||||
export function getInputPlaceholder(width?: number) {
|
||||
if (!width) return "What's your role and what eats up most of your day?";
|
||||
|
||||
if (width < 500) {
|
||||
return "I'm a chef and I hate...";
|
||||
}
|
||||
if (width <= 1080) {
|
||||
return "What's your role and what eats up most of your day?";
|
||||
}
|
||||
return "What's your role and what eats up most of your day? e.g. 'I'm a recruiter and I hate...'";
|
||||
}
|
||||
|
||||
@@ -6,9 +6,7 @@ import { Text } from "@/components/atoms/Text/Text";
|
||||
import { Chat } from "@/components/contextual/Chat/Chat";
|
||||
import { ChatInput } from "@/components/contextual/Chat/components/ChatInput/ChatInput";
|
||||
import { Dialog } from "@/components/molecules/Dialog/Dialog";
|
||||
import { useEffect, useState } from "react";
|
||||
import { useCopilotStore } from "./copilot-page-store";
|
||||
import { getInputPlaceholder } from "./helpers";
|
||||
import { useCopilotPage } from "./useCopilotPage";
|
||||
|
||||
export default function CopilotPage() {
|
||||
@@ -16,25 +14,8 @@ export default function CopilotPage() {
|
||||
const isInterruptModalOpen = useCopilotStore((s) => s.isInterruptModalOpen);
|
||||
const confirmInterrupt = useCopilotStore((s) => s.confirmInterrupt);
|
||||
const cancelInterrupt = useCopilotStore((s) => s.cancelInterrupt);
|
||||
|
||||
const [inputPlaceholder, setInputPlaceholder] = useState(
|
||||
getInputPlaceholder(),
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
const handleResize = () => {
|
||||
setInputPlaceholder(getInputPlaceholder(window.innerWidth));
|
||||
};
|
||||
|
||||
handleResize();
|
||||
|
||||
window.addEventListener("resize", handleResize);
|
||||
return () => window.removeEventListener("resize", handleResize);
|
||||
}, []);
|
||||
|
||||
const { greetingName, quickActions, isLoading, hasSession, initialPrompt } =
|
||||
state;
|
||||
|
||||
const {
|
||||
handleQuickAction,
|
||||
startChatWithPrompt,
|
||||
@@ -92,7 +73,7 @@ export default function CopilotPage() {
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="flex h-full flex-1 items-center justify-center overflow-y-auto bg-[#f8f8f9] px-3 py-5 md:px-6 md:py-10">
|
||||
<div className="flex h-full flex-1 items-center justify-center overflow-y-auto bg-[#f8f8f9] px-6 py-10">
|
||||
<div className="w-full text-center">
|
||||
{isLoading ? (
|
||||
<div className="mx-auto max-w-2xl">
|
||||
@@ -109,25 +90,25 @@ export default function CopilotPage() {
|
||||
</div>
|
||||
) : (
|
||||
<>
|
||||
<div className="mx-auto max-w-3xl">
|
||||
<div className="mx-auto max-w-2xl">
|
||||
<Text
|
||||
variant="h3"
|
||||
className="mb-1 !text-[1.375rem] text-zinc-700"
|
||||
className="mb-3 !text-[1.375rem] text-zinc-700"
|
||||
>
|
||||
Hey, <span className="text-violet-600">{greetingName}</span>
|
||||
</Text>
|
||||
<Text variant="h3" className="mb-8 !font-normal">
|
||||
Tell me about your work — I'll find what to automate.
|
||||
What do you want to automate?
|
||||
</Text>
|
||||
|
||||
<div className="mb-6">
|
||||
<ChatInput
|
||||
onSend={startChatWithPrompt}
|
||||
placeholder={inputPlaceholder}
|
||||
placeholder='You can search or just ask - e.g. "create a blog post outline"'
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
<div className="flex flex-wrap items-center justify-center gap-3 overflow-x-auto [-ms-overflow-style:none] [scrollbar-width:none] [&::-webkit-scrollbar]:hidden">
|
||||
<div className="flex flex-nowrap items-center justify-center gap-3 overflow-x-auto [-ms-overflow-style:none] [scrollbar-width:none] [&::-webkit-scrollbar]:hidden">
|
||||
{quickActions.map((action) => (
|
||||
<Button
|
||||
key={action}
|
||||
@@ -135,7 +116,7 @@ export default function CopilotPage() {
|
||||
variant="outline"
|
||||
size="small"
|
||||
onClick={() => handleQuickAction(action)}
|
||||
className="h-auto shrink-0 border-zinc-300 px-3 py-2 text-[.9rem] text-zinc-600"
|
||||
className="h-auto shrink-0 border-zinc-600 !px-4 !py-2 text-[1rem] text-zinc-600"
|
||||
>
|
||||
{action}
|
||||
</Button>
|
||||
|
||||
@@ -7804,57 +7804,68 @@
|
||||
"anyOf": [{ "type": "integer" }, { "type": "null" }],
|
||||
"title": "Forked From Version"
|
||||
},
|
||||
"sub_graphs": {
|
||||
"items": { "$ref": "#/components/schemas/BaseGraph-Output" },
|
||||
"type": "array",
|
||||
"title": "Sub Graphs",
|
||||
"default": []
|
||||
},
|
||||
"user_id": { "type": "string", "title": "User Id" },
|
||||
"input_schema": {
|
||||
"additionalProperties": true,
|
||||
"type": "object",
|
||||
"title": "Input Schema"
|
||||
"title": "Input Schema",
|
||||
"readOnly": true
|
||||
},
|
||||
"output_schema": {
|
||||
"additionalProperties": true,
|
||||
"type": "object",
|
||||
"title": "Output Schema"
|
||||
},
|
||||
"credentials_input_schema": {
|
||||
"additionalProperties": true,
|
||||
"type": "object",
|
||||
"title": "Credentials Input Schema"
|
||||
"title": "Output Schema",
|
||||
"readOnly": true
|
||||
},
|
||||
"has_external_trigger": {
|
||||
"type": "boolean",
|
||||
"title": "Has External Trigger"
|
||||
"title": "Has External Trigger",
|
||||
"readOnly": true
|
||||
},
|
||||
"has_human_in_the_loop": {
|
||||
"type": "boolean",
|
||||
"title": "Has Human In The Loop"
|
||||
"title": "Has Human In The Loop",
|
||||
"readOnly": true
|
||||
},
|
||||
"has_sensitive_action": {
|
||||
"type": "boolean",
|
||||
"title": "Has Sensitive Action"
|
||||
"title": "Has Sensitive Action",
|
||||
"readOnly": true
|
||||
},
|
||||
"trigger_setup_info": {
|
||||
"anyOf": [
|
||||
{ "$ref": "#/components/schemas/GraphTriggerInfo" },
|
||||
{ "type": "null" }
|
||||
]
|
||||
],
|
||||
"readOnly": true
|
||||
},
|
||||
"credentials_input_schema": {
|
||||
"additionalProperties": true,
|
||||
"type": "object",
|
||||
"title": "Credentials Input Schema",
|
||||
"readOnly": true
|
||||
}
|
||||
},
|
||||
"type": "object",
|
||||
"required": [
|
||||
"id",
|
||||
"name",
|
||||
"description",
|
||||
"user_id",
|
||||
"input_schema",
|
||||
"output_schema",
|
||||
"credentials_input_schema",
|
||||
"has_external_trigger",
|
||||
"has_human_in_the_loop",
|
||||
"has_sensitive_action",
|
||||
"trigger_setup_info"
|
||||
"trigger_setup_info",
|
||||
"credentials_input_schema"
|
||||
],
|
||||
"title": "GraphMeta",
|
||||
"description": "Graph metadata without nodes/links, used for list endpoints.\n\nThis is a flat, lightweight model (not inheriting from Graph) to avoid recomputing\nexpensive computed fields. Values are copied from GraphModel."
|
||||
"title": "GraphMeta"
|
||||
},
|
||||
"GraphModel": {
|
||||
"properties": {
|
||||
|
||||
@@ -2,6 +2,7 @@ import type { SessionDetailResponse } from "@/app/api/__generated__/models/sessi
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import { Dialog } from "@/components/molecules/Dialog/Dialog";
|
||||
import { useBreakpoint } from "@/lib/hooks/useBreakpoint";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { GlobeHemisphereEastIcon } from "@phosphor-icons/react";
|
||||
import { useEffect } from "react";
|
||||
@@ -55,6 +56,10 @@ export function ChatContainer({
|
||||
onStreamingChange?.(isStreaming);
|
||||
}, [isStreaming, onStreamingChange]);
|
||||
|
||||
const breakpoint = useBreakpoint();
|
||||
const isMobile =
|
||||
breakpoint === "base" || breakpoint === "sm" || breakpoint === "md";
|
||||
|
||||
return (
|
||||
<div
|
||||
className={cn(
|
||||
@@ -122,7 +127,11 @@ export function ChatContainer({
|
||||
disabled={isStreaming || !sessionId}
|
||||
isStreaming={isStreaming}
|
||||
onStop={stopStreaming}
|
||||
placeholder="What else can I help with?"
|
||||
placeholder={
|
||||
isMobile
|
||||
? "You can search or just ask"
|
||||
: 'You can search or just ask — e.g. "create a blog post outline"'
|
||||
}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -74,20 +74,19 @@ export function ChatInput({
|
||||
hasMultipleLines ? "rounded-xlarge" : "rounded-full",
|
||||
)}
|
||||
>
|
||||
{!value && !isRecording && (
|
||||
<div
|
||||
className="pointer-events-none absolute inset-0 top-0.5 flex items-center justify-start pl-14 text-[1rem] text-zinc-400"
|
||||
aria-hidden="true"
|
||||
>
|
||||
{isTranscribing ? "Transcribing..." : placeholder}
|
||||
</div>
|
||||
)}
|
||||
<textarea
|
||||
id={inputId}
|
||||
aria-label="Chat message input"
|
||||
value={value}
|
||||
onChange={handleChange}
|
||||
onKeyDown={handleKeyDown}
|
||||
placeholder={
|
||||
isTranscribing
|
||||
? "Transcribing..."
|
||||
: isRecording
|
||||
? ""
|
||||
: placeholder
|
||||
}
|
||||
disabled={isInputDisabled}
|
||||
rows={1}
|
||||
className={cn(
|
||||
@@ -123,14 +122,13 @@ export function ChatInput({
|
||||
size="icon"
|
||||
aria-label={isRecording ? "Stop recording" : "Start recording"}
|
||||
onClick={toggleRecording}
|
||||
disabled={disabled || isTranscribing || isStreaming}
|
||||
disabled={disabled || isTranscribing}
|
||||
className={cn(
|
||||
isRecording
|
||||
? "animate-pulse border-red-500 bg-red-500 text-white hover:border-red-600 hover:bg-red-600"
|
||||
: isTranscribing
|
||||
? "border-zinc-300 bg-zinc-100 text-zinc-400"
|
||||
: "border-zinc-300 bg-white text-zinc-500 hover:border-zinc-400 hover:bg-zinc-50 hover:text-zinc-700",
|
||||
isStreaming && "opacity-40",
|
||||
)}
|
||||
>
|
||||
{isTranscribing ? (
|
||||
|
||||
@@ -38,8 +38,8 @@ export function AudioWaveform({
|
||||
// Create audio context and analyser
|
||||
const audioContext = new AudioContext();
|
||||
const analyser = audioContext.createAnalyser();
|
||||
analyser.fftSize = 256;
|
||||
analyser.smoothingTimeConstant = 0.3;
|
||||
analyser.fftSize = 512;
|
||||
analyser.smoothingTimeConstant = 0.8;
|
||||
|
||||
// Connect the stream to the analyser
|
||||
const source = audioContext.createMediaStreamSource(stream);
|
||||
@@ -73,11 +73,10 @@ export function AudioWaveform({
|
||||
maxAmplitude = Math.max(maxAmplitude, amplitude);
|
||||
}
|
||||
|
||||
// Normalize amplitude (0-128 range) to 0-1
|
||||
const normalized = maxAmplitude / 128;
|
||||
// Apply sensitivity boost (multiply by 4) and use sqrt curve to amplify quiet sounds
|
||||
const boosted = Math.min(1, Math.sqrt(normalized) * 4);
|
||||
const height = minBarHeight + boosted * (maxBarHeight - minBarHeight);
|
||||
// Map amplitude (0-128) to bar height
|
||||
const normalized = (maxAmplitude / 128) * 255;
|
||||
const height =
|
||||
minBarHeight + (normalized / 255) * (maxBarHeight - minBarHeight);
|
||||
newBars.push(height);
|
||||
}
|
||||
|
||||
|
||||
@@ -224,7 +224,7 @@ export function useVoiceRecording({
|
||||
[value, isTranscribing, toggleRecording, baseHandleKeyDown],
|
||||
);
|
||||
|
||||
const showMicButton = isSupported;
|
||||
const showMicButton = isSupported && !isStreaming;
|
||||
const isInputDisabled = disabled || isStreaming || isTranscribing;
|
||||
|
||||
// Cleanup on unmount
|
||||
|
||||
@@ -346,7 +346,6 @@ export function ChatMessage({
|
||||
toolId={message.toolId}
|
||||
toolName={message.toolName}
|
||||
result={message.result}
|
||||
onSendMessage={onSendMessage}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
|
||||
@@ -73,7 +73,6 @@ export function MessageList({
|
||||
key={index}
|
||||
message={message}
|
||||
prevMessage={messages[index - 1]}
|
||||
onSendMessage={onSendMessage}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -5,13 +5,11 @@ import { shouldSkipAgentOutput } from "../../helpers";
|
||||
export interface LastToolResponseProps {
|
||||
message: ChatMessageData;
|
||||
prevMessage: ChatMessageData | undefined;
|
||||
onSendMessage?: (content: string) => void;
|
||||
}
|
||||
|
||||
export function LastToolResponse({
|
||||
message,
|
||||
prevMessage,
|
||||
onSendMessage,
|
||||
}: LastToolResponseProps) {
|
||||
if (message.type !== "tool_response") return null;
|
||||
|
||||
@@ -23,7 +21,6 @@ export function LastToolResponse({
|
||||
toolId={message.toolId}
|
||||
toolName={message.toolName}
|
||||
result={message.result}
|
||||
onSendMessage={onSendMessage}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
import { Progress } from "@/components/atoms/Progress/Progress";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { useEffect, useRef, useState } from "react";
|
||||
import { AIChatBubble } from "../AIChatBubble/AIChatBubble";
|
||||
import { useAsymptoticProgress } from "../ToolCallMessage/useAsymptoticProgress";
|
||||
|
||||
export interface ThinkingMessageProps {
|
||||
className?: string;
|
||||
@@ -13,19 +11,18 @@ export function ThinkingMessage({ className }: ThinkingMessageProps) {
|
||||
const [showCoffeeMessage, setShowCoffeeMessage] = useState(false);
|
||||
const timerRef = useRef<NodeJS.Timeout | null>(null);
|
||||
const coffeeTimerRef = useRef<NodeJS.Timeout | null>(null);
|
||||
const progress = useAsymptoticProgress(showCoffeeMessage);
|
||||
|
||||
useEffect(() => {
|
||||
if (timerRef.current === null) {
|
||||
timerRef.current = setTimeout(() => {
|
||||
setShowSlowLoader(true);
|
||||
}, 3000);
|
||||
}, 8000);
|
||||
}
|
||||
|
||||
if (coffeeTimerRef.current === null) {
|
||||
coffeeTimerRef.current = setTimeout(() => {
|
||||
setShowCoffeeMessage(true);
|
||||
}, 8000);
|
||||
}, 10000);
|
||||
}
|
||||
|
||||
return () => {
|
||||
@@ -52,18 +49,9 @@ export function ThinkingMessage({ className }: ThinkingMessageProps) {
|
||||
<AIChatBubble>
|
||||
<div className="transition-all duration-500 ease-in-out">
|
||||
{showCoffeeMessage ? (
|
||||
<div className="flex flex-col items-center gap-3">
|
||||
<div className="flex w-full max-w-[280px] flex-col gap-1.5">
|
||||
<div className="flex items-center justify-between text-xs text-neutral-500">
|
||||
<span>Working on it...</span>
|
||||
<span>{Math.round(progress)}%</span>
|
||||
</div>
|
||||
<Progress value={progress} className="h-2 w-full" />
|
||||
</div>
|
||||
<span className="inline-block animate-shimmer bg-gradient-to-r from-neutral-400 via-neutral-600 to-neutral-400 bg-[length:200%_100%] bg-clip-text text-transparent">
|
||||
This could take a few minutes, grab a coffee ☕️
|
||||
</span>
|
||||
</div>
|
||||
<span className="inline-block animate-shimmer bg-gradient-to-r from-neutral-400 via-neutral-600 to-neutral-400 bg-[length:200%_100%] bg-clip-text text-transparent">
|
||||
This could take a few minutes, grab a coffee ☕️
|
||||
</span>
|
||||
) : showSlowLoader ? (
|
||||
<span className="inline-block animate-shimmer bg-gradient-to-r from-neutral-400 via-neutral-600 to-neutral-400 bg-[length:200%_100%] bg-clip-text text-transparent">
|
||||
Taking a bit more time...
|
||||
|
||||
@@ -1,50 +0,0 @@
|
||||
import { useEffect, useRef, useState } from "react";
|
||||
|
||||
/**
|
||||
* Hook that returns a progress value that starts fast and slows down,
|
||||
* asymptotically approaching but never reaching the max value.
|
||||
*
|
||||
* Uses a half-life formula: progress = max * (1 - 0.5^(time/halfLife))
|
||||
* This creates the "game loading bar" effect where:
|
||||
* - 50% is reached at halfLifeSeconds
|
||||
* - 75% is reached at 2 * halfLifeSeconds
|
||||
* - 87.5% is reached at 3 * halfLifeSeconds
|
||||
* - and so on...
|
||||
*
|
||||
* @param isActive - Whether the progress should be animating
|
||||
* @param halfLifeSeconds - Time in seconds to reach 50% progress (default: 30)
|
||||
* @param maxProgress - Maximum progress value to approach (default: 100)
|
||||
* @param intervalMs - Update interval in milliseconds (default: 100)
|
||||
* @returns Current progress value (0-maxProgress)
|
||||
*/
|
||||
export function useAsymptoticProgress(
|
||||
isActive: boolean,
|
||||
halfLifeSeconds = 30,
|
||||
maxProgress = 100,
|
||||
intervalMs = 100,
|
||||
) {
|
||||
const [progress, setProgress] = useState(0);
|
||||
const elapsedTimeRef = useRef(0);
|
||||
|
||||
useEffect(() => {
|
||||
if (!isActive) {
|
||||
setProgress(0);
|
||||
elapsedTimeRef.current = 0;
|
||||
return;
|
||||
}
|
||||
|
||||
const interval = setInterval(() => {
|
||||
elapsedTimeRef.current += intervalMs / 1000;
|
||||
// Half-life approach: progress = max * (1 - 0.5^(time/halfLife))
|
||||
// At t=halfLife: 50%, at t=2*halfLife: 75%, at t=3*halfLife: 87.5%, etc.
|
||||
const newProgress =
|
||||
maxProgress *
|
||||
(1 - Math.pow(0.5, elapsedTimeRef.current / halfLifeSeconds));
|
||||
setProgress(newProgress);
|
||||
}, intervalMs);
|
||||
|
||||
return () => clearInterval(interval);
|
||||
}, [isActive, halfLifeSeconds, maxProgress, intervalMs]);
|
||||
|
||||
return progress;
|
||||
}
|
||||
@@ -1,128 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { useGetV2GetLibraryAgent } from "@/app/api/__generated__/endpoints/library/library";
|
||||
import { GraphExecutionJobInfo } from "@/app/api/__generated__/models/graphExecutionJobInfo";
|
||||
import { GraphExecutionMeta } from "@/app/api/__generated__/models/graphExecutionMeta";
|
||||
import { RunAgentModal } from "@/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/RunAgentModal/RunAgentModal";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import {
|
||||
CheckCircleIcon,
|
||||
PencilLineIcon,
|
||||
PlayIcon,
|
||||
} from "@phosphor-icons/react";
|
||||
import { AIChatBubble } from "../AIChatBubble/AIChatBubble";
|
||||
|
||||
interface Props {
|
||||
agentName: string;
|
||||
libraryAgentId: string;
|
||||
onSendMessage?: (content: string) => void;
|
||||
}
|
||||
|
||||
export function AgentCreatedPrompt({
|
||||
agentName,
|
||||
libraryAgentId,
|
||||
onSendMessage,
|
||||
}: Props) {
|
||||
// Fetch library agent eagerly so modal is ready when user clicks
|
||||
const { data: libraryAgentResponse, isLoading } = useGetV2GetLibraryAgent(
|
||||
libraryAgentId,
|
||||
{
|
||||
query: {
|
||||
enabled: !!libraryAgentId,
|
||||
},
|
||||
},
|
||||
);
|
||||
|
||||
const libraryAgent =
|
||||
libraryAgentResponse?.status === 200 ? libraryAgentResponse.data : null;
|
||||
|
||||
function handleRunWithPlaceholders() {
|
||||
onSendMessage?.(
|
||||
`Run the agent "${agentName}" with placeholder/example values so I can test it.`,
|
||||
);
|
||||
}
|
||||
|
||||
function handleRunCreated(execution: GraphExecutionMeta) {
|
||||
onSendMessage?.(
|
||||
`I've started the agent "${agentName}". The execution ID is ${execution.id}. Please monitor its progress and let me know when it completes.`,
|
||||
);
|
||||
}
|
||||
|
||||
function handleScheduleCreated(schedule: GraphExecutionJobInfo) {
|
||||
const scheduleInfo = schedule.cron
|
||||
? `with cron schedule "${schedule.cron}"`
|
||||
: "to run on the specified schedule";
|
||||
onSendMessage?.(
|
||||
`I've scheduled the agent "${agentName}" ${scheduleInfo}. The schedule ID is ${schedule.id}.`,
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<AIChatBubble>
|
||||
<div className="flex flex-col gap-4">
|
||||
<div className="flex items-center gap-2">
|
||||
<div className="flex h-8 w-8 items-center justify-center rounded-full bg-green-100">
|
||||
<CheckCircleIcon
|
||||
size={18}
|
||||
weight="fill"
|
||||
className="text-green-600"
|
||||
/>
|
||||
</div>
|
||||
<div>
|
||||
<Text variant="body-medium" className="text-neutral-900">
|
||||
Agent Created Successfully
|
||||
</Text>
|
||||
<Text variant="small" className="text-neutral-500">
|
||||
"{agentName}" is ready to test
|
||||
</Text>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="flex flex-col gap-2">
|
||||
<Text variant="small-medium" className="text-neutral-700">
|
||||
Ready to test?
|
||||
</Text>
|
||||
<div className="flex flex-wrap gap-2">
|
||||
<Button
|
||||
variant="outline"
|
||||
size="small"
|
||||
onClick={handleRunWithPlaceholders}
|
||||
className="gap-2"
|
||||
>
|
||||
<PlayIcon size={16} />
|
||||
Run with example values
|
||||
</Button>
|
||||
{libraryAgent ? (
|
||||
<RunAgentModal
|
||||
triggerSlot={
|
||||
<Button variant="outline" size="small" className="gap-2">
|
||||
<PencilLineIcon size={16} />
|
||||
Run with my inputs
|
||||
</Button>
|
||||
}
|
||||
agent={libraryAgent}
|
||||
onRunCreated={handleRunCreated}
|
||||
onScheduleCreated={handleScheduleCreated}
|
||||
/>
|
||||
) : (
|
||||
<Button
|
||||
variant="outline"
|
||||
size="small"
|
||||
loading={isLoading}
|
||||
disabled
|
||||
className="gap-2"
|
||||
>
|
||||
<PencilLineIcon size={16} />
|
||||
Run with my inputs
|
||||
</Button>
|
||||
)}
|
||||
</div>
|
||||
<Text variant="small" className="text-neutral-500">
|
||||
or just ask me
|
||||
</Text>
|
||||
</div>
|
||||
</div>
|
||||
</AIChatBubble>
|
||||
);
|
||||
}
|
||||
@@ -2,13 +2,11 @@ import { Text } from "@/components/atoms/Text/Text";
|
||||
import { cn } from "@/lib/utils";
|
||||
import type { ToolResult } from "@/types/chat";
|
||||
import { WarningCircleIcon } from "@phosphor-icons/react";
|
||||
import { AgentCreatedPrompt } from "./AgentCreatedPrompt";
|
||||
import { AIChatBubble } from "../AIChatBubble/AIChatBubble";
|
||||
import { MarkdownContent } from "../MarkdownContent/MarkdownContent";
|
||||
import {
|
||||
formatToolResponse,
|
||||
getErrorMessage,
|
||||
isAgentSavedResponse,
|
||||
isErrorResponse,
|
||||
} from "./helpers";
|
||||
|
||||
@@ -18,7 +16,6 @@ export interface ToolResponseMessageProps {
|
||||
result?: ToolResult;
|
||||
success?: boolean;
|
||||
className?: string;
|
||||
onSendMessage?: (content: string) => void;
|
||||
}
|
||||
|
||||
export function ToolResponseMessage({
|
||||
@@ -27,7 +24,6 @@ export function ToolResponseMessage({
|
||||
result,
|
||||
success: _success,
|
||||
className,
|
||||
onSendMessage,
|
||||
}: ToolResponseMessageProps) {
|
||||
if (isErrorResponse(result)) {
|
||||
const errorMessage = getErrorMessage(result);
|
||||
@@ -47,18 +43,6 @@ export function ToolResponseMessage({
|
||||
);
|
||||
}
|
||||
|
||||
// Check for agent_saved response - show special prompt
|
||||
const agentSavedData = isAgentSavedResponse(result);
|
||||
if (agentSavedData.isSaved) {
|
||||
return (
|
||||
<AgentCreatedPrompt
|
||||
agentName={agentSavedData.agentName}
|
||||
libraryAgentId={agentSavedData.libraryAgentId}
|
||||
onSendMessage={onSendMessage}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
const formattedText = formatToolResponse(result, toolName);
|
||||
|
||||
return (
|
||||
|
||||
@@ -6,43 +6,6 @@ function stripInternalReasoning(content: string): string {
|
||||
.trim();
|
||||
}
|
||||
|
||||
export interface AgentSavedData {
|
||||
isSaved: boolean;
|
||||
agentName: string;
|
||||
agentId: string;
|
||||
libraryAgentId: string;
|
||||
libraryAgentLink: string;
|
||||
}
|
||||
|
||||
export function isAgentSavedResponse(result: unknown): AgentSavedData {
|
||||
if (typeof result !== "object" || result === null) {
|
||||
return {
|
||||
isSaved: false,
|
||||
agentName: "",
|
||||
agentId: "",
|
||||
libraryAgentId: "",
|
||||
libraryAgentLink: "",
|
||||
};
|
||||
}
|
||||
const response = result as Record<string, unknown>;
|
||||
if (response.type === "agent_saved") {
|
||||
return {
|
||||
isSaved: true,
|
||||
agentName: (response.agent_name as string) || "Agent",
|
||||
agentId: (response.agent_id as string) || "",
|
||||
libraryAgentId: (response.library_agent_id as string) || "",
|
||||
libraryAgentLink: (response.library_agent_link as string) || "",
|
||||
};
|
||||
}
|
||||
return {
|
||||
isSaved: false,
|
||||
agentName: "",
|
||||
agentId: "",
|
||||
libraryAgentId: "",
|
||||
libraryAgentLink: "",
|
||||
};
|
||||
}
|
||||
|
||||
export function isErrorResponse(result: unknown): boolean {
|
||||
if (typeof result === "string") {
|
||||
const lower = result.toLowerCase();
|
||||
|
||||
@@ -41,17 +41,7 @@ export function HostScopedCredentialsModal({
|
||||
const currentHost = currentUrl ? getHostFromUrl(currentUrl) : "";
|
||||
|
||||
const formSchema = z.object({
|
||||
host: z
|
||||
.string()
|
||||
.min(1, "Host is required")
|
||||
.refine((val) => !/^[a-zA-Z][a-zA-Z\d+\-.]*:\/\//.test(val), {
|
||||
message: "Enter only the host (e.g. api.example.com), not a full URL",
|
||||
})
|
||||
.refine((val) => !val.includes("/"), {
|
||||
message:
|
||||
"Enter only the host (e.g. api.example.com), without a trailing path. " +
|
||||
"You may specify a port (e.g. api.example.com:8080) if needed.",
|
||||
}),
|
||||
host: z.string().min(1, "Host is required"),
|
||||
title: z.string().optional(),
|
||||
headers: z.record(z.string()).optional(),
|
||||
});
|
||||
|
||||
@@ -62,6 +62,7 @@ Below is a comprehensive list of all available blocks, categorized by their prim
|
||||
| [Get Store Agent Details](block-integrations/system/store_operations.md#get-store-agent-details) | Get detailed information about an agent from the store |
|
||||
| [Get Weather Information](block-integrations/basic.md#get-weather-information) | Retrieves weather information for a specified location using OpenWeatherMap API |
|
||||
| [Human In The Loop](block-integrations/basic.md#human-in-the-loop) | Pause execution and wait for human approval or modification of data |
|
||||
| [Linear Search Issues](block-integrations/linear/issues.md#linear-search-issues) | Searches for issues on Linear |
|
||||
| [List Is Empty](block-integrations/basic.md#list-is-empty) | Checks if a list is empty |
|
||||
| [List Library Agents](block-integrations/system/library_operations.md#list-library-agents) | List all agents in your personal library |
|
||||
| [Note](block-integrations/basic.md#note) | A visual annotation block that displays a sticky note in the workflow editor for documentation and organization purposes |
|
||||
@@ -192,7 +193,6 @@ Below is a comprehensive list of all available blocks, categorized by their prim
|
||||
| [Get Current Time](block-integrations/text.md#get-current-time) | This block outputs the current time |
|
||||
| [Match Text Pattern](block-integrations/text.md#match-text-pattern) | Matches text against a regex pattern and forwards data to positive or negative output based on the match |
|
||||
| [Text Decoder](block-integrations/text.md#text-decoder) | Decodes a string containing escape sequences into actual text |
|
||||
| [Text Encoder](block-integrations/text.md#text-encoder) | Encodes a string by converting special characters into escape sequences |
|
||||
| [Text Replace](block-integrations/text.md#text-replace) | This block is used to replace a text with a new text |
|
||||
| [Text Split](block-integrations/text.md#text-split) | This block is used to split a text into a list of strings |
|
||||
| [Word Character Count](block-integrations/text.md#word-character-count) | Counts the number of words and characters in a given text |
|
||||
@@ -571,7 +571,6 @@ Below is a comprehensive list of all available blocks, categorized by their prim
|
||||
| [Linear Create Comment](block-integrations/linear/comment.md#linear-create-comment) | Creates a new comment on a Linear issue |
|
||||
| [Linear Create Issue](block-integrations/linear/issues.md#linear-create-issue) | Creates a new issue on Linear |
|
||||
| [Linear Get Project Issues](block-integrations/linear/issues.md#linear-get-project-issues) | Gets issues from a Linear project filtered by status and assignee |
|
||||
| [Linear Search Issues](block-integrations/linear/issues.md#linear-search-issues) | Searches for issues on Linear |
|
||||
| [Linear Search Projects](block-integrations/linear/projects.md#linear-search-projects) | Searches for projects on Linear |
|
||||
|
||||
## Hardware
|
||||
|
||||
@@ -90,9 +90,9 @@ Searches for issues on Linear
|
||||
|
||||
### How it works
|
||||
<!-- MANUAL: how_it_works -->
|
||||
This block searches for issues in Linear using a text query. It searches across issue titles, descriptions, and other fields to find matching issues. You can limit the number of results returned using the `max_results` parameter (default: 10, max: 100) to control token consumption and response size.
|
||||
This block searches for issues in Linear using a text query. It searches across issue titles, descriptions, and other fields to find matching issues.
|
||||
|
||||
Optionally filter results by team name to narrow searches to specific workspaces. If a team name is provided, the block resolves it to a team ID before searching. Returns matching issues with their state, creation date, project, and assignee information. If the search or team resolution fails, an error message is returned.
|
||||
Returns a list of issues matching the search term.
|
||||
<!-- END MANUAL -->
|
||||
|
||||
### Inputs
|
||||
@@ -100,14 +100,12 @@ Optionally filter results by team name to narrow searches to specific workspaces
|
||||
| Input | Description | Type | Required |
|
||||
|-------|-------------|------|----------|
|
||||
| term | Term to search for issues | str | Yes |
|
||||
| max_results | Maximum number of results to return | int | No |
|
||||
| team_name | Optional team name to filter results (e.g., 'Internal', 'Open Source') | str | No |
|
||||
|
||||
### Outputs
|
||||
|
||||
| Output | Description | Type |
|
||||
|--------|-------------|------|
|
||||
| error | Error message if the search failed | str |
|
||||
| error | Error message if the operation failed | str |
|
||||
| issues | List of issues | List[Issue] |
|
||||
|
||||
### Possible use case
|
||||
|
||||
@@ -380,42 +380,6 @@ This is useful when working with data from APIs or files where escape sequences
|
||||
|
||||
---
|
||||
|
||||
## Text Encoder
|
||||
|
||||
### What it is
|
||||
Encodes a string by converting special characters into escape sequences
|
||||
|
||||
### How it works
|
||||
<!-- MANUAL: how_it_works -->
|
||||
The Text Encoder takes the input string and applies Python's `unicode_escape` encoding (equivalent to `codecs.encode(text, "unicode_escape").decode("utf-8")`) to transform special characters like newlines, tabs, and backslashes into their escaped forms.
|
||||
|
||||
The block relies on the input schema to ensure the value is a string; non-string inputs are rejected by validation, and any encoding failures surface as block errors. Non-ASCII characters are emitted as `\uXXXX` sequences, which is useful for ASCII-only payloads.
|
||||
<!-- END MANUAL -->
|
||||
|
||||
### Inputs
|
||||
|
||||
| Input | Description | Type | Required |
|
||||
|-------|-------------|------|----------|
|
||||
| text | A string containing special characters to be encoded | str | Yes |
|
||||
|
||||
### Outputs
|
||||
|
||||
| Output | Description | Type |
|
||||
|--------|-------------|------|
|
||||
| error | Error message if encoding fails | str |
|
||||
| encoded_text | The encoded text with special characters converted to escape sequences | str |
|
||||
|
||||
### Possible use case
|
||||
<!-- MANUAL: use_case -->
|
||||
**JSON Payload Preparation**: Encode multiline or quoted text before embedding it in JSON string fields to ensure proper escaping.
|
||||
|
||||
**Config/ENV Generation**: Convert template text into escaped strings for `.env` or YAML values that require special character handling.
|
||||
|
||||
**Snapshot Fixtures**: Produce stable escaped strings for golden files or API tests where consistent text representation is needed.
|
||||
<!-- END MANUAL -->
|
||||
|
||||
---
|
||||
|
||||
## Text Replace
|
||||
|
||||
### What it is
|
||||
|
||||
Reference in New Issue
Block a user