Compare commits
8 Commits
testing-cl
...
dev
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9a6e17ff52 | ||
|
|
fb58827c61 | ||
|
|
595f3508c1 | ||
|
|
7892590b12 | ||
|
|
82d7134fc6 | ||
|
|
90466908a8 | ||
|
|
f9f984a8f4 | ||
|
|
fc87ed4e34 |
38
.github/workflows/platform-frontend-ci.yml
vendored
@@ -128,7 +128,7 @@ jobs:
|
|||||||
token: ${{ secrets.GITHUB_TOKEN }}
|
token: ${{ secrets.GITHUB_TOKEN }}
|
||||||
exitOnceUploaded: true
|
exitOnceUploaded: true
|
||||||
|
|
||||||
test:
|
e2e_test:
|
||||||
runs-on: big-boi
|
runs-on: big-boi
|
||||||
needs: setup
|
needs: setup
|
||||||
strategy:
|
strategy:
|
||||||
@@ -258,3 +258,39 @@ jobs:
|
|||||||
- name: Print Final Docker Compose logs
|
- name: Print Final Docker Compose logs
|
||||||
if: always()
|
if: always()
|
||||||
run: docker compose -f ../docker-compose.yml logs
|
run: docker compose -f ../docker-compose.yml logs
|
||||||
|
|
||||||
|
integration_test:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
needs: setup
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Checkout repository
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
submodules: recursive
|
||||||
|
|
||||||
|
- name: Set up Node.js
|
||||||
|
uses: actions/setup-node@v4
|
||||||
|
with:
|
||||||
|
node-version: "22.18.0"
|
||||||
|
|
||||||
|
- name: Enable corepack
|
||||||
|
run: corepack enable
|
||||||
|
|
||||||
|
- name: Restore dependencies cache
|
||||||
|
uses: actions/cache@v4
|
||||||
|
with:
|
||||||
|
path: ~/.pnpm-store
|
||||||
|
key: ${{ needs.setup.outputs.cache-key }}
|
||||||
|
restore-keys: |
|
||||||
|
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
|
||||||
|
${{ runner.os }}-pnpm-
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
run: pnpm install --frozen-lockfile
|
||||||
|
|
||||||
|
- name: Generate API client
|
||||||
|
run: pnpm generate:api
|
||||||
|
|
||||||
|
- name: Run Integration Tests
|
||||||
|
run: pnpm test:unit
|
||||||
|
|||||||
@@ -1,29 +1,28 @@
|
|||||||
"""Agent generator package - Creates agents from natural language."""
|
"""Agent generator package - Creates agents from natural language."""
|
||||||
|
|
||||||
from .core import (
|
from .core import (
|
||||||
apply_agent_patch,
|
AgentGeneratorNotConfiguredError,
|
||||||
decompose_goal,
|
decompose_goal,
|
||||||
generate_agent,
|
generate_agent,
|
||||||
generate_agent_patch,
|
generate_agent_patch,
|
||||||
get_agent_as_json,
|
get_agent_as_json,
|
||||||
|
json_to_graph,
|
||||||
save_agent_to_library,
|
save_agent_to_library,
|
||||||
)
|
)
|
||||||
from .fixer import apply_all_fixes
|
from .service import health_check as check_external_service_health
|
||||||
from .utils import get_blocks_info
|
from .service import is_external_service_configured
|
||||||
from .validator import validate_agent
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
# Core functions
|
# Core functions
|
||||||
"decompose_goal",
|
"decompose_goal",
|
||||||
"generate_agent",
|
"generate_agent",
|
||||||
"generate_agent_patch",
|
"generate_agent_patch",
|
||||||
"apply_agent_patch",
|
|
||||||
"save_agent_to_library",
|
"save_agent_to_library",
|
||||||
"get_agent_as_json",
|
"get_agent_as_json",
|
||||||
# Fixer
|
"json_to_graph",
|
||||||
"apply_all_fixes",
|
# Exceptions
|
||||||
# Validator
|
"AgentGeneratorNotConfiguredError",
|
||||||
"validate_agent",
|
# Service
|
||||||
# Utils
|
"is_external_service_configured",
|
||||||
"get_blocks_info",
|
"check_external_service_health",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -1,25 +0,0 @@
|
|||||||
"""OpenRouter client configuration for agent generation."""
|
|
||||||
|
|
||||||
import os
|
|
||||||
|
|
||||||
from openai import AsyncOpenAI
|
|
||||||
|
|
||||||
# Configuration - use OPEN_ROUTER_API_KEY for consistency with chat/config.py
|
|
||||||
OPENROUTER_API_KEY = os.getenv("OPEN_ROUTER_API_KEY")
|
|
||||||
AGENT_GENERATOR_MODEL = os.getenv("AGENT_GENERATOR_MODEL", "anthropic/claude-opus-4.5")
|
|
||||||
|
|
||||||
# OpenRouter client (OpenAI-compatible API)
|
|
||||||
_client: AsyncOpenAI | None = None
|
|
||||||
|
|
||||||
|
|
||||||
def get_client() -> AsyncOpenAI:
|
|
||||||
"""Get or create the OpenRouter client."""
|
|
||||||
global _client
|
|
||||||
if _client is None:
|
|
||||||
if not OPENROUTER_API_KEY:
|
|
||||||
raise ValueError("OPENROUTER_API_KEY environment variable is required")
|
|
||||||
_client = AsyncOpenAI(
|
|
||||||
base_url="https://openrouter.ai/api/v1",
|
|
||||||
api_key=OPENROUTER_API_KEY,
|
|
||||||
)
|
|
||||||
return _client
|
|
||||||
@@ -1,7 +1,5 @@
|
|||||||
"""Core agent generation functions."""
|
"""Core agent generation functions."""
|
||||||
|
|
||||||
import copy
|
|
||||||
import json
|
|
||||||
import logging
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any
|
from typing import Any
|
||||||
@@ -9,13 +7,35 @@ from typing import Any
|
|||||||
from backend.api.features.library import db as library_db
|
from backend.api.features.library import db as library_db
|
||||||
from backend.data.graph import Graph, Link, Node, create_graph
|
from backend.data.graph import Graph, Link, Node, create_graph
|
||||||
|
|
||||||
from .client import AGENT_GENERATOR_MODEL, get_client
|
from .service import (
|
||||||
from .prompts import DECOMPOSITION_PROMPT, GENERATION_PROMPT, PATCH_PROMPT
|
decompose_goal_external,
|
||||||
from .utils import get_block_summaries, parse_json_from_llm
|
generate_agent_external,
|
||||||
|
generate_agent_patch_external,
|
||||||
|
is_external_service_configured,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class AgentGeneratorNotConfiguredError(Exception):
|
||||||
|
"""Raised when the external Agent Generator service is not configured."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def _check_service_configured() -> None:
|
||||||
|
"""Check if the external Agent Generator service is configured.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AgentGeneratorNotConfiguredError: If the service is not configured.
|
||||||
|
"""
|
||||||
|
if not is_external_service_configured():
|
||||||
|
raise AgentGeneratorNotConfiguredError(
|
||||||
|
"Agent Generator service is not configured. "
|
||||||
|
"Set AGENTGENERATOR_HOST environment variable to enable agent generation."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def decompose_goal(description: str, context: str = "") -> dict[str, Any] | None:
|
async def decompose_goal(description: str, context: str = "") -> dict[str, Any] | None:
|
||||||
"""Break down a goal into steps or return clarifying questions.
|
"""Break down a goal into steps or return clarifying questions.
|
||||||
|
|
||||||
@@ -28,40 +48,13 @@ async def decompose_goal(description: str, context: str = "") -> dict[str, Any]
|
|||||||
- {"type": "clarifying_questions", "questions": [...]}
|
- {"type": "clarifying_questions", "questions": [...]}
|
||||||
- {"type": "instructions", "steps": [...]}
|
- {"type": "instructions", "steps": [...]}
|
||||||
Or None on error
|
Or None on error
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AgentGeneratorNotConfiguredError: If the external service is not configured.
|
||||||
"""
|
"""
|
||||||
client = get_client()
|
_check_service_configured()
|
||||||
prompt = DECOMPOSITION_PROMPT.format(block_summaries=get_block_summaries())
|
logger.info("Calling external Agent Generator service for decompose_goal")
|
||||||
|
return await decompose_goal_external(description, context)
|
||||||
full_description = description
|
|
||||||
if context:
|
|
||||||
full_description = f"{description}\n\nAdditional context:\n{context}"
|
|
||||||
|
|
||||||
try:
|
|
||||||
response = await client.chat.completions.create(
|
|
||||||
model=AGENT_GENERATOR_MODEL,
|
|
||||||
messages=[
|
|
||||||
{"role": "system", "content": prompt},
|
|
||||||
{"role": "user", "content": full_description},
|
|
||||||
],
|
|
||||||
temperature=0,
|
|
||||||
)
|
|
||||||
|
|
||||||
content = response.choices[0].message.content
|
|
||||||
if content is None:
|
|
||||||
logger.error("LLM returned empty content for decomposition")
|
|
||||||
return None
|
|
||||||
|
|
||||||
result = parse_json_from_llm(content)
|
|
||||||
|
|
||||||
if result is None:
|
|
||||||
logger.error(f"Failed to parse decomposition response: {content[:200]}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error decomposing goal: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
async def generate_agent(instructions: dict[str, Any]) -> dict[str, Any] | None:
|
async def generate_agent(instructions: dict[str, Any]) -> dict[str, Any] | None:
|
||||||
@@ -72,31 +65,14 @@ async def generate_agent(instructions: dict[str, Any]) -> dict[str, Any] | None:
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Agent JSON dict or None on error
|
Agent JSON dict or None on error
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AgentGeneratorNotConfiguredError: If the external service is not configured.
|
||||||
"""
|
"""
|
||||||
client = get_client()
|
_check_service_configured()
|
||||||
prompt = GENERATION_PROMPT.format(block_summaries=get_block_summaries())
|
logger.info("Calling external Agent Generator service for generate_agent")
|
||||||
|
result = await generate_agent_external(instructions)
|
||||||
try:
|
if result:
|
||||||
response = await client.chat.completions.create(
|
|
||||||
model=AGENT_GENERATOR_MODEL,
|
|
||||||
messages=[
|
|
||||||
{"role": "system", "content": prompt},
|
|
||||||
{"role": "user", "content": json.dumps(instructions, indent=2)},
|
|
||||||
],
|
|
||||||
temperature=0,
|
|
||||||
)
|
|
||||||
|
|
||||||
content = response.choices[0].message.content
|
|
||||||
if content is None:
|
|
||||||
logger.error("LLM returned empty content for agent generation")
|
|
||||||
return None
|
|
||||||
|
|
||||||
result = parse_json_from_llm(content)
|
|
||||||
|
|
||||||
if result is None:
|
|
||||||
logger.error(f"Failed to parse agent JSON: {content[:200]}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Ensure required fields
|
# Ensure required fields
|
||||||
if "id" not in result:
|
if "id" not in result:
|
||||||
result["id"] = str(uuid.uuid4())
|
result["id"] = str(uuid.uuid4())
|
||||||
@@ -104,12 +80,7 @@ async def generate_agent(instructions: dict[str, Any]) -> dict[str, Any] | None:
|
|||||||
result["version"] = 1
|
result["version"] = 1
|
||||||
if "is_active" not in result:
|
if "is_active" not in result:
|
||||||
result["is_active"] = True
|
result["is_active"] = True
|
||||||
|
return result
|
||||||
return result
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error generating agent: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def json_to_graph(agent_json: dict[str, Any]) -> Graph:
|
def json_to_graph(agent_json: dict[str, Any]) -> Graph:
|
||||||
@@ -284,108 +255,23 @@ async def get_agent_as_json(
|
|||||||
async def generate_agent_patch(
|
async def generate_agent_patch(
|
||||||
update_request: str, current_agent: dict[str, Any]
|
update_request: str, current_agent: dict[str, Any]
|
||||||
) -> dict[str, Any] | None:
|
) -> dict[str, Any] | None:
|
||||||
"""Generate a patch to update an existing agent.
|
"""Update an existing agent using natural language.
|
||||||
|
|
||||||
|
The external Agent Generator service handles:
|
||||||
|
- Generating the patch
|
||||||
|
- Applying the patch
|
||||||
|
- Fixing and validating the result
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
update_request: Natural language description of changes
|
update_request: Natural language description of changes
|
||||||
current_agent: Current agent JSON
|
current_agent: Current agent JSON
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Patch dict or clarifying questions, or None on error
|
Updated agent JSON, clarifying questions dict, or None on error
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AgentGeneratorNotConfiguredError: If the external service is not configured.
|
||||||
"""
|
"""
|
||||||
client = get_client()
|
_check_service_configured()
|
||||||
prompt = PATCH_PROMPT.format(
|
logger.info("Calling external Agent Generator service for generate_agent_patch")
|
||||||
current_agent=json.dumps(current_agent, indent=2),
|
return await generate_agent_patch_external(update_request, current_agent)
|
||||||
block_summaries=get_block_summaries(),
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
response = await client.chat.completions.create(
|
|
||||||
model=AGENT_GENERATOR_MODEL,
|
|
||||||
messages=[
|
|
||||||
{"role": "system", "content": prompt},
|
|
||||||
{"role": "user", "content": update_request},
|
|
||||||
],
|
|
||||||
temperature=0,
|
|
||||||
)
|
|
||||||
|
|
||||||
content = response.choices[0].message.content
|
|
||||||
if content is None:
|
|
||||||
logger.error("LLM returned empty content for patch generation")
|
|
||||||
return None
|
|
||||||
|
|
||||||
return parse_json_from_llm(content)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error generating patch: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def apply_agent_patch(
|
|
||||||
current_agent: dict[str, Any], patch: dict[str, Any]
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
"""Apply a patch to an existing agent.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
current_agent: Current agent JSON
|
|
||||||
patch: Patch dict with operations
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Updated agent JSON
|
|
||||||
"""
|
|
||||||
agent = copy.deepcopy(current_agent)
|
|
||||||
patches = patch.get("patches", [])
|
|
||||||
|
|
||||||
for p in patches:
|
|
||||||
patch_type = p.get("type")
|
|
||||||
|
|
||||||
if patch_type == "modify":
|
|
||||||
node_id = p.get("node_id")
|
|
||||||
changes = p.get("changes", {})
|
|
||||||
|
|
||||||
for node in agent.get("nodes", []):
|
|
||||||
if node["id"] == node_id:
|
|
||||||
_deep_update(node, changes)
|
|
||||||
logger.debug(f"Modified node {node_id}")
|
|
||||||
break
|
|
||||||
|
|
||||||
elif patch_type == "add":
|
|
||||||
new_nodes = p.get("new_nodes", [])
|
|
||||||
new_links = p.get("new_links", [])
|
|
||||||
|
|
||||||
agent["nodes"] = agent.get("nodes", []) + new_nodes
|
|
||||||
agent["links"] = agent.get("links", []) + new_links
|
|
||||||
logger.debug(f"Added {len(new_nodes)} nodes, {len(new_links)} links")
|
|
||||||
|
|
||||||
elif patch_type == "remove":
|
|
||||||
node_ids_to_remove = set(p.get("node_ids", []))
|
|
||||||
link_ids_to_remove = set(p.get("link_ids", []))
|
|
||||||
|
|
||||||
# Remove nodes
|
|
||||||
agent["nodes"] = [
|
|
||||||
n for n in agent.get("nodes", []) if n["id"] not in node_ids_to_remove
|
|
||||||
]
|
|
||||||
|
|
||||||
# Remove links (both explicit and those referencing removed nodes)
|
|
||||||
agent["links"] = [
|
|
||||||
link
|
|
||||||
for link in agent.get("links", [])
|
|
||||||
if link["id"] not in link_ids_to_remove
|
|
||||||
and link["source_id"] not in node_ids_to_remove
|
|
||||||
and link["sink_id"] not in node_ids_to_remove
|
|
||||||
]
|
|
||||||
|
|
||||||
logger.debug(
|
|
||||||
f"Removed {len(node_ids_to_remove)} nodes, {len(link_ids_to_remove)} links"
|
|
||||||
)
|
|
||||||
|
|
||||||
return agent
|
|
||||||
|
|
||||||
|
|
||||||
def _deep_update(target: dict, source: dict) -> None:
|
|
||||||
"""Recursively update a dict with another dict."""
|
|
||||||
for key, value in source.items():
|
|
||||||
if key in target and isinstance(target[key], dict) and isinstance(value, dict):
|
|
||||||
_deep_update(target[key], value)
|
|
||||||
else:
|
|
||||||
target[key] = value
|
|
||||||
|
|||||||
@@ -1,606 +0,0 @@
|
|||||||
"""Agent fixer - Fixes common LLM generation errors."""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import re
|
|
||||||
import uuid
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from .utils import (
|
|
||||||
ADDTODICTIONARY_BLOCK_ID,
|
|
||||||
ADDTOLIST_BLOCK_ID,
|
|
||||||
CODE_EXECUTION_BLOCK_ID,
|
|
||||||
CONDITION_BLOCK_ID,
|
|
||||||
CREATEDICT_BLOCK_ID,
|
|
||||||
CREATELIST_BLOCK_ID,
|
|
||||||
DATA_SAMPLING_BLOCK_ID,
|
|
||||||
DOUBLE_CURLY_BRACES_BLOCK_IDS,
|
|
||||||
GET_CURRENT_DATE_BLOCK_ID,
|
|
||||||
STORE_VALUE_BLOCK_ID,
|
|
||||||
UNIVERSAL_TYPE_CONVERTER_BLOCK_ID,
|
|
||||||
get_blocks_info,
|
|
||||||
is_valid_uuid,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
def fix_agent_ids(agent: dict[str, Any]) -> dict[str, Any]:
|
|
||||||
"""Fix invalid UUIDs in agent and link IDs."""
|
|
||||||
# Fix agent ID
|
|
||||||
if not is_valid_uuid(agent.get("id", "")):
|
|
||||||
agent["id"] = str(uuid.uuid4())
|
|
||||||
logger.debug(f"Fixed agent ID: {agent['id']}")
|
|
||||||
|
|
||||||
# Fix node IDs
|
|
||||||
id_mapping = {} # Old ID -> New ID
|
|
||||||
for node in agent.get("nodes", []):
|
|
||||||
if not is_valid_uuid(node.get("id", "")):
|
|
||||||
old_id = node.get("id", "")
|
|
||||||
new_id = str(uuid.uuid4())
|
|
||||||
id_mapping[old_id] = new_id
|
|
||||||
node["id"] = new_id
|
|
||||||
logger.debug(f"Fixed node ID: {old_id} -> {new_id}")
|
|
||||||
|
|
||||||
# Fix link IDs and update references
|
|
||||||
for link in agent.get("links", []):
|
|
||||||
if not is_valid_uuid(link.get("id", "")):
|
|
||||||
link["id"] = str(uuid.uuid4())
|
|
||||||
logger.debug(f"Fixed link ID: {link['id']}")
|
|
||||||
|
|
||||||
# Update source/sink IDs if they were remapped
|
|
||||||
if link.get("source_id") in id_mapping:
|
|
||||||
link["source_id"] = id_mapping[link["source_id"]]
|
|
||||||
if link.get("sink_id") in id_mapping:
|
|
||||||
link["sink_id"] = id_mapping[link["sink_id"]]
|
|
||||||
|
|
||||||
return agent
|
|
||||||
|
|
||||||
|
|
||||||
def fix_double_curly_braces(agent: dict[str, Any]) -> dict[str, Any]:
|
|
||||||
"""Fix single curly braces to double in template blocks."""
|
|
||||||
for node in agent.get("nodes", []):
|
|
||||||
if node.get("block_id") not in DOUBLE_CURLY_BRACES_BLOCK_IDS:
|
|
||||||
continue
|
|
||||||
|
|
||||||
input_data = node.get("input_default", {})
|
|
||||||
for key in ("prompt", "format"):
|
|
||||||
if key in input_data and isinstance(input_data[key], str):
|
|
||||||
original = input_data[key]
|
|
||||||
# Fix simple variable references: {var} -> {{var}}
|
|
||||||
fixed = re.sub(
|
|
||||||
r"(?<!\{)\{([a-zA-Z_][a-zA-Z0-9_]*)\}(?!\})",
|
|
||||||
r"{{\1}}",
|
|
||||||
original,
|
|
||||||
)
|
|
||||||
if fixed != original:
|
|
||||||
input_data[key] = fixed
|
|
||||||
logger.debug(f"Fixed curly braces in {key}")
|
|
||||||
|
|
||||||
return agent
|
|
||||||
|
|
||||||
|
|
||||||
def fix_storevalue_before_condition(agent: dict[str, Any]) -> dict[str, Any]:
|
|
||||||
"""Add StoreValueBlock before ConditionBlock if needed for value2."""
|
|
||||||
nodes = agent.get("nodes", [])
|
|
||||||
links = agent.get("links", [])
|
|
||||||
|
|
||||||
# Find all ConditionBlock nodes
|
|
||||||
condition_node_ids = {
|
|
||||||
node["id"] for node in nodes if node.get("block_id") == CONDITION_BLOCK_ID
|
|
||||||
}
|
|
||||||
|
|
||||||
if not condition_node_ids:
|
|
||||||
return agent
|
|
||||||
|
|
||||||
new_nodes = []
|
|
||||||
new_links = []
|
|
||||||
processed_conditions = set()
|
|
||||||
|
|
||||||
for link in links:
|
|
||||||
sink_id = link.get("sink_id")
|
|
||||||
sink_name = link.get("sink_name")
|
|
||||||
|
|
||||||
# Check if this link goes to a ConditionBlock's value2
|
|
||||||
if sink_id in condition_node_ids and sink_name == "value2":
|
|
||||||
source_node = next(
|
|
||||||
(n for n in nodes if n["id"] == link.get("source_id")), None
|
|
||||||
)
|
|
||||||
|
|
||||||
# Skip if source is already a StoreValueBlock
|
|
||||||
if source_node and source_node.get("block_id") == STORE_VALUE_BLOCK_ID:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Skip if we already processed this condition
|
|
||||||
if sink_id in processed_conditions:
|
|
||||||
continue
|
|
||||||
|
|
||||||
processed_conditions.add(sink_id)
|
|
||||||
|
|
||||||
# Create StoreValueBlock
|
|
||||||
store_node_id = str(uuid.uuid4())
|
|
||||||
store_node = {
|
|
||||||
"id": store_node_id,
|
|
||||||
"block_id": STORE_VALUE_BLOCK_ID,
|
|
||||||
"input_default": {"data": None},
|
|
||||||
"metadata": {"position": {"x": 0, "y": -100}},
|
|
||||||
}
|
|
||||||
new_nodes.append(store_node)
|
|
||||||
|
|
||||||
# Create link: original source -> StoreValueBlock
|
|
||||||
new_links.append(
|
|
||||||
{
|
|
||||||
"id": str(uuid.uuid4()),
|
|
||||||
"source_id": link["source_id"],
|
|
||||||
"source_name": link["source_name"],
|
|
||||||
"sink_id": store_node_id,
|
|
||||||
"sink_name": "input",
|
|
||||||
"is_static": False,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Update original link: StoreValueBlock -> ConditionBlock
|
|
||||||
link["source_id"] = store_node_id
|
|
||||||
link["source_name"] = "output"
|
|
||||||
|
|
||||||
logger.debug(f"Added StoreValueBlock before ConditionBlock {sink_id}")
|
|
||||||
|
|
||||||
if new_nodes:
|
|
||||||
agent["nodes"] = nodes + new_nodes
|
|
||||||
|
|
||||||
return agent
|
|
||||||
|
|
||||||
|
|
||||||
def fix_addtolist_blocks(agent: dict[str, Any]) -> dict[str, Any]:
|
|
||||||
"""Fix AddToList blocks by adding prerequisite empty AddToList block.
|
|
||||||
|
|
||||||
When an AddToList block is found:
|
|
||||||
1. Checks if there's a CreateListBlock before it
|
|
||||||
2. Removes CreateListBlock if linked directly to AddToList
|
|
||||||
3. Adds an empty AddToList block before the original
|
|
||||||
4. Ensures the original has a self-referencing link
|
|
||||||
"""
|
|
||||||
nodes = agent.get("nodes", [])
|
|
||||||
links = agent.get("links", [])
|
|
||||||
new_nodes = []
|
|
||||||
original_addtolist_ids = set()
|
|
||||||
nodes_to_remove = set()
|
|
||||||
links_to_remove = []
|
|
||||||
|
|
||||||
# First pass: identify CreateListBlock nodes to remove
|
|
||||||
for link in links:
|
|
||||||
source_node = next(
|
|
||||||
(n for n in nodes if n.get("id") == link.get("source_id")), None
|
|
||||||
)
|
|
||||||
sink_node = next((n for n in nodes if n.get("id") == link.get("sink_id")), None)
|
|
||||||
|
|
||||||
if (
|
|
||||||
source_node
|
|
||||||
and sink_node
|
|
||||||
and source_node.get("block_id") == CREATELIST_BLOCK_ID
|
|
||||||
and sink_node.get("block_id") == ADDTOLIST_BLOCK_ID
|
|
||||||
):
|
|
||||||
nodes_to_remove.add(source_node.get("id"))
|
|
||||||
links_to_remove.append(link)
|
|
||||||
logger.debug(f"Removing CreateListBlock {source_node.get('id')}")
|
|
||||||
|
|
||||||
# Second pass: process AddToList blocks
|
|
||||||
filtered_nodes = []
|
|
||||||
for node in nodes:
|
|
||||||
if node.get("id") in nodes_to_remove:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if node.get("block_id") == ADDTOLIST_BLOCK_ID:
|
|
||||||
original_addtolist_ids.add(node.get("id"))
|
|
||||||
node_id = node.get("id")
|
|
||||||
pos = node.get("metadata", {}).get("position", {"x": 0, "y": 0})
|
|
||||||
|
|
||||||
# Check if already has prerequisite
|
|
||||||
has_prereq = any(
|
|
||||||
link.get("sink_id") == node_id
|
|
||||||
and link.get("sink_name") == "list"
|
|
||||||
and link.get("source_name") == "updated_list"
|
|
||||||
for link in links
|
|
||||||
)
|
|
||||||
|
|
||||||
if not has_prereq:
|
|
||||||
# Remove links to "list" input (except self-reference)
|
|
||||||
for link in links:
|
|
||||||
if (
|
|
||||||
link.get("sink_id") == node_id
|
|
||||||
and link.get("sink_name") == "list"
|
|
||||||
and link.get("source_id") != node_id
|
|
||||||
and link not in links_to_remove
|
|
||||||
):
|
|
||||||
links_to_remove.append(link)
|
|
||||||
|
|
||||||
# Create prerequisite AddToList block
|
|
||||||
prereq_id = str(uuid.uuid4())
|
|
||||||
prereq_node = {
|
|
||||||
"id": prereq_id,
|
|
||||||
"block_id": ADDTOLIST_BLOCK_ID,
|
|
||||||
"input_default": {"list": [], "entry": None, "entries": []},
|
|
||||||
"metadata": {
|
|
||||||
"position": {"x": pos.get("x", 0) - 800, "y": pos.get("y", 0)}
|
|
||||||
},
|
|
||||||
}
|
|
||||||
new_nodes.append(prereq_node)
|
|
||||||
|
|
||||||
# Link prerequisite to original
|
|
||||||
links.append(
|
|
||||||
{
|
|
||||||
"id": str(uuid.uuid4()),
|
|
||||||
"source_id": prereq_id,
|
|
||||||
"source_name": "updated_list",
|
|
||||||
"sink_id": node_id,
|
|
||||||
"sink_name": "list",
|
|
||||||
"is_static": False,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
logger.debug(f"Added prerequisite AddToList block for {node_id}")
|
|
||||||
|
|
||||||
filtered_nodes.append(node)
|
|
||||||
|
|
||||||
# Remove marked links
|
|
||||||
filtered_links = [link for link in links if link not in links_to_remove]
|
|
||||||
|
|
||||||
# Add self-referencing links for original AddToList blocks
|
|
||||||
for node in filtered_nodes + new_nodes:
|
|
||||||
if (
|
|
||||||
node.get("block_id") == ADDTOLIST_BLOCK_ID
|
|
||||||
and node.get("id") in original_addtolist_ids
|
|
||||||
):
|
|
||||||
node_id = node.get("id")
|
|
||||||
has_self_ref = any(
|
|
||||||
link["source_id"] == node_id
|
|
||||||
and link["sink_id"] == node_id
|
|
||||||
and link["source_name"] == "updated_list"
|
|
||||||
and link["sink_name"] == "list"
|
|
||||||
for link in filtered_links
|
|
||||||
)
|
|
||||||
if not has_self_ref:
|
|
||||||
filtered_links.append(
|
|
||||||
{
|
|
||||||
"id": str(uuid.uuid4()),
|
|
||||||
"source_id": node_id,
|
|
||||||
"source_name": "updated_list",
|
|
||||||
"sink_id": node_id,
|
|
||||||
"sink_name": "list",
|
|
||||||
"is_static": False,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
logger.debug(f"Added self-reference for AddToList {node_id}")
|
|
||||||
|
|
||||||
agent["nodes"] = filtered_nodes + new_nodes
|
|
||||||
agent["links"] = filtered_links
|
|
||||||
return agent
|
|
||||||
|
|
||||||
|
|
||||||
def fix_addtodictionary_blocks(agent: dict[str, Any]) -> dict[str, Any]:
|
|
||||||
"""Fix AddToDictionary blocks by removing empty CreateDictionary nodes."""
|
|
||||||
nodes = agent.get("nodes", [])
|
|
||||||
links = agent.get("links", [])
|
|
||||||
nodes_to_remove = set()
|
|
||||||
links_to_remove = []
|
|
||||||
|
|
||||||
for link in links:
|
|
||||||
source_node = next(
|
|
||||||
(n for n in nodes if n.get("id") == link.get("source_id")), None
|
|
||||||
)
|
|
||||||
sink_node = next((n for n in nodes if n.get("id") == link.get("sink_id")), None)
|
|
||||||
|
|
||||||
if (
|
|
||||||
source_node
|
|
||||||
and sink_node
|
|
||||||
and source_node.get("block_id") == CREATEDICT_BLOCK_ID
|
|
||||||
and sink_node.get("block_id") == ADDTODICTIONARY_BLOCK_ID
|
|
||||||
):
|
|
||||||
nodes_to_remove.add(source_node.get("id"))
|
|
||||||
links_to_remove.append(link)
|
|
||||||
logger.debug(f"Removing CreateDictionary {source_node.get('id')}")
|
|
||||||
|
|
||||||
agent["nodes"] = [n for n in nodes if n.get("id") not in nodes_to_remove]
|
|
||||||
agent["links"] = [link for link in links if link not in links_to_remove]
|
|
||||||
return agent
|
|
||||||
|
|
||||||
|
|
||||||
def fix_code_execution_output(agent: dict[str, Any]) -> dict[str, Any]:
|
|
||||||
"""Fix CodeExecutionBlock output: change 'response' to 'stdout_logs'."""
|
|
||||||
nodes = agent.get("nodes", [])
|
|
||||||
links = agent.get("links", [])
|
|
||||||
|
|
||||||
for link in links:
|
|
||||||
source_node = next(
|
|
||||||
(n for n in nodes if n.get("id") == link.get("source_id")), None
|
|
||||||
)
|
|
||||||
if (
|
|
||||||
source_node
|
|
||||||
and source_node.get("block_id") == CODE_EXECUTION_BLOCK_ID
|
|
||||||
and link.get("source_name") == "response"
|
|
||||||
):
|
|
||||||
link["source_name"] = "stdout_logs"
|
|
||||||
logger.debug("Fixed CodeExecutionBlock output: response -> stdout_logs")
|
|
||||||
|
|
||||||
return agent
|
|
||||||
|
|
||||||
|
|
||||||
def fix_data_sampling_sample_size(agent: dict[str, Any]) -> dict[str, Any]:
|
|
||||||
"""Fix DataSamplingBlock by setting sample_size to 1 as default."""
|
|
||||||
nodes = agent.get("nodes", [])
|
|
||||||
links = agent.get("links", [])
|
|
||||||
links_to_remove = []
|
|
||||||
|
|
||||||
for node in nodes:
|
|
||||||
if node.get("block_id") == DATA_SAMPLING_BLOCK_ID:
|
|
||||||
node_id = node.get("id")
|
|
||||||
input_default = node.get("input_default", {})
|
|
||||||
|
|
||||||
# Remove links to sample_size
|
|
||||||
for link in links:
|
|
||||||
if (
|
|
||||||
link.get("sink_id") == node_id
|
|
||||||
and link.get("sink_name") == "sample_size"
|
|
||||||
):
|
|
||||||
links_to_remove.append(link)
|
|
||||||
|
|
||||||
# Set default
|
|
||||||
input_default["sample_size"] = 1
|
|
||||||
node["input_default"] = input_default
|
|
||||||
logger.debug(f"Fixed DataSamplingBlock {node_id} sample_size to 1")
|
|
||||||
|
|
||||||
if links_to_remove:
|
|
||||||
agent["links"] = [link for link in links if link not in links_to_remove]
|
|
||||||
|
|
||||||
return agent
|
|
||||||
|
|
||||||
|
|
||||||
def fix_node_x_coordinates(agent: dict[str, Any]) -> dict[str, Any]:
|
|
||||||
"""Fix node x-coordinates to ensure 800+ unit spacing between linked nodes."""
|
|
||||||
nodes = agent.get("nodes", [])
|
|
||||||
links = agent.get("links", [])
|
|
||||||
node_lookup = {n.get("id"): n for n in nodes}
|
|
||||||
|
|
||||||
for link in links:
|
|
||||||
source_id = link.get("source_id")
|
|
||||||
sink_id = link.get("sink_id")
|
|
||||||
|
|
||||||
source_node = node_lookup.get(source_id)
|
|
||||||
sink_node = node_lookup.get(sink_id)
|
|
||||||
|
|
||||||
if not source_node or not sink_node:
|
|
||||||
continue
|
|
||||||
|
|
||||||
source_pos = source_node.get("metadata", {}).get("position", {})
|
|
||||||
sink_pos = sink_node.get("metadata", {}).get("position", {})
|
|
||||||
|
|
||||||
source_x = source_pos.get("x", 0)
|
|
||||||
sink_x = sink_pos.get("x", 0)
|
|
||||||
|
|
||||||
if abs(sink_x - source_x) < 800:
|
|
||||||
new_x = source_x + 800
|
|
||||||
if "metadata" not in sink_node:
|
|
||||||
sink_node["metadata"] = {}
|
|
||||||
if "position" not in sink_node["metadata"]:
|
|
||||||
sink_node["metadata"]["position"] = {}
|
|
||||||
sink_node["metadata"]["position"]["x"] = new_x
|
|
||||||
logger.debug(f"Fixed node {sink_id} x: {sink_x} -> {new_x}")
|
|
||||||
|
|
||||||
return agent
|
|
||||||
|
|
||||||
|
|
||||||
def fix_getcurrentdate_offset(agent: dict[str, Any]) -> dict[str, Any]:
|
|
||||||
"""Fix GetCurrentDateBlock offset to ensure it's positive."""
|
|
||||||
for node in agent.get("nodes", []):
|
|
||||||
if node.get("block_id") == GET_CURRENT_DATE_BLOCK_ID:
|
|
||||||
input_default = node.get("input_default", {})
|
|
||||||
if "offset" in input_default:
|
|
||||||
offset = input_default["offset"]
|
|
||||||
if isinstance(offset, (int, float)) and offset < 0:
|
|
||||||
input_default["offset"] = abs(offset)
|
|
||||||
logger.debug(f"Fixed offset: {offset} -> {abs(offset)}")
|
|
||||||
|
|
||||||
return agent
|
|
||||||
|
|
||||||
|
|
||||||
def fix_ai_model_parameter(
|
|
||||||
agent: dict[str, Any],
|
|
||||||
blocks_info: list[dict[str, Any]],
|
|
||||||
default_model: str = "gpt-4o",
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
"""Add default model parameter to AI blocks if missing."""
|
|
||||||
block_map = {b.get("id"): b for b in blocks_info}
|
|
||||||
|
|
||||||
for node in agent.get("nodes", []):
|
|
||||||
block_id = node.get("block_id")
|
|
||||||
block = block_map.get(block_id)
|
|
||||||
|
|
||||||
if not block:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Check if block has AI category
|
|
||||||
categories = block.get("categories", [])
|
|
||||||
is_ai_block = any(
|
|
||||||
cat.get("category") == "AI" for cat in categories if isinstance(cat, dict)
|
|
||||||
)
|
|
||||||
|
|
||||||
if is_ai_block:
|
|
||||||
input_default = node.get("input_default", {})
|
|
||||||
if "model" not in input_default:
|
|
||||||
input_default["model"] = default_model
|
|
||||||
node["input_default"] = input_default
|
|
||||||
logger.debug(
|
|
||||||
f"Added model '{default_model}' to AI block {node.get('id')}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return agent
|
|
||||||
|
|
||||||
|
|
||||||
def fix_link_static_properties(
|
|
||||||
agent: dict[str, Any], blocks_info: list[dict[str, Any]]
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
"""Fix is_static property based on source block's staticOutput."""
|
|
||||||
block_map = {b.get("id"): b for b in blocks_info}
|
|
||||||
node_lookup = {n.get("id"): n for n in agent.get("nodes", [])}
|
|
||||||
|
|
||||||
for link in agent.get("links", []):
|
|
||||||
source_node = node_lookup.get(link.get("source_id"))
|
|
||||||
if not source_node:
|
|
||||||
continue
|
|
||||||
|
|
||||||
source_block = block_map.get(source_node.get("block_id"))
|
|
||||||
if not source_block:
|
|
||||||
continue
|
|
||||||
|
|
||||||
static_output = source_block.get("staticOutput", False)
|
|
||||||
if link.get("is_static") != static_output:
|
|
||||||
link["is_static"] = static_output
|
|
||||||
logger.debug(f"Fixed link {link.get('id')} is_static to {static_output}")
|
|
||||||
|
|
||||||
return agent
|
|
||||||
|
|
||||||
|
|
||||||
def fix_data_type_mismatch(
|
|
||||||
agent: dict[str, Any], blocks_info: list[dict[str, Any]]
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
"""Fix data type mismatches by inserting UniversalTypeConverterBlock."""
|
|
||||||
nodes = agent.get("nodes", [])
|
|
||||||
links = agent.get("links", [])
|
|
||||||
block_map = {b.get("id"): b for b in blocks_info}
|
|
||||||
node_lookup = {n.get("id"): n for n in nodes}
|
|
||||||
|
|
||||||
def get_property_type(schema: dict, name: str) -> str | None:
|
|
||||||
if "_#_" in name:
|
|
||||||
parent, child = name.split("_#_", 1)
|
|
||||||
parent_schema = schema.get(parent, {})
|
|
||||||
if "properties" in parent_schema:
|
|
||||||
return parent_schema["properties"].get(child, {}).get("type")
|
|
||||||
return None
|
|
||||||
return schema.get(name, {}).get("type")
|
|
||||||
|
|
||||||
def are_types_compatible(src: str, sink: str) -> bool:
|
|
||||||
if {src, sink} <= {"integer", "number"}:
|
|
||||||
return True
|
|
||||||
return src == sink
|
|
||||||
|
|
||||||
type_mapping = {
|
|
||||||
"string": "string",
|
|
||||||
"text": "string",
|
|
||||||
"integer": "number",
|
|
||||||
"number": "number",
|
|
||||||
"float": "number",
|
|
||||||
"boolean": "boolean",
|
|
||||||
"bool": "boolean",
|
|
||||||
"array": "list",
|
|
||||||
"list": "list",
|
|
||||||
"object": "dictionary",
|
|
||||||
"dict": "dictionary",
|
|
||||||
"dictionary": "dictionary",
|
|
||||||
}
|
|
||||||
|
|
||||||
new_links = []
|
|
||||||
nodes_to_add = []
|
|
||||||
|
|
||||||
for link in links:
|
|
||||||
source_node = node_lookup.get(link.get("source_id"))
|
|
||||||
sink_node = node_lookup.get(link.get("sink_id"))
|
|
||||||
|
|
||||||
if not source_node or not sink_node:
|
|
||||||
new_links.append(link)
|
|
||||||
continue
|
|
||||||
|
|
||||||
source_block = block_map.get(source_node.get("block_id"))
|
|
||||||
sink_block = block_map.get(sink_node.get("block_id"))
|
|
||||||
|
|
||||||
if not source_block or not sink_block:
|
|
||||||
new_links.append(link)
|
|
||||||
continue
|
|
||||||
|
|
||||||
source_outputs = source_block.get("outputSchema", {}).get("properties", {})
|
|
||||||
sink_inputs = sink_block.get("inputSchema", {}).get("properties", {})
|
|
||||||
|
|
||||||
source_type = get_property_type(source_outputs, link.get("source_name", ""))
|
|
||||||
sink_type = get_property_type(sink_inputs, link.get("sink_name", ""))
|
|
||||||
|
|
||||||
if (
|
|
||||||
source_type
|
|
||||||
and sink_type
|
|
||||||
and not are_types_compatible(source_type, sink_type)
|
|
||||||
):
|
|
||||||
# Insert type converter
|
|
||||||
converter_id = str(uuid.uuid4())
|
|
||||||
target_type = type_mapping.get(sink_type, sink_type)
|
|
||||||
|
|
||||||
converter_node = {
|
|
||||||
"id": converter_id,
|
|
||||||
"block_id": UNIVERSAL_TYPE_CONVERTER_BLOCK_ID,
|
|
||||||
"input_default": {"type": target_type},
|
|
||||||
"metadata": {"position": {"x": 0, "y": 100}},
|
|
||||||
}
|
|
||||||
nodes_to_add.append(converter_node)
|
|
||||||
|
|
||||||
# source -> converter
|
|
||||||
new_links.append(
|
|
||||||
{
|
|
||||||
"id": str(uuid.uuid4()),
|
|
||||||
"source_id": link["source_id"],
|
|
||||||
"source_name": link["source_name"],
|
|
||||||
"sink_id": converter_id,
|
|
||||||
"sink_name": "value",
|
|
||||||
"is_static": False,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# converter -> sink
|
|
||||||
new_links.append(
|
|
||||||
{
|
|
||||||
"id": str(uuid.uuid4()),
|
|
||||||
"source_id": converter_id,
|
|
||||||
"source_name": "value",
|
|
||||||
"sink_id": link["sink_id"],
|
|
||||||
"sink_name": link["sink_name"],
|
|
||||||
"is_static": False,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.debug(f"Inserted type converter: {source_type} -> {target_type}")
|
|
||||||
else:
|
|
||||||
new_links.append(link)
|
|
||||||
|
|
||||||
if nodes_to_add:
|
|
||||||
agent["nodes"] = nodes + nodes_to_add
|
|
||||||
agent["links"] = new_links
|
|
||||||
|
|
||||||
return agent
|
|
||||||
|
|
||||||
|
|
||||||
def apply_all_fixes(
|
|
||||||
agent: dict[str, Any], blocks_info: list[dict[str, Any]] | None = None
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
"""Apply all fixes to an agent JSON.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
agent: Agent JSON dict
|
|
||||||
blocks_info: Optional list of block info dicts for advanced fixes
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Fixed agent JSON
|
|
||||||
"""
|
|
||||||
# Basic fixes (no block info needed)
|
|
||||||
agent = fix_agent_ids(agent)
|
|
||||||
agent = fix_double_curly_braces(agent)
|
|
||||||
agent = fix_storevalue_before_condition(agent)
|
|
||||||
agent = fix_addtolist_blocks(agent)
|
|
||||||
agent = fix_addtodictionary_blocks(agent)
|
|
||||||
agent = fix_code_execution_output(agent)
|
|
||||||
agent = fix_data_sampling_sample_size(agent)
|
|
||||||
agent = fix_node_x_coordinates(agent)
|
|
||||||
agent = fix_getcurrentdate_offset(agent)
|
|
||||||
|
|
||||||
# Advanced fixes (require block info)
|
|
||||||
if blocks_info is None:
|
|
||||||
blocks_info = get_blocks_info()
|
|
||||||
|
|
||||||
agent = fix_ai_model_parameter(agent, blocks_info)
|
|
||||||
agent = fix_link_static_properties(agent, blocks_info)
|
|
||||||
agent = fix_data_type_mismatch(agent, blocks_info)
|
|
||||||
|
|
||||||
return agent
|
|
||||||
@@ -1,225 +0,0 @@
|
|||||||
"""Prompt templates for agent generation."""
|
|
||||||
|
|
||||||
DECOMPOSITION_PROMPT = """
|
|
||||||
You are an expert AutoGPT Workflow Decomposer. Your task is to analyze a user's high-level goal and break it down into a clear, step-by-step plan using the available blocks.
|
|
||||||
|
|
||||||
Each step should represent a distinct, automatable action suitable for execution by an AI automation system.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
FIRST: Analyze the user's goal and determine:
|
|
||||||
1) Design-time configuration (fixed settings that won't change per run)
|
|
||||||
2) Runtime inputs (values the agent's end-user will provide each time it runs)
|
|
||||||
|
|
||||||
For anything that can vary per run (email addresses, names, dates, search terms, etc.):
|
|
||||||
- DO NOT ask for the actual value
|
|
||||||
- Instead, define it as an Agent Input with a clear name, type, and description
|
|
||||||
|
|
||||||
Only ask clarifying questions about design-time config that affects how you build the workflow:
|
|
||||||
- Which external service to use (e.g., "Gmail vs Outlook", "Notion vs Google Docs")
|
|
||||||
- Required formats or structures (e.g., "CSV, JSON, or PDF output?")
|
|
||||||
- Business rules that must be hard-coded
|
|
||||||
|
|
||||||
IMPORTANT CLARIFICATIONS POLICY:
|
|
||||||
- Ask no more than five essential questions
|
|
||||||
- Do not ask for concrete values that can be provided at runtime as Agent Inputs
|
|
||||||
- Do not ask for API keys or credentials; the platform handles those directly
|
|
||||||
- If there is enough information to infer reasonable defaults, prefer to propose defaults
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
GUIDELINES:
|
|
||||||
1. List each step as a numbered item
|
|
||||||
2. Describe the action clearly and specify inputs/outputs
|
|
||||||
3. Ensure steps are in logical, sequential order
|
|
||||||
4. Mention block names naturally (e.g., "Use GetWeatherByLocationBlock to...")
|
|
||||||
5. Help the user reach their goal efficiently
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
RULES:
|
|
||||||
1. OUTPUT FORMAT: Only output either clarifying questions OR step-by-step instructions, not both
|
|
||||||
2. USE ONLY THE BLOCKS PROVIDED
|
|
||||||
3. ALL required_input fields must be provided
|
|
||||||
4. Data types of linked properties must match
|
|
||||||
5. Write expert-level prompts for AI-related blocks
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
CRITICAL BLOCK RESTRICTIONS:
|
|
||||||
1. AddToListBlock: Outputs updated list EVERY addition, not after all additions
|
|
||||||
2. SendEmailBlock: Draft the email for user review; set SMTP config based on email type
|
|
||||||
3. ConditionBlock: value2 is reference, value1 is contrast
|
|
||||||
4. CodeExecutionBlock: DO NOT USE - use AI blocks instead
|
|
||||||
5. ReadCsvBlock: Only use the 'rows' output, not 'row'
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
OUTPUT FORMAT:
|
|
||||||
|
|
||||||
If more information is needed:
|
|
||||||
```json
|
|
||||||
{{
|
|
||||||
"type": "clarifying_questions",
|
|
||||||
"questions": [
|
|
||||||
{{
|
|
||||||
"question": "Which email provider should be used? (Gmail, Outlook, custom SMTP)",
|
|
||||||
"keyword": "email_provider",
|
|
||||||
"example": "Gmail"
|
|
||||||
}}
|
|
||||||
]
|
|
||||||
}}
|
|
||||||
```
|
|
||||||
|
|
||||||
If ready to proceed:
|
|
||||||
```json
|
|
||||||
{{
|
|
||||||
"type": "instructions",
|
|
||||||
"steps": [
|
|
||||||
{{
|
|
||||||
"step_number": 1,
|
|
||||||
"block_name": "AgentShortTextInputBlock",
|
|
||||||
"description": "Get the URL of the content to analyze.",
|
|
||||||
"inputs": [{{"name": "name", "value": "URL"}}],
|
|
||||||
"outputs": [{{"name": "result", "description": "The URL entered by user"}}]
|
|
||||||
}}
|
|
||||||
]
|
|
||||||
}}
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
AVAILABLE BLOCKS:
|
|
||||||
{block_summaries}
|
|
||||||
"""
|
|
||||||
|
|
||||||
GENERATION_PROMPT = """
|
|
||||||
You are an expert AI workflow builder. Generate a valid agent JSON from the given instructions.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
NODES:
|
|
||||||
Each node must include:
|
|
||||||
- `id`: Unique UUID v4 (e.g. `a8f5b1e2-c3d4-4e5f-8a9b-0c1d2e3f4a5b`)
|
|
||||||
- `block_id`: The block identifier (must match an Allowed Block)
|
|
||||||
- `input_default`: Dict of inputs (can be empty if no static inputs needed)
|
|
||||||
- `metadata`: Must contain:
|
|
||||||
- `position`: {{"x": number, "y": number}} - adjacent nodes should differ by 800+ in X
|
|
||||||
- `customized_name`: Clear name describing this block's purpose in the workflow
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
LINKS:
|
|
||||||
Each link connects a source node's output to a sink node's input:
|
|
||||||
- `id`: MUST be UUID v4 (NOT "link-1", "link-2", etc.)
|
|
||||||
- `source_id`: ID of the source node
|
|
||||||
- `source_name`: Output field name from the source block
|
|
||||||
- `sink_id`: ID of the sink node
|
|
||||||
- `sink_name`: Input field name on the sink block
|
|
||||||
- `is_static`: true only if source block has static_output: true
|
|
||||||
|
|
||||||
CRITICAL: All IDs must be valid UUID v4 format!
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
AGENT (GRAPH):
|
|
||||||
Wrap nodes and links in:
|
|
||||||
- `id`: UUID of the agent
|
|
||||||
- `name`: Short, generic name (avoid specific company names, URLs)
|
|
||||||
- `description`: Short, generic description
|
|
||||||
- `nodes`: List of all nodes
|
|
||||||
- `links`: List of all links
|
|
||||||
- `version`: 1
|
|
||||||
- `is_active`: true
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
TIPS:
|
|
||||||
- All required_input fields must be provided via input_default or a valid link
|
|
||||||
- Ensure consistent source_id and sink_id references
|
|
||||||
- Avoid dangling links
|
|
||||||
- Input/output pins must match block schemas
|
|
||||||
- Do not invent unknown block_ids
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
ALLOWED BLOCKS:
|
|
||||||
{block_summaries}
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
Generate the complete agent JSON. Output ONLY valid JSON, no explanation.
|
|
||||||
"""
|
|
||||||
|
|
||||||
PATCH_PROMPT = """
|
|
||||||
You are an expert at modifying AutoGPT agent workflows. Given the current agent and a modification request, generate a JSON patch to update the agent.
|
|
||||||
|
|
||||||
CURRENT AGENT:
|
|
||||||
{current_agent}
|
|
||||||
|
|
||||||
AVAILABLE BLOCKS:
|
|
||||||
{block_summaries}
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
PATCH FORMAT:
|
|
||||||
Return a JSON object with the following structure:
|
|
||||||
|
|
||||||
```json
|
|
||||||
{{
|
|
||||||
"type": "patch",
|
|
||||||
"intent": "Brief description of what the patch does",
|
|
||||||
"patches": [
|
|
||||||
{{
|
|
||||||
"type": "modify",
|
|
||||||
"node_id": "uuid-of-node-to-modify",
|
|
||||||
"changes": {{
|
|
||||||
"input_default": {{"field": "new_value"}},
|
|
||||||
"metadata": {{"customized_name": "New Name"}}
|
|
||||||
}}
|
|
||||||
}},
|
|
||||||
{{
|
|
||||||
"type": "add",
|
|
||||||
"new_nodes": [
|
|
||||||
{{
|
|
||||||
"id": "new-uuid",
|
|
||||||
"block_id": "block-uuid",
|
|
||||||
"input_default": {{}},
|
|
||||||
"metadata": {{"position": {{"x": 0, "y": 0}}, "customized_name": "Name"}}
|
|
||||||
}}
|
|
||||||
],
|
|
||||||
"new_links": [
|
|
||||||
{{
|
|
||||||
"id": "link-uuid",
|
|
||||||
"source_id": "source-node-id",
|
|
||||||
"source_name": "output_field",
|
|
||||||
"sink_id": "sink-node-id",
|
|
||||||
"sink_name": "input_field"
|
|
||||||
}}
|
|
||||||
]
|
|
||||||
}},
|
|
||||||
{{
|
|
||||||
"type": "remove",
|
|
||||||
"node_ids": ["uuid-of-node-to-remove"],
|
|
||||||
"link_ids": ["uuid-of-link-to-remove"]
|
|
||||||
}}
|
|
||||||
]
|
|
||||||
}}
|
|
||||||
```
|
|
||||||
|
|
||||||
If you need more information, return:
|
|
||||||
```json
|
|
||||||
{{
|
|
||||||
"type": "clarifying_questions",
|
|
||||||
"questions": [
|
|
||||||
{{
|
|
||||||
"question": "What specific change do you want?",
|
|
||||||
"keyword": "change_type",
|
|
||||||
"example": "Add error handling"
|
|
||||||
}}
|
|
||||||
]
|
|
||||||
}}
|
|
||||||
```
|
|
||||||
|
|
||||||
Generate the minimal patch needed. Output ONLY valid JSON.
|
|
||||||
"""
|
|
||||||
@@ -0,0 +1,269 @@
|
|||||||
|
"""External Agent Generator service client.
|
||||||
|
|
||||||
|
This module provides a client for communicating with the external Agent Generator
|
||||||
|
microservice. When AGENTGENERATOR_HOST is configured, the agent generation functions
|
||||||
|
will delegate to the external service instead of using the built-in LLM-based implementation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from backend.util.settings import Settings
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_client: httpx.AsyncClient | None = None
|
||||||
|
_settings: Settings | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def _get_settings() -> Settings:
|
||||||
|
"""Get or create settings singleton."""
|
||||||
|
global _settings
|
||||||
|
if _settings is None:
|
||||||
|
_settings = Settings()
|
||||||
|
return _settings
|
||||||
|
|
||||||
|
|
||||||
|
def is_external_service_configured() -> bool:
|
||||||
|
"""Check if external Agent Generator service is configured."""
|
||||||
|
settings = _get_settings()
|
||||||
|
return bool(settings.config.agentgenerator_host)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_base_url() -> str:
|
||||||
|
"""Get the base URL for the external service."""
|
||||||
|
settings = _get_settings()
|
||||||
|
host = settings.config.agentgenerator_host
|
||||||
|
port = settings.config.agentgenerator_port
|
||||||
|
return f"http://{host}:{port}"
|
||||||
|
|
||||||
|
|
||||||
|
def _get_client() -> httpx.AsyncClient:
|
||||||
|
"""Get or create the HTTP client for the external service."""
|
||||||
|
global _client
|
||||||
|
if _client is None:
|
||||||
|
settings = _get_settings()
|
||||||
|
_client = httpx.AsyncClient(
|
||||||
|
base_url=_get_base_url(),
|
||||||
|
timeout=httpx.Timeout(settings.config.agentgenerator_timeout),
|
||||||
|
)
|
||||||
|
return _client
|
||||||
|
|
||||||
|
|
||||||
|
async def decompose_goal_external(
|
||||||
|
description: str, context: str = ""
|
||||||
|
) -> dict[str, Any] | None:
|
||||||
|
"""Call the external service to decompose a goal.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
description: Natural language goal description
|
||||||
|
context: Additional context (e.g., answers to previous questions)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict with either:
|
||||||
|
- {"type": "clarifying_questions", "questions": [...]}
|
||||||
|
- {"type": "instructions", "steps": [...]}
|
||||||
|
- {"type": "unachievable_goal", ...}
|
||||||
|
- {"type": "vague_goal", ...}
|
||||||
|
Or None on error
|
||||||
|
"""
|
||||||
|
client = _get_client()
|
||||||
|
|
||||||
|
# Build the request payload
|
||||||
|
payload: dict[str, Any] = {"description": description}
|
||||||
|
if context:
|
||||||
|
# The external service uses user_instruction for additional context
|
||||||
|
payload["user_instruction"] = context
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await client.post("/api/decompose-description", json=payload)
|
||||||
|
response.raise_for_status()
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
if not data.get("success"):
|
||||||
|
logger.error(f"External service returned error: {data.get('error')}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Map the response to the expected format
|
||||||
|
response_type = data.get("type")
|
||||||
|
if response_type == "instructions":
|
||||||
|
return {"type": "instructions", "steps": data.get("steps", [])}
|
||||||
|
elif response_type == "clarifying_questions":
|
||||||
|
return {
|
||||||
|
"type": "clarifying_questions",
|
||||||
|
"questions": data.get("questions", []),
|
||||||
|
}
|
||||||
|
elif response_type == "unachievable_goal":
|
||||||
|
return {
|
||||||
|
"type": "unachievable_goal",
|
||||||
|
"reason": data.get("reason"),
|
||||||
|
"suggested_goal": data.get("suggested_goal"),
|
||||||
|
}
|
||||||
|
elif response_type == "vague_goal":
|
||||||
|
return {
|
||||||
|
"type": "vague_goal",
|
||||||
|
"suggested_goal": data.get("suggested_goal"),
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
logger.error(
|
||||||
|
f"Unknown response type from external service: {response_type}"
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
except httpx.HTTPStatusError as e:
|
||||||
|
logger.error(f"HTTP error calling external agent generator: {e}")
|
||||||
|
return None
|
||||||
|
except httpx.RequestError as e:
|
||||||
|
logger.error(f"Request error calling external agent generator: {e}")
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Unexpected error calling external agent generator: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def generate_agent_external(
|
||||||
|
instructions: dict[str, Any]
|
||||||
|
) -> dict[str, Any] | None:
|
||||||
|
"""Call the external service to generate an agent from instructions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
instructions: Structured instructions from decompose_goal
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Agent JSON dict or None on error
|
||||||
|
"""
|
||||||
|
client = _get_client()
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await client.post(
|
||||||
|
"/api/generate-agent", json={"instructions": instructions}
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
if not data.get("success"):
|
||||||
|
logger.error(f"External service returned error: {data.get('error')}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
return data.get("agent_json")
|
||||||
|
|
||||||
|
except httpx.HTTPStatusError as e:
|
||||||
|
logger.error(f"HTTP error calling external agent generator: {e}")
|
||||||
|
return None
|
||||||
|
except httpx.RequestError as e:
|
||||||
|
logger.error(f"Request error calling external agent generator: {e}")
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Unexpected error calling external agent generator: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def generate_agent_patch_external(
|
||||||
|
update_request: str, current_agent: dict[str, Any]
|
||||||
|
) -> dict[str, Any] | None:
|
||||||
|
"""Call the external service to generate a patch for an existing agent.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
update_request: Natural language description of changes
|
||||||
|
current_agent: Current agent JSON
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Updated agent JSON, clarifying questions dict, or None on error
|
||||||
|
"""
|
||||||
|
client = _get_client()
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await client.post(
|
||||||
|
"/api/update-agent",
|
||||||
|
json={
|
||||||
|
"update_request": update_request,
|
||||||
|
"current_agent_json": current_agent,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
if not data.get("success"):
|
||||||
|
logger.error(f"External service returned error: {data.get('error')}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Check if it's clarifying questions
|
||||||
|
if data.get("type") == "clarifying_questions":
|
||||||
|
return {
|
||||||
|
"type": "clarifying_questions",
|
||||||
|
"questions": data.get("questions", []),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Otherwise return the updated agent JSON
|
||||||
|
return data.get("agent_json")
|
||||||
|
|
||||||
|
except httpx.HTTPStatusError as e:
|
||||||
|
logger.error(f"HTTP error calling external agent generator: {e}")
|
||||||
|
return None
|
||||||
|
except httpx.RequestError as e:
|
||||||
|
logger.error(f"Request error calling external agent generator: {e}")
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Unexpected error calling external agent generator: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def get_blocks_external() -> list[dict[str, Any]] | None:
|
||||||
|
"""Get available blocks from the external service.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of block info dicts or None on error
|
||||||
|
"""
|
||||||
|
client = _get_client()
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await client.get("/api/blocks")
|
||||||
|
response.raise_for_status()
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
if not data.get("success"):
|
||||||
|
logger.error("External service returned error getting blocks")
|
||||||
|
return None
|
||||||
|
|
||||||
|
return data.get("blocks", [])
|
||||||
|
|
||||||
|
except httpx.HTTPStatusError as e:
|
||||||
|
logger.error(f"HTTP error getting blocks from external service: {e}")
|
||||||
|
return None
|
||||||
|
except httpx.RequestError as e:
|
||||||
|
logger.error(f"Request error getting blocks from external service: {e}")
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Unexpected error getting blocks from external service: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def health_check() -> bool:
|
||||||
|
"""Check if the external service is healthy.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if healthy, False otherwise
|
||||||
|
"""
|
||||||
|
if not is_external_service_configured():
|
||||||
|
return False
|
||||||
|
|
||||||
|
client = _get_client()
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await client.get("/health")
|
||||||
|
response.raise_for_status()
|
||||||
|
data = response.json()
|
||||||
|
return data.get("status") == "healthy" and data.get("blocks_loaded", False)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"External agent generator health check failed: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
async def close_client() -> None:
|
||||||
|
"""Close the HTTP client."""
|
||||||
|
global _client
|
||||||
|
if _client is not None:
|
||||||
|
await _client.aclose()
|
||||||
|
_client = None
|
||||||
@@ -1,213 +0,0 @@
|
|||||||
"""Utilities for agent generation."""
|
|
||||||
|
|
||||||
import json
|
|
||||||
import re
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from backend.data.block import get_blocks
|
|
||||||
|
|
||||||
# UUID validation regex
|
|
||||||
UUID_REGEX = re.compile(
|
|
||||||
r"^[a-f0-9]{8}-[a-f0-9]{4}-4[a-f0-9]{3}-[89ab][a-f0-9]{3}-[a-f0-9]{12}$"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Block IDs for various fixes
|
|
||||||
STORE_VALUE_BLOCK_ID = "1ff065e9-88e8-4358-9d82-8dc91f622ba9"
|
|
||||||
CONDITION_BLOCK_ID = "715696a0-e1da-45c8-b209-c2fa9c3b0be6"
|
|
||||||
ADDTOLIST_BLOCK_ID = "aeb08fc1-2fc1-4141-bc8e-f758f183a822"
|
|
||||||
ADDTODICTIONARY_BLOCK_ID = "31d1064e-7446-4693-a7d4-65e5ca1180d1"
|
|
||||||
CREATELIST_BLOCK_ID = "a912d5c7-6e00-4542-b2a9-8034136930e4"
|
|
||||||
CREATEDICT_BLOCK_ID = "b924ddf4-de4f-4b56-9a85-358930dcbc91"
|
|
||||||
CODE_EXECUTION_BLOCK_ID = "0b02b072-abe7-11ef-8372-fb5d162dd712"
|
|
||||||
DATA_SAMPLING_BLOCK_ID = "4a448883-71fa-49cf-91cf-70d793bd7d87"
|
|
||||||
UNIVERSAL_TYPE_CONVERTER_BLOCK_ID = "95d1b990-ce13-4d88-9737-ba5c2070c97b"
|
|
||||||
GET_CURRENT_DATE_BLOCK_ID = "b29c1b50-5d0e-4d9f-8f9d-1b0e6fcbf0b1"
|
|
||||||
|
|
||||||
DOUBLE_CURLY_BRACES_BLOCK_IDS = [
|
|
||||||
"44f6c8ad-d75c-4ae1-8209-aad1c0326928", # FillTextTemplateBlock
|
|
||||||
"6ab085e2-20b3-4055-bc3e-08036e01eca6",
|
|
||||||
"90f8c45e-e983-4644-aa0b-b4ebe2f531bc",
|
|
||||||
"363ae599-353e-4804-937e-b2ee3cef3da4", # AgentOutputBlock
|
|
||||||
"3b191d9f-356f-482d-8238-ba04b6d18381",
|
|
||||||
"db7d8f02-2f44-4c55-ab7a-eae0941f0c30",
|
|
||||||
"3a7c4b8d-6e2f-4a5d-b9c1-f8d23c5a9b0e",
|
|
||||||
"ed1ae7a0-b770-4089-b520-1f0005fad19a",
|
|
||||||
"a892b8d9-3e4e-4e9c-9c1e-75f8efcf1bfa",
|
|
||||||
"b29c1b50-5d0e-4d9f-8f9d-1b0e6fcbf0b1",
|
|
||||||
"716a67b3-6760-42e7-86dc-18645c6e00fc",
|
|
||||||
"530cf046-2ce0-4854-ae2c-659db17c7a46",
|
|
||||||
"ed55ac19-356e-4243-a6cb-bc599e9b716f",
|
|
||||||
"1f292d4a-41a4-4977-9684-7c8d560b9f91", # LLM blocks
|
|
||||||
"32a87eab-381e-4dd4-bdb8-4c47151be35a",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def is_valid_uuid(value: str) -> bool:
|
|
||||||
"""Check if a string is a valid UUID v4."""
|
|
||||||
return isinstance(value, str) and UUID_REGEX.match(value) is not None
|
|
||||||
|
|
||||||
|
|
||||||
def _compact_schema(schema: dict) -> dict[str, str]:
|
|
||||||
"""Extract compact type info from a JSON schema properties dict.
|
|
||||||
|
|
||||||
Returns a dict of {field_name: type_string} for essential info only.
|
|
||||||
"""
|
|
||||||
props = schema.get("properties", {})
|
|
||||||
result = {}
|
|
||||||
|
|
||||||
for name, prop in props.items():
|
|
||||||
# Skip internal/complex fields
|
|
||||||
if name.startswith("_"):
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Get type string
|
|
||||||
type_str = prop.get("type", "any")
|
|
||||||
|
|
||||||
# Handle anyOf/oneOf (optional types)
|
|
||||||
if "anyOf" in prop:
|
|
||||||
types = [t.get("type", "?") for t in prop["anyOf"] if t.get("type")]
|
|
||||||
type_str = "|".join(types) if types else "any"
|
|
||||||
elif "allOf" in prop:
|
|
||||||
type_str = "object"
|
|
||||||
|
|
||||||
# Add array item type if present
|
|
||||||
if type_str == "array" and "items" in prop:
|
|
||||||
items = prop["items"]
|
|
||||||
if isinstance(items, dict):
|
|
||||||
item_type = items.get("type", "any")
|
|
||||||
type_str = f"array[{item_type}]"
|
|
||||||
|
|
||||||
result[name] = type_str
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
def get_block_summaries(include_schemas: bool = True) -> str:
|
|
||||||
"""Generate compact block summaries for prompts.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
include_schemas: Whether to include input/output type info
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Formatted string of block summaries (compact format)
|
|
||||||
"""
|
|
||||||
blocks = get_blocks()
|
|
||||||
summaries = []
|
|
||||||
|
|
||||||
for block_id, block_cls in blocks.items():
|
|
||||||
block = block_cls()
|
|
||||||
name = block.name
|
|
||||||
desc = getattr(block, "description", "") or ""
|
|
||||||
|
|
||||||
# Truncate description
|
|
||||||
if len(desc) > 150:
|
|
||||||
desc = desc[:147] + "..."
|
|
||||||
|
|
||||||
if not include_schemas:
|
|
||||||
summaries.append(f"- {name} (id: {block_id}): {desc}")
|
|
||||||
else:
|
|
||||||
# Compact format with type info only
|
|
||||||
inputs = {}
|
|
||||||
outputs = {}
|
|
||||||
required = []
|
|
||||||
|
|
||||||
if hasattr(block, "input_schema"):
|
|
||||||
try:
|
|
||||||
schema = block.input_schema.jsonschema()
|
|
||||||
inputs = _compact_schema(schema)
|
|
||||||
required = schema.get("required", [])
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
if hasattr(block, "output_schema"):
|
|
||||||
try:
|
|
||||||
schema = block.output_schema.jsonschema()
|
|
||||||
outputs = _compact_schema(schema)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# Build compact line format
|
|
||||||
# Format: NAME (id): desc | in: {field:type, ...} [required] | out: {field:type}
|
|
||||||
in_str = ", ".join(f"{k}:{v}" for k, v in inputs.items())
|
|
||||||
out_str = ", ".join(f"{k}:{v}" for k, v in outputs.items())
|
|
||||||
req_str = f" req=[{','.join(required)}]" if required else ""
|
|
||||||
|
|
||||||
static = " [static]" if getattr(block, "static_output", False) else ""
|
|
||||||
|
|
||||||
line = f"- {name} (id: {block_id}): {desc}"
|
|
||||||
if in_str:
|
|
||||||
line += f"\n in: {{{in_str}}}{req_str}"
|
|
||||||
if out_str:
|
|
||||||
line += f"\n out: {{{out_str}}}{static}"
|
|
||||||
|
|
||||||
summaries.append(line)
|
|
||||||
|
|
||||||
return "\n".join(summaries)
|
|
||||||
|
|
||||||
|
|
||||||
def get_blocks_info() -> list[dict[str, Any]]:
|
|
||||||
"""Get block information with schemas for validation and fixing."""
|
|
||||||
blocks = get_blocks()
|
|
||||||
blocks_info = []
|
|
||||||
for block_id, block_cls in blocks.items():
|
|
||||||
block = block_cls()
|
|
||||||
blocks_info.append(
|
|
||||||
{
|
|
||||||
"id": block_id,
|
|
||||||
"name": block.name,
|
|
||||||
"description": getattr(block, "description", ""),
|
|
||||||
"categories": getattr(block, "categories", []),
|
|
||||||
"staticOutput": getattr(block, "static_output", False),
|
|
||||||
"inputSchema": (
|
|
||||||
block.input_schema.jsonschema()
|
|
||||||
if hasattr(block, "input_schema")
|
|
||||||
else {}
|
|
||||||
),
|
|
||||||
"outputSchema": (
|
|
||||||
block.output_schema.jsonschema()
|
|
||||||
if hasattr(block, "output_schema")
|
|
||||||
else {}
|
|
||||||
),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
return blocks_info
|
|
||||||
|
|
||||||
|
|
||||||
def parse_json_from_llm(text: str) -> dict[str, Any] | None:
|
|
||||||
"""Extract JSON from LLM response (handles markdown code blocks)."""
|
|
||||||
if not text:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Try fenced code block
|
|
||||||
match = re.search(r"```(?:json)?\s*([\s\S]*?)```", text, re.IGNORECASE)
|
|
||||||
if match:
|
|
||||||
try:
|
|
||||||
return json.loads(match.group(1).strip())
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# Try raw text
|
|
||||||
try:
|
|
||||||
return json.loads(text.strip())
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# Try finding {...} span
|
|
||||||
start = text.find("{")
|
|
||||||
end = text.rfind("}")
|
|
||||||
if start != -1 and end > start:
|
|
||||||
try:
|
|
||||||
return json.loads(text[start : end + 1])
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# Try finding [...] span
|
|
||||||
start = text.find("[")
|
|
||||||
end = text.rfind("]")
|
|
||||||
if start != -1 and end > start:
|
|
||||||
try:
|
|
||||||
return json.loads(text[start : end + 1])
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
return None
|
|
||||||
@@ -1,279 +0,0 @@
|
|||||||
"""Agent validator - Validates agent structure and connections."""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import re
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from .utils import get_blocks_info
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class AgentValidator:
|
|
||||||
"""Validator for AutoGPT agents with detailed error reporting."""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.errors: list[str] = []
|
|
||||||
|
|
||||||
def add_error(self, error: str) -> None:
|
|
||||||
"""Add an error message."""
|
|
||||||
self.errors.append(error)
|
|
||||||
|
|
||||||
def validate_block_existence(
|
|
||||||
self, agent: dict[str, Any], blocks_info: list[dict[str, Any]]
|
|
||||||
) -> bool:
|
|
||||||
"""Validate all block IDs exist in the blocks library."""
|
|
||||||
valid = True
|
|
||||||
valid_block_ids = {b.get("id") for b in blocks_info if b.get("id")}
|
|
||||||
|
|
||||||
for node in agent.get("nodes", []):
|
|
||||||
block_id = node.get("block_id")
|
|
||||||
node_id = node.get("id")
|
|
||||||
|
|
||||||
if not block_id:
|
|
||||||
self.add_error(f"Node '{node_id}' is missing 'block_id' field.")
|
|
||||||
valid = False
|
|
||||||
continue
|
|
||||||
|
|
||||||
if block_id not in valid_block_ids:
|
|
||||||
self.add_error(
|
|
||||||
f"Node '{node_id}' references block_id '{block_id}' which does not exist."
|
|
||||||
)
|
|
||||||
valid = False
|
|
||||||
|
|
||||||
return valid
|
|
||||||
|
|
||||||
def validate_link_node_references(self, agent: dict[str, Any]) -> bool:
|
|
||||||
"""Validate all node IDs referenced in links exist."""
|
|
||||||
valid = True
|
|
||||||
valid_node_ids = {n.get("id") for n in agent.get("nodes", []) if n.get("id")}
|
|
||||||
|
|
||||||
for link in agent.get("links", []):
|
|
||||||
link_id = link.get("id", "Unknown")
|
|
||||||
source_id = link.get("source_id")
|
|
||||||
sink_id = link.get("sink_id")
|
|
||||||
|
|
||||||
if not source_id:
|
|
||||||
self.add_error(f"Link '{link_id}' is missing 'source_id'.")
|
|
||||||
valid = False
|
|
||||||
elif source_id not in valid_node_ids:
|
|
||||||
self.add_error(
|
|
||||||
f"Link '{link_id}' references non-existent source_id '{source_id}'."
|
|
||||||
)
|
|
||||||
valid = False
|
|
||||||
|
|
||||||
if not sink_id:
|
|
||||||
self.add_error(f"Link '{link_id}' is missing 'sink_id'.")
|
|
||||||
valid = False
|
|
||||||
elif sink_id not in valid_node_ids:
|
|
||||||
self.add_error(
|
|
||||||
f"Link '{link_id}' references non-existent sink_id '{sink_id}'."
|
|
||||||
)
|
|
||||||
valid = False
|
|
||||||
|
|
||||||
return valid
|
|
||||||
|
|
||||||
def validate_required_inputs(
|
|
||||||
self, agent: dict[str, Any], blocks_info: list[dict[str, Any]]
|
|
||||||
) -> bool:
|
|
||||||
"""Validate required inputs are provided."""
|
|
||||||
valid = True
|
|
||||||
block_map = {b.get("id"): b for b in blocks_info}
|
|
||||||
|
|
||||||
for node in agent.get("nodes", []):
|
|
||||||
block_id = node.get("block_id")
|
|
||||||
block = block_map.get(block_id)
|
|
||||||
|
|
||||||
if not block:
|
|
||||||
continue
|
|
||||||
|
|
||||||
required_inputs = block.get("inputSchema", {}).get("required", [])
|
|
||||||
input_defaults = node.get("input_default", {})
|
|
||||||
node_id = node.get("id")
|
|
||||||
|
|
||||||
# Get linked inputs
|
|
||||||
linked_inputs = {
|
|
||||||
link["sink_name"]
|
|
||||||
for link in agent.get("links", [])
|
|
||||||
if link.get("sink_id") == node_id
|
|
||||||
}
|
|
||||||
|
|
||||||
for req_input in required_inputs:
|
|
||||||
if (
|
|
||||||
req_input not in input_defaults
|
|
||||||
and req_input not in linked_inputs
|
|
||||||
and req_input != "credentials"
|
|
||||||
):
|
|
||||||
block_name = block.get("name", "Unknown Block")
|
|
||||||
self.add_error(
|
|
||||||
f"Node '{node_id}' ({block_name}) is missing required input '{req_input}'."
|
|
||||||
)
|
|
||||||
valid = False
|
|
||||||
|
|
||||||
return valid
|
|
||||||
|
|
||||||
def validate_data_type_compatibility(
|
|
||||||
self, agent: dict[str, Any], blocks_info: list[dict[str, Any]]
|
|
||||||
) -> bool:
|
|
||||||
"""Validate linked data types are compatible."""
|
|
||||||
valid = True
|
|
||||||
block_map = {b.get("id"): b for b in blocks_info}
|
|
||||||
node_lookup = {n.get("id"): n for n in agent.get("nodes", [])}
|
|
||||||
|
|
||||||
def get_type(schema: dict, name: str) -> str | None:
|
|
||||||
if "_#_" in name:
|
|
||||||
parent, child = name.split("_#_", 1)
|
|
||||||
parent_schema = schema.get(parent, {})
|
|
||||||
if "properties" in parent_schema:
|
|
||||||
return parent_schema["properties"].get(child, {}).get("type")
|
|
||||||
return None
|
|
||||||
return schema.get(name, {}).get("type")
|
|
||||||
|
|
||||||
def are_compatible(src: str, sink: str) -> bool:
|
|
||||||
if {src, sink} <= {"integer", "number"}:
|
|
||||||
return True
|
|
||||||
return src == sink
|
|
||||||
|
|
||||||
for link in agent.get("links", []):
|
|
||||||
source_node = node_lookup.get(link.get("source_id"))
|
|
||||||
sink_node = node_lookup.get(link.get("sink_id"))
|
|
||||||
|
|
||||||
if not source_node or not sink_node:
|
|
||||||
continue
|
|
||||||
|
|
||||||
source_block = block_map.get(source_node.get("block_id"))
|
|
||||||
sink_block = block_map.get(sink_node.get("block_id"))
|
|
||||||
|
|
||||||
if not source_block or not sink_block:
|
|
||||||
continue
|
|
||||||
|
|
||||||
source_outputs = source_block.get("outputSchema", {}).get("properties", {})
|
|
||||||
sink_inputs = sink_block.get("inputSchema", {}).get("properties", {})
|
|
||||||
|
|
||||||
source_type = get_type(source_outputs, link.get("source_name", ""))
|
|
||||||
sink_type = get_type(sink_inputs, link.get("sink_name", ""))
|
|
||||||
|
|
||||||
if source_type and sink_type and not are_compatible(source_type, sink_type):
|
|
||||||
self.add_error(
|
|
||||||
f"Type mismatch: {source_block.get('name')} output '{link['source_name']}' "
|
|
||||||
f"({source_type}) -> {sink_block.get('name')} input '{link['sink_name']}' ({sink_type})."
|
|
||||||
)
|
|
||||||
valid = False
|
|
||||||
|
|
||||||
return valid
|
|
||||||
|
|
||||||
def validate_nested_sink_links(
|
|
||||||
self, agent: dict[str, Any], blocks_info: list[dict[str, Any]]
|
|
||||||
) -> bool:
|
|
||||||
"""Validate nested sink links (with _#_ notation)."""
|
|
||||||
valid = True
|
|
||||||
block_map = {b.get("id"): b for b in blocks_info}
|
|
||||||
node_lookup = {n.get("id"): n for n in agent.get("nodes", [])}
|
|
||||||
|
|
||||||
for link in agent.get("links", []):
|
|
||||||
sink_name = link.get("sink_name", "")
|
|
||||||
|
|
||||||
if "_#_" in sink_name:
|
|
||||||
parent, child = sink_name.split("_#_", 1)
|
|
||||||
|
|
||||||
sink_node = node_lookup.get(link.get("sink_id"))
|
|
||||||
if not sink_node:
|
|
||||||
continue
|
|
||||||
|
|
||||||
block = block_map.get(sink_node.get("block_id"))
|
|
||||||
if not block:
|
|
||||||
continue
|
|
||||||
|
|
||||||
input_props = block.get("inputSchema", {}).get("properties", {})
|
|
||||||
parent_schema = input_props.get(parent)
|
|
||||||
|
|
||||||
if not parent_schema:
|
|
||||||
self.add_error(
|
|
||||||
f"Invalid nested link '{sink_name}': parent '{parent}' not found."
|
|
||||||
)
|
|
||||||
valid = False
|
|
||||||
continue
|
|
||||||
|
|
||||||
if not parent_schema.get("additionalProperties"):
|
|
||||||
if not (
|
|
||||||
isinstance(parent_schema, dict)
|
|
||||||
and "properties" in parent_schema
|
|
||||||
and child in parent_schema.get("properties", {})
|
|
||||||
):
|
|
||||||
self.add_error(
|
|
||||||
f"Invalid nested link '{sink_name}': child '{child}' not found in '{parent}'."
|
|
||||||
)
|
|
||||||
valid = False
|
|
||||||
|
|
||||||
return valid
|
|
||||||
|
|
||||||
def validate_prompt_spaces(self, agent: dict[str, Any]) -> bool:
|
|
||||||
"""Validate prompts don't have spaces in template variables."""
|
|
||||||
valid = True
|
|
||||||
|
|
||||||
for node in agent.get("nodes", []):
|
|
||||||
input_default = node.get("input_default", {})
|
|
||||||
prompt = input_default.get("prompt", "")
|
|
||||||
|
|
||||||
if not isinstance(prompt, str):
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Find {{...}} with spaces
|
|
||||||
matches = re.finditer(r"\{\{([^}]+)\}\}", prompt)
|
|
||||||
for match in matches:
|
|
||||||
content = match.group(1)
|
|
||||||
if " " in content:
|
|
||||||
self.add_error(
|
|
||||||
f"Node '{node.get('id')}' has spaces in template variable: "
|
|
||||||
f"'{{{{{content}}}}}' should be '{{{{{content.replace(' ', '_')}}}}}'."
|
|
||||||
)
|
|
||||||
valid = False
|
|
||||||
|
|
||||||
return valid
|
|
||||||
|
|
||||||
def validate(
|
|
||||||
self, agent: dict[str, Any], blocks_info: list[dict[str, Any]] | None = None
|
|
||||||
) -> tuple[bool, str | None]:
|
|
||||||
"""Run all validations.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of (is_valid, error_message)
|
|
||||||
"""
|
|
||||||
self.errors = []
|
|
||||||
|
|
||||||
if blocks_info is None:
|
|
||||||
blocks_info = get_blocks_info()
|
|
||||||
|
|
||||||
checks = [
|
|
||||||
self.validate_block_existence(agent, blocks_info),
|
|
||||||
self.validate_link_node_references(agent),
|
|
||||||
self.validate_required_inputs(agent, blocks_info),
|
|
||||||
self.validate_data_type_compatibility(agent, blocks_info),
|
|
||||||
self.validate_nested_sink_links(agent, blocks_info),
|
|
||||||
self.validate_prompt_spaces(agent),
|
|
||||||
]
|
|
||||||
|
|
||||||
all_passed = all(checks)
|
|
||||||
|
|
||||||
if all_passed:
|
|
||||||
logger.info("Agent validation successful")
|
|
||||||
return True, None
|
|
||||||
|
|
||||||
error_message = "Agent validation failed:\n"
|
|
||||||
for i, error in enumerate(self.errors, 1):
|
|
||||||
error_message += f"{i}. {error}\n"
|
|
||||||
|
|
||||||
logger.warning(f"Agent validation failed with {len(self.errors)} errors")
|
|
||||||
return False, error_message
|
|
||||||
|
|
||||||
|
|
||||||
def validate_agent(
|
|
||||||
agent: dict[str, Any], blocks_info: list[dict[str, Any]] | None = None
|
|
||||||
) -> tuple[bool, str | None]:
|
|
||||||
"""Convenience function to validate an agent.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of (is_valid, error_message)
|
|
||||||
"""
|
|
||||||
validator = AgentValidator()
|
|
||||||
return validator.validate(agent, blocks_info)
|
|
||||||
@@ -8,12 +8,10 @@ from langfuse import observe
|
|||||||
from backend.api.features.chat.model import ChatSession
|
from backend.api.features.chat.model import ChatSession
|
||||||
|
|
||||||
from .agent_generator import (
|
from .agent_generator import (
|
||||||
apply_all_fixes,
|
AgentGeneratorNotConfiguredError,
|
||||||
decompose_goal,
|
decompose_goal,
|
||||||
generate_agent,
|
generate_agent,
|
||||||
get_blocks_info,
|
|
||||||
save_agent_to_library,
|
save_agent_to_library,
|
||||||
validate_agent,
|
|
||||||
)
|
)
|
||||||
from .base import BaseTool
|
from .base import BaseTool
|
||||||
from .models import (
|
from .models import (
|
||||||
@@ -27,9 +25,6 @@ from .models import (
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Maximum retries for agent generation with validation feedback
|
|
||||||
MAX_GENERATION_RETRIES = 2
|
|
||||||
|
|
||||||
|
|
||||||
class CreateAgentTool(BaseTool):
|
class CreateAgentTool(BaseTool):
|
||||||
"""Tool for creating agents from natural language descriptions."""
|
"""Tool for creating agents from natural language descriptions."""
|
||||||
@@ -91,9 +86,8 @@ class CreateAgentTool(BaseTool):
|
|||||||
|
|
||||||
Flow:
|
Flow:
|
||||||
1. Decompose the description into steps (may return clarifying questions)
|
1. Decompose the description into steps (may return clarifying questions)
|
||||||
2. Generate agent JSON from the steps
|
2. Generate agent JSON (external service handles fixing and validation)
|
||||||
3. Apply fixes to correct common LLM errors
|
3. Preview or save based on the save parameter
|
||||||
4. Preview or save based on the save parameter
|
|
||||||
"""
|
"""
|
||||||
description = kwargs.get("description", "").strip()
|
description = kwargs.get("description", "").strip()
|
||||||
context = kwargs.get("context", "")
|
context = kwargs.get("context", "")
|
||||||
@@ -110,11 +104,13 @@ class CreateAgentTool(BaseTool):
|
|||||||
# Step 1: Decompose goal into steps
|
# Step 1: Decompose goal into steps
|
||||||
try:
|
try:
|
||||||
decomposition_result = await decompose_goal(description, context)
|
decomposition_result = await decompose_goal(description, context)
|
||||||
except ValueError as e:
|
except AgentGeneratorNotConfiguredError:
|
||||||
# Handle missing API key or configuration errors
|
|
||||||
return ErrorResponse(
|
return ErrorResponse(
|
||||||
message=f"Agent generation is not configured: {str(e)}",
|
message=(
|
||||||
error="configuration_error",
|
"Agent generation is not available. "
|
||||||
|
"The Agent Generator service is not configured."
|
||||||
|
),
|
||||||
|
error="service_not_configured",
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -171,72 +167,32 @@ class CreateAgentTool(BaseTool):
|
|||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Step 2: Generate agent JSON with retry on validation failure
|
# Step 2: Generate agent JSON (external service handles fixing and validation)
|
||||||
blocks_info = get_blocks_info()
|
try:
|
||||||
agent_json = None
|
agent_json = await generate_agent(decomposition_result)
|
||||||
validation_errors = None
|
except AgentGeneratorNotConfiguredError:
|
||||||
|
return ErrorResponse(
|
||||||
for attempt in range(MAX_GENERATION_RETRIES + 1):
|
message=(
|
||||||
# Generate agent (include validation errors from previous attempt)
|
"Agent generation is not available. "
|
||||||
if attempt == 0:
|
"The Agent Generator service is not configured."
|
||||||
agent_json = await generate_agent(decomposition_result)
|
),
|
||||||
else:
|
error="service_not_configured",
|
||||||
# Retry with validation error feedback
|
session_id=session_id,
|
||||||
logger.info(
|
|
||||||
f"Retry {attempt}/{MAX_GENERATION_RETRIES} with validation feedback"
|
|
||||||
)
|
|
||||||
retry_instructions = {
|
|
||||||
**decomposition_result,
|
|
||||||
"previous_errors": validation_errors,
|
|
||||||
"retry_instructions": (
|
|
||||||
"The previous generation had validation errors. "
|
|
||||||
"Please fix these issues in the new generation:\n"
|
|
||||||
f"{validation_errors}"
|
|
||||||
),
|
|
||||||
}
|
|
||||||
agent_json = await generate_agent(retry_instructions)
|
|
||||||
|
|
||||||
if agent_json is None:
|
|
||||||
if attempt == MAX_GENERATION_RETRIES:
|
|
||||||
return ErrorResponse(
|
|
||||||
message="Failed to generate the agent. Please try again.",
|
|
||||||
error="Generation failed",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Step 3: Apply fixes to correct common errors
|
|
||||||
agent_json = apply_all_fixes(agent_json, blocks_info)
|
|
||||||
|
|
||||||
# Step 4: Validate the agent
|
|
||||||
is_valid, validation_errors = validate_agent(agent_json, blocks_info)
|
|
||||||
|
|
||||||
if is_valid:
|
|
||||||
logger.info(f"Agent generated successfully on attempt {attempt + 1}")
|
|
||||||
break
|
|
||||||
|
|
||||||
logger.warning(
|
|
||||||
f"Validation failed on attempt {attempt + 1}: {validation_errors}"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if attempt == MAX_GENERATION_RETRIES:
|
if agent_json is None:
|
||||||
# Return error with validation details
|
return ErrorResponse(
|
||||||
return ErrorResponse(
|
message="Failed to generate the agent. Please try again.",
|
||||||
message=(
|
error="Generation failed",
|
||||||
f"Generated agent has validation errors after {MAX_GENERATION_RETRIES + 1} attempts. "
|
session_id=session_id,
|
||||||
f"Please try rephrasing your request or simplify the workflow."
|
)
|
||||||
),
|
|
||||||
error="validation_failed",
|
|
||||||
details={"validation_errors": validation_errors},
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
agent_name = agent_json.get("name", "Generated Agent")
|
agent_name = agent_json.get("name", "Generated Agent")
|
||||||
agent_description = agent_json.get("description", "")
|
agent_description = agent_json.get("description", "")
|
||||||
node_count = len(agent_json.get("nodes", []))
|
node_count = len(agent_json.get("nodes", []))
|
||||||
link_count = len(agent_json.get("links", []))
|
link_count = len(agent_json.get("links", []))
|
||||||
|
|
||||||
# Step 4: Preview or save
|
# Step 3: Preview or save
|
||||||
if not save:
|
if not save:
|
||||||
return AgentPreviewResponse(
|
return AgentPreviewResponse(
|
||||||
message=(
|
message=(
|
||||||
|
|||||||
@@ -8,13 +8,10 @@ from langfuse import observe
|
|||||||
from backend.api.features.chat.model import ChatSession
|
from backend.api.features.chat.model import ChatSession
|
||||||
|
|
||||||
from .agent_generator import (
|
from .agent_generator import (
|
||||||
apply_agent_patch,
|
AgentGeneratorNotConfiguredError,
|
||||||
apply_all_fixes,
|
|
||||||
generate_agent_patch,
|
generate_agent_patch,
|
||||||
get_agent_as_json,
|
get_agent_as_json,
|
||||||
get_blocks_info,
|
|
||||||
save_agent_to_library,
|
save_agent_to_library,
|
||||||
validate_agent,
|
|
||||||
)
|
)
|
||||||
from .base import BaseTool
|
from .base import BaseTool
|
||||||
from .models import (
|
from .models import (
|
||||||
@@ -28,9 +25,6 @@ from .models import (
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Maximum retries for patch generation with validation feedback
|
|
||||||
MAX_GENERATION_RETRIES = 2
|
|
||||||
|
|
||||||
|
|
||||||
class EditAgentTool(BaseTool):
|
class EditAgentTool(BaseTool):
|
||||||
"""Tool for editing existing agents using natural language."""
|
"""Tool for editing existing agents using natural language."""
|
||||||
@@ -43,7 +37,7 @@ class EditAgentTool(BaseTool):
|
|||||||
def description(self) -> str:
|
def description(self) -> str:
|
||||||
return (
|
return (
|
||||||
"Edit an existing agent from the user's library using natural language. "
|
"Edit an existing agent from the user's library using natural language. "
|
||||||
"Generates a patch to update the agent while preserving unchanged parts."
|
"Generates updates to the agent while preserving unchanged parts."
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -98,9 +92,8 @@ class EditAgentTool(BaseTool):
|
|||||||
|
|
||||||
Flow:
|
Flow:
|
||||||
1. Fetch the current agent
|
1. Fetch the current agent
|
||||||
2. Generate a patch based on the requested changes
|
2. Generate updated agent (external service handles fixing and validation)
|
||||||
3. Apply the patch to create an updated agent
|
3. Preview or save based on the save parameter
|
||||||
4. Preview or save based on the save parameter
|
|
||||||
"""
|
"""
|
||||||
agent_id = kwargs.get("agent_id", "").strip()
|
agent_id = kwargs.get("agent_id", "").strip()
|
||||||
changes = kwargs.get("changes", "").strip()
|
changes = kwargs.get("changes", "").strip()
|
||||||
@@ -137,121 +130,58 @@ class EditAgentTool(BaseTool):
|
|||||||
if context:
|
if context:
|
||||||
update_request = f"{changes}\n\nAdditional context:\n{context}"
|
update_request = f"{changes}\n\nAdditional context:\n{context}"
|
||||||
|
|
||||||
# Step 2: Generate patch with retry on validation failure
|
# Step 2: Generate updated agent (external service handles fixing and validation)
|
||||||
blocks_info = get_blocks_info()
|
try:
|
||||||
updated_agent = None
|
result = await generate_agent_patch(update_request, current_agent)
|
||||||
validation_errors = None
|
except AgentGeneratorNotConfiguredError:
|
||||||
intent = "Applied requested changes"
|
return ErrorResponse(
|
||||||
|
message=(
|
||||||
for attempt in range(MAX_GENERATION_RETRIES + 1):
|
"Agent editing is not available. "
|
||||||
# Generate patch (include validation errors from previous attempt)
|
"The Agent Generator service is not configured."
|
||||||
try:
|
),
|
||||||
if attempt == 0:
|
error="service_not_configured",
|
||||||
patch_result = await generate_agent_patch(
|
session_id=session_id,
|
||||||
update_request, current_agent
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Retry with validation error feedback
|
|
||||||
logger.info(
|
|
||||||
f"Retry {attempt}/{MAX_GENERATION_RETRIES} with validation feedback"
|
|
||||||
)
|
|
||||||
retry_request = (
|
|
||||||
f"{update_request}\n\n"
|
|
||||||
f"IMPORTANT: The previous edit had validation errors. "
|
|
||||||
f"Please fix these issues:\n{validation_errors}"
|
|
||||||
)
|
|
||||||
patch_result = await generate_agent_patch(
|
|
||||||
retry_request, current_agent
|
|
||||||
)
|
|
||||||
except ValueError as e:
|
|
||||||
# Handle missing API key or configuration errors
|
|
||||||
return ErrorResponse(
|
|
||||||
message=f"Agent generation is not configured: {str(e)}",
|
|
||||||
error="configuration_error",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
if patch_result is None:
|
|
||||||
if attempt == MAX_GENERATION_RETRIES:
|
|
||||||
return ErrorResponse(
|
|
||||||
message="Failed to generate changes. Please try rephrasing.",
|
|
||||||
error="Patch generation failed",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Check if LLM returned clarifying questions
|
|
||||||
if patch_result.get("type") == "clarifying_questions":
|
|
||||||
questions = patch_result.get("questions", [])
|
|
||||||
return ClarificationNeededResponse(
|
|
||||||
message=(
|
|
||||||
"I need some more information about the changes. "
|
|
||||||
"Please answer the following questions:"
|
|
||||||
),
|
|
||||||
questions=[
|
|
||||||
ClarifyingQuestion(
|
|
||||||
question=q.get("question", ""),
|
|
||||||
keyword=q.get("keyword", ""),
|
|
||||||
example=q.get("example"),
|
|
||||||
)
|
|
||||||
for q in questions
|
|
||||||
],
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Step 3: Apply patch and fixes
|
|
||||||
try:
|
|
||||||
updated_agent = apply_agent_patch(current_agent, patch_result)
|
|
||||||
updated_agent = apply_all_fixes(updated_agent, blocks_info)
|
|
||||||
except Exception as e:
|
|
||||||
if attempt == MAX_GENERATION_RETRIES:
|
|
||||||
return ErrorResponse(
|
|
||||||
message=f"Failed to apply changes: {str(e)}",
|
|
||||||
error="patch_apply_failed",
|
|
||||||
details={"exception": str(e)},
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
validation_errors = str(e)
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Step 4: Validate the updated agent
|
|
||||||
is_valid, validation_errors = validate_agent(updated_agent, blocks_info)
|
|
||||||
|
|
||||||
if is_valid:
|
|
||||||
logger.info(f"Agent edited successfully on attempt {attempt + 1}")
|
|
||||||
intent = patch_result.get("intent", "Applied requested changes")
|
|
||||||
break
|
|
||||||
|
|
||||||
logger.warning(
|
|
||||||
f"Validation failed on attempt {attempt + 1}: {validation_errors}"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if attempt == MAX_GENERATION_RETRIES:
|
if result is None:
|
||||||
# Return error with validation details
|
return ErrorResponse(
|
||||||
return ErrorResponse(
|
message="Failed to generate changes. Please try rephrasing.",
|
||||||
message=(
|
error="Update generation failed",
|
||||||
f"Updated agent has validation errors after "
|
session_id=session_id,
|
||||||
f"{MAX_GENERATION_RETRIES + 1} attempts. "
|
)
|
||||||
f"Please try rephrasing your request or simplify the changes."
|
|
||||||
),
|
|
||||||
error="validation_failed",
|
|
||||||
details={"validation_errors": validation_errors},
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# At this point, updated_agent is guaranteed to be set (we return on all failure paths)
|
# Check if LLM returned clarifying questions
|
||||||
assert updated_agent is not None
|
if result.get("type") == "clarifying_questions":
|
||||||
|
questions = result.get("questions", [])
|
||||||
|
return ClarificationNeededResponse(
|
||||||
|
message=(
|
||||||
|
"I need some more information about the changes. "
|
||||||
|
"Please answer the following questions:"
|
||||||
|
),
|
||||||
|
questions=[
|
||||||
|
ClarifyingQuestion(
|
||||||
|
question=q.get("question", ""),
|
||||||
|
keyword=q.get("keyword", ""),
|
||||||
|
example=q.get("example"),
|
||||||
|
)
|
||||||
|
for q in questions
|
||||||
|
],
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Result is the updated agent JSON
|
||||||
|
updated_agent = result
|
||||||
|
|
||||||
agent_name = updated_agent.get("name", "Updated Agent")
|
agent_name = updated_agent.get("name", "Updated Agent")
|
||||||
agent_description = updated_agent.get("description", "")
|
agent_description = updated_agent.get("description", "")
|
||||||
node_count = len(updated_agent.get("nodes", []))
|
node_count = len(updated_agent.get("nodes", []))
|
||||||
link_count = len(updated_agent.get("links", []))
|
link_count = len(updated_agent.get("links", []))
|
||||||
|
|
||||||
# Step 5: Preview or save
|
# Step 3: Preview or save
|
||||||
if not save:
|
if not save:
|
||||||
return AgentPreviewResponse(
|
return AgentPreviewResponse(
|
||||||
message=(
|
message=(
|
||||||
f"I've updated the agent. Changes: {intent}. "
|
f"I've updated the agent. "
|
||||||
f"The agent now has {node_count} blocks. "
|
f"The agent now has {node_count} blocks. "
|
||||||
f"Review it and call edit_agent with save=true to save the changes."
|
f"Review it and call edit_agent with save=true to save the changes."
|
||||||
),
|
),
|
||||||
@@ -277,10 +207,7 @@ class EditAgentTool(BaseTool):
|
|||||||
)
|
)
|
||||||
|
|
||||||
return AgentSavedResponse(
|
return AgentSavedResponse(
|
||||||
message=(
|
message=f"Updated agent '{created_graph.name}' has been saved to your library!",
|
||||||
f"Updated agent '{created_graph.name}' has been saved to your library! "
|
|
||||||
f"Changes: {intent}"
|
|
||||||
),
|
|
||||||
agent_id=created_graph.id,
|
agent_id=created_graph.id,
|
||||||
agent_name=created_graph.name,
|
agent_name=created_graph.name,
|
||||||
library_agent_id=library_agent.id,
|
library_agent_id=library_agent.id,
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ def mock_embedding_functions():
|
|||||||
yield
|
yield
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(scope="session")
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
async def test_run_agent(setup_test_data):
|
async def test_run_agent(setup_test_data):
|
||||||
"""Test that the run_agent tool successfully executes an approved agent"""
|
"""Test that the run_agent tool successfully executes an approved agent"""
|
||||||
# Use test data from fixture
|
# Use test data from fixture
|
||||||
@@ -70,7 +70,7 @@ async def test_run_agent(setup_test_data):
|
|||||||
assert result_data["graph_name"] == "Test Agent"
|
assert result_data["graph_name"] == "Test Agent"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(scope="session")
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
async def test_run_agent_missing_inputs(setup_test_data):
|
async def test_run_agent_missing_inputs(setup_test_data):
|
||||||
"""Test that the run_agent tool returns error when inputs are missing"""
|
"""Test that the run_agent tool returns error when inputs are missing"""
|
||||||
# Use test data from fixture
|
# Use test data from fixture
|
||||||
@@ -106,7 +106,7 @@ async def test_run_agent_missing_inputs(setup_test_data):
|
|||||||
assert "message" in result_data
|
assert "message" in result_data
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(scope="session")
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
async def test_run_agent_invalid_agent_id(setup_test_data):
|
async def test_run_agent_invalid_agent_id(setup_test_data):
|
||||||
"""Test that the run_agent tool returns error for invalid agent ID"""
|
"""Test that the run_agent tool returns error for invalid agent ID"""
|
||||||
# Use test data from fixture
|
# Use test data from fixture
|
||||||
@@ -141,7 +141,7 @@ async def test_run_agent_invalid_agent_id(setup_test_data):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(scope="session")
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
async def test_run_agent_with_llm_credentials(setup_llm_test_data):
|
async def test_run_agent_with_llm_credentials(setup_llm_test_data):
|
||||||
"""Test that run_agent works with an agent requiring LLM credentials"""
|
"""Test that run_agent works with an agent requiring LLM credentials"""
|
||||||
# Use test data from fixture
|
# Use test data from fixture
|
||||||
@@ -185,7 +185,7 @@ async def test_run_agent_with_llm_credentials(setup_llm_test_data):
|
|||||||
assert result_data["graph_name"] == "LLM Test Agent"
|
assert result_data["graph_name"] == "LLM Test Agent"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(scope="session")
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
async def test_run_agent_shows_available_inputs_when_none_provided(setup_test_data):
|
async def test_run_agent_shows_available_inputs_when_none_provided(setup_test_data):
|
||||||
"""Test that run_agent returns available inputs when called without inputs or use_defaults."""
|
"""Test that run_agent returns available inputs when called without inputs or use_defaults."""
|
||||||
user = setup_test_data["user"]
|
user = setup_test_data["user"]
|
||||||
@@ -219,7 +219,7 @@ async def test_run_agent_shows_available_inputs_when_none_provided(setup_test_da
|
|||||||
assert "inputs" in result_data["message"].lower()
|
assert "inputs" in result_data["message"].lower()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(scope="session")
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
async def test_run_agent_with_use_defaults(setup_test_data):
|
async def test_run_agent_with_use_defaults(setup_test_data):
|
||||||
"""Test that run_agent executes successfully with use_defaults=True."""
|
"""Test that run_agent executes successfully with use_defaults=True."""
|
||||||
user = setup_test_data["user"]
|
user = setup_test_data["user"]
|
||||||
@@ -251,7 +251,7 @@ async def test_run_agent_with_use_defaults(setup_test_data):
|
|||||||
assert result_data["graph_id"] == graph.id
|
assert result_data["graph_id"] == graph.id
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(scope="session")
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
async def test_run_agent_missing_credentials(setup_firecrawl_test_data):
|
async def test_run_agent_missing_credentials(setup_firecrawl_test_data):
|
||||||
"""Test that run_agent returns setup_requirements when credentials are missing."""
|
"""Test that run_agent returns setup_requirements when credentials are missing."""
|
||||||
user = setup_firecrawl_test_data["user"]
|
user = setup_firecrawl_test_data["user"]
|
||||||
@@ -285,7 +285,7 @@ async def test_run_agent_missing_credentials(setup_firecrawl_test_data):
|
|||||||
assert len(setup_info["user_readiness"]["missing_credentials"]) > 0
|
assert len(setup_info["user_readiness"]["missing_credentials"]) > 0
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(scope="session")
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
async def test_run_agent_invalid_slug_format(setup_test_data):
|
async def test_run_agent_invalid_slug_format(setup_test_data):
|
||||||
"""Test that run_agent returns error for invalid slug format (no slash)."""
|
"""Test that run_agent returns error for invalid slug format (no slash)."""
|
||||||
user = setup_test_data["user"]
|
user = setup_test_data["user"]
|
||||||
@@ -313,7 +313,7 @@ async def test_run_agent_invalid_slug_format(setup_test_data):
|
|||||||
assert "username/agent-name" in result_data["message"]
|
assert "username/agent-name" in result_data["message"]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(scope="session")
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
async def test_run_agent_unauthenticated():
|
async def test_run_agent_unauthenticated():
|
||||||
"""Test that run_agent returns need_login for unauthenticated users."""
|
"""Test that run_agent returns need_login for unauthenticated users."""
|
||||||
tool = RunAgentTool()
|
tool = RunAgentTool()
|
||||||
@@ -340,7 +340,7 @@ async def test_run_agent_unauthenticated():
|
|||||||
assert "sign in" in result_data["message"].lower()
|
assert "sign in" in result_data["message"].lower()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(scope="session")
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
async def test_run_agent_schedule_without_cron(setup_test_data):
|
async def test_run_agent_schedule_without_cron(setup_test_data):
|
||||||
"""Test that run_agent returns error when scheduling without cron expression."""
|
"""Test that run_agent returns error when scheduling without cron expression."""
|
||||||
user = setup_test_data["user"]
|
user = setup_test_data["user"]
|
||||||
@@ -372,7 +372,7 @@ async def test_run_agent_schedule_without_cron(setup_test_data):
|
|||||||
assert "cron" in result_data["message"].lower()
|
assert "cron" in result_data["message"].lower()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(scope="session")
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
async def test_run_agent_schedule_without_name(setup_test_data):
|
async def test_run_agent_schedule_without_name(setup_test_data):
|
||||||
"""Test that run_agent returns error when scheduling without schedule_name."""
|
"""Test that run_agent returns error when scheduling without schedule_name."""
|
||||||
user = setup_test_data["user"]
|
user = setup_test_data["user"]
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ class PendingHumanReviewModel(BaseModel):
|
|||||||
id: Unique identifier for the review record
|
id: Unique identifier for the review record
|
||||||
user_id: ID of the user who must perform the review
|
user_id: ID of the user who must perform the review
|
||||||
node_exec_id: ID of the node execution that created this review
|
node_exec_id: ID of the node execution that created this review
|
||||||
|
node_id: ID of the node definition (for grouping reviews from same node)
|
||||||
graph_exec_id: ID of the graph execution containing the node
|
graph_exec_id: ID of the graph execution containing the node
|
||||||
graph_id: ID of the graph template being executed
|
graph_id: ID of the graph template being executed
|
||||||
graph_version: Version number of the graph template
|
graph_version: Version number of the graph template
|
||||||
@@ -37,6 +38,10 @@ class PendingHumanReviewModel(BaseModel):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
node_exec_id: str = Field(description="Node execution ID (primary key)")
|
node_exec_id: str = Field(description="Node execution ID (primary key)")
|
||||||
|
node_id: str = Field(
|
||||||
|
description="Node definition ID (for grouping)",
|
||||||
|
default="", # Temporary default for test compatibility
|
||||||
|
)
|
||||||
user_id: str = Field(description="User ID associated with the review")
|
user_id: str = Field(description="User ID associated with the review")
|
||||||
graph_exec_id: str = Field(description="Graph execution ID")
|
graph_exec_id: str = Field(description="Graph execution ID")
|
||||||
graph_id: str = Field(description="Graph ID")
|
graph_id: str = Field(description="Graph ID")
|
||||||
@@ -66,7 +71,9 @@ class PendingHumanReviewModel(BaseModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_db(cls, review: "PendingHumanReview") -> "PendingHumanReviewModel":
|
def from_db(
|
||||||
|
cls, review: "PendingHumanReview", node_id: str
|
||||||
|
) -> "PendingHumanReviewModel":
|
||||||
"""
|
"""
|
||||||
Convert a database model to a response model.
|
Convert a database model to a response model.
|
||||||
|
|
||||||
@@ -74,9 +81,14 @@ class PendingHumanReviewModel(BaseModel):
|
|||||||
payload, instructions, and editable flag.
|
payload, instructions, and editable flag.
|
||||||
|
|
||||||
Handles invalid data gracefully by using safe defaults.
|
Handles invalid data gracefully by using safe defaults.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
review: Database review object
|
||||||
|
node_id: Node definition ID (fetched from NodeExecution)
|
||||||
"""
|
"""
|
||||||
return cls(
|
return cls(
|
||||||
node_exec_id=review.nodeExecId,
|
node_exec_id=review.nodeExecId,
|
||||||
|
node_id=node_id,
|
||||||
user_id=review.userId,
|
user_id=review.userId,
|
||||||
graph_exec_id=review.graphExecId,
|
graph_exec_id=review.graphExecId,
|
||||||
graph_id=review.graphId,
|
graph_id=review.graphId,
|
||||||
@@ -107,6 +119,13 @@ class ReviewItem(BaseModel):
|
|||||||
reviewed_data: SafeJsonData | None = Field(
|
reviewed_data: SafeJsonData | None = Field(
|
||||||
None, description="Optional edited data (ignored if approved=False)"
|
None, description="Optional edited data (ignored if approved=False)"
|
||||||
)
|
)
|
||||||
|
auto_approve_future: bool = Field(
|
||||||
|
default=False,
|
||||||
|
description=(
|
||||||
|
"If true and this review is approved, future executions of this same "
|
||||||
|
"block (node) will be automatically approved. This only affects approved reviews."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
@field_validator("reviewed_data")
|
@field_validator("reviewed_data")
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -174,6 +193,9 @@ class ReviewRequest(BaseModel):
|
|||||||
This request must include ALL pending reviews for a graph execution.
|
This request must include ALL pending reviews for a graph execution.
|
||||||
Each review will be either approved (with optional data modifications)
|
Each review will be either approved (with optional data modifications)
|
||||||
or rejected (data ignored). The execution will resume only after ALL reviews are processed.
|
or rejected (data ignored). The execution will resume only after ALL reviews are processed.
|
||||||
|
|
||||||
|
Each review item can individually specify whether to auto-approve future executions
|
||||||
|
of the same block via the `auto_approve_future` field on ReviewItem.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
reviews: List[ReviewItem] = Field(
|
reviews: List[ReviewItem] = Field(
|
||||||
|
|||||||
@@ -1,17 +1,27 @@
|
|||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from typing import List
|
from typing import Any, List
|
||||||
|
|
||||||
import autogpt_libs.auth as autogpt_auth_lib
|
import autogpt_libs.auth as autogpt_auth_lib
|
||||||
from fastapi import APIRouter, HTTPException, Query, Security, status
|
from fastapi import APIRouter, HTTPException, Query, Security, status
|
||||||
from prisma.enums import ReviewStatus
|
from prisma.enums import ReviewStatus
|
||||||
|
|
||||||
from backend.data.execution import get_graph_execution_meta
|
from backend.data.execution import (
|
||||||
|
ExecutionContext,
|
||||||
|
ExecutionStatus,
|
||||||
|
get_graph_execution_meta,
|
||||||
|
)
|
||||||
|
from backend.data.graph import get_graph_settings
|
||||||
from backend.data.human_review import (
|
from backend.data.human_review import (
|
||||||
|
create_auto_approval_record,
|
||||||
|
get_pending_reviews_by_node_exec_ids,
|
||||||
get_pending_reviews_for_execution,
|
get_pending_reviews_for_execution,
|
||||||
get_pending_reviews_for_user,
|
get_pending_reviews_for_user,
|
||||||
has_pending_reviews_for_graph_exec,
|
has_pending_reviews_for_graph_exec,
|
||||||
process_all_reviews_for_execution,
|
process_all_reviews_for_execution,
|
||||||
)
|
)
|
||||||
|
from backend.data.model import USER_TIMEZONE_NOT_SET
|
||||||
|
from backend.data.user import get_user_by_id
|
||||||
from backend.executor.utils import add_graph_execution
|
from backend.executor.utils import add_graph_execution
|
||||||
|
|
||||||
from .model import PendingHumanReviewModel, ReviewRequest, ReviewResponse
|
from .model import PendingHumanReviewModel, ReviewRequest, ReviewResponse
|
||||||
@@ -127,17 +137,70 @@ async def process_review_action(
|
|||||||
detail="At least one review must be provided",
|
detail="At least one review must be provided",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Build review decisions map
|
# Batch fetch all requested reviews
|
||||||
|
reviews_map = await get_pending_reviews_by_node_exec_ids(
|
||||||
|
list(all_request_node_ids), user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate all reviews were found
|
||||||
|
missing_ids = all_request_node_ids - set(reviews_map.keys())
|
||||||
|
if missing_ids:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail=f"No pending review found for node execution(s): {', '.join(missing_ids)}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate all reviews belong to the same execution
|
||||||
|
graph_exec_ids = {review.graph_exec_id for review in reviews_map.values()}
|
||||||
|
if len(graph_exec_ids) > 1:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_409_CONFLICT,
|
||||||
|
detail="All reviews in a single request must belong to the same execution.",
|
||||||
|
)
|
||||||
|
|
||||||
|
graph_exec_id = next(iter(graph_exec_ids))
|
||||||
|
|
||||||
|
# Validate execution status before processing reviews
|
||||||
|
graph_exec_meta = await get_graph_execution_meta(
|
||||||
|
user_id=user_id, execution_id=graph_exec_id
|
||||||
|
)
|
||||||
|
|
||||||
|
if not graph_exec_meta:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail=f"Graph execution #{graph_exec_id} not found",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Only allow processing reviews if execution is paused for review
|
||||||
|
# or incomplete (partial execution with some reviews already processed)
|
||||||
|
if graph_exec_meta.status not in (
|
||||||
|
ExecutionStatus.REVIEW,
|
||||||
|
ExecutionStatus.INCOMPLETE,
|
||||||
|
):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_409_CONFLICT,
|
||||||
|
detail=f"Cannot process reviews while execution status is {graph_exec_meta.status}. "
|
||||||
|
f"Reviews can only be processed when execution is paused (REVIEW status). "
|
||||||
|
f"Current status: {graph_exec_meta.status}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build review decisions map and track which reviews requested auto-approval
|
||||||
|
# Auto-approved reviews use original data (no modifications allowed)
|
||||||
review_decisions = {}
|
review_decisions = {}
|
||||||
|
auto_approve_requests = {} # Map node_exec_id -> auto_approve_future flag
|
||||||
|
|
||||||
for review in request.reviews:
|
for review in request.reviews:
|
||||||
review_status = (
|
review_status = (
|
||||||
ReviewStatus.APPROVED if review.approved else ReviewStatus.REJECTED
|
ReviewStatus.APPROVED if review.approved else ReviewStatus.REJECTED
|
||||||
)
|
)
|
||||||
|
# If this review requested auto-approval, don't allow data modifications
|
||||||
|
reviewed_data = None if review.auto_approve_future else review.reviewed_data
|
||||||
review_decisions[review.node_exec_id] = (
|
review_decisions[review.node_exec_id] = (
|
||||||
review_status,
|
review_status,
|
||||||
review.reviewed_data,
|
reviewed_data,
|
||||||
review.message,
|
review.message,
|
||||||
)
|
)
|
||||||
|
auto_approve_requests[review.node_exec_id] = review.auto_approve_future
|
||||||
|
|
||||||
# Process all reviews
|
# Process all reviews
|
||||||
updated_reviews = await process_all_reviews_for_execution(
|
updated_reviews = await process_all_reviews_for_execution(
|
||||||
@@ -145,6 +208,87 @@ async def process_review_action(
|
|||||||
review_decisions=review_decisions,
|
review_decisions=review_decisions,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Create auto-approval records for approved reviews that requested it
|
||||||
|
# Deduplicate by node_id to avoid race conditions when multiple reviews
|
||||||
|
# for the same node are processed in parallel
|
||||||
|
async def create_auto_approval_for_node(
|
||||||
|
node_id: str, review_result
|
||||||
|
) -> tuple[str, bool]:
|
||||||
|
"""
|
||||||
|
Create auto-approval record for a node.
|
||||||
|
Returns (node_id, success) tuple for tracking failures.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
await create_auto_approval_record(
|
||||||
|
user_id=user_id,
|
||||||
|
graph_exec_id=review_result.graph_exec_id,
|
||||||
|
graph_id=review_result.graph_id,
|
||||||
|
graph_version=review_result.graph_version,
|
||||||
|
node_id=node_id,
|
||||||
|
payload=review_result.payload,
|
||||||
|
)
|
||||||
|
return (node_id, True)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Failed to create auto-approval record for node {node_id}",
|
||||||
|
exc_info=e,
|
||||||
|
)
|
||||||
|
return (node_id, False)
|
||||||
|
|
||||||
|
# Collect node_exec_ids that need auto-approval
|
||||||
|
node_exec_ids_needing_auto_approval = [
|
||||||
|
node_exec_id
|
||||||
|
for node_exec_id, review_result in updated_reviews.items()
|
||||||
|
if review_result.status == ReviewStatus.APPROVED
|
||||||
|
and auto_approve_requests.get(node_exec_id, False)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Batch-fetch node executions to get node_ids
|
||||||
|
nodes_needing_auto_approval: dict[str, Any] = {}
|
||||||
|
if node_exec_ids_needing_auto_approval:
|
||||||
|
from backend.data.execution import get_node_executions
|
||||||
|
|
||||||
|
node_execs = await get_node_executions(
|
||||||
|
graph_exec_id=graph_exec_id, include_exec_data=False
|
||||||
|
)
|
||||||
|
node_exec_map = {node_exec.node_exec_id: node_exec for node_exec in node_execs}
|
||||||
|
|
||||||
|
for node_exec_id in node_exec_ids_needing_auto_approval:
|
||||||
|
node_exec = node_exec_map.get(node_exec_id)
|
||||||
|
if node_exec:
|
||||||
|
review_result = updated_reviews[node_exec_id]
|
||||||
|
# Use the first approved review for this node (deduplicate by node_id)
|
||||||
|
if node_exec.node_id not in nodes_needing_auto_approval:
|
||||||
|
nodes_needing_auto_approval[node_exec.node_id] = review_result
|
||||||
|
else:
|
||||||
|
logger.error(
|
||||||
|
f"Failed to create auto-approval record for {node_exec_id}: "
|
||||||
|
f"Node execution not found. This may indicate a race condition "
|
||||||
|
f"or data inconsistency."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Execute all auto-approval creations in parallel (deduplicated by node_id)
|
||||||
|
auto_approval_results = await asyncio.gather(
|
||||||
|
*[
|
||||||
|
create_auto_approval_for_node(node_id, review_result)
|
||||||
|
for node_id, review_result in nodes_needing_auto_approval.items()
|
||||||
|
],
|
||||||
|
return_exceptions=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Count auto-approval failures
|
||||||
|
auto_approval_failed_count = 0
|
||||||
|
for result in auto_approval_results:
|
||||||
|
if isinstance(result, Exception):
|
||||||
|
# Unexpected exception during auto-approval creation
|
||||||
|
auto_approval_failed_count += 1
|
||||||
|
logger.error(
|
||||||
|
f"Unexpected exception during auto-approval creation: {result}"
|
||||||
|
)
|
||||||
|
elif isinstance(result, tuple) and len(result) == 2 and not result[1]:
|
||||||
|
# Auto-approval creation failed (returned False)
|
||||||
|
auto_approval_failed_count += 1
|
||||||
|
|
||||||
# Count results
|
# Count results
|
||||||
approved_count = sum(
|
approved_count = sum(
|
||||||
1
|
1
|
||||||
@@ -157,30 +301,53 @@ async def process_review_action(
|
|||||||
if review.status == ReviewStatus.REJECTED
|
if review.status == ReviewStatus.REJECTED
|
||||||
)
|
)
|
||||||
|
|
||||||
# Resume execution if we processed some reviews
|
# Resume execution only if ALL pending reviews for this execution have been processed
|
||||||
if updated_reviews:
|
if updated_reviews:
|
||||||
# Get graph execution ID from any processed review
|
|
||||||
first_review = next(iter(updated_reviews.values()))
|
|
||||||
graph_exec_id = first_review.graph_exec_id
|
|
||||||
|
|
||||||
# Check if any pending reviews remain for this execution
|
|
||||||
still_has_pending = await has_pending_reviews_for_graph_exec(graph_exec_id)
|
still_has_pending = await has_pending_reviews_for_graph_exec(graph_exec_id)
|
||||||
|
|
||||||
if not still_has_pending:
|
if not still_has_pending:
|
||||||
# Resume execution
|
# Get the graph_id from any processed review
|
||||||
|
first_review = next(iter(updated_reviews.values()))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
# Fetch user and settings to build complete execution context
|
||||||
|
user = await get_user_by_id(user_id)
|
||||||
|
settings = await get_graph_settings(
|
||||||
|
user_id=user_id, graph_id=first_review.graph_id
|
||||||
|
)
|
||||||
|
|
||||||
|
# Preserve user's timezone preference when resuming execution
|
||||||
|
user_timezone = (
|
||||||
|
user.timezone if user.timezone != USER_TIMEZONE_NOT_SET else "UTC"
|
||||||
|
)
|
||||||
|
|
||||||
|
execution_context = ExecutionContext(
|
||||||
|
human_in_the_loop_safe_mode=settings.human_in_the_loop_safe_mode,
|
||||||
|
sensitive_action_safe_mode=settings.sensitive_action_safe_mode,
|
||||||
|
user_timezone=user_timezone,
|
||||||
|
)
|
||||||
|
|
||||||
await add_graph_execution(
|
await add_graph_execution(
|
||||||
graph_id=first_review.graph_id,
|
graph_id=first_review.graph_id,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
graph_exec_id=graph_exec_id,
|
graph_exec_id=graph_exec_id,
|
||||||
|
execution_context=execution_context,
|
||||||
)
|
)
|
||||||
logger.info(f"Resumed execution {graph_exec_id}")
|
logger.info(f"Resumed execution {graph_exec_id}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to resume execution {graph_exec_id}: {str(e)}")
|
logger.error(f"Failed to resume execution {graph_exec_id}: {str(e)}")
|
||||||
|
|
||||||
|
# Build error message if auto-approvals failed
|
||||||
|
error_message = None
|
||||||
|
if auto_approval_failed_count > 0:
|
||||||
|
error_message = (
|
||||||
|
f"{auto_approval_failed_count} auto-approval setting(s) could not be saved. "
|
||||||
|
f"You may need to manually approve these reviews in future executions."
|
||||||
|
)
|
||||||
|
|
||||||
return ReviewResponse(
|
return ReviewResponse(
|
||||||
approved_count=approved_count,
|
approved_count=approved_count,
|
||||||
rejected_count=rejected_count,
|
rejected_count=rejected_count,
|
||||||
failed_count=0,
|
failed_count=auto_approval_failed_count,
|
||||||
error=None,
|
error=error_message,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -583,7 +583,13 @@ async def update_library_agent(
|
|||||||
)
|
)
|
||||||
update_fields["isDeleted"] = is_deleted
|
update_fields["isDeleted"] = is_deleted
|
||||||
if settings is not None:
|
if settings is not None:
|
||||||
update_fields["settings"] = SafeJson(settings.model_dump())
|
existing_agent = await get_library_agent(id=library_agent_id, user_id=user_id)
|
||||||
|
current_settings_dict = (
|
||||||
|
existing_agent.settings.model_dump() if existing_agent.settings else {}
|
||||||
|
)
|
||||||
|
new_settings = settings.model_dump(exclude_unset=True)
|
||||||
|
merged_settings = {**current_settings_dict, **new_settings}
|
||||||
|
update_fields["settings"] = SafeJson(merged_settings)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# If graph_version is provided, update to that specific version
|
# If graph_version is provided, update to that specific version
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ from typing import AsyncGenerator
|
|||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
import pytest
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
from autogpt_libs.api_key.keysmith import APIKeySmith
|
from autogpt_libs.api_key.keysmith import APIKeySmith
|
||||||
from prisma.enums import APIKeyPermission
|
from prisma.enums import APIKeyPermission
|
||||||
from prisma.models import OAuthAccessToken as PrismaOAuthAccessToken
|
from prisma.models import OAuthAccessToken as PrismaOAuthAccessToken
|
||||||
@@ -38,13 +39,13 @@ keysmith = APIKeySmith()
|
|||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture(scope="session")
|
||||||
def test_user_id() -> str:
|
def test_user_id() -> str:
|
||||||
"""Test user ID for OAuth tests."""
|
"""Test user ID for OAuth tests."""
|
||||||
return str(uuid.uuid4())
|
return str(uuid.uuid4())
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest_asyncio.fixture(scope="session", loop_scope="session")
|
||||||
async def test_user(server, test_user_id: str):
|
async def test_user(server, test_user_id: str):
|
||||||
"""Create a test user in the database."""
|
"""Create a test user in the database."""
|
||||||
await PrismaUser.prisma().create(
|
await PrismaUser.prisma().create(
|
||||||
@@ -67,7 +68,7 @@ async def test_user(server, test_user_id: str):
|
|||||||
await PrismaUser.prisma().delete(where={"id": test_user_id})
|
await PrismaUser.prisma().delete(where={"id": test_user_id})
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest_asyncio.fixture
|
||||||
async def test_oauth_app(test_user: str):
|
async def test_oauth_app(test_user: str):
|
||||||
"""Create a test OAuth application in the database."""
|
"""Create a test OAuth application in the database."""
|
||||||
app_id = str(uuid.uuid4())
|
app_id = str(uuid.uuid4())
|
||||||
@@ -122,7 +123,7 @@ def pkce_credentials() -> tuple[str, str]:
|
|||||||
return generate_pkce()
|
return generate_pkce()
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest_asyncio.fixture
|
||||||
async def client(server, test_user: str) -> AsyncGenerator[httpx.AsyncClient, None]:
|
async def client(server, test_user: str) -> AsyncGenerator[httpx.AsyncClient, None]:
|
||||||
"""
|
"""
|
||||||
Create an async HTTP client that talks directly to the FastAPI app.
|
Create an async HTTP client that talks directly to the FastAPI app.
|
||||||
@@ -287,7 +288,7 @@ async def test_authorize_invalid_client_returns_error(
|
|||||||
assert query_params["error"][0] == "invalid_client"
|
assert query_params["error"][0] == "invalid_client"
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest_asyncio.fixture
|
||||||
async def inactive_oauth_app(test_user: str):
|
async def inactive_oauth_app(test_user: str):
|
||||||
"""Create an inactive test OAuth application in the database."""
|
"""Create an inactive test OAuth application in the database."""
|
||||||
app_id = str(uuid.uuid4())
|
app_id = str(uuid.uuid4())
|
||||||
@@ -1004,7 +1005,7 @@ async def test_token_refresh_revoked(
|
|||||||
assert "revoked" in response.json()["detail"].lower()
|
assert "revoked" in response.json()["detail"].lower()
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest_asyncio.fixture
|
||||||
async def other_oauth_app(test_user: str):
|
async def other_oauth_app(test_user: str):
|
||||||
"""Create a second OAuth application for cross-app tests."""
|
"""Create a second OAuth application for cross-app tests."""
|
||||||
app_id = str(uuid.uuid4())
|
app_id = str(uuid.uuid4())
|
||||||
|
|||||||
@@ -1552,7 +1552,7 @@ async def review_store_submission(
|
|||||||
|
|
||||||
# Generate embedding for approved listing (blocking - admin operation)
|
# Generate embedding for approved listing (blocking - admin operation)
|
||||||
# Inside transaction: if embedding fails, entire transaction rolls back
|
# Inside transaction: if embedding fails, entire transaction rolls back
|
||||||
embedding_success = await ensure_embedding(
|
await ensure_embedding(
|
||||||
version_id=store_listing_version_id,
|
version_id=store_listing_version_id,
|
||||||
name=store_listing_version.name,
|
name=store_listing_version.name,
|
||||||
description=store_listing_version.description,
|
description=store_listing_version.description,
|
||||||
@@ -1560,12 +1560,6 @@ async def review_store_submission(
|
|||||||
categories=store_listing_version.categories or [],
|
categories=store_listing_version.categories or [],
|
||||||
tx=tx,
|
tx=tx,
|
||||||
)
|
)
|
||||||
if not embedding_success:
|
|
||||||
raise ValueError(
|
|
||||||
f"Failed to generate embedding for listing {store_listing_version_id}. "
|
|
||||||
"This is likely due to OpenAI API being unavailable. "
|
|
||||||
"Please try again later or contact support if the issue persists."
|
|
||||||
)
|
|
||||||
|
|
||||||
await prisma.models.StoreListing.prisma(tx).update(
|
await prisma.models.StoreListing.prisma(tx).update(
|
||||||
where={"id": store_listing_version.StoreListing.id},
|
where={"id": store_listing_version.StoreListing.id},
|
||||||
|
|||||||
@@ -21,7 +21,6 @@ from backend.util.json import dumps
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
# OpenAI embedding model configuration
|
# OpenAI embedding model configuration
|
||||||
EMBEDDING_MODEL = "text-embedding-3-small"
|
EMBEDDING_MODEL = "text-embedding-3-small"
|
||||||
# Embedding dimension for the model above
|
# Embedding dimension for the model above
|
||||||
@@ -63,49 +62,42 @@ def build_searchable_text(
|
|||||||
return " ".join(parts)
|
return " ".join(parts)
|
||||||
|
|
||||||
|
|
||||||
async def generate_embedding(text: str) -> list[float] | None:
|
async def generate_embedding(text: str) -> list[float]:
|
||||||
"""
|
"""
|
||||||
Generate embedding for text using OpenAI API.
|
Generate embedding for text using OpenAI API.
|
||||||
|
|
||||||
Returns None if embedding generation fails.
|
Raises exceptions on failure - caller should handle.
|
||||||
Fail-fast: no retries to maintain consistency with approval flow.
|
|
||||||
"""
|
"""
|
||||||
try:
|
client = get_openai_client()
|
||||||
client = get_openai_client()
|
if not client:
|
||||||
if not client:
|
raise RuntimeError("openai_internal_api_key not set, cannot generate embedding")
|
||||||
logger.error("openai_internal_api_key not set, cannot generate embedding")
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Truncate text to token limit using tiktoken
|
# Truncate text to token limit using tiktoken
|
||||||
# Character-based truncation is insufficient because token ratios vary by content type
|
# Character-based truncation is insufficient because token ratios vary by content type
|
||||||
enc = encoding_for_model(EMBEDDING_MODEL)
|
enc = encoding_for_model(EMBEDDING_MODEL)
|
||||||
tokens = enc.encode(text)
|
tokens = enc.encode(text)
|
||||||
if len(tokens) > EMBEDDING_MAX_TOKENS:
|
if len(tokens) > EMBEDDING_MAX_TOKENS:
|
||||||
tokens = tokens[:EMBEDDING_MAX_TOKENS]
|
tokens = tokens[:EMBEDDING_MAX_TOKENS]
|
||||||
truncated_text = enc.decode(tokens)
|
truncated_text = enc.decode(tokens)
|
||||||
logger.info(
|
|
||||||
f"Truncated text from {len(enc.encode(text))} to {len(tokens)} tokens"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
truncated_text = text
|
|
||||||
|
|
||||||
start_time = time.time()
|
|
||||||
response = await client.embeddings.create(
|
|
||||||
model=EMBEDDING_MODEL,
|
|
||||||
input=truncated_text,
|
|
||||||
)
|
|
||||||
latency_ms = (time.time() - start_time) * 1000
|
|
||||||
|
|
||||||
embedding = response.data[0].embedding
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Generated embedding: {len(embedding)} dims, "
|
f"Truncated text from {len(enc.encode(text))} to {len(tokens)} tokens"
|
||||||
f"{len(tokens)} tokens, {latency_ms:.0f}ms"
|
|
||||||
)
|
)
|
||||||
return embedding
|
else:
|
||||||
|
truncated_text = text
|
||||||
|
|
||||||
except Exception as e:
|
start_time = time.time()
|
||||||
logger.error(f"Failed to generate embedding: {e}")
|
response = await client.embeddings.create(
|
||||||
return None
|
model=EMBEDDING_MODEL,
|
||||||
|
input=truncated_text,
|
||||||
|
)
|
||||||
|
latency_ms = (time.time() - start_time) * 1000
|
||||||
|
|
||||||
|
embedding = response.data[0].embedding
|
||||||
|
logger.info(
|
||||||
|
f"Generated embedding: {len(embedding)} dims, "
|
||||||
|
f"{len(tokens)} tokens, {latency_ms:.0f}ms"
|
||||||
|
)
|
||||||
|
return embedding
|
||||||
|
|
||||||
|
|
||||||
async def store_embedding(
|
async def store_embedding(
|
||||||
@@ -144,48 +136,45 @@ async def store_content_embedding(
|
|||||||
|
|
||||||
New function for unified content embedding storage.
|
New function for unified content embedding storage.
|
||||||
Uses raw SQL since Prisma doesn't natively support pgvector.
|
Uses raw SQL since Prisma doesn't natively support pgvector.
|
||||||
|
|
||||||
|
Raises exceptions on failure - caller should handle.
|
||||||
"""
|
"""
|
||||||
try:
|
client = tx if tx else prisma.get_client()
|
||||||
client = tx if tx else prisma.get_client()
|
|
||||||
|
|
||||||
# Convert embedding to PostgreSQL vector format
|
# Convert embedding to PostgreSQL vector format
|
||||||
embedding_str = embedding_to_vector_string(embedding)
|
embedding_str = embedding_to_vector_string(embedding)
|
||||||
metadata_json = dumps(metadata or {})
|
metadata_json = dumps(metadata or {})
|
||||||
|
|
||||||
# Upsert the embedding
|
# Upsert the embedding
|
||||||
# WHERE clause in DO UPDATE prevents PostgreSQL 15 bug with NULLS NOT DISTINCT
|
# WHERE clause in DO UPDATE prevents PostgreSQL 15 bug with NULLS NOT DISTINCT
|
||||||
# Use unqualified ::vector - pgvector is in search_path on all environments
|
# Use unqualified ::vector - pgvector is in search_path on all environments
|
||||||
await execute_raw_with_schema(
|
await execute_raw_with_schema(
|
||||||
"""
|
"""
|
||||||
INSERT INTO {schema_prefix}"UnifiedContentEmbedding" (
|
INSERT INTO {schema_prefix}"UnifiedContentEmbedding" (
|
||||||
"id", "contentType", "contentId", "userId", "embedding", "searchableText", "metadata", "createdAt", "updatedAt"
|
"id", "contentType", "contentId", "userId", "embedding", "searchableText", "metadata", "createdAt", "updatedAt"
|
||||||
)
|
|
||||||
VALUES (gen_random_uuid()::text, $1::{schema_prefix}"ContentType", $2, $3, $4::vector, $5, $6::jsonb, NOW(), NOW())
|
|
||||||
ON CONFLICT ("contentType", "contentId", "userId")
|
|
||||||
DO UPDATE SET
|
|
||||||
"embedding" = $4::vector,
|
|
||||||
"searchableText" = $5,
|
|
||||||
"metadata" = $6::jsonb,
|
|
||||||
"updatedAt" = NOW()
|
|
||||||
WHERE {schema_prefix}"UnifiedContentEmbedding"."contentType" = $1::{schema_prefix}"ContentType"
|
|
||||||
AND {schema_prefix}"UnifiedContentEmbedding"."contentId" = $2
|
|
||||||
AND ({schema_prefix}"UnifiedContentEmbedding"."userId" = $3 OR ($3 IS NULL AND {schema_prefix}"UnifiedContentEmbedding"."userId" IS NULL))
|
|
||||||
""",
|
|
||||||
content_type,
|
|
||||||
content_id,
|
|
||||||
user_id,
|
|
||||||
embedding_str,
|
|
||||||
searchable_text,
|
|
||||||
metadata_json,
|
|
||||||
client=client,
|
|
||||||
)
|
)
|
||||||
|
VALUES (gen_random_uuid()::text, $1::{schema_prefix}"ContentType", $2, $3, $4::vector, $5, $6::jsonb, NOW(), NOW())
|
||||||
|
ON CONFLICT ("contentType", "contentId", "userId")
|
||||||
|
DO UPDATE SET
|
||||||
|
"embedding" = $4::vector,
|
||||||
|
"searchableText" = $5,
|
||||||
|
"metadata" = $6::jsonb,
|
||||||
|
"updatedAt" = NOW()
|
||||||
|
WHERE {schema_prefix}"UnifiedContentEmbedding"."contentType" = $1::{schema_prefix}"ContentType"
|
||||||
|
AND {schema_prefix}"UnifiedContentEmbedding"."contentId" = $2
|
||||||
|
AND ({schema_prefix}"UnifiedContentEmbedding"."userId" = $3 OR ($3 IS NULL AND {schema_prefix}"UnifiedContentEmbedding"."userId" IS NULL))
|
||||||
|
""",
|
||||||
|
content_type,
|
||||||
|
content_id,
|
||||||
|
user_id,
|
||||||
|
embedding_str,
|
||||||
|
searchable_text,
|
||||||
|
metadata_json,
|
||||||
|
client=client,
|
||||||
|
)
|
||||||
|
|
||||||
logger.info(f"Stored embedding for {content_type}:{content_id}")
|
logger.info(f"Stored embedding for {content_type}:{content_id}")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to store embedding for {content_type}:{content_id}: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
async def get_embedding(version_id: str) -> dict[str, Any] | None:
|
async def get_embedding(version_id: str) -> dict[str, Any] | None:
|
||||||
@@ -217,34 +206,31 @@ async def get_content_embedding(
|
|||||||
|
|
||||||
New function for unified content embedding retrieval.
|
New function for unified content embedding retrieval.
|
||||||
Returns dict with contentType, contentId, embedding, timestamps or None if not found.
|
Returns dict with contentType, contentId, embedding, timestamps or None if not found.
|
||||||
|
|
||||||
|
Raises exceptions on failure - caller should handle.
|
||||||
"""
|
"""
|
||||||
try:
|
result = await query_raw_with_schema(
|
||||||
result = await query_raw_with_schema(
|
"""
|
||||||
"""
|
SELECT
|
||||||
SELECT
|
"contentType",
|
||||||
"contentType",
|
"contentId",
|
||||||
"contentId",
|
"userId",
|
||||||
"userId",
|
"embedding"::text as "embedding",
|
||||||
"embedding"::text as "embedding",
|
"searchableText",
|
||||||
"searchableText",
|
"metadata",
|
||||||
"metadata",
|
"createdAt",
|
||||||
"createdAt",
|
"updatedAt"
|
||||||
"updatedAt"
|
FROM {schema_prefix}"UnifiedContentEmbedding"
|
||||||
FROM {schema_prefix}"UnifiedContentEmbedding"
|
WHERE "contentType" = $1::{schema_prefix}"ContentType" AND "contentId" = $2 AND ("userId" = $3 OR ($3 IS NULL AND "userId" IS NULL))
|
||||||
WHERE "contentType" = $1::{schema_prefix}"ContentType" AND "contentId" = $2 AND ("userId" = $3 OR ($3 IS NULL AND "userId" IS NULL))
|
""",
|
||||||
""",
|
content_type,
|
||||||
content_type,
|
content_id,
|
||||||
content_id,
|
user_id,
|
||||||
user_id,
|
)
|
||||||
)
|
|
||||||
|
|
||||||
if result and len(result) > 0:
|
if result and len(result) > 0:
|
||||||
return result[0]
|
return result[0]
|
||||||
return None
|
return None
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to get embedding for {content_type}:{content_id}: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
async def ensure_embedding(
|
async def ensure_embedding(
|
||||||
@@ -272,46 +258,38 @@ async def ensure_embedding(
|
|||||||
tx: Optional transaction client
|
tx: Optional transaction client
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
True if embedding exists/was created, False on failure
|
True if embedding exists/was created
|
||||||
|
|
||||||
|
Raises exceptions on failure - caller should handle.
|
||||||
"""
|
"""
|
||||||
try:
|
# Check if embedding already exists
|
||||||
# Check if embedding already exists
|
if not force:
|
||||||
if not force:
|
existing = await get_embedding(version_id)
|
||||||
existing = await get_embedding(version_id)
|
if existing and existing.get("embedding"):
|
||||||
if existing and existing.get("embedding"):
|
logger.debug(f"Embedding for version {version_id} already exists")
|
||||||
logger.debug(f"Embedding for version {version_id} already exists")
|
return True
|
||||||
return True
|
|
||||||
|
|
||||||
# Build searchable text for embedding
|
# Build searchable text for embedding
|
||||||
searchable_text = build_searchable_text(
|
searchable_text = build_searchable_text(name, description, sub_heading, categories)
|
||||||
name, description, sub_heading, categories
|
|
||||||
)
|
|
||||||
|
|
||||||
# Generate new embedding
|
# Generate new embedding
|
||||||
embedding = await generate_embedding(searchable_text)
|
embedding = await generate_embedding(searchable_text)
|
||||||
if embedding is None:
|
|
||||||
logger.warning(f"Could not generate embedding for version {version_id}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Store the embedding with metadata using new function
|
# Store the embedding with metadata using new function
|
||||||
metadata = {
|
metadata = {
|
||||||
"name": name,
|
"name": name,
|
||||||
"subHeading": sub_heading,
|
"subHeading": sub_heading,
|
||||||
"categories": categories,
|
"categories": categories,
|
||||||
}
|
}
|
||||||
return await store_content_embedding(
|
return await store_content_embedding(
|
||||||
content_type=ContentType.STORE_AGENT,
|
content_type=ContentType.STORE_AGENT,
|
||||||
content_id=version_id,
|
content_id=version_id,
|
||||||
embedding=embedding,
|
embedding=embedding,
|
||||||
searchable_text=searchable_text,
|
searchable_text=searchable_text,
|
||||||
metadata=metadata,
|
metadata=metadata,
|
||||||
user_id=None, # Store agents are public
|
user_id=None, # Store agents are public
|
||||||
tx=tx,
|
tx=tx,
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to ensure embedding for version {version_id}: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
async def delete_embedding(version_id: str) -> bool:
|
async def delete_embedding(version_id: str) -> bool:
|
||||||
@@ -521,6 +499,24 @@ async def backfill_all_content_types(batch_size: int = 10) -> dict[str, Any]:
|
|||||||
success = sum(1 for result in results if result is True)
|
success = sum(1 for result in results if result is True)
|
||||||
failed = len(results) - success
|
failed = len(results) - success
|
||||||
|
|
||||||
|
# Aggregate unique errors to avoid Sentry spam
|
||||||
|
if failed > 0:
|
||||||
|
# Group errors by type and message
|
||||||
|
error_summary: dict[str, int] = {}
|
||||||
|
for result in results:
|
||||||
|
if isinstance(result, Exception):
|
||||||
|
error_key = f"{type(result).__name__}: {str(result)}"
|
||||||
|
error_summary[error_key] = error_summary.get(error_key, 0) + 1
|
||||||
|
|
||||||
|
# Log aggregated error summary
|
||||||
|
error_details = ", ".join(
|
||||||
|
f"{error} ({count}x)" for error, count in error_summary.items()
|
||||||
|
)
|
||||||
|
logger.error(
|
||||||
|
f"{content_type.value}: {failed}/{len(results)} embeddings failed. "
|
||||||
|
f"Errors: {error_details}"
|
||||||
|
)
|
||||||
|
|
||||||
results_by_type[content_type.value] = {
|
results_by_type[content_type.value] = {
|
||||||
"processed": len(missing_items),
|
"processed": len(missing_items),
|
||||||
"success": success,
|
"success": success,
|
||||||
@@ -557,11 +553,12 @@ async def backfill_all_content_types(batch_size: int = 10) -> dict[str, Any]:
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
async def embed_query(query: str) -> list[float] | None:
|
async def embed_query(query: str) -> list[float]:
|
||||||
"""
|
"""
|
||||||
Generate embedding for a search query.
|
Generate embedding for a search query.
|
||||||
|
|
||||||
Same as generate_embedding but with clearer intent.
|
Same as generate_embedding but with clearer intent.
|
||||||
|
Raises exceptions on failure - caller should handle.
|
||||||
"""
|
"""
|
||||||
return await generate_embedding(query)
|
return await generate_embedding(query)
|
||||||
|
|
||||||
@@ -594,40 +591,30 @@ async def ensure_content_embedding(
|
|||||||
tx: Optional transaction client
|
tx: Optional transaction client
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
True if embedding exists/was created, False on failure
|
True if embedding exists/was created
|
||||||
|
|
||||||
|
Raises exceptions on failure - caller should handle.
|
||||||
"""
|
"""
|
||||||
try:
|
# Check if embedding already exists
|
||||||
# Check if embedding already exists
|
if not force:
|
||||||
if not force:
|
existing = await get_content_embedding(content_type, content_id, user_id)
|
||||||
existing = await get_content_embedding(content_type, content_id, user_id)
|
if existing and existing.get("embedding"):
|
||||||
if existing and existing.get("embedding"):
|
logger.debug(f"Embedding for {content_type}:{content_id} already exists")
|
||||||
logger.debug(
|
return True
|
||||||
f"Embedding for {content_type}:{content_id} already exists"
|
|
||||||
)
|
|
||||||
return True
|
|
||||||
|
|
||||||
# Generate new embedding
|
# Generate new embedding
|
||||||
embedding = await generate_embedding(searchable_text)
|
embedding = await generate_embedding(searchable_text)
|
||||||
if embedding is None:
|
|
||||||
logger.warning(
|
|
||||||
f"Could not generate embedding for {content_type}:{content_id}"
|
|
||||||
)
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Store the embedding
|
# Store the embedding
|
||||||
return await store_content_embedding(
|
return await store_content_embedding(
|
||||||
content_type=content_type,
|
content_type=content_type,
|
||||||
content_id=content_id,
|
content_id=content_id,
|
||||||
embedding=embedding,
|
embedding=embedding,
|
||||||
searchable_text=searchable_text,
|
searchable_text=searchable_text,
|
||||||
metadata=metadata or {},
|
metadata=metadata or {},
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
tx=tx,
|
tx=tx,
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to ensure embedding for {content_type}:{content_id}: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
async def cleanup_orphaned_embeddings() -> dict[str, Any]:
|
async def cleanup_orphaned_embeddings() -> dict[str, Any]:
|
||||||
@@ -854,9 +841,8 @@ async def semantic_search(
|
|||||||
limit = 100
|
limit = 100
|
||||||
|
|
||||||
# Generate query embedding
|
# Generate query embedding
|
||||||
query_embedding = await embed_query(query)
|
try:
|
||||||
|
query_embedding = await embed_query(query)
|
||||||
if query_embedding is not None:
|
|
||||||
# Semantic search with embeddings
|
# Semantic search with embeddings
|
||||||
embedding_str = embedding_to_vector_string(query_embedding)
|
embedding_str = embedding_to_vector_string(query_embedding)
|
||||||
|
|
||||||
@@ -907,24 +893,21 @@ async def semantic_search(
|
|||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
results = await query_raw_with_schema(sql, *params)
|
||||||
results = await query_raw_with_schema(sql, *params)
|
return [
|
||||||
return [
|
{
|
||||||
{
|
"content_id": row["content_id"],
|
||||||
"content_id": row["content_id"],
|
"content_type": row["content_type"],
|
||||||
"content_type": row["content_type"],
|
"searchable_text": row["searchable_text"],
|
||||||
"searchable_text": row["searchable_text"],
|
"metadata": row["metadata"],
|
||||||
"metadata": row["metadata"],
|
"similarity": float(row["similarity"]),
|
||||||
"similarity": float(row["similarity"]),
|
}
|
||||||
}
|
for row in results
|
||||||
for row in results
|
]
|
||||||
]
|
except Exception as e:
|
||||||
except Exception as e:
|
logger.warning(f"Semantic search failed, falling back to lexical search: {e}")
|
||||||
logger.error(f"Semantic search failed: {e}")
|
|
||||||
# Fall through to lexical search below
|
|
||||||
|
|
||||||
# Fallback to lexical search if embeddings unavailable
|
# Fallback to lexical search if embeddings unavailable
|
||||||
logger.warning("Falling back to lexical search (embeddings unavailable)")
|
|
||||||
|
|
||||||
params_lexical: list[Any] = [limit]
|
params_lexical: list[Any] = [limit]
|
||||||
user_filter = ""
|
user_filter = ""
|
||||||
|
|||||||
@@ -298,17 +298,16 @@ async def test_schema_handling_error_cases():
|
|||||||
mock_client.execute_raw.side_effect = Exception("Database error")
|
mock_client.execute_raw.side_effect = Exception("Database error")
|
||||||
mock_get_client.return_value = mock_client
|
mock_get_client.return_value = mock_client
|
||||||
|
|
||||||
result = await embeddings.store_content_embedding(
|
# Should raise exception on error
|
||||||
content_type=ContentType.STORE_AGENT,
|
with pytest.raises(Exception, match="Database error"):
|
||||||
content_id="test-id",
|
await embeddings.store_content_embedding(
|
||||||
embedding=[0.1] * EMBEDDING_DIM,
|
content_type=ContentType.STORE_AGENT,
|
||||||
searchable_text="test",
|
content_id="test-id",
|
||||||
metadata=None,
|
embedding=[0.1] * EMBEDDING_DIM,
|
||||||
user_id=None,
|
searchable_text="test",
|
||||||
)
|
metadata=None,
|
||||||
|
user_id=None,
|
||||||
# Should return False on error, not raise
|
)
|
||||||
assert result is False
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -80,9 +80,8 @@ async def test_generate_embedding_no_api_key():
|
|||||||
) as mock_get_client:
|
) as mock_get_client:
|
||||||
mock_get_client.return_value = None
|
mock_get_client.return_value = None
|
||||||
|
|
||||||
result = await embeddings.generate_embedding("test text")
|
with pytest.raises(RuntimeError, match="openai_internal_api_key not set"):
|
||||||
|
await embeddings.generate_embedding("test text")
|
||||||
assert result is None
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
@@ -97,9 +96,8 @@ async def test_generate_embedding_api_error():
|
|||||||
) as mock_get_client:
|
) as mock_get_client:
|
||||||
mock_get_client.return_value = mock_client
|
mock_get_client.return_value = mock_client
|
||||||
|
|
||||||
result = await embeddings.generate_embedding("test text")
|
with pytest.raises(Exception, match="API Error"):
|
||||||
|
await embeddings.generate_embedding("test text")
|
||||||
assert result is None
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
@@ -173,11 +171,10 @@ async def test_store_embedding_database_error(mocker):
|
|||||||
|
|
||||||
embedding = [0.1, 0.2, 0.3]
|
embedding = [0.1, 0.2, 0.3]
|
||||||
|
|
||||||
result = await embeddings.store_embedding(
|
with pytest.raises(Exception, match="Database error"):
|
||||||
version_id="test-version-id", embedding=embedding, tx=mock_client
|
await embeddings.store_embedding(
|
||||||
)
|
version_id="test-version-id", embedding=embedding, tx=mock_client
|
||||||
|
)
|
||||||
assert result is False
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
@@ -277,17 +274,16 @@ async def test_ensure_embedding_create_new(mock_get, mock_store, mock_generate):
|
|||||||
async def test_ensure_embedding_generation_fails(mock_get, mock_generate):
|
async def test_ensure_embedding_generation_fails(mock_get, mock_generate):
|
||||||
"""Test ensure_embedding when generation fails."""
|
"""Test ensure_embedding when generation fails."""
|
||||||
mock_get.return_value = None
|
mock_get.return_value = None
|
||||||
mock_generate.return_value = None
|
mock_generate.side_effect = Exception("Generation failed")
|
||||||
|
|
||||||
result = await embeddings.ensure_embedding(
|
with pytest.raises(Exception, match="Generation failed"):
|
||||||
version_id="test-id",
|
await embeddings.ensure_embedding(
|
||||||
name="Test",
|
version_id="test-id",
|
||||||
description="Test description",
|
name="Test",
|
||||||
sub_heading="Test heading",
|
description="Test description",
|
||||||
categories=["test"],
|
sub_heading="Test heading",
|
||||||
)
|
categories=["test"],
|
||||||
|
)
|
||||||
assert result is False
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
|||||||
@@ -186,13 +186,12 @@ async def unified_hybrid_search(
|
|||||||
|
|
||||||
offset = (page - 1) * page_size
|
offset = (page - 1) * page_size
|
||||||
|
|
||||||
# Generate query embedding
|
# Generate query embedding with graceful degradation
|
||||||
query_embedding = await embed_query(query)
|
try:
|
||||||
|
query_embedding = await embed_query(query)
|
||||||
# Graceful degradation if embedding unavailable
|
except Exception as e:
|
||||||
if query_embedding is None or not query_embedding:
|
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Failed to generate query embedding - falling back to lexical-only search. "
|
f"Failed to generate query embedding - falling back to lexical-only search: {e}. "
|
||||||
"Check that openai_internal_api_key is configured and OpenAI API is accessible."
|
"Check that openai_internal_api_key is configured and OpenAI API is accessible."
|
||||||
)
|
)
|
||||||
query_embedding = [0.0] * EMBEDDING_DIM
|
query_embedding = [0.0] * EMBEDDING_DIM
|
||||||
@@ -464,13 +463,12 @@ async def hybrid_search(
|
|||||||
|
|
||||||
offset = (page - 1) * page_size
|
offset = (page - 1) * page_size
|
||||||
|
|
||||||
# Generate query embedding
|
# Generate query embedding with graceful degradation
|
||||||
query_embedding = await embed_query(query)
|
try:
|
||||||
|
query_embedding = await embed_query(query)
|
||||||
# Graceful degradation
|
except Exception as e:
|
||||||
if query_embedding is None or not query_embedding:
|
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Failed to generate query embedding - falling back to lexical-only search."
|
f"Failed to generate query embedding - falling back to lexical-only search: {e}"
|
||||||
)
|
)
|
||||||
query_embedding = [0.0] * EMBEDDING_DIM
|
query_embedding = [0.0] * EMBEDDING_DIM
|
||||||
total_non_semantic = (
|
total_non_semantic = (
|
||||||
|
|||||||
@@ -172,8 +172,8 @@ async def test_hybrid_search_without_embeddings():
|
|||||||
with patch(
|
with patch(
|
||||||
"backend.api.features.store.hybrid_search.query_raw_with_schema"
|
"backend.api.features.store.hybrid_search.query_raw_with_schema"
|
||||||
) as mock_query:
|
) as mock_query:
|
||||||
# Simulate embedding failure
|
# Simulate embedding failure by raising exception
|
||||||
mock_embed.return_value = None
|
mock_embed.side_effect = Exception("Embedding generation failed")
|
||||||
mock_query.return_value = mock_results
|
mock_query.return_value = mock_results
|
||||||
|
|
||||||
# Should NOT raise - graceful degradation
|
# Should NOT raise - graceful degradation
|
||||||
@@ -613,7 +613,9 @@ async def test_unified_hybrid_search_graceful_degradation():
|
|||||||
"backend.api.features.store.hybrid_search.embed_query"
|
"backend.api.features.store.hybrid_search.embed_query"
|
||||||
) as mock_embed:
|
) as mock_embed:
|
||||||
mock_query.return_value = mock_results
|
mock_query.return_value = mock_results
|
||||||
mock_embed.return_value = None # Embedding failure
|
mock_embed.side_effect = Exception(
|
||||||
|
"Embedding generation failed"
|
||||||
|
) # Embedding failure
|
||||||
|
|
||||||
# Should NOT raise - graceful degradation
|
# Should NOT raise - graceful degradation
|
||||||
results, total = await unified_hybrid_search(
|
results, total = await unified_hybrid_search(
|
||||||
|
|||||||
@@ -116,6 +116,7 @@ class PrintToConsoleBlock(Block):
|
|||||||
input_schema=PrintToConsoleBlock.Input,
|
input_schema=PrintToConsoleBlock.Input,
|
||||||
output_schema=PrintToConsoleBlock.Output,
|
output_schema=PrintToConsoleBlock.Output,
|
||||||
test_input={"text": "Hello, World!"},
|
test_input={"text": "Hello, World!"},
|
||||||
|
is_sensitive_action=True,
|
||||||
test_output=[
|
test_output=[
|
||||||
("output", "Hello, World!"),
|
("output", "Hello, World!"),
|
||||||
("status", "printed"),
|
("status", "printed"),
|
||||||
|
|||||||
659
autogpt_platform/backend/backend/blocks/claude_code.py
Normal file
@@ -0,0 +1,659 @@
|
|||||||
|
import json
|
||||||
|
import shlex
|
||||||
|
import uuid
|
||||||
|
from typing import Literal, Optional
|
||||||
|
|
||||||
|
from e2b import AsyncSandbox as BaseAsyncSandbox
|
||||||
|
from pydantic import BaseModel, SecretStr
|
||||||
|
|
||||||
|
from backend.data.block import (
|
||||||
|
Block,
|
||||||
|
BlockCategory,
|
||||||
|
BlockOutput,
|
||||||
|
BlockSchemaInput,
|
||||||
|
BlockSchemaOutput,
|
||||||
|
)
|
||||||
|
from backend.data.model import (
|
||||||
|
APIKeyCredentials,
|
||||||
|
CredentialsField,
|
||||||
|
CredentialsMetaInput,
|
||||||
|
SchemaField,
|
||||||
|
)
|
||||||
|
from backend.integrations.providers import ProviderName
|
||||||
|
|
||||||
|
|
||||||
|
class ClaudeCodeExecutionError(Exception):
|
||||||
|
"""Exception raised when Claude Code execution fails.
|
||||||
|
|
||||||
|
Carries the sandbox_id so it can be returned to the user for cleanup
|
||||||
|
when dispose_sandbox=False.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, message: str, sandbox_id: str = ""):
|
||||||
|
super().__init__(message)
|
||||||
|
self.sandbox_id = sandbox_id
|
||||||
|
|
||||||
|
|
||||||
|
# Test credentials for E2B
|
||||||
|
TEST_E2B_CREDENTIALS = APIKeyCredentials(
|
||||||
|
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||||
|
provider="e2b",
|
||||||
|
api_key=SecretStr("mock-e2b-api-key"),
|
||||||
|
title="Mock E2B API key",
|
||||||
|
expires_at=None,
|
||||||
|
)
|
||||||
|
TEST_E2B_CREDENTIALS_INPUT = {
|
||||||
|
"provider": TEST_E2B_CREDENTIALS.provider,
|
||||||
|
"id": TEST_E2B_CREDENTIALS.id,
|
||||||
|
"type": TEST_E2B_CREDENTIALS.type,
|
||||||
|
"title": TEST_E2B_CREDENTIALS.title,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Test credentials for Anthropic
|
||||||
|
TEST_ANTHROPIC_CREDENTIALS = APIKeyCredentials(
|
||||||
|
id="2e568a2b-b2ea-475a-8564-9a676bf31c56",
|
||||||
|
provider="anthropic",
|
||||||
|
api_key=SecretStr("mock-anthropic-api-key"),
|
||||||
|
title="Mock Anthropic API key",
|
||||||
|
expires_at=None,
|
||||||
|
)
|
||||||
|
TEST_ANTHROPIC_CREDENTIALS_INPUT = {
|
||||||
|
"provider": TEST_ANTHROPIC_CREDENTIALS.provider,
|
||||||
|
"id": TEST_ANTHROPIC_CREDENTIALS.id,
|
||||||
|
"type": TEST_ANTHROPIC_CREDENTIALS.type,
|
||||||
|
"title": TEST_ANTHROPIC_CREDENTIALS.title,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class ClaudeCodeBlock(Block):
|
||||||
|
"""
|
||||||
|
Execute tasks using Claude Code (Anthropic's AI coding assistant) in an E2B sandbox.
|
||||||
|
|
||||||
|
Claude Code can create files, install tools, run commands, and perform complex
|
||||||
|
coding tasks autonomously within a secure sandbox environment.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Use base template - we'll install Claude Code ourselves for latest version
|
||||||
|
DEFAULT_TEMPLATE = "base"
|
||||||
|
|
||||||
|
class Input(BlockSchemaInput):
|
||||||
|
e2b_credentials: CredentialsMetaInput[
|
||||||
|
Literal[ProviderName.E2B], Literal["api_key"]
|
||||||
|
] = CredentialsField(
|
||||||
|
description=(
|
||||||
|
"API key for the E2B platform to create the sandbox. "
|
||||||
|
"Get one on the [e2b website](https://e2b.dev/docs)"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
anthropic_credentials: CredentialsMetaInput[
|
||||||
|
Literal[ProviderName.ANTHROPIC], Literal["api_key"]
|
||||||
|
] = CredentialsField(
|
||||||
|
description=(
|
||||||
|
"API key for Anthropic to power Claude Code. "
|
||||||
|
"Get one at [Anthropic's website](https://console.anthropic.com)"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt: str = SchemaField(
|
||||||
|
description=(
|
||||||
|
"The task or instruction for Claude Code to execute. "
|
||||||
|
"Claude Code can create files, install packages, run commands, "
|
||||||
|
"and perform complex coding tasks."
|
||||||
|
),
|
||||||
|
placeholder="Create a hello world index.html file",
|
||||||
|
default="",
|
||||||
|
advanced=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
timeout: int = SchemaField(
|
||||||
|
description=(
|
||||||
|
"Sandbox timeout in seconds. Claude Code tasks can take "
|
||||||
|
"a while, so set this appropriately for your task complexity. "
|
||||||
|
"Note: This only applies when creating a new sandbox. "
|
||||||
|
"When reconnecting to an existing sandbox via sandbox_id, "
|
||||||
|
"the original timeout is retained."
|
||||||
|
),
|
||||||
|
default=300, # 5 minutes default
|
||||||
|
advanced=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
setup_commands: list[str] = SchemaField(
|
||||||
|
description=(
|
||||||
|
"Optional shell commands to run before executing Claude Code. "
|
||||||
|
"Useful for installing dependencies or setting up the environment."
|
||||||
|
),
|
||||||
|
default_factory=list,
|
||||||
|
advanced=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
working_directory: str = SchemaField(
|
||||||
|
description="Working directory for Claude Code to operate in.",
|
||||||
|
default="/home/user",
|
||||||
|
advanced=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Session/continuation support
|
||||||
|
session_id: str = SchemaField(
|
||||||
|
description=(
|
||||||
|
"Session ID to resume a previous conversation. "
|
||||||
|
"Leave empty for a new conversation. "
|
||||||
|
"Use the session_id from a previous run to continue that conversation."
|
||||||
|
),
|
||||||
|
default="",
|
||||||
|
advanced=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
sandbox_id: str = SchemaField(
|
||||||
|
description=(
|
||||||
|
"Sandbox ID to reconnect to an existing sandbox. "
|
||||||
|
"Required when resuming a session (along with session_id). "
|
||||||
|
"Use the sandbox_id from a previous run where dispose_sandbox was False."
|
||||||
|
),
|
||||||
|
default="",
|
||||||
|
advanced=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
conversation_history: str = SchemaField(
|
||||||
|
description=(
|
||||||
|
"Previous conversation history to continue from. "
|
||||||
|
"Use this to restore context on a fresh sandbox if the previous one timed out. "
|
||||||
|
"Pass the conversation_history output from a previous run."
|
||||||
|
),
|
||||||
|
default="",
|
||||||
|
advanced=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
dispose_sandbox: bool = SchemaField(
|
||||||
|
description=(
|
||||||
|
"Whether to dispose of the sandbox immediately after execution. "
|
||||||
|
"Set to False if you want to continue the conversation later "
|
||||||
|
"(you'll need both sandbox_id and session_id from the output)."
|
||||||
|
),
|
||||||
|
default=True,
|
||||||
|
advanced=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
class FileOutput(BaseModel):
|
||||||
|
"""A file extracted from the sandbox."""
|
||||||
|
|
||||||
|
path: str
|
||||||
|
relative_path: str # Path relative to working directory (for GitHub, etc.)
|
||||||
|
name: str
|
||||||
|
content: str
|
||||||
|
|
||||||
|
class Output(BlockSchemaOutput):
|
||||||
|
response: str = SchemaField(
|
||||||
|
description="The output/response from Claude Code execution"
|
||||||
|
)
|
||||||
|
files: list["ClaudeCodeBlock.FileOutput"] = SchemaField(
|
||||||
|
description=(
|
||||||
|
"List of text files created/modified by Claude Code during this execution. "
|
||||||
|
"Each file has 'path', 'relative_path', 'name', and 'content' fields."
|
||||||
|
)
|
||||||
|
)
|
||||||
|
conversation_history: str = SchemaField(
|
||||||
|
description=(
|
||||||
|
"Full conversation history including this turn. "
|
||||||
|
"Pass this to conversation_history input to continue on a fresh sandbox "
|
||||||
|
"if the previous sandbox timed out."
|
||||||
|
)
|
||||||
|
)
|
||||||
|
session_id: str = SchemaField(
|
||||||
|
description=(
|
||||||
|
"Session ID for this conversation. "
|
||||||
|
"Pass this back along with sandbox_id to continue the conversation."
|
||||||
|
)
|
||||||
|
)
|
||||||
|
sandbox_id: Optional[str] = SchemaField(
|
||||||
|
description=(
|
||||||
|
"ID of the sandbox instance. "
|
||||||
|
"Pass this back along with session_id to continue the conversation. "
|
||||||
|
"This is None if dispose_sandbox was True (sandbox was disposed)."
|
||||||
|
),
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
error: str = SchemaField(description="Error message if execution failed")
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
id="4e34f4a5-9b89-4326-ba77-2dd6750b7194",
|
||||||
|
description=(
|
||||||
|
"Execute tasks using Claude Code in an E2B sandbox. "
|
||||||
|
"Claude Code can create files, install tools, run commands, "
|
||||||
|
"and perform complex coding tasks autonomously."
|
||||||
|
),
|
||||||
|
categories={BlockCategory.DEVELOPER_TOOLS, BlockCategory.AI},
|
||||||
|
input_schema=ClaudeCodeBlock.Input,
|
||||||
|
output_schema=ClaudeCodeBlock.Output,
|
||||||
|
test_credentials={
|
||||||
|
"e2b_credentials": TEST_E2B_CREDENTIALS,
|
||||||
|
"anthropic_credentials": TEST_ANTHROPIC_CREDENTIALS,
|
||||||
|
},
|
||||||
|
test_input={
|
||||||
|
"e2b_credentials": TEST_E2B_CREDENTIALS_INPUT,
|
||||||
|
"anthropic_credentials": TEST_ANTHROPIC_CREDENTIALS_INPUT,
|
||||||
|
"prompt": "Create a hello world HTML file",
|
||||||
|
"timeout": 300,
|
||||||
|
"setup_commands": [],
|
||||||
|
"working_directory": "/home/user",
|
||||||
|
"session_id": "",
|
||||||
|
"sandbox_id": "",
|
||||||
|
"conversation_history": "",
|
||||||
|
"dispose_sandbox": True,
|
||||||
|
},
|
||||||
|
test_output=[
|
||||||
|
("response", "Created index.html with hello world content"),
|
||||||
|
(
|
||||||
|
"files",
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"path": "/home/user/index.html",
|
||||||
|
"relative_path": "index.html",
|
||||||
|
"name": "index.html",
|
||||||
|
"content": "<html>Hello World</html>",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"conversation_history",
|
||||||
|
"User: Create a hello world HTML file\n"
|
||||||
|
"Claude: Created index.html with hello world content",
|
||||||
|
),
|
||||||
|
("session_id", str),
|
||||||
|
("sandbox_id", None), # None because dispose_sandbox=True in test_input
|
||||||
|
],
|
||||||
|
test_mock={
|
||||||
|
"execute_claude_code": lambda *args, **kwargs: (
|
||||||
|
"Created index.html with hello world content", # response
|
||||||
|
[
|
||||||
|
ClaudeCodeBlock.FileOutput(
|
||||||
|
path="/home/user/index.html",
|
||||||
|
relative_path="index.html",
|
||||||
|
name="index.html",
|
||||||
|
content="<html>Hello World</html>",
|
||||||
|
)
|
||||||
|
], # files
|
||||||
|
"User: Create a hello world HTML file\n"
|
||||||
|
"Claude: Created index.html with hello world content", # conversation_history
|
||||||
|
"test-session-id", # session_id
|
||||||
|
"sandbox_id", # sandbox_id
|
||||||
|
),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def execute_claude_code(
|
||||||
|
self,
|
||||||
|
e2b_api_key: str,
|
||||||
|
anthropic_api_key: str,
|
||||||
|
prompt: str,
|
||||||
|
timeout: int,
|
||||||
|
setup_commands: list[str],
|
||||||
|
working_directory: str,
|
||||||
|
session_id: str,
|
||||||
|
existing_sandbox_id: str,
|
||||||
|
conversation_history: str,
|
||||||
|
dispose_sandbox: bool,
|
||||||
|
) -> tuple[str, list["ClaudeCodeBlock.FileOutput"], str, str, str]:
|
||||||
|
"""
|
||||||
|
Execute Claude Code in an E2B sandbox.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (response, files, conversation_history, session_id, sandbox_id)
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Validate that sandbox_id is provided when resuming a session
|
||||||
|
if session_id and not existing_sandbox_id:
|
||||||
|
raise ValueError(
|
||||||
|
"sandbox_id is required when resuming a session with session_id. "
|
||||||
|
"The session state is stored in the original sandbox. "
|
||||||
|
"If the sandbox has timed out, use conversation_history instead "
|
||||||
|
"to restore context on a fresh sandbox."
|
||||||
|
)
|
||||||
|
|
||||||
|
sandbox = None
|
||||||
|
sandbox_id = ""
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Either reconnect to existing sandbox or create a new one
|
||||||
|
if existing_sandbox_id:
|
||||||
|
# Reconnect to existing sandbox for conversation continuation
|
||||||
|
sandbox = await BaseAsyncSandbox.connect(
|
||||||
|
sandbox_id=existing_sandbox_id,
|
||||||
|
api_key=e2b_api_key,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Create new sandbox
|
||||||
|
sandbox = await BaseAsyncSandbox.create(
|
||||||
|
template=self.DEFAULT_TEMPLATE,
|
||||||
|
api_key=e2b_api_key,
|
||||||
|
timeout=timeout,
|
||||||
|
envs={"ANTHROPIC_API_KEY": anthropic_api_key},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Install Claude Code from npm (ensures we get the latest version)
|
||||||
|
install_result = await sandbox.commands.run(
|
||||||
|
"npm install -g @anthropic-ai/claude-code@latest",
|
||||||
|
timeout=120, # 2 min timeout for install
|
||||||
|
)
|
||||||
|
if install_result.exit_code != 0:
|
||||||
|
raise Exception(
|
||||||
|
f"Failed to install Claude Code: {install_result.stderr}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run any user-provided setup commands
|
||||||
|
for cmd in setup_commands:
|
||||||
|
setup_result = await sandbox.commands.run(cmd)
|
||||||
|
if setup_result.exit_code != 0:
|
||||||
|
raise Exception(
|
||||||
|
f"Setup command failed: {cmd}\n"
|
||||||
|
f"Exit code: {setup_result.exit_code}\n"
|
||||||
|
f"Stdout: {setup_result.stdout}\n"
|
||||||
|
f"Stderr: {setup_result.stderr}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Capture sandbox_id immediately after creation/connection
|
||||||
|
# so it's available for error recovery if dispose_sandbox=False
|
||||||
|
sandbox_id = sandbox.sandbox_id
|
||||||
|
|
||||||
|
# Generate or use provided session ID
|
||||||
|
current_session_id = session_id if session_id else str(uuid.uuid4())
|
||||||
|
|
||||||
|
# Build base Claude flags
|
||||||
|
base_flags = "-p --dangerously-skip-permissions --output-format json"
|
||||||
|
|
||||||
|
# Add conversation history context if provided (for fresh sandbox continuation)
|
||||||
|
history_flag = ""
|
||||||
|
if conversation_history and not session_id:
|
||||||
|
# Inject previous conversation as context via system prompt
|
||||||
|
# Use consistent escaping via _escape_prompt helper
|
||||||
|
escaped_history = self._escape_prompt(
|
||||||
|
f"Previous conversation context: {conversation_history}"
|
||||||
|
)
|
||||||
|
history_flag = f" --append-system-prompt {escaped_history}"
|
||||||
|
|
||||||
|
# Build Claude command based on whether we're resuming or starting new
|
||||||
|
# Use shlex.quote for working_directory and session IDs to prevent injection
|
||||||
|
safe_working_dir = shlex.quote(working_directory)
|
||||||
|
if session_id:
|
||||||
|
# Resuming existing session (sandbox still alive)
|
||||||
|
safe_session_id = shlex.quote(session_id)
|
||||||
|
claude_command = (
|
||||||
|
f"cd {safe_working_dir} && "
|
||||||
|
f"echo {self._escape_prompt(prompt)} | "
|
||||||
|
f"claude --resume {safe_session_id} {base_flags}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# New session with specific ID
|
||||||
|
safe_current_session_id = shlex.quote(current_session_id)
|
||||||
|
claude_command = (
|
||||||
|
f"cd {safe_working_dir} && "
|
||||||
|
f"echo {self._escape_prompt(prompt)} | "
|
||||||
|
f"claude --session-id {safe_current_session_id} {base_flags}{history_flag}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Capture timestamp before running Claude Code to filter files later
|
||||||
|
# Capture timestamp 1 second in the past to avoid race condition with file creation
|
||||||
|
timestamp_result = await sandbox.commands.run(
|
||||||
|
"date -u -d '1 second ago' +%Y-%m-%dT%H:%M:%S"
|
||||||
|
)
|
||||||
|
if timestamp_result.exit_code != 0:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Failed to capture timestamp: {timestamp_result.stderr}"
|
||||||
|
)
|
||||||
|
start_timestamp = (
|
||||||
|
timestamp_result.stdout.strip() if timestamp_result.stdout else None
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await sandbox.commands.run(
|
||||||
|
claude_command,
|
||||||
|
timeout=0, # No command timeout - let sandbox timeout handle it
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check for command failure
|
||||||
|
if result.exit_code != 0:
|
||||||
|
error_msg = result.stderr or result.stdout or "Unknown error"
|
||||||
|
raise Exception(
|
||||||
|
f"Claude Code command failed with exit code {result.exit_code}:\n"
|
||||||
|
f"{error_msg}"
|
||||||
|
)
|
||||||
|
|
||||||
|
raw_output = result.stdout or ""
|
||||||
|
|
||||||
|
# Parse JSON output to extract response and build conversation history
|
||||||
|
response = ""
|
||||||
|
new_conversation_history = conversation_history or ""
|
||||||
|
|
||||||
|
try:
|
||||||
|
# The JSON output contains the result
|
||||||
|
output_data = json.loads(raw_output)
|
||||||
|
response = output_data.get("result", raw_output)
|
||||||
|
|
||||||
|
# Build conversation history entry
|
||||||
|
turn_entry = f"User: {prompt}\nClaude: {response}"
|
||||||
|
if new_conversation_history:
|
||||||
|
new_conversation_history = (
|
||||||
|
f"{new_conversation_history}\n\n{turn_entry}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
new_conversation_history = turn_entry
|
||||||
|
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
# If not valid JSON, use raw output
|
||||||
|
response = raw_output
|
||||||
|
turn_entry = f"User: {prompt}\nClaude: {response}"
|
||||||
|
if new_conversation_history:
|
||||||
|
new_conversation_history = (
|
||||||
|
f"{new_conversation_history}\n\n{turn_entry}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
new_conversation_history = turn_entry
|
||||||
|
|
||||||
|
# Extract files created/modified during this run
|
||||||
|
files = await self._extract_files(
|
||||||
|
sandbox, working_directory, start_timestamp
|
||||||
|
)
|
||||||
|
|
||||||
|
return (
|
||||||
|
response,
|
||||||
|
files,
|
||||||
|
new_conversation_history,
|
||||||
|
current_session_id,
|
||||||
|
sandbox_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# Wrap exception with sandbox_id so caller can access/cleanup
|
||||||
|
# the preserved sandbox when dispose_sandbox=False
|
||||||
|
raise ClaudeCodeExecutionError(str(e), sandbox_id) from e
|
||||||
|
|
||||||
|
finally:
|
||||||
|
if dispose_sandbox and sandbox:
|
||||||
|
await sandbox.kill()
|
||||||
|
|
||||||
|
async def _extract_files(
|
||||||
|
self,
|
||||||
|
sandbox: BaseAsyncSandbox,
|
||||||
|
working_directory: str,
|
||||||
|
since_timestamp: str | None = None,
|
||||||
|
) -> list["ClaudeCodeBlock.FileOutput"]:
|
||||||
|
"""
|
||||||
|
Extract text files created/modified during this Claude Code execution.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sandbox: The E2B sandbox instance
|
||||||
|
working_directory: Directory to search for files
|
||||||
|
since_timestamp: ISO timestamp - only return files modified after this time
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of FileOutput objects with path, relative_path, name, and content
|
||||||
|
"""
|
||||||
|
files: list[ClaudeCodeBlock.FileOutput] = []
|
||||||
|
|
||||||
|
# Text file extensions we can safely read as text
|
||||||
|
text_extensions = {
|
||||||
|
".txt",
|
||||||
|
".md",
|
||||||
|
".html",
|
||||||
|
".htm",
|
||||||
|
".css",
|
||||||
|
".js",
|
||||||
|
".ts",
|
||||||
|
".jsx",
|
||||||
|
".tsx",
|
||||||
|
".json",
|
||||||
|
".xml",
|
||||||
|
".yaml",
|
||||||
|
".yml",
|
||||||
|
".toml",
|
||||||
|
".ini",
|
||||||
|
".cfg",
|
||||||
|
".conf",
|
||||||
|
".py",
|
||||||
|
".rb",
|
||||||
|
".php",
|
||||||
|
".java",
|
||||||
|
".c",
|
||||||
|
".cpp",
|
||||||
|
".h",
|
||||||
|
".hpp",
|
||||||
|
".cs",
|
||||||
|
".go",
|
||||||
|
".rs",
|
||||||
|
".swift",
|
||||||
|
".kt",
|
||||||
|
".scala",
|
||||||
|
".sh",
|
||||||
|
".bash",
|
||||||
|
".zsh",
|
||||||
|
".sql",
|
||||||
|
".graphql",
|
||||||
|
".env",
|
||||||
|
".gitignore",
|
||||||
|
".dockerfile",
|
||||||
|
"Dockerfile",
|
||||||
|
".vue",
|
||||||
|
".svelte",
|
||||||
|
".astro",
|
||||||
|
".mdx",
|
||||||
|
".rst",
|
||||||
|
".tex",
|
||||||
|
".csv",
|
||||||
|
".log",
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
# List files recursively using find command
|
||||||
|
# Exclude node_modules and .git directories, but allow hidden files
|
||||||
|
# like .env and .gitignore (they're filtered by text_extensions later)
|
||||||
|
# Filter by timestamp to only get files created/modified during this run
|
||||||
|
safe_working_dir = shlex.quote(working_directory)
|
||||||
|
timestamp_filter = ""
|
||||||
|
if since_timestamp:
|
||||||
|
timestamp_filter = f"-newermt {shlex.quote(since_timestamp)} "
|
||||||
|
find_result = await sandbox.commands.run(
|
||||||
|
f"find {safe_working_dir} -type f "
|
||||||
|
f"{timestamp_filter}"
|
||||||
|
f"-not -path '*/node_modules/*' "
|
||||||
|
f"-not -path '*/.git/*' "
|
||||||
|
f"2>/dev/null"
|
||||||
|
)
|
||||||
|
|
||||||
|
if find_result.stdout:
|
||||||
|
for file_path in find_result.stdout.strip().split("\n"):
|
||||||
|
if not file_path:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check if it's a text file we can read
|
||||||
|
is_text = any(
|
||||||
|
file_path.endswith(ext) for ext in text_extensions
|
||||||
|
) or file_path.endswith("Dockerfile")
|
||||||
|
|
||||||
|
if is_text:
|
||||||
|
try:
|
||||||
|
content = await sandbox.files.read(file_path)
|
||||||
|
# Handle bytes or string
|
||||||
|
if isinstance(content, bytes):
|
||||||
|
content = content.decode("utf-8", errors="replace")
|
||||||
|
|
||||||
|
# Extract filename from path
|
||||||
|
file_name = file_path.split("/")[-1]
|
||||||
|
|
||||||
|
# Calculate relative path by stripping working directory
|
||||||
|
relative_path = file_path
|
||||||
|
if file_path.startswith(working_directory):
|
||||||
|
relative_path = file_path[len(working_directory) :]
|
||||||
|
# Remove leading slash if present
|
||||||
|
if relative_path.startswith("/"):
|
||||||
|
relative_path = relative_path[1:]
|
||||||
|
|
||||||
|
files.append(
|
||||||
|
ClaudeCodeBlock.FileOutput(
|
||||||
|
path=file_path,
|
||||||
|
relative_path=relative_path,
|
||||||
|
name=file_name,
|
||||||
|
content=content,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
# Skip files that can't be read
|
||||||
|
pass
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
# If file extraction fails, return empty results
|
||||||
|
pass
|
||||||
|
|
||||||
|
return files
|
||||||
|
|
||||||
|
def _escape_prompt(self, prompt: str) -> str:
|
||||||
|
"""Escape the prompt for safe shell execution."""
|
||||||
|
# Use single quotes and escape any single quotes in the prompt
|
||||||
|
escaped = prompt.replace("'", "'\"'\"'")
|
||||||
|
return f"'{escaped}'"
|
||||||
|
|
||||||
|
async def run(
|
||||||
|
self,
|
||||||
|
input_data: Input,
|
||||||
|
*,
|
||||||
|
e2b_credentials: APIKeyCredentials,
|
||||||
|
anthropic_credentials: APIKeyCredentials,
|
||||||
|
**kwargs,
|
||||||
|
) -> BlockOutput:
|
||||||
|
try:
|
||||||
|
(
|
||||||
|
response,
|
||||||
|
files,
|
||||||
|
conversation_history,
|
||||||
|
session_id,
|
||||||
|
sandbox_id,
|
||||||
|
) = await self.execute_claude_code(
|
||||||
|
e2b_api_key=e2b_credentials.api_key.get_secret_value(),
|
||||||
|
anthropic_api_key=anthropic_credentials.api_key.get_secret_value(),
|
||||||
|
prompt=input_data.prompt,
|
||||||
|
timeout=input_data.timeout,
|
||||||
|
setup_commands=input_data.setup_commands,
|
||||||
|
working_directory=input_data.working_directory,
|
||||||
|
session_id=input_data.session_id,
|
||||||
|
existing_sandbox_id=input_data.sandbox_id,
|
||||||
|
conversation_history=input_data.conversation_history,
|
||||||
|
dispose_sandbox=input_data.dispose_sandbox,
|
||||||
|
)
|
||||||
|
|
||||||
|
yield "response", response
|
||||||
|
# Always yield files (empty list if none) to match Output schema
|
||||||
|
yield "files", [f.model_dump() for f in files]
|
||||||
|
# Always yield conversation_history so user can restore context on fresh sandbox
|
||||||
|
yield "conversation_history", conversation_history
|
||||||
|
# Always yield session_id so user can continue conversation
|
||||||
|
yield "session_id", session_id
|
||||||
|
# Always yield sandbox_id (None if disposed) to match Output schema
|
||||||
|
yield "sandbox_id", sandbox_id if not input_data.dispose_sandbox else None
|
||||||
|
|
||||||
|
except ClaudeCodeExecutionError as e:
|
||||||
|
yield "error", str(e)
|
||||||
|
# If sandbox was preserved (dispose_sandbox=False), yield sandbox_id
|
||||||
|
# so user can reconnect to or clean up the orphaned sandbox
|
||||||
|
if not input_data.dispose_sandbox and e.sandbox_id:
|
||||||
|
yield "sandbox_id", e.sandbox_id
|
||||||
|
except Exception as e:
|
||||||
|
yield "error", str(e)
|
||||||
@@ -9,7 +9,7 @@ from typing import Any, Optional
|
|||||||
from prisma.enums import ReviewStatus
|
from prisma.enums import ReviewStatus
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from backend.data.execution import ExecutionContext, ExecutionStatus
|
from backend.data.execution import ExecutionStatus
|
||||||
from backend.data.human_review import ReviewResult
|
from backend.data.human_review import ReviewResult
|
||||||
from backend.executor.manager import async_update_node_execution_status
|
from backend.executor.manager import async_update_node_execution_status
|
||||||
from backend.util.clients import get_database_manager_async_client
|
from backend.util.clients import get_database_manager_async_client
|
||||||
@@ -28,6 +28,11 @@ class ReviewDecision(BaseModel):
|
|||||||
class HITLReviewHelper:
|
class HITLReviewHelper:
|
||||||
"""Helper class for Human-In-The-Loop review operations."""
|
"""Helper class for Human-In-The-Loop review operations."""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def check_approval(**kwargs) -> Optional[ReviewResult]:
|
||||||
|
"""Check if there's an existing approval for this node execution."""
|
||||||
|
return await get_database_manager_async_client().check_approval(**kwargs)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def get_or_create_human_review(**kwargs) -> Optional[ReviewResult]:
|
async def get_or_create_human_review(**kwargs) -> Optional[ReviewResult]:
|
||||||
"""Create or retrieve a human review from the database."""
|
"""Create or retrieve a human review from the database."""
|
||||||
@@ -55,11 +60,11 @@ class HITLReviewHelper:
|
|||||||
async def _handle_review_request(
|
async def _handle_review_request(
|
||||||
input_data: Any,
|
input_data: Any,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
|
node_id: str,
|
||||||
node_exec_id: str,
|
node_exec_id: str,
|
||||||
graph_exec_id: str,
|
graph_exec_id: str,
|
||||||
graph_id: str,
|
graph_id: str,
|
||||||
graph_version: int,
|
graph_version: int,
|
||||||
execution_context: ExecutionContext,
|
|
||||||
block_name: str = "Block",
|
block_name: str = "Block",
|
||||||
editable: bool = False,
|
editable: bool = False,
|
||||||
) -> Optional[ReviewResult]:
|
) -> Optional[ReviewResult]:
|
||||||
@@ -69,11 +74,11 @@ class HITLReviewHelper:
|
|||||||
Args:
|
Args:
|
||||||
input_data: The input data to be reviewed
|
input_data: The input data to be reviewed
|
||||||
user_id: ID of the user requesting the review
|
user_id: ID of the user requesting the review
|
||||||
|
node_id: ID of the node in the graph definition
|
||||||
node_exec_id: ID of the node execution
|
node_exec_id: ID of the node execution
|
||||||
graph_exec_id: ID of the graph execution
|
graph_exec_id: ID of the graph execution
|
||||||
graph_id: ID of the graph
|
graph_id: ID of the graph
|
||||||
graph_version: Version of the graph
|
graph_version: Version of the graph
|
||||||
execution_context: Current execution context
|
|
||||||
block_name: Name of the block requesting review
|
block_name: Name of the block requesting review
|
||||||
editable: Whether the reviewer can edit the data
|
editable: Whether the reviewer can edit the data
|
||||||
|
|
||||||
@@ -83,15 +88,41 @@ class HITLReviewHelper:
|
|||||||
Raises:
|
Raises:
|
||||||
Exception: If review creation or status update fails
|
Exception: If review creation or status update fails
|
||||||
"""
|
"""
|
||||||
# Skip review if safe mode is disabled - return auto-approved result
|
# Note: Safe mode checks (human_in_the_loop_safe_mode, sensitive_action_safe_mode)
|
||||||
if not execution_context.human_in_the_loop_safe_mode:
|
# are handled by the caller:
|
||||||
|
# - HITL blocks check human_in_the_loop_safe_mode in their run() method
|
||||||
|
# - Sensitive action blocks check sensitive_action_safe_mode in is_block_exec_need_review()
|
||||||
|
# This function only handles checking for existing approvals.
|
||||||
|
|
||||||
|
# Check if this node has already been approved (normal or auto-approval)
|
||||||
|
if approval_result := await HITLReviewHelper.check_approval(
|
||||||
|
node_exec_id=node_exec_id,
|
||||||
|
graph_exec_id=graph_exec_id,
|
||||||
|
node_id=node_id,
|
||||||
|
user_id=user_id,
|
||||||
|
input_data=input_data,
|
||||||
|
):
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Block {block_name} skipping review for node {node_exec_id} - safe mode disabled"
|
f"Block {block_name} skipping review for node {node_exec_id} - "
|
||||||
|
f"found existing approval"
|
||||||
|
)
|
||||||
|
# Return a new ReviewResult with the current node_exec_id but approved status
|
||||||
|
# For auto-approvals, always use current input_data
|
||||||
|
# For normal approvals, use approval_result.data unless it's None
|
||||||
|
is_auto_approval = approval_result.node_exec_id != node_exec_id
|
||||||
|
approved_data = (
|
||||||
|
input_data
|
||||||
|
if is_auto_approval
|
||||||
|
else (
|
||||||
|
approval_result.data
|
||||||
|
if approval_result.data is not None
|
||||||
|
else input_data
|
||||||
|
)
|
||||||
)
|
)
|
||||||
return ReviewResult(
|
return ReviewResult(
|
||||||
data=input_data,
|
data=approved_data,
|
||||||
status=ReviewStatus.APPROVED,
|
status=ReviewStatus.APPROVED,
|
||||||
message="Auto-approved (safe mode disabled)",
|
message=approval_result.message,
|
||||||
processed=True,
|
processed=True,
|
||||||
node_exec_id=node_exec_id,
|
node_exec_id=node_exec_id,
|
||||||
)
|
)
|
||||||
@@ -103,7 +134,7 @@ class HITLReviewHelper:
|
|||||||
graph_id=graph_id,
|
graph_id=graph_id,
|
||||||
graph_version=graph_version,
|
graph_version=graph_version,
|
||||||
input_data=input_data,
|
input_data=input_data,
|
||||||
message=f"Review required for {block_name} execution",
|
message=block_name, # Use block_name directly as the message
|
||||||
editable=editable,
|
editable=editable,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -129,11 +160,11 @@ class HITLReviewHelper:
|
|||||||
async def handle_review_decision(
|
async def handle_review_decision(
|
||||||
input_data: Any,
|
input_data: Any,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
|
node_id: str,
|
||||||
node_exec_id: str,
|
node_exec_id: str,
|
||||||
graph_exec_id: str,
|
graph_exec_id: str,
|
||||||
graph_id: str,
|
graph_id: str,
|
||||||
graph_version: int,
|
graph_version: int,
|
||||||
execution_context: ExecutionContext,
|
|
||||||
block_name: str = "Block",
|
block_name: str = "Block",
|
||||||
editable: bool = False,
|
editable: bool = False,
|
||||||
) -> Optional[ReviewDecision]:
|
) -> Optional[ReviewDecision]:
|
||||||
@@ -143,11 +174,11 @@ class HITLReviewHelper:
|
|||||||
Args:
|
Args:
|
||||||
input_data: The input data to be reviewed
|
input_data: The input data to be reviewed
|
||||||
user_id: ID of the user requesting the review
|
user_id: ID of the user requesting the review
|
||||||
|
node_id: ID of the node in the graph definition
|
||||||
node_exec_id: ID of the node execution
|
node_exec_id: ID of the node execution
|
||||||
graph_exec_id: ID of the graph execution
|
graph_exec_id: ID of the graph execution
|
||||||
graph_id: ID of the graph
|
graph_id: ID of the graph
|
||||||
graph_version: Version of the graph
|
graph_version: Version of the graph
|
||||||
execution_context: Current execution context
|
|
||||||
block_name: Name of the block requesting review
|
block_name: Name of the block requesting review
|
||||||
editable: Whether the reviewer can edit the data
|
editable: Whether the reviewer can edit the data
|
||||||
|
|
||||||
@@ -158,11 +189,11 @@ class HITLReviewHelper:
|
|||||||
review_result = await HITLReviewHelper._handle_review_request(
|
review_result = await HITLReviewHelper._handle_review_request(
|
||||||
input_data=input_data,
|
input_data=input_data,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
|
node_id=node_id,
|
||||||
node_exec_id=node_exec_id,
|
node_exec_id=node_exec_id,
|
||||||
graph_exec_id=graph_exec_id,
|
graph_exec_id=graph_exec_id,
|
||||||
graph_id=graph_id,
|
graph_id=graph_id,
|
||||||
graph_version=graph_version,
|
graph_version=graph_version,
|
||||||
execution_context=execution_context,
|
|
||||||
block_name=block_name,
|
block_name=block_name,
|
||||||
editable=editable,
|
editable=editable,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -97,6 +97,7 @@ class HumanInTheLoopBlock(Block):
|
|||||||
input_data: Input,
|
input_data: Input,
|
||||||
*,
|
*,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
|
node_id: str,
|
||||||
node_exec_id: str,
|
node_exec_id: str,
|
||||||
graph_exec_id: str,
|
graph_exec_id: str,
|
||||||
graph_id: str,
|
graph_id: str,
|
||||||
@@ -115,12 +116,12 @@ class HumanInTheLoopBlock(Block):
|
|||||||
decision = await self.handle_review_decision(
|
decision = await self.handle_review_decision(
|
||||||
input_data=input_data.data,
|
input_data=input_data.data,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
|
node_id=node_id,
|
||||||
node_exec_id=node_exec_id,
|
node_exec_id=node_exec_id,
|
||||||
graph_exec_id=graph_exec_id,
|
graph_exec_id=graph_exec_id,
|
||||||
graph_id=graph_id,
|
graph_id=graph_id,
|
||||||
graph_version=graph_version,
|
graph_version=graph_version,
|
||||||
execution_context=execution_context,
|
block_name=input_data.name, # Use user-provided name instead of block type
|
||||||
block_name=self.name,
|
|
||||||
editable=input_data.editable,
|
editable=input_data.editable,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import pytest
|
import pytest_asyncio
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
from backend.util.logging import configure_logging
|
from backend.util.logging import configure_logging
|
||||||
@@ -19,7 +19,7 @@ if not os.getenv("PRISMA_DEBUG"):
|
|||||||
prisma_logger.setLevel(logging.INFO)
|
prisma_logger.setLevel(logging.INFO)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest_asyncio.fixture(scope="session", loop_scope="session")
|
||||||
async def server():
|
async def server():
|
||||||
from backend.util.test import SpinTestServer
|
from backend.util.test import SpinTestServer
|
||||||
|
|
||||||
@@ -27,7 +27,7 @@ async def server():
|
|||||||
yield server
|
yield server
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session", autouse=True)
|
@pytest_asyncio.fixture(scope="session", loop_scope="session", autouse=True)
|
||||||
async def graph_cleanup(server):
|
async def graph_cleanup(server):
|
||||||
created_graph_ids = []
|
created_graph_ids = []
|
||||||
original_create_graph = server.agent_server.test_create_graph
|
original_create_graph = server.agent_server.test_create_graph
|
||||||
|
|||||||
@@ -441,6 +441,7 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
|||||||
static_output: bool = False,
|
static_output: bool = False,
|
||||||
block_type: BlockType = BlockType.STANDARD,
|
block_type: BlockType = BlockType.STANDARD,
|
||||||
webhook_config: Optional[BlockWebhookConfig | BlockManualWebhookConfig] = None,
|
webhook_config: Optional[BlockWebhookConfig | BlockManualWebhookConfig] = None,
|
||||||
|
is_sensitive_action: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize the block with the given schema.
|
Initialize the block with the given schema.
|
||||||
@@ -473,8 +474,8 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
|||||||
self.static_output = static_output
|
self.static_output = static_output
|
||||||
self.block_type = block_type
|
self.block_type = block_type
|
||||||
self.webhook_config = webhook_config
|
self.webhook_config = webhook_config
|
||||||
|
self.is_sensitive_action = is_sensitive_action
|
||||||
self.execution_stats: NodeExecutionStats = NodeExecutionStats()
|
self.execution_stats: NodeExecutionStats = NodeExecutionStats()
|
||||||
self.is_sensitive_action: bool = False
|
|
||||||
|
|
||||||
if self.webhook_config:
|
if self.webhook_config:
|
||||||
if isinstance(self.webhook_config, BlockWebhookConfig):
|
if isinstance(self.webhook_config, BlockWebhookConfig):
|
||||||
@@ -622,6 +623,7 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
|||||||
input_data: BlockInput,
|
input_data: BlockInput,
|
||||||
*,
|
*,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
|
node_id: str,
|
||||||
node_exec_id: str,
|
node_exec_id: str,
|
||||||
graph_exec_id: str,
|
graph_exec_id: str,
|
||||||
graph_id: str,
|
graph_id: str,
|
||||||
@@ -648,11 +650,11 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
|||||||
decision = await HITLReviewHelper.handle_review_decision(
|
decision = await HITLReviewHelper.handle_review_decision(
|
||||||
input_data=input_data,
|
input_data=input_data,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
|
node_id=node_id,
|
||||||
node_exec_id=node_exec_id,
|
node_exec_id=node_exec_id,
|
||||||
graph_exec_id=graph_exec_id,
|
graph_exec_id=graph_exec_id,
|
||||||
graph_id=graph_id,
|
graph_id=graph_id,
|
||||||
graph_version=graph_version,
|
graph_version=graph_version,
|
||||||
execution_context=execution_context,
|
|
||||||
block_name=self.name,
|
block_name=self.name,
|
||||||
editable=True,
|
editable=True,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -6,10 +6,10 @@ Handles all database operations for pending human reviews.
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import Optional
|
from typing import TYPE_CHECKING, Optional
|
||||||
|
|
||||||
from prisma.enums import ReviewStatus
|
from prisma.enums import ReviewStatus
|
||||||
from prisma.models import PendingHumanReview
|
from prisma.models import AgentNodeExecution, PendingHumanReview
|
||||||
from prisma.types import PendingHumanReviewUpdateInput
|
from prisma.types import PendingHumanReviewUpdateInput
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
@@ -17,8 +17,12 @@ from backend.api.features.executions.review.model import (
|
|||||||
PendingHumanReviewModel,
|
PendingHumanReviewModel,
|
||||||
SafeJsonData,
|
SafeJsonData,
|
||||||
)
|
)
|
||||||
|
from backend.data.execution import get_graph_execution_meta
|
||||||
from backend.util.json import SafeJson
|
from backend.util.json import SafeJson
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
pass
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@@ -32,6 +36,125 @@ class ReviewResult(BaseModel):
|
|||||||
node_exec_id: str
|
node_exec_id: str
|
||||||
|
|
||||||
|
|
||||||
|
def get_auto_approve_key(graph_exec_id: str, node_id: str) -> str:
|
||||||
|
"""Generate the special nodeExecId key for auto-approval records."""
|
||||||
|
return f"auto_approve_{graph_exec_id}_{node_id}"
|
||||||
|
|
||||||
|
|
||||||
|
async def check_approval(
|
||||||
|
node_exec_id: str,
|
||||||
|
graph_exec_id: str,
|
||||||
|
node_id: str,
|
||||||
|
user_id: str,
|
||||||
|
input_data: SafeJsonData | None = None,
|
||||||
|
) -> Optional[ReviewResult]:
|
||||||
|
"""
|
||||||
|
Check if there's an existing approval for this node execution.
|
||||||
|
|
||||||
|
Checks both:
|
||||||
|
1. Normal approval by node_exec_id (previous run of the same node execution)
|
||||||
|
2. Auto-approval by special key pattern "auto_approve_{graph_exec_id}_{node_id}"
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node_exec_id: ID of the node execution
|
||||||
|
graph_exec_id: ID of the graph execution
|
||||||
|
node_id: ID of the node definition (not execution)
|
||||||
|
user_id: ID of the user (for data isolation)
|
||||||
|
input_data: Current input data (used for auto-approvals to avoid stale data)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ReviewResult if approval found (either normal or auto), None otherwise
|
||||||
|
"""
|
||||||
|
auto_approve_key = get_auto_approve_key(graph_exec_id, node_id)
|
||||||
|
|
||||||
|
# Check for either normal approval or auto-approval in a single query
|
||||||
|
existing_review = await PendingHumanReview.prisma().find_first(
|
||||||
|
where={
|
||||||
|
"OR": [
|
||||||
|
{"nodeExecId": node_exec_id},
|
||||||
|
{"nodeExecId": auto_approve_key},
|
||||||
|
],
|
||||||
|
"status": ReviewStatus.APPROVED,
|
||||||
|
"userId": user_id,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
if existing_review:
|
||||||
|
is_auto_approval = existing_review.nodeExecId == auto_approve_key
|
||||||
|
logger.info(
|
||||||
|
f"Found {'auto-' if is_auto_approval else ''}approval for node {node_id} "
|
||||||
|
f"(exec: {node_exec_id}) in execution {graph_exec_id}"
|
||||||
|
)
|
||||||
|
# For auto-approvals, use current input_data to avoid replaying stale payload
|
||||||
|
# For normal approvals, use the stored payload (which may have been edited)
|
||||||
|
return ReviewResult(
|
||||||
|
data=(
|
||||||
|
input_data
|
||||||
|
if is_auto_approval and input_data is not None
|
||||||
|
else existing_review.payload
|
||||||
|
),
|
||||||
|
status=ReviewStatus.APPROVED,
|
||||||
|
message=(
|
||||||
|
"Auto-approved (user approved all future actions for this node)"
|
||||||
|
if is_auto_approval
|
||||||
|
else existing_review.reviewMessage or ""
|
||||||
|
),
|
||||||
|
processed=True,
|
||||||
|
node_exec_id=existing_review.nodeExecId,
|
||||||
|
)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def create_auto_approval_record(
|
||||||
|
user_id: str,
|
||||||
|
graph_exec_id: str,
|
||||||
|
graph_id: str,
|
||||||
|
graph_version: int,
|
||||||
|
node_id: str,
|
||||||
|
payload: SafeJsonData,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Create an auto-approval record for a node in this execution.
|
||||||
|
|
||||||
|
This is stored as a PendingHumanReview with a special nodeExecId pattern
|
||||||
|
and status=APPROVED, so future executions of the same node can skip review.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the graph execution doesn't belong to the user
|
||||||
|
"""
|
||||||
|
# Validate that the graph execution belongs to this user (defense in depth)
|
||||||
|
graph_exec = await get_graph_execution_meta(
|
||||||
|
user_id=user_id, execution_id=graph_exec_id
|
||||||
|
)
|
||||||
|
if not graph_exec:
|
||||||
|
raise ValueError(
|
||||||
|
f"Graph execution {graph_exec_id} not found or doesn't belong to user {user_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
auto_approve_key = get_auto_approve_key(graph_exec_id, node_id)
|
||||||
|
|
||||||
|
await PendingHumanReview.prisma().upsert(
|
||||||
|
where={"nodeExecId": auto_approve_key},
|
||||||
|
data={
|
||||||
|
"create": {
|
||||||
|
"nodeExecId": auto_approve_key,
|
||||||
|
"userId": user_id,
|
||||||
|
"graphExecId": graph_exec_id,
|
||||||
|
"graphId": graph_id,
|
||||||
|
"graphVersion": graph_version,
|
||||||
|
"payload": SafeJson(payload),
|
||||||
|
"instructions": "Auto-approval record",
|
||||||
|
"editable": False,
|
||||||
|
"status": ReviewStatus.APPROVED,
|
||||||
|
"processed": True,
|
||||||
|
"reviewedAt": datetime.now(timezone.utc),
|
||||||
|
},
|
||||||
|
"update": {}, # Already exists, no update needed
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def get_or_create_human_review(
|
async def get_or_create_human_review(
|
||||||
user_id: str,
|
user_id: str,
|
||||||
node_exec_id: str,
|
node_exec_id: str,
|
||||||
@@ -108,6 +231,87 @@ async def get_or_create_human_review(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_pending_review_by_node_exec_id(
|
||||||
|
node_exec_id: str, user_id: str
|
||||||
|
) -> Optional["PendingHumanReviewModel"]:
|
||||||
|
"""
|
||||||
|
Get a pending review by its node execution ID.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node_exec_id: The node execution ID to look up
|
||||||
|
user_id: User ID for authorization (only returns if review belongs to this user)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The pending review if found and belongs to user, None otherwise
|
||||||
|
"""
|
||||||
|
review = await PendingHumanReview.prisma().find_first(
|
||||||
|
where={
|
||||||
|
"nodeExecId": node_exec_id,
|
||||||
|
"userId": user_id,
|
||||||
|
"status": ReviewStatus.WAITING,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
if not review:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Local import to avoid event loop conflicts in tests
|
||||||
|
from backend.data.execution import get_node_execution
|
||||||
|
|
||||||
|
node_exec = await get_node_execution(review.nodeExecId)
|
||||||
|
node_id = node_exec.node_id if node_exec else review.nodeExecId
|
||||||
|
return PendingHumanReviewModel.from_db(review, node_id=node_id)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_pending_reviews_by_node_exec_ids(
|
||||||
|
node_exec_ids: list[str], user_id: str
|
||||||
|
) -> dict[str, "PendingHumanReviewModel"]:
|
||||||
|
"""
|
||||||
|
Get multiple pending reviews by their node execution IDs in a single batch query.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node_exec_ids: List of node execution IDs to look up
|
||||||
|
user_id: User ID for authorization (only returns reviews belonging to this user)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary mapping node_exec_id -> PendingHumanReviewModel for found reviews
|
||||||
|
"""
|
||||||
|
if not node_exec_ids:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
reviews = await PendingHumanReview.prisma().find_many(
|
||||||
|
where={
|
||||||
|
"nodeExecId": {"in": node_exec_ids},
|
||||||
|
"userId": user_id,
|
||||||
|
"status": ReviewStatus.WAITING,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
if not reviews:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
# Batch fetch all node executions to avoid N+1 queries
|
||||||
|
node_exec_ids_to_fetch = [review.nodeExecId for review in reviews]
|
||||||
|
node_execs = await AgentNodeExecution.prisma().find_many(
|
||||||
|
where={"id": {"in": node_exec_ids_to_fetch}},
|
||||||
|
include={"Node": True},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create mapping from node_exec_id to node_id
|
||||||
|
node_exec_id_to_node_id = {
|
||||||
|
node_exec.id: node_exec.agentNodeId for node_exec in node_execs
|
||||||
|
}
|
||||||
|
|
||||||
|
result = {}
|
||||||
|
for review in reviews:
|
||||||
|
node_id = node_exec_id_to_node_id.get(review.nodeExecId, review.nodeExecId)
|
||||||
|
result[review.nodeExecId] = PendingHumanReviewModel.from_db(
|
||||||
|
review, node_id=node_id
|
||||||
|
)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
async def has_pending_reviews_for_graph_exec(graph_exec_id: str) -> bool:
|
async def has_pending_reviews_for_graph_exec(graph_exec_id: str) -> bool:
|
||||||
"""
|
"""
|
||||||
Check if a graph execution has any pending reviews.
|
Check if a graph execution has any pending reviews.
|
||||||
@@ -137,8 +341,11 @@ async def get_pending_reviews_for_user(
|
|||||||
page_size: Number of reviews per page
|
page_size: Number of reviews per page
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of pending review models
|
List of pending review models with node_id included
|
||||||
"""
|
"""
|
||||||
|
# Local import to avoid event loop conflicts in tests
|
||||||
|
from backend.data.execution import get_node_execution
|
||||||
|
|
||||||
# Calculate offset for pagination
|
# Calculate offset for pagination
|
||||||
offset = (page - 1) * page_size
|
offset = (page - 1) * page_size
|
||||||
|
|
||||||
@@ -149,7 +356,14 @@ async def get_pending_reviews_for_user(
|
|||||||
take=page_size,
|
take=page_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
return [PendingHumanReviewModel.from_db(review) for review in reviews]
|
# Fetch node_id for each review from NodeExecution
|
||||||
|
result = []
|
||||||
|
for review in reviews:
|
||||||
|
node_exec = await get_node_execution(review.nodeExecId)
|
||||||
|
node_id = node_exec.node_id if node_exec else review.nodeExecId
|
||||||
|
result.append(PendingHumanReviewModel.from_db(review, node_id=node_id))
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
async def get_pending_reviews_for_execution(
|
async def get_pending_reviews_for_execution(
|
||||||
@@ -163,8 +377,11 @@ async def get_pending_reviews_for_execution(
|
|||||||
user_id: User ID for security validation
|
user_id: User ID for security validation
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of pending review models
|
List of pending review models with node_id included
|
||||||
"""
|
"""
|
||||||
|
# Local import to avoid event loop conflicts in tests
|
||||||
|
from backend.data.execution import get_node_execution
|
||||||
|
|
||||||
reviews = await PendingHumanReview.prisma().find_many(
|
reviews = await PendingHumanReview.prisma().find_many(
|
||||||
where={
|
where={
|
||||||
"userId": user_id,
|
"userId": user_id,
|
||||||
@@ -174,7 +391,14 @@ async def get_pending_reviews_for_execution(
|
|||||||
order={"createdAt": "asc"},
|
order={"createdAt": "asc"},
|
||||||
)
|
)
|
||||||
|
|
||||||
return [PendingHumanReviewModel.from_db(review) for review in reviews]
|
# Fetch node_id for each review from NodeExecution
|
||||||
|
result = []
|
||||||
|
for review in reviews:
|
||||||
|
node_exec = await get_node_execution(review.nodeExecId)
|
||||||
|
node_id = node_exec.node_id if node_exec else review.nodeExecId
|
||||||
|
result.append(PendingHumanReviewModel.from_db(review, node_id=node_id))
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
async def process_all_reviews_for_execution(
|
async def process_all_reviews_for_execution(
|
||||||
@@ -244,11 +468,19 @@ async def process_all_reviews_for_execution(
|
|||||||
# Note: Execution resumption is now handled at the API layer after ALL reviews
|
# Note: Execution resumption is now handled at the API layer after ALL reviews
|
||||||
# for an execution are processed (both approved and rejected)
|
# for an execution are processed (both approved and rejected)
|
||||||
|
|
||||||
# Return as dict for easy access
|
# Fetch node_id for each review and return as dict for easy access
|
||||||
return {
|
# Local import to avoid event loop conflicts in tests
|
||||||
review.nodeExecId: PendingHumanReviewModel.from_db(review)
|
from backend.data.execution import get_node_execution
|
||||||
for review in updated_reviews
|
|
||||||
}
|
result = {}
|
||||||
|
for review in updated_reviews:
|
||||||
|
node_exec = await get_node_execution(review.nodeExecId)
|
||||||
|
node_id = node_exec.node_id if node_exec else review.nodeExecId
|
||||||
|
result[review.nodeExecId] = PendingHumanReviewModel.from_db(
|
||||||
|
review, node_id=node_id
|
||||||
|
)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
async def update_review_processed_status(node_exec_id: str, processed: bool) -> None:
|
async def update_review_processed_status(node_exec_id: str, processed: bool) -> None:
|
||||||
@@ -256,3 +488,44 @@ async def update_review_processed_status(node_exec_id: str, processed: bool) ->
|
|||||||
await PendingHumanReview.prisma().update(
|
await PendingHumanReview.prisma().update(
|
||||||
where={"nodeExecId": node_exec_id}, data={"processed": processed}
|
where={"nodeExecId": node_exec_id}, data={"processed": processed}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def cancel_pending_reviews_for_execution(graph_exec_id: str, user_id: str) -> int:
|
||||||
|
"""
|
||||||
|
Cancel all pending reviews for a graph execution (e.g., when execution is stopped).
|
||||||
|
|
||||||
|
Marks all WAITING reviews as REJECTED with a message indicating the execution was stopped.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
graph_exec_id: The graph execution ID
|
||||||
|
user_id: User ID who owns the execution (for security validation)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of reviews cancelled
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the graph execution doesn't belong to the user
|
||||||
|
"""
|
||||||
|
# Validate user ownership before cancelling reviews
|
||||||
|
graph_exec = await get_graph_execution_meta(
|
||||||
|
user_id=user_id, execution_id=graph_exec_id
|
||||||
|
)
|
||||||
|
if not graph_exec:
|
||||||
|
raise ValueError(
|
||||||
|
f"Graph execution {graph_exec_id} not found or doesn't belong to user {user_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await PendingHumanReview.prisma().update_many(
|
||||||
|
where={
|
||||||
|
"graphExecId": graph_exec_id,
|
||||||
|
"userId": user_id,
|
||||||
|
"status": ReviewStatus.WAITING,
|
||||||
|
},
|
||||||
|
data={
|
||||||
|
"status": ReviewStatus.REJECTED,
|
||||||
|
"reviewMessage": "Execution was stopped by user",
|
||||||
|
"processed": True,
|
||||||
|
"reviewedAt": datetime.now(timezone.utc),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|||||||
@@ -36,7 +36,7 @@ def sample_db_review():
|
|||||||
return mock_review
|
return mock_review
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio(loop_scope="function")
|
||||||
async def test_get_or_create_human_review_new(
|
async def test_get_or_create_human_review_new(
|
||||||
mocker: pytest_mock.MockFixture,
|
mocker: pytest_mock.MockFixture,
|
||||||
sample_db_review,
|
sample_db_review,
|
||||||
@@ -46,8 +46,8 @@ async def test_get_or_create_human_review_new(
|
|||||||
sample_db_review.status = ReviewStatus.WAITING
|
sample_db_review.status = ReviewStatus.WAITING
|
||||||
sample_db_review.processed = False
|
sample_db_review.processed = False
|
||||||
|
|
||||||
mock_upsert = mocker.patch("backend.data.human_review.PendingHumanReview.prisma")
|
mock_prisma = mocker.patch("backend.data.human_review.PendingHumanReview.prisma")
|
||||||
mock_upsert.return_value.upsert = AsyncMock(return_value=sample_db_review)
|
mock_prisma.return_value.upsert = AsyncMock(return_value=sample_db_review)
|
||||||
|
|
||||||
result = await get_or_create_human_review(
|
result = await get_or_create_human_review(
|
||||||
user_id="test-user-123",
|
user_id="test-user-123",
|
||||||
@@ -64,7 +64,7 @@ async def test_get_or_create_human_review_new(
|
|||||||
assert result is None
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio(loop_scope="function")
|
||||||
async def test_get_or_create_human_review_approved(
|
async def test_get_or_create_human_review_approved(
|
||||||
mocker: pytest_mock.MockFixture,
|
mocker: pytest_mock.MockFixture,
|
||||||
sample_db_review,
|
sample_db_review,
|
||||||
@@ -75,8 +75,8 @@ async def test_get_or_create_human_review_approved(
|
|||||||
sample_db_review.processed = False
|
sample_db_review.processed = False
|
||||||
sample_db_review.reviewMessage = "Looks good"
|
sample_db_review.reviewMessage = "Looks good"
|
||||||
|
|
||||||
mock_upsert = mocker.patch("backend.data.human_review.PendingHumanReview.prisma")
|
mock_prisma = mocker.patch("backend.data.human_review.PendingHumanReview.prisma")
|
||||||
mock_upsert.return_value.upsert = AsyncMock(return_value=sample_db_review)
|
mock_prisma.return_value.upsert = AsyncMock(return_value=sample_db_review)
|
||||||
|
|
||||||
result = await get_or_create_human_review(
|
result = await get_or_create_human_review(
|
||||||
user_id="test-user-123",
|
user_id="test-user-123",
|
||||||
@@ -96,7 +96,7 @@ async def test_get_or_create_human_review_approved(
|
|||||||
assert result.message == "Looks good"
|
assert result.message == "Looks good"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio(loop_scope="function")
|
||||||
async def test_has_pending_reviews_for_graph_exec_true(
|
async def test_has_pending_reviews_for_graph_exec_true(
|
||||||
mocker: pytest_mock.MockFixture,
|
mocker: pytest_mock.MockFixture,
|
||||||
):
|
):
|
||||||
@@ -109,7 +109,7 @@ async def test_has_pending_reviews_for_graph_exec_true(
|
|||||||
assert result is True
|
assert result is True
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio(loop_scope="function")
|
||||||
async def test_has_pending_reviews_for_graph_exec_false(
|
async def test_has_pending_reviews_for_graph_exec_false(
|
||||||
mocker: pytest_mock.MockFixture,
|
mocker: pytest_mock.MockFixture,
|
||||||
):
|
):
|
||||||
@@ -122,7 +122,7 @@ async def test_has_pending_reviews_for_graph_exec_false(
|
|||||||
assert result is False
|
assert result is False
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio(loop_scope="function")
|
||||||
async def test_get_pending_reviews_for_user(
|
async def test_get_pending_reviews_for_user(
|
||||||
mocker: pytest_mock.MockFixture,
|
mocker: pytest_mock.MockFixture,
|
||||||
sample_db_review,
|
sample_db_review,
|
||||||
@@ -131,10 +131,19 @@ async def test_get_pending_reviews_for_user(
|
|||||||
mock_find_many = mocker.patch("backend.data.human_review.PendingHumanReview.prisma")
|
mock_find_many = mocker.patch("backend.data.human_review.PendingHumanReview.prisma")
|
||||||
mock_find_many.return_value.find_many = AsyncMock(return_value=[sample_db_review])
|
mock_find_many.return_value.find_many = AsyncMock(return_value=[sample_db_review])
|
||||||
|
|
||||||
|
# Mock get_node_execution to return node with node_id (async function)
|
||||||
|
mock_node_exec = Mock()
|
||||||
|
mock_node_exec.node_id = "test_node_def_789"
|
||||||
|
mocker.patch(
|
||||||
|
"backend.data.execution.get_node_execution",
|
||||||
|
new=AsyncMock(return_value=mock_node_exec),
|
||||||
|
)
|
||||||
|
|
||||||
result = await get_pending_reviews_for_user("test_user", page=2, page_size=10)
|
result = await get_pending_reviews_for_user("test_user", page=2, page_size=10)
|
||||||
|
|
||||||
assert len(result) == 1
|
assert len(result) == 1
|
||||||
assert result[0].node_exec_id == "test_node_123"
|
assert result[0].node_exec_id == "test_node_123"
|
||||||
|
assert result[0].node_id == "test_node_def_789"
|
||||||
|
|
||||||
# Verify pagination parameters
|
# Verify pagination parameters
|
||||||
call_args = mock_find_many.return_value.find_many.call_args
|
call_args = mock_find_many.return_value.find_many.call_args
|
||||||
@@ -142,7 +151,7 @@ async def test_get_pending_reviews_for_user(
|
|||||||
assert call_args.kwargs["take"] == 10
|
assert call_args.kwargs["take"] == 10
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio(loop_scope="function")
|
||||||
async def test_get_pending_reviews_for_execution(
|
async def test_get_pending_reviews_for_execution(
|
||||||
mocker: pytest_mock.MockFixture,
|
mocker: pytest_mock.MockFixture,
|
||||||
sample_db_review,
|
sample_db_review,
|
||||||
@@ -151,12 +160,21 @@ async def test_get_pending_reviews_for_execution(
|
|||||||
mock_find_many = mocker.patch("backend.data.human_review.PendingHumanReview.prisma")
|
mock_find_many = mocker.patch("backend.data.human_review.PendingHumanReview.prisma")
|
||||||
mock_find_many.return_value.find_many = AsyncMock(return_value=[sample_db_review])
|
mock_find_many.return_value.find_many = AsyncMock(return_value=[sample_db_review])
|
||||||
|
|
||||||
|
# Mock get_node_execution to return node with node_id (async function)
|
||||||
|
mock_node_exec = Mock()
|
||||||
|
mock_node_exec.node_id = "test_node_def_789"
|
||||||
|
mocker.patch(
|
||||||
|
"backend.data.execution.get_node_execution",
|
||||||
|
new=AsyncMock(return_value=mock_node_exec),
|
||||||
|
)
|
||||||
|
|
||||||
result = await get_pending_reviews_for_execution(
|
result = await get_pending_reviews_for_execution(
|
||||||
"test_graph_exec_456", "test-user-123"
|
"test_graph_exec_456", "test-user-123"
|
||||||
)
|
)
|
||||||
|
|
||||||
assert len(result) == 1
|
assert len(result) == 1
|
||||||
assert result[0].graph_exec_id == "test_graph_exec_456"
|
assert result[0].graph_exec_id == "test_graph_exec_456"
|
||||||
|
assert result[0].node_id == "test_node_def_789"
|
||||||
|
|
||||||
# Verify it filters by execution and user
|
# Verify it filters by execution and user
|
||||||
call_args = mock_find_many.return_value.find_many.call_args
|
call_args = mock_find_many.return_value.find_many.call_args
|
||||||
@@ -166,7 +184,7 @@ async def test_get_pending_reviews_for_execution(
|
|||||||
assert where_clause["status"] == ReviewStatus.WAITING
|
assert where_clause["status"] == ReviewStatus.WAITING
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio(loop_scope="function")
|
||||||
async def test_process_all_reviews_for_execution_success(
|
async def test_process_all_reviews_for_execution_success(
|
||||||
mocker: pytest_mock.MockFixture,
|
mocker: pytest_mock.MockFixture,
|
||||||
sample_db_review,
|
sample_db_review,
|
||||||
@@ -201,6 +219,14 @@ async def test_process_all_reviews_for_execution_success(
|
|||||||
new=AsyncMock(return_value=[updated_review]),
|
new=AsyncMock(return_value=[updated_review]),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Mock get_node_execution to return node with node_id (async function)
|
||||||
|
mock_node_exec = Mock()
|
||||||
|
mock_node_exec.node_id = "test_node_def_789"
|
||||||
|
mocker.patch(
|
||||||
|
"backend.data.execution.get_node_execution",
|
||||||
|
new=AsyncMock(return_value=mock_node_exec),
|
||||||
|
)
|
||||||
|
|
||||||
result = await process_all_reviews_for_execution(
|
result = await process_all_reviews_for_execution(
|
||||||
user_id="test-user-123",
|
user_id="test-user-123",
|
||||||
review_decisions={
|
review_decisions={
|
||||||
@@ -211,9 +237,10 @@ async def test_process_all_reviews_for_execution_success(
|
|||||||
assert len(result) == 1
|
assert len(result) == 1
|
||||||
assert "test_node_123" in result
|
assert "test_node_123" in result
|
||||||
assert result["test_node_123"].status == ReviewStatus.APPROVED
|
assert result["test_node_123"].status == ReviewStatus.APPROVED
|
||||||
|
assert result["test_node_123"].node_id == "test_node_def_789"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio(loop_scope="function")
|
||||||
async def test_process_all_reviews_for_execution_validation_errors(
|
async def test_process_all_reviews_for_execution_validation_errors(
|
||||||
mocker: pytest_mock.MockFixture,
|
mocker: pytest_mock.MockFixture,
|
||||||
):
|
):
|
||||||
@@ -233,7 +260,7 @@ async def test_process_all_reviews_for_execution_validation_errors(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio(loop_scope="function")
|
||||||
async def test_process_all_reviews_edit_permission_error(
|
async def test_process_all_reviews_edit_permission_error(
|
||||||
mocker: pytest_mock.MockFixture,
|
mocker: pytest_mock.MockFixture,
|
||||||
sample_db_review,
|
sample_db_review,
|
||||||
@@ -259,7 +286,7 @@ async def test_process_all_reviews_edit_permission_error(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio(loop_scope="function")
|
||||||
async def test_process_all_reviews_mixed_approval_rejection(
|
async def test_process_all_reviews_mixed_approval_rejection(
|
||||||
mocker: pytest_mock.MockFixture,
|
mocker: pytest_mock.MockFixture,
|
||||||
sample_db_review,
|
sample_db_review,
|
||||||
@@ -329,6 +356,14 @@ async def test_process_all_reviews_mixed_approval_rejection(
|
|||||||
new=AsyncMock(return_value=[approved_review, rejected_review]),
|
new=AsyncMock(return_value=[approved_review, rejected_review]),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Mock get_node_execution to return node with node_id (async function)
|
||||||
|
mock_node_exec = Mock()
|
||||||
|
mock_node_exec.node_id = "test_node_def_789"
|
||||||
|
mocker.patch(
|
||||||
|
"backend.data.execution.get_node_execution",
|
||||||
|
new=AsyncMock(return_value=mock_node_exec),
|
||||||
|
)
|
||||||
|
|
||||||
result = await process_all_reviews_for_execution(
|
result = await process_all_reviews_for_execution(
|
||||||
user_id="test-user-123",
|
user_id="test-user-123",
|
||||||
review_decisions={
|
review_decisions={
|
||||||
@@ -340,3 +375,5 @@ async def test_process_all_reviews_mixed_approval_rejection(
|
|||||||
assert len(result) == 2
|
assert len(result) == 2
|
||||||
assert "test_node_123" in result
|
assert "test_node_123" in result
|
||||||
assert "test_node_456" in result
|
assert "test_node_456" in result
|
||||||
|
assert result["test_node_123"].node_id == "test_node_def_789"
|
||||||
|
assert result["test_node_456"].node_id == "test_node_def_789"
|
||||||
|
|||||||
@@ -50,6 +50,8 @@ from backend.data.graph import (
|
|||||||
validate_graph_execution_permissions,
|
validate_graph_execution_permissions,
|
||||||
)
|
)
|
||||||
from backend.data.human_review import (
|
from backend.data.human_review import (
|
||||||
|
cancel_pending_reviews_for_execution,
|
||||||
|
check_approval,
|
||||||
get_or_create_human_review,
|
get_or_create_human_review,
|
||||||
has_pending_reviews_for_graph_exec,
|
has_pending_reviews_for_graph_exec,
|
||||||
update_review_processed_status,
|
update_review_processed_status,
|
||||||
@@ -190,6 +192,8 @@ class DatabaseManager(AppService):
|
|||||||
get_user_notification_preference = _(get_user_notification_preference)
|
get_user_notification_preference = _(get_user_notification_preference)
|
||||||
|
|
||||||
# Human In The Loop
|
# Human In The Loop
|
||||||
|
cancel_pending_reviews_for_execution = _(cancel_pending_reviews_for_execution)
|
||||||
|
check_approval = _(check_approval)
|
||||||
get_or_create_human_review = _(get_or_create_human_review)
|
get_or_create_human_review = _(get_or_create_human_review)
|
||||||
has_pending_reviews_for_graph_exec = _(has_pending_reviews_for_graph_exec)
|
has_pending_reviews_for_graph_exec = _(has_pending_reviews_for_graph_exec)
|
||||||
update_review_processed_status = _(update_review_processed_status)
|
update_review_processed_status = _(update_review_processed_status)
|
||||||
@@ -313,6 +317,8 @@ class DatabaseManagerAsyncClient(AppServiceClient):
|
|||||||
set_execution_kv_data = d.set_execution_kv_data
|
set_execution_kv_data = d.set_execution_kv_data
|
||||||
|
|
||||||
# Human In The Loop
|
# Human In The Loop
|
||||||
|
cancel_pending_reviews_for_execution = d.cancel_pending_reviews_for_execution
|
||||||
|
check_approval = d.check_approval
|
||||||
get_or_create_human_review = d.get_or_create_human_review
|
get_or_create_human_review = d.get_or_create_human_review
|
||||||
update_review_processed_status = d.update_review_processed_status
|
update_review_processed_status = d.update_review_processed_status
|
||||||
|
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ from pydantic import BaseModel, JsonValue, ValidationError
|
|||||||
|
|
||||||
from backend.data import execution as execution_db
|
from backend.data import execution as execution_db
|
||||||
from backend.data import graph as graph_db
|
from backend.data import graph as graph_db
|
||||||
|
from backend.data import human_review as human_review_db
|
||||||
from backend.data import onboarding as onboarding_db
|
from backend.data import onboarding as onboarding_db
|
||||||
from backend.data import user as user_db
|
from backend.data import user as user_db
|
||||||
from backend.data.block import (
|
from backend.data.block import (
|
||||||
@@ -749,9 +750,27 @@ async def stop_graph_execution(
|
|||||||
if graph_exec.status in [
|
if graph_exec.status in [
|
||||||
ExecutionStatus.QUEUED,
|
ExecutionStatus.QUEUED,
|
||||||
ExecutionStatus.INCOMPLETE,
|
ExecutionStatus.INCOMPLETE,
|
||||||
|
ExecutionStatus.REVIEW,
|
||||||
]:
|
]:
|
||||||
# If the graph is still on the queue, we can prevent them from being executed
|
# If the graph is queued/incomplete/paused for review, terminate immediately
|
||||||
# by setting the status to TERMINATED.
|
# No need to wait for executor since it's not actively running
|
||||||
|
|
||||||
|
# If graph is in REVIEW status, clean up pending reviews before terminating
|
||||||
|
if graph_exec.status == ExecutionStatus.REVIEW:
|
||||||
|
# Use human_review_db if Prisma connected, else database manager
|
||||||
|
review_db = (
|
||||||
|
human_review_db
|
||||||
|
if prisma.is_connected()
|
||||||
|
else get_database_manager_async_client()
|
||||||
|
)
|
||||||
|
# Mark all pending reviews as rejected/cancelled
|
||||||
|
cancelled_count = await review_db.cancel_pending_reviews_for_execution(
|
||||||
|
graph_exec_id, user_id
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"Cancelled {cancelled_count} pending review(s) for stopped execution {graph_exec_id}"
|
||||||
|
)
|
||||||
|
|
||||||
graph_exec.status = ExecutionStatus.TERMINATED
|
graph_exec.status = ExecutionStatus.TERMINATED
|
||||||
|
|
||||||
await asyncio.gather(
|
await asyncio.gather(
|
||||||
@@ -887,9 +906,28 @@ async def add_graph_execution(
|
|||||||
nodes_to_skip=nodes_to_skip,
|
nodes_to_skip=nodes_to_skip,
|
||||||
execution_context=execution_context,
|
execution_context=execution_context,
|
||||||
)
|
)
|
||||||
logger.info(f"Publishing execution {graph_exec.id} to execution queue")
|
logger.info(f"Queueing execution {graph_exec.id}")
|
||||||
|
|
||||||
|
# Update execution status to QUEUED BEFORE publishing to prevent race condition
|
||||||
|
# where two concurrent requests could both publish the same execution
|
||||||
|
updated_exec = await edb.update_graph_execution_stats(
|
||||||
|
graph_exec_id=graph_exec.id,
|
||||||
|
status=ExecutionStatus.QUEUED,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the status update succeeded (prevents duplicate queueing in race conditions)
|
||||||
|
# If another request already updated the status, this execution will not be QUEUED
|
||||||
|
if not updated_exec or updated_exec.status != ExecutionStatus.QUEUED:
|
||||||
|
logger.warning(
|
||||||
|
f"Skipping queue publish for execution {graph_exec.id} - "
|
||||||
|
f"status update failed or execution already queued by another request"
|
||||||
|
)
|
||||||
|
return graph_exec
|
||||||
|
|
||||||
|
graph_exec.status = ExecutionStatus.QUEUED
|
||||||
|
|
||||||
# Publish to execution queue for executor to pick up
|
# Publish to execution queue for executor to pick up
|
||||||
|
# This happens AFTER status update to ensure only one request publishes
|
||||||
exec_queue = await get_async_execution_queue()
|
exec_queue = await get_async_execution_queue()
|
||||||
await exec_queue.publish_message(
|
await exec_queue.publish_message(
|
||||||
routing_key=GRAPH_EXECUTION_ROUTING_KEY,
|
routing_key=GRAPH_EXECUTION_ROUTING_KEY,
|
||||||
@@ -897,13 +935,6 @@ async def add_graph_execution(
|
|||||||
exchange=GRAPH_EXECUTION_EXCHANGE,
|
exchange=GRAPH_EXECUTION_EXCHANGE,
|
||||||
)
|
)
|
||||||
logger.info(f"Published execution {graph_exec.id} to RabbitMQ queue")
|
logger.info(f"Published execution {graph_exec.id} to RabbitMQ queue")
|
||||||
|
|
||||||
# Update execution status to QUEUED
|
|
||||||
graph_exec.status = ExecutionStatus.QUEUED
|
|
||||||
await edb.update_graph_execution_stats(
|
|
||||||
graph_exec_id=graph_exec.id,
|
|
||||||
status=graph_exec.status,
|
|
||||||
)
|
|
||||||
except BaseException as e:
|
except BaseException as e:
|
||||||
err = str(e) or type(e).__name__
|
err = str(e) or type(e).__name__
|
||||||
if not graph_exec:
|
if not graph_exec:
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import pytest
|
|||||||
from pytest_mock import MockerFixture
|
from pytest_mock import MockerFixture
|
||||||
|
|
||||||
from backend.data.dynamic_fields import merge_execution_input, parse_execution_output
|
from backend.data.dynamic_fields import merge_execution_input, parse_execution_output
|
||||||
|
from backend.data.execution import ExecutionStatus
|
||||||
from backend.util.mock import MockObject
|
from backend.util.mock import MockObject
|
||||||
|
|
||||||
|
|
||||||
@@ -346,6 +347,7 @@ async def test_add_graph_execution_is_repeatable(mocker: MockerFixture):
|
|||||||
mock_graph_exec = mocker.MagicMock(spec=GraphExecutionWithNodes)
|
mock_graph_exec = mocker.MagicMock(spec=GraphExecutionWithNodes)
|
||||||
mock_graph_exec.id = "execution-id-123"
|
mock_graph_exec.id = "execution-id-123"
|
||||||
mock_graph_exec.node_executions = [] # Add this to avoid AttributeError
|
mock_graph_exec.node_executions = [] # Add this to avoid AttributeError
|
||||||
|
mock_graph_exec.status = ExecutionStatus.QUEUED # Required for race condition check
|
||||||
mock_graph_exec.to_graph_execution_entry.return_value = mocker.MagicMock()
|
mock_graph_exec.to_graph_execution_entry.return_value = mocker.MagicMock()
|
||||||
|
|
||||||
# Mock the queue and event bus
|
# Mock the queue and event bus
|
||||||
@@ -611,6 +613,7 @@ async def test_add_graph_execution_with_nodes_to_skip(mocker: MockerFixture):
|
|||||||
mock_graph_exec = mocker.MagicMock(spec=GraphExecutionWithNodes)
|
mock_graph_exec = mocker.MagicMock(spec=GraphExecutionWithNodes)
|
||||||
mock_graph_exec.id = "execution-id-123"
|
mock_graph_exec.id = "execution-id-123"
|
||||||
mock_graph_exec.node_executions = []
|
mock_graph_exec.node_executions = []
|
||||||
|
mock_graph_exec.status = ExecutionStatus.QUEUED # Required for race condition check
|
||||||
|
|
||||||
# Track what's passed to to_graph_execution_entry
|
# Track what's passed to to_graph_execution_entry
|
||||||
captured_kwargs = {}
|
captured_kwargs = {}
|
||||||
@@ -670,3 +673,232 @@ async def test_add_graph_execution_with_nodes_to_skip(mocker: MockerFixture):
|
|||||||
# Verify nodes_to_skip was passed to to_graph_execution_entry
|
# Verify nodes_to_skip was passed to to_graph_execution_entry
|
||||||
assert "nodes_to_skip" in captured_kwargs
|
assert "nodes_to_skip" in captured_kwargs
|
||||||
assert captured_kwargs["nodes_to_skip"] == nodes_to_skip
|
assert captured_kwargs["nodes_to_skip"] == nodes_to_skip
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_stop_graph_execution_in_review_status_cancels_pending_reviews(
|
||||||
|
mocker: MockerFixture,
|
||||||
|
):
|
||||||
|
"""Test that stopping an execution in REVIEW status cancels pending reviews."""
|
||||||
|
from backend.data.execution import ExecutionStatus, GraphExecutionMeta
|
||||||
|
from backend.executor.utils import stop_graph_execution
|
||||||
|
|
||||||
|
user_id = "test-user"
|
||||||
|
graph_exec_id = "test-exec-123"
|
||||||
|
|
||||||
|
# Mock graph execution in REVIEW status
|
||||||
|
mock_graph_exec = mocker.MagicMock(spec=GraphExecutionMeta)
|
||||||
|
mock_graph_exec.id = graph_exec_id
|
||||||
|
mock_graph_exec.status = ExecutionStatus.REVIEW
|
||||||
|
|
||||||
|
# Mock dependencies
|
||||||
|
mock_get_queue = mocker.patch("backend.executor.utils.get_async_execution_queue")
|
||||||
|
mock_queue_client = mocker.AsyncMock()
|
||||||
|
mock_get_queue.return_value = mock_queue_client
|
||||||
|
|
||||||
|
mock_prisma = mocker.patch("backend.executor.utils.prisma")
|
||||||
|
mock_prisma.is_connected.return_value = True
|
||||||
|
|
||||||
|
mock_human_review_db = mocker.patch("backend.executor.utils.human_review_db")
|
||||||
|
mock_human_review_db.cancel_pending_reviews_for_execution = mocker.AsyncMock(
|
||||||
|
return_value=2 # 2 reviews cancelled
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_execution_db = mocker.patch("backend.executor.utils.execution_db")
|
||||||
|
mock_execution_db.get_graph_execution_meta = mocker.AsyncMock(
|
||||||
|
return_value=mock_graph_exec
|
||||||
|
)
|
||||||
|
mock_execution_db.update_graph_execution_stats = mocker.AsyncMock()
|
||||||
|
|
||||||
|
mock_get_event_bus = mocker.patch(
|
||||||
|
"backend.executor.utils.get_async_execution_event_bus"
|
||||||
|
)
|
||||||
|
mock_event_bus = mocker.MagicMock()
|
||||||
|
mock_event_bus.publish = mocker.AsyncMock()
|
||||||
|
mock_get_event_bus.return_value = mock_event_bus
|
||||||
|
|
||||||
|
mock_get_child_executions = mocker.patch(
|
||||||
|
"backend.executor.utils._get_child_executions"
|
||||||
|
)
|
||||||
|
mock_get_child_executions.return_value = [] # No children
|
||||||
|
|
||||||
|
# Call stop_graph_execution with timeout to allow status check
|
||||||
|
await stop_graph_execution(
|
||||||
|
user_id=user_id,
|
||||||
|
graph_exec_id=graph_exec_id,
|
||||||
|
wait_timeout=1.0, # Wait to allow status check
|
||||||
|
cascade=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify pending reviews were cancelled
|
||||||
|
mock_human_review_db.cancel_pending_reviews_for_execution.assert_called_once_with(
|
||||||
|
graph_exec_id, user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify execution status was updated to TERMINATED
|
||||||
|
mock_execution_db.update_graph_execution_stats.assert_called_once()
|
||||||
|
call_kwargs = mock_execution_db.update_graph_execution_stats.call_args[1]
|
||||||
|
assert call_kwargs["graph_exec_id"] == graph_exec_id
|
||||||
|
assert call_kwargs["status"] == ExecutionStatus.TERMINATED
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_stop_graph_execution_with_database_manager_when_prisma_disconnected(
|
||||||
|
mocker: MockerFixture,
|
||||||
|
):
|
||||||
|
"""Test that stop uses database manager when Prisma is not connected."""
|
||||||
|
from backend.data.execution import ExecutionStatus, GraphExecutionMeta
|
||||||
|
from backend.executor.utils import stop_graph_execution
|
||||||
|
|
||||||
|
user_id = "test-user"
|
||||||
|
graph_exec_id = "test-exec-456"
|
||||||
|
|
||||||
|
# Mock graph execution in REVIEW status
|
||||||
|
mock_graph_exec = mocker.MagicMock(spec=GraphExecutionMeta)
|
||||||
|
mock_graph_exec.id = graph_exec_id
|
||||||
|
mock_graph_exec.status = ExecutionStatus.REVIEW
|
||||||
|
|
||||||
|
# Mock dependencies
|
||||||
|
mock_get_queue = mocker.patch("backend.executor.utils.get_async_execution_queue")
|
||||||
|
mock_queue_client = mocker.AsyncMock()
|
||||||
|
mock_get_queue.return_value = mock_queue_client
|
||||||
|
|
||||||
|
# Prisma is NOT connected
|
||||||
|
mock_prisma = mocker.patch("backend.executor.utils.prisma")
|
||||||
|
mock_prisma.is_connected.return_value = False
|
||||||
|
|
||||||
|
# Mock database manager client
|
||||||
|
mock_get_db_manager = mocker.patch(
|
||||||
|
"backend.executor.utils.get_database_manager_async_client"
|
||||||
|
)
|
||||||
|
mock_db_manager = mocker.AsyncMock()
|
||||||
|
mock_db_manager.get_graph_execution_meta = mocker.AsyncMock(
|
||||||
|
return_value=mock_graph_exec
|
||||||
|
)
|
||||||
|
mock_db_manager.cancel_pending_reviews_for_execution = mocker.AsyncMock(
|
||||||
|
return_value=3 # 3 reviews cancelled
|
||||||
|
)
|
||||||
|
mock_db_manager.update_graph_execution_stats = mocker.AsyncMock()
|
||||||
|
mock_get_db_manager.return_value = mock_db_manager
|
||||||
|
|
||||||
|
mock_get_event_bus = mocker.patch(
|
||||||
|
"backend.executor.utils.get_async_execution_event_bus"
|
||||||
|
)
|
||||||
|
mock_event_bus = mocker.MagicMock()
|
||||||
|
mock_event_bus.publish = mocker.AsyncMock()
|
||||||
|
mock_get_event_bus.return_value = mock_event_bus
|
||||||
|
|
||||||
|
mock_get_child_executions = mocker.patch(
|
||||||
|
"backend.executor.utils._get_child_executions"
|
||||||
|
)
|
||||||
|
mock_get_child_executions.return_value = [] # No children
|
||||||
|
|
||||||
|
# Call stop_graph_execution with timeout
|
||||||
|
await stop_graph_execution(
|
||||||
|
user_id=user_id,
|
||||||
|
graph_exec_id=graph_exec_id,
|
||||||
|
wait_timeout=1.0,
|
||||||
|
cascade=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify database manager was used for cancel_pending_reviews
|
||||||
|
mock_db_manager.cancel_pending_reviews_for_execution.assert_called_once_with(
|
||||||
|
graph_exec_id, user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify execution status was updated via database manager
|
||||||
|
mock_db_manager.update_graph_execution_stats.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_stop_graph_execution_cascades_to_child_with_reviews(
|
||||||
|
mocker: MockerFixture,
|
||||||
|
):
|
||||||
|
"""Test that stopping parent execution cascades to children and cancels their reviews."""
|
||||||
|
from backend.data.execution import ExecutionStatus, GraphExecutionMeta
|
||||||
|
from backend.executor.utils import stop_graph_execution
|
||||||
|
|
||||||
|
user_id = "test-user"
|
||||||
|
parent_exec_id = "parent-exec"
|
||||||
|
child_exec_id = "child-exec"
|
||||||
|
|
||||||
|
# Mock parent execution in RUNNING status
|
||||||
|
mock_parent_exec = mocker.MagicMock(spec=GraphExecutionMeta)
|
||||||
|
mock_parent_exec.id = parent_exec_id
|
||||||
|
mock_parent_exec.status = ExecutionStatus.RUNNING
|
||||||
|
|
||||||
|
# Mock child execution in REVIEW status
|
||||||
|
mock_child_exec = mocker.MagicMock(spec=GraphExecutionMeta)
|
||||||
|
mock_child_exec.id = child_exec_id
|
||||||
|
mock_child_exec.status = ExecutionStatus.REVIEW
|
||||||
|
|
||||||
|
# Mock dependencies
|
||||||
|
mock_get_queue = mocker.patch("backend.executor.utils.get_async_execution_queue")
|
||||||
|
mock_queue_client = mocker.AsyncMock()
|
||||||
|
mock_get_queue.return_value = mock_queue_client
|
||||||
|
|
||||||
|
mock_prisma = mocker.patch("backend.executor.utils.prisma")
|
||||||
|
mock_prisma.is_connected.return_value = True
|
||||||
|
|
||||||
|
mock_human_review_db = mocker.patch("backend.executor.utils.human_review_db")
|
||||||
|
mock_human_review_db.cancel_pending_reviews_for_execution = mocker.AsyncMock(
|
||||||
|
return_value=1 # 1 child review cancelled
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock execution_db to return different status based on which execution is queried
|
||||||
|
mock_execution_db = mocker.patch("backend.executor.utils.execution_db")
|
||||||
|
|
||||||
|
# Track call count to simulate status transition
|
||||||
|
call_count = {"count": 0}
|
||||||
|
|
||||||
|
async def get_exec_meta_side_effect(execution_id, user_id):
|
||||||
|
call_count["count"] += 1
|
||||||
|
if execution_id == parent_exec_id:
|
||||||
|
# After a few calls (child processing happens), transition parent to TERMINATED
|
||||||
|
# This simulates the executor service processing the stop request
|
||||||
|
if call_count["count"] > 3:
|
||||||
|
mock_parent_exec.status = ExecutionStatus.TERMINATED
|
||||||
|
return mock_parent_exec
|
||||||
|
elif execution_id == child_exec_id:
|
||||||
|
return mock_child_exec
|
||||||
|
return None
|
||||||
|
|
||||||
|
mock_execution_db.get_graph_execution_meta = mocker.AsyncMock(
|
||||||
|
side_effect=get_exec_meta_side_effect
|
||||||
|
)
|
||||||
|
mock_execution_db.update_graph_execution_stats = mocker.AsyncMock()
|
||||||
|
|
||||||
|
mock_get_event_bus = mocker.patch(
|
||||||
|
"backend.executor.utils.get_async_execution_event_bus"
|
||||||
|
)
|
||||||
|
mock_event_bus = mocker.MagicMock()
|
||||||
|
mock_event_bus.publish = mocker.AsyncMock()
|
||||||
|
mock_get_event_bus.return_value = mock_event_bus
|
||||||
|
|
||||||
|
# Mock _get_child_executions to return the child
|
||||||
|
mock_get_child_executions = mocker.patch(
|
||||||
|
"backend.executor.utils._get_child_executions"
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_children_side_effect(parent_id):
|
||||||
|
if parent_id == parent_exec_id:
|
||||||
|
return [mock_child_exec]
|
||||||
|
return []
|
||||||
|
|
||||||
|
mock_get_child_executions.side_effect = get_children_side_effect
|
||||||
|
|
||||||
|
# Call stop_graph_execution on parent with cascade=True
|
||||||
|
await stop_graph_execution(
|
||||||
|
user_id=user_id,
|
||||||
|
graph_exec_id=parent_exec_id,
|
||||||
|
wait_timeout=1.0,
|
||||||
|
cascade=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify child reviews were cancelled
|
||||||
|
mock_human_review_db.cancel_pending_reviews_for_execution.assert_called_once_with(
|
||||||
|
child_exec_id, user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify both parent and child status updates
|
||||||
|
assert mock_execution_db.update_graph_execution_stats.call_count >= 1
|
||||||
|
|||||||
@@ -350,6 +350,19 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
|
|||||||
description="Whether to mark failed scans as clean or not",
|
description="Whether to mark failed scans as clean or not",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
agentgenerator_host: str = Field(
|
||||||
|
default="",
|
||||||
|
description="The host for the Agent Generator service (empty to use built-in)",
|
||||||
|
)
|
||||||
|
agentgenerator_port: int = Field(
|
||||||
|
default=8000,
|
||||||
|
description="The port for the Agent Generator service",
|
||||||
|
)
|
||||||
|
agentgenerator_timeout: int = Field(
|
||||||
|
default=120,
|
||||||
|
description="The timeout in seconds for Agent Generator service requests",
|
||||||
|
)
|
||||||
|
|
||||||
enable_example_blocks: bool = Field(
|
enable_example_blocks: bool = Field(
|
||||||
default=False,
|
default=False,
|
||||||
description="Whether to enable example blocks in production",
|
description="Whether to enable example blocks in production",
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import asyncio
|
||||||
import inspect
|
import inspect
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
@@ -58,6 +59,11 @@ class SpinTestServer:
|
|||||||
self.db_api.__exit__(exc_type, exc_val, exc_tb)
|
self.db_api.__exit__(exc_type, exc_val, exc_tb)
|
||||||
self.notif_manager.__exit__(exc_type, exc_val, exc_tb)
|
self.notif_manager.__exit__(exc_type, exc_val, exc_tb)
|
||||||
|
|
||||||
|
# Give services time to fully shut down
|
||||||
|
# This prevents event loop issues where services haven't fully cleaned up
|
||||||
|
# before the next test starts
|
||||||
|
await asyncio.sleep(0.5)
|
||||||
|
|
||||||
def setup_dependency_overrides(self):
|
def setup_dependency_overrides(self):
|
||||||
# Override get_user_id for testing
|
# Override get_user_id for testing
|
||||||
self.agent_server.set_test_dependency_overrides(
|
self.agent_server.set_test_dependency_overrides(
|
||||||
|
|||||||
@@ -1,12 +1,37 @@
|
|||||||
-- CreateExtension
|
-- CreateExtension
|
||||||
-- Supabase: pgvector must be enabled via Dashboard → Database → Extensions first
|
-- Supabase: pgvector must be enabled via Dashboard → Database → Extensions first
|
||||||
-- Creates extension in current schema (determined by search_path from DATABASE_URL ?schema= param)
|
-- Ensures vector extension is in the current schema (from DATABASE_URL ?schema= param)
|
||||||
|
-- If it exists in a different schema (e.g., public), we drop and recreate it in the current schema
|
||||||
-- This ensures vector type is in the same schema as tables, making ::vector work without explicit qualification
|
-- This ensures vector type is in the same schema as tables, making ::vector work without explicit qualification
|
||||||
DO $$
|
DO $$
|
||||||
|
DECLARE
|
||||||
|
current_schema_name text;
|
||||||
|
vector_schema text;
|
||||||
BEGIN
|
BEGIN
|
||||||
CREATE EXTENSION IF NOT EXISTS "vector";
|
-- Get the current schema from search_path
|
||||||
EXCEPTION WHEN OTHERS THEN
|
SELECT current_schema() INTO current_schema_name;
|
||||||
RAISE NOTICE 'vector extension not available or already exists, skipping';
|
|
||||||
|
-- Check if vector extension exists and which schema it's in
|
||||||
|
SELECT n.nspname INTO vector_schema
|
||||||
|
FROM pg_extension e
|
||||||
|
JOIN pg_namespace n ON e.extnamespace = n.oid
|
||||||
|
WHERE e.extname = 'vector';
|
||||||
|
|
||||||
|
-- Handle removal if in wrong schema
|
||||||
|
IF vector_schema IS NOT NULL AND vector_schema != current_schema_name THEN
|
||||||
|
BEGIN
|
||||||
|
-- Vector exists in a different schema, drop it first
|
||||||
|
RAISE WARNING 'pgvector found in schema "%" but need it in "%". Dropping and reinstalling...',
|
||||||
|
vector_schema, current_schema_name;
|
||||||
|
EXECUTE 'DROP EXTENSION IF EXISTS vector CASCADE';
|
||||||
|
EXCEPTION WHEN OTHERS THEN
|
||||||
|
RAISE EXCEPTION 'Failed to drop pgvector from schema "%": %. You may need to drop it manually.',
|
||||||
|
vector_schema, SQLERRM;
|
||||||
|
END;
|
||||||
|
END IF;
|
||||||
|
|
||||||
|
-- Create extension in current schema (let it fail naturally if not available)
|
||||||
|
EXECUTE format('CREATE EXTENSION IF NOT EXISTS vector SCHEMA %I', current_schema_name);
|
||||||
END $$;
|
END $$;
|
||||||
|
|
||||||
-- CreateEnum
|
-- CreateEnum
|
||||||
|
|||||||
@@ -1,71 +0,0 @@
|
|||||||
-- Acknowledge Supabase-managed extensions to prevent drift warnings
|
|
||||||
-- These extensions are pre-installed by Supabase in specific schemas
|
|
||||||
-- This migration ensures they exist where available (Supabase) or skips gracefully (CI)
|
|
||||||
|
|
||||||
-- Create schemas (safe in both CI and Supabase)
|
|
||||||
CREATE SCHEMA IF NOT EXISTS "extensions";
|
|
||||||
|
|
||||||
-- Extensions that exist in both CI and Supabase
|
|
||||||
DO $$
|
|
||||||
BEGIN
|
|
||||||
CREATE EXTENSION IF NOT EXISTS "pgcrypto" WITH SCHEMA "extensions";
|
|
||||||
EXCEPTION WHEN OTHERS THEN
|
|
||||||
RAISE NOTICE 'pgcrypto extension not available, skipping';
|
|
||||||
END $$;
|
|
||||||
|
|
||||||
DO $$
|
|
||||||
BEGIN
|
|
||||||
CREATE EXTENSION IF NOT EXISTS "uuid-ossp" WITH SCHEMA "extensions";
|
|
||||||
EXCEPTION WHEN OTHERS THEN
|
|
||||||
RAISE NOTICE 'uuid-ossp extension not available, skipping';
|
|
||||||
END $$;
|
|
||||||
|
|
||||||
-- Supabase-specific extensions (skip gracefully in CI)
|
|
||||||
DO $$
|
|
||||||
BEGIN
|
|
||||||
CREATE EXTENSION IF NOT EXISTS "pg_stat_statements" WITH SCHEMA "extensions";
|
|
||||||
EXCEPTION WHEN OTHERS THEN
|
|
||||||
RAISE NOTICE 'pg_stat_statements extension not available, skipping';
|
|
||||||
END $$;
|
|
||||||
|
|
||||||
DO $$
|
|
||||||
BEGIN
|
|
||||||
CREATE EXTENSION IF NOT EXISTS "pg_net" WITH SCHEMA "extensions";
|
|
||||||
EXCEPTION WHEN OTHERS THEN
|
|
||||||
RAISE NOTICE 'pg_net extension not available, skipping';
|
|
||||||
END $$;
|
|
||||||
|
|
||||||
DO $$
|
|
||||||
BEGIN
|
|
||||||
CREATE EXTENSION IF NOT EXISTS "pgjwt" WITH SCHEMA "extensions";
|
|
||||||
EXCEPTION WHEN OTHERS THEN
|
|
||||||
RAISE NOTICE 'pgjwt extension not available, skipping';
|
|
||||||
END $$;
|
|
||||||
|
|
||||||
DO $$
|
|
||||||
BEGIN
|
|
||||||
CREATE SCHEMA IF NOT EXISTS "graphql";
|
|
||||||
CREATE EXTENSION IF NOT EXISTS "pg_graphql" WITH SCHEMA "graphql";
|
|
||||||
EXCEPTION WHEN OTHERS THEN
|
|
||||||
RAISE NOTICE 'pg_graphql extension not available, skipping';
|
|
||||||
END $$;
|
|
||||||
|
|
||||||
DO $$
|
|
||||||
BEGIN
|
|
||||||
CREATE SCHEMA IF NOT EXISTS "pgsodium";
|
|
||||||
CREATE EXTENSION IF NOT EXISTS "pgsodium" WITH SCHEMA "pgsodium";
|
|
||||||
EXCEPTION WHEN OTHERS THEN
|
|
||||||
RAISE NOTICE 'pgsodium extension not available, skipping';
|
|
||||||
END $$;
|
|
||||||
|
|
||||||
DO $$
|
|
||||||
BEGIN
|
|
||||||
CREATE SCHEMA IF NOT EXISTS "vault";
|
|
||||||
CREATE EXTENSION IF NOT EXISTS "supabase_vault" WITH SCHEMA "vault";
|
|
||||||
EXCEPTION WHEN OTHERS THEN
|
|
||||||
RAISE NOTICE 'supabase_vault extension not available, skipping';
|
|
||||||
END $$;
|
|
||||||
|
|
||||||
|
|
||||||
-- Return to platform
|
|
||||||
CREATE SCHEMA IF NOT EXISTS "platform";
|
|
||||||
@@ -0,0 +1,7 @@
|
|||||||
|
-- Remove NodeExecution foreign key from PendingHumanReview
|
||||||
|
-- The nodeExecId column remains as the primary key, but we remove the FK constraint
|
||||||
|
-- to AgentNodeExecution since PendingHumanReview records can persist after node
|
||||||
|
-- execution records are deleted.
|
||||||
|
|
||||||
|
-- Drop foreign key constraint that linked PendingHumanReview.nodeExecId to AgentNodeExecution.id
|
||||||
|
ALTER TABLE "PendingHumanReview" DROP CONSTRAINT IF EXISTS "PendingHumanReview_nodeExecId_fkey";
|
||||||
@@ -517,8 +517,6 @@ model AgentNodeExecution {
|
|||||||
|
|
||||||
stats Json?
|
stats Json?
|
||||||
|
|
||||||
PendingHumanReview PendingHumanReview?
|
|
||||||
|
|
||||||
@@index([agentGraphExecutionId, agentNodeId, executionStatus])
|
@@index([agentGraphExecutionId, agentNodeId, executionStatus])
|
||||||
@@index([agentNodeId, executionStatus])
|
@@index([agentNodeId, executionStatus])
|
||||||
@@index([addedTime, queuedTime])
|
@@index([addedTime, queuedTime])
|
||||||
@@ -567,6 +565,7 @@ enum ReviewStatus {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Pending human reviews for Human-in-the-loop blocks
|
// Pending human reviews for Human-in-the-loop blocks
|
||||||
|
// Also stores auto-approval records with special nodeExecId patterns (e.g., "auto_approve_{graph_exec_id}_{node_id}")
|
||||||
model PendingHumanReview {
|
model PendingHumanReview {
|
||||||
nodeExecId String @id
|
nodeExecId String @id
|
||||||
userId String
|
userId String
|
||||||
@@ -585,7 +584,6 @@ model PendingHumanReview {
|
|||||||
reviewedAt DateTime?
|
reviewedAt DateTime?
|
||||||
|
|
||||||
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
||||||
NodeExecution AgentNodeExecution @relation(fields: [nodeExecId], references: [id], onDelete: Cascade)
|
|
||||||
GraphExecution AgentGraphExecution @relation(fields: [graphExecId], references: [id], onDelete: Cascade)
|
GraphExecution AgentGraphExecution @relation(fields: [graphExecId], references: [id], onDelete: Cascade)
|
||||||
|
|
||||||
@@unique([nodeExecId]) // One pending review per node execution
|
@@unique([nodeExecId]) // One pending review per node execution
|
||||||
|
|||||||
@@ -34,7 +34,10 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
# Default output directory relative to repo root
|
# Default output directory relative to repo root
|
||||||
DEFAULT_OUTPUT_DIR = (
|
DEFAULT_OUTPUT_DIR = (
|
||||||
Path(__file__).parent.parent.parent.parent / "docs" / "integrations"
|
Path(__file__).parent.parent.parent.parent
|
||||||
|
/ "docs"
|
||||||
|
/ "integrations"
|
||||||
|
/ "block-integrations"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -421,6 +424,14 @@ def generate_block_markdown(
|
|||||||
lines.append("<!-- END MANUAL -->")
|
lines.append("<!-- END MANUAL -->")
|
||||||
lines.append("")
|
lines.append("")
|
||||||
|
|
||||||
|
# Optional per-block extras (only include if has content)
|
||||||
|
extras = manual_content.get("extras", "")
|
||||||
|
if extras:
|
||||||
|
lines.append("<!-- MANUAL: extras -->")
|
||||||
|
lines.append(extras)
|
||||||
|
lines.append("<!-- END MANUAL -->")
|
||||||
|
lines.append("")
|
||||||
|
|
||||||
lines.append("---")
|
lines.append("---")
|
||||||
lines.append("")
|
lines.append("")
|
||||||
|
|
||||||
@@ -456,25 +467,52 @@ def get_block_file_mapping(blocks: list[BlockDoc]) -> dict[str, list[BlockDoc]]:
|
|||||||
return dict(file_mapping)
|
return dict(file_mapping)
|
||||||
|
|
||||||
|
|
||||||
def generate_overview_table(blocks: list[BlockDoc]) -> str:
|
def generate_overview_table(blocks: list[BlockDoc], block_dir_prefix: str = "") -> str:
|
||||||
"""Generate the overview table markdown (blocks.md)."""
|
"""Generate the overview table markdown (blocks.md).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
blocks: List of block documentation objects
|
||||||
|
block_dir_prefix: Prefix for block file links (e.g., "block-integrations/")
|
||||||
|
"""
|
||||||
lines = []
|
lines = []
|
||||||
|
|
||||||
|
# GitBook YAML frontmatter
|
||||||
|
lines.append("---")
|
||||||
|
lines.append("layout:")
|
||||||
|
lines.append(" width: default")
|
||||||
|
lines.append(" title:")
|
||||||
|
lines.append(" visible: true")
|
||||||
|
lines.append(" description:")
|
||||||
|
lines.append(" visible: true")
|
||||||
|
lines.append(" tableOfContents:")
|
||||||
|
lines.append(" visible: false")
|
||||||
|
lines.append(" outline:")
|
||||||
|
lines.append(" visible: true")
|
||||||
|
lines.append(" pagination:")
|
||||||
|
lines.append(" visible: true")
|
||||||
|
lines.append(" metadata:")
|
||||||
|
lines.append(" visible: true")
|
||||||
|
lines.append("---")
|
||||||
|
lines.append("")
|
||||||
|
|
||||||
lines.append("# AutoGPT Blocks Overview")
|
lines.append("# AutoGPT Blocks Overview")
|
||||||
lines.append("")
|
lines.append("")
|
||||||
lines.append(
|
lines.append(
|
||||||
'AutoGPT uses a modular approach with various "blocks" to handle different tasks. These blocks are the building blocks of AutoGPT workflows, allowing users to create complex automations by combining simple, specialized components.'
|
'AutoGPT uses a modular approach with various "blocks" to handle different tasks. These blocks are the building blocks of AutoGPT workflows, allowing users to create complex automations by combining simple, specialized components.'
|
||||||
)
|
)
|
||||||
lines.append("")
|
lines.append("")
|
||||||
lines.append('!!! info "Creating Your Own Blocks"')
|
lines.append('{% hint style="info" %}')
|
||||||
lines.append(" Want to create your own custom blocks? Check out our guides:")
|
lines.append("**Creating Your Own Blocks**")
|
||||||
lines.append(" ")
|
lines.append("")
|
||||||
|
lines.append("Want to create your own custom blocks? Check out our guides:")
|
||||||
|
lines.append("")
|
||||||
lines.append(
|
lines.append(
|
||||||
" - [Build your own Blocks](https://docs.agpt.co/platform/new_blocks/) - Step-by-step tutorial with examples"
|
"* [Build your own Blocks](https://docs.agpt.co/platform/new_blocks/) - Step-by-step tutorial with examples"
|
||||||
)
|
)
|
||||||
lines.append(
|
lines.append(
|
||||||
" - [Block SDK Guide](https://docs.agpt.co/platform/block-sdk-guide/) - Advanced SDK patterns with OAuth, webhooks, and provider configuration"
|
"* [Block SDK Guide](https://docs.agpt.co/platform/block-sdk-guide/) - Advanced SDK patterns with OAuth, webhooks, and provider configuration"
|
||||||
)
|
)
|
||||||
|
lines.append("{% endhint %}")
|
||||||
lines.append("")
|
lines.append("")
|
||||||
lines.append(
|
lines.append(
|
||||||
"Below is a comprehensive list of all available blocks, categorized by their primary function. Click on any block name to view its detailed documentation."
|
"Below is a comprehensive list of all available blocks, categorized by their primary function. Click on any block name to view its detailed documentation."
|
||||||
@@ -537,7 +575,8 @@ def generate_overview_table(blocks: list[BlockDoc]) -> str:
|
|||||||
else "No description"
|
else "No description"
|
||||||
)
|
)
|
||||||
short_desc = short_desc.replace("\n", " ").replace("|", "\\|")
|
short_desc = short_desc.replace("\n", " ").replace("|", "\\|")
|
||||||
lines.append(f"| [{block.name}]({file_path}#{anchor}) | {short_desc} |")
|
link_path = f"{block_dir_prefix}{file_path}"
|
||||||
|
lines.append(f"| [{block.name}]({link_path}#{anchor}) | {short_desc} |")
|
||||||
lines.append("")
|
lines.append("")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@@ -563,13 +602,55 @@ def generate_overview_table(blocks: list[BlockDoc]) -> str:
|
|||||||
)
|
)
|
||||||
short_desc = short_desc.replace("\n", " ").replace("|", "\\|")
|
short_desc = short_desc.replace("\n", " ").replace("|", "\\|")
|
||||||
|
|
||||||
lines.append(f"| [{block.name}]({file_path}#{anchor}) | {short_desc} |")
|
link_path = f"{block_dir_prefix}{file_path}"
|
||||||
|
lines.append(f"| [{block.name}]({link_path}#{anchor}) | {short_desc} |")
|
||||||
|
|
||||||
lines.append("")
|
lines.append("")
|
||||||
|
|
||||||
return "\n".join(lines)
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
|
def generate_summary_md(
|
||||||
|
blocks: list[BlockDoc], root_dir: Path, block_dir_prefix: str = ""
|
||||||
|
) -> str:
|
||||||
|
"""Generate SUMMARY.md for GitBook navigation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
blocks: List of block documentation objects
|
||||||
|
root_dir: The root docs directory (e.g., docs/integrations/)
|
||||||
|
block_dir_prefix: Prefix for block file links (e.g., "block-integrations/")
|
||||||
|
"""
|
||||||
|
lines = []
|
||||||
|
lines.append("# Table of contents")
|
||||||
|
lines.append("")
|
||||||
|
lines.append("* [AutoGPT Blocks Overview](README.md)")
|
||||||
|
lines.append("")
|
||||||
|
|
||||||
|
# Check for guides/ directory at the root level (docs/integrations/guides/)
|
||||||
|
guides_dir = root_dir / "guides"
|
||||||
|
if guides_dir.exists():
|
||||||
|
lines.append("## Guides")
|
||||||
|
lines.append("")
|
||||||
|
for guide_file in sorted(guides_dir.glob("*.md")):
|
||||||
|
# Use just the file name for title (replace hyphens/underscores with spaces)
|
||||||
|
title = file_path_to_title(guide_file.stem.replace("-", "_") + ".md")
|
||||||
|
lines.append(f"* [{title}](guides/{guide_file.name})")
|
||||||
|
lines.append("")
|
||||||
|
|
||||||
|
lines.append("## Block Integrations")
|
||||||
|
lines.append("")
|
||||||
|
|
||||||
|
file_mapping = get_block_file_mapping(blocks)
|
||||||
|
for file_path in sorted(file_mapping.keys()):
|
||||||
|
title = file_path_to_title(file_path)
|
||||||
|
link_path = f"{block_dir_prefix}{file_path}"
|
||||||
|
lines.append(f"* [{title}]({link_path})")
|
||||||
|
|
||||||
|
lines.append("")
|
||||||
|
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
def load_all_blocks_for_docs() -> list[BlockDoc]:
|
def load_all_blocks_for_docs() -> list[BlockDoc]:
|
||||||
"""Load all blocks and extract documentation."""
|
"""Load all blocks and extract documentation."""
|
||||||
from backend.blocks import load_all_blocks
|
from backend.blocks import load_all_blocks
|
||||||
@@ -653,6 +734,16 @@ def write_block_docs(
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Add file-level additional_content section if present
|
||||||
|
file_additional = extract_manual_content(existing_content).get(
|
||||||
|
"additional_content", ""
|
||||||
|
)
|
||||||
|
if file_additional:
|
||||||
|
content_parts.append("<!-- MANUAL: additional_content -->")
|
||||||
|
content_parts.append(file_additional)
|
||||||
|
content_parts.append("<!-- END MANUAL -->")
|
||||||
|
content_parts.append("")
|
||||||
|
|
||||||
full_content = file_header + "\n" + "\n".join(content_parts)
|
full_content = file_header + "\n" + "\n".join(content_parts)
|
||||||
generated_files[str(file_path)] = full_content
|
generated_files[str(file_path)] = full_content
|
||||||
|
|
||||||
@@ -661,14 +752,28 @@ def write_block_docs(
|
|||||||
|
|
||||||
full_path.write_text(full_content)
|
full_path.write_text(full_content)
|
||||||
|
|
||||||
# Generate overview file
|
# Generate overview file at the parent directory (docs/integrations/)
|
||||||
overview_content = generate_overview_table(blocks)
|
# with links prefixed to point into block-integrations/
|
||||||
overview_path = output_dir / "README.md"
|
root_dir = output_dir.parent
|
||||||
|
block_dir_name = output_dir.name # "block-integrations"
|
||||||
|
block_dir_prefix = f"{block_dir_name}/"
|
||||||
|
|
||||||
|
overview_content = generate_overview_table(blocks, block_dir_prefix)
|
||||||
|
overview_path = root_dir / "README.md"
|
||||||
generated_files["README.md"] = overview_content
|
generated_files["README.md"] = overview_content
|
||||||
overview_path.write_text(overview_content)
|
overview_path.write_text(overview_content)
|
||||||
|
|
||||||
if verbose:
|
if verbose:
|
||||||
print(" Writing README.md (overview)")
|
print(" Writing README.md (overview) to parent directory")
|
||||||
|
|
||||||
|
# Generate SUMMARY.md for GitBook navigation at the parent directory
|
||||||
|
summary_content = generate_summary_md(blocks, root_dir, block_dir_prefix)
|
||||||
|
summary_path = root_dir / "SUMMARY.md"
|
||||||
|
generated_files["SUMMARY.md"] = summary_content
|
||||||
|
summary_path.write_text(summary_content)
|
||||||
|
|
||||||
|
if verbose:
|
||||||
|
print(" Writing SUMMARY.md (navigation) to parent directory")
|
||||||
|
|
||||||
return generated_files
|
return generated_files
|
||||||
|
|
||||||
@@ -748,6 +853,16 @@ def check_docs_in_sync(output_dir: Path, blocks: list[BlockDoc]) -> bool:
|
|||||||
elif block_match.group(1).strip() != expected_block_content.strip():
|
elif block_match.group(1).strip() != expected_block_content.strip():
|
||||||
mismatched_blocks.append(block.name)
|
mismatched_blocks.append(block.name)
|
||||||
|
|
||||||
|
# Add file-level additional_content to expected content (matches write_block_docs)
|
||||||
|
file_additional = extract_manual_content(existing_content).get(
|
||||||
|
"additional_content", ""
|
||||||
|
)
|
||||||
|
if file_additional:
|
||||||
|
content_parts.append("<!-- MANUAL: additional_content -->")
|
||||||
|
content_parts.append(file_additional)
|
||||||
|
content_parts.append("<!-- END MANUAL -->")
|
||||||
|
content_parts.append("")
|
||||||
|
|
||||||
expected_content = file_header + "\n" + "\n".join(content_parts)
|
expected_content = file_header + "\n" + "\n".join(content_parts)
|
||||||
|
|
||||||
if existing_content.strip() != expected_content.strip():
|
if existing_content.strip() != expected_content.strip():
|
||||||
@@ -757,11 +872,15 @@ def check_docs_in_sync(output_dir: Path, blocks: list[BlockDoc]) -> bool:
|
|||||||
out_of_sync_details.append((file_path, mismatched_blocks))
|
out_of_sync_details.append((file_path, mismatched_blocks))
|
||||||
all_match = False
|
all_match = False
|
||||||
|
|
||||||
# Check overview
|
# Check overview at the parent directory (docs/integrations/)
|
||||||
overview_path = output_dir / "README.md"
|
root_dir = output_dir.parent
|
||||||
|
block_dir_name = output_dir.name # "block-integrations"
|
||||||
|
block_dir_prefix = f"{block_dir_name}/"
|
||||||
|
|
||||||
|
overview_path = root_dir / "README.md"
|
||||||
if overview_path.exists():
|
if overview_path.exists():
|
||||||
existing_overview = overview_path.read_text()
|
existing_overview = overview_path.read_text()
|
||||||
expected_overview = generate_overview_table(blocks)
|
expected_overview = generate_overview_table(blocks, block_dir_prefix)
|
||||||
if existing_overview.strip() != expected_overview.strip():
|
if existing_overview.strip() != expected_overview.strip():
|
||||||
print("OUT OF SYNC: README.md (overview)")
|
print("OUT OF SYNC: README.md (overview)")
|
||||||
print(" The blocks overview table needs regeneration")
|
print(" The blocks overview table needs regeneration")
|
||||||
@@ -772,6 +891,21 @@ def check_docs_in_sync(output_dir: Path, blocks: list[BlockDoc]) -> bool:
|
|||||||
out_of_sync_details.append(("README.md", ["overview table"]))
|
out_of_sync_details.append(("README.md", ["overview table"]))
|
||||||
all_match = False
|
all_match = False
|
||||||
|
|
||||||
|
# Check SUMMARY.md at the parent directory
|
||||||
|
summary_path = root_dir / "SUMMARY.md"
|
||||||
|
if summary_path.exists():
|
||||||
|
existing_summary = summary_path.read_text()
|
||||||
|
expected_summary = generate_summary_md(blocks, root_dir, block_dir_prefix)
|
||||||
|
if existing_summary.strip() != expected_summary.strip():
|
||||||
|
print("OUT OF SYNC: SUMMARY.md (navigation)")
|
||||||
|
print(" The GitBook navigation needs regeneration")
|
||||||
|
out_of_sync_details.append(("SUMMARY.md", ["navigation"]))
|
||||||
|
all_match = False
|
||||||
|
else:
|
||||||
|
print("MISSING: SUMMARY.md (navigation)")
|
||||||
|
out_of_sync_details.append(("SUMMARY.md", ["navigation"]))
|
||||||
|
all_match = False
|
||||||
|
|
||||||
# Check for unfilled manual sections
|
# Check for unfilled manual sections
|
||||||
unfilled_patterns = [
|
unfilled_patterns = [
|
||||||
"_Add a description of this category of blocks._",
|
"_Add a description of this category of blocks._",
|
||||||
|
|||||||
@@ -0,0 +1 @@
|
|||||||
|
"""Tests for agent generator module."""
|
||||||
@@ -0,0 +1,273 @@
|
|||||||
|
"""
|
||||||
|
Tests for the Agent Generator core module.
|
||||||
|
|
||||||
|
This test suite verifies that the core functions correctly delegate to
|
||||||
|
the external Agent Generator service.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from backend.api.features.chat.tools.agent_generator import core
|
||||||
|
from backend.api.features.chat.tools.agent_generator.core import (
|
||||||
|
AgentGeneratorNotConfiguredError,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestServiceNotConfigured:
|
||||||
|
"""Test that functions raise AgentGeneratorNotConfiguredError when service is not configured."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_decompose_goal_raises_when_not_configured(self):
|
||||||
|
"""Test that decompose_goal raises error when service not configured."""
|
||||||
|
with patch.object(core, "is_external_service_configured", return_value=False):
|
||||||
|
with pytest.raises(AgentGeneratorNotConfiguredError):
|
||||||
|
await core.decompose_goal("Build a chatbot")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_generate_agent_raises_when_not_configured(self):
|
||||||
|
"""Test that generate_agent raises error when service not configured."""
|
||||||
|
with patch.object(core, "is_external_service_configured", return_value=False):
|
||||||
|
with pytest.raises(AgentGeneratorNotConfiguredError):
|
||||||
|
await core.generate_agent({"steps": []})
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_generate_agent_patch_raises_when_not_configured(self):
|
||||||
|
"""Test that generate_agent_patch raises error when service not configured."""
|
||||||
|
with patch.object(core, "is_external_service_configured", return_value=False):
|
||||||
|
with pytest.raises(AgentGeneratorNotConfiguredError):
|
||||||
|
await core.generate_agent_patch("Add a node", {"nodes": []})
|
||||||
|
|
||||||
|
|
||||||
|
class TestDecomposeGoal:
|
||||||
|
"""Test decompose_goal function service delegation."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_calls_external_service(self):
|
||||||
|
"""Test that decompose_goal calls the external service."""
|
||||||
|
expected_result = {"type": "instructions", "steps": ["Step 1"]}
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
core, "is_external_service_configured", return_value=True
|
||||||
|
), patch.object(
|
||||||
|
core, "decompose_goal_external", new_callable=AsyncMock
|
||||||
|
) as mock_external:
|
||||||
|
mock_external.return_value = expected_result
|
||||||
|
|
||||||
|
result = await core.decompose_goal("Build a chatbot")
|
||||||
|
|
||||||
|
mock_external.assert_called_once_with("Build a chatbot", "")
|
||||||
|
assert result == expected_result
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_passes_context_to_external_service(self):
|
||||||
|
"""Test that decompose_goal passes context to external service."""
|
||||||
|
expected_result = {"type": "instructions", "steps": ["Step 1"]}
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
core, "is_external_service_configured", return_value=True
|
||||||
|
), patch.object(
|
||||||
|
core, "decompose_goal_external", new_callable=AsyncMock
|
||||||
|
) as mock_external:
|
||||||
|
mock_external.return_value = expected_result
|
||||||
|
|
||||||
|
await core.decompose_goal("Build a chatbot", "Use Python")
|
||||||
|
|
||||||
|
mock_external.assert_called_once_with("Build a chatbot", "Use Python")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_returns_none_on_service_failure(self):
|
||||||
|
"""Test that decompose_goal returns None when external service fails."""
|
||||||
|
with patch.object(
|
||||||
|
core, "is_external_service_configured", return_value=True
|
||||||
|
), patch.object(
|
||||||
|
core, "decompose_goal_external", new_callable=AsyncMock
|
||||||
|
) as mock_external:
|
||||||
|
mock_external.return_value = None
|
||||||
|
|
||||||
|
result = await core.decompose_goal("Build a chatbot")
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestGenerateAgent:
|
||||||
|
"""Test generate_agent function service delegation."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_calls_external_service(self):
|
||||||
|
"""Test that generate_agent calls the external service."""
|
||||||
|
expected_result = {"name": "Test Agent", "nodes": [], "links": []}
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
core, "is_external_service_configured", return_value=True
|
||||||
|
), patch.object(
|
||||||
|
core, "generate_agent_external", new_callable=AsyncMock
|
||||||
|
) as mock_external:
|
||||||
|
mock_external.return_value = expected_result
|
||||||
|
|
||||||
|
instructions = {"type": "instructions", "steps": ["Step 1"]}
|
||||||
|
result = await core.generate_agent(instructions)
|
||||||
|
|
||||||
|
mock_external.assert_called_once_with(instructions)
|
||||||
|
# Result should have id, version, is_active added if not present
|
||||||
|
assert result is not None
|
||||||
|
assert result["name"] == "Test Agent"
|
||||||
|
assert "id" in result
|
||||||
|
assert result["version"] == 1
|
||||||
|
assert result["is_active"] is True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_preserves_existing_id_and_version(self):
|
||||||
|
"""Test that external service result preserves existing id and version."""
|
||||||
|
expected_result = {
|
||||||
|
"id": "existing-id",
|
||||||
|
"version": 3,
|
||||||
|
"is_active": False,
|
||||||
|
"name": "Test Agent",
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
core, "is_external_service_configured", return_value=True
|
||||||
|
), patch.object(
|
||||||
|
core, "generate_agent_external", new_callable=AsyncMock
|
||||||
|
) as mock_external:
|
||||||
|
mock_external.return_value = expected_result.copy()
|
||||||
|
|
||||||
|
result = await core.generate_agent({"steps": []})
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result["id"] == "existing-id"
|
||||||
|
assert result["version"] == 3
|
||||||
|
assert result["is_active"] is False
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_returns_none_when_external_service_fails(self):
|
||||||
|
"""Test that generate_agent returns None when external service fails."""
|
||||||
|
with patch.object(
|
||||||
|
core, "is_external_service_configured", return_value=True
|
||||||
|
), patch.object(
|
||||||
|
core, "generate_agent_external", new_callable=AsyncMock
|
||||||
|
) as mock_external:
|
||||||
|
mock_external.return_value = None
|
||||||
|
|
||||||
|
result = await core.generate_agent({"steps": []})
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestGenerateAgentPatch:
|
||||||
|
"""Test generate_agent_patch function service delegation."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_calls_external_service(self):
|
||||||
|
"""Test that generate_agent_patch calls the external service."""
|
||||||
|
expected_result = {"name": "Updated Agent", "nodes": [], "links": []}
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
core, "is_external_service_configured", return_value=True
|
||||||
|
), patch.object(
|
||||||
|
core, "generate_agent_patch_external", new_callable=AsyncMock
|
||||||
|
) as mock_external:
|
||||||
|
mock_external.return_value = expected_result
|
||||||
|
|
||||||
|
current_agent = {"nodes": [], "links": []}
|
||||||
|
result = await core.generate_agent_patch("Add a node", current_agent)
|
||||||
|
|
||||||
|
mock_external.assert_called_once_with("Add a node", current_agent)
|
||||||
|
assert result == expected_result
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_returns_clarifying_questions(self):
|
||||||
|
"""Test that generate_agent_patch returns clarifying questions."""
|
||||||
|
expected_result = {
|
||||||
|
"type": "clarifying_questions",
|
||||||
|
"questions": [{"question": "What type of node?"}],
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
core, "is_external_service_configured", return_value=True
|
||||||
|
), patch.object(
|
||||||
|
core, "generate_agent_patch_external", new_callable=AsyncMock
|
||||||
|
) as mock_external:
|
||||||
|
mock_external.return_value = expected_result
|
||||||
|
|
||||||
|
result = await core.generate_agent_patch("Add a node", {"nodes": []})
|
||||||
|
|
||||||
|
assert result == expected_result
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_returns_none_when_external_service_fails(self):
|
||||||
|
"""Test that generate_agent_patch returns None when service fails."""
|
||||||
|
with patch.object(
|
||||||
|
core, "is_external_service_configured", return_value=True
|
||||||
|
), patch.object(
|
||||||
|
core, "generate_agent_patch_external", new_callable=AsyncMock
|
||||||
|
) as mock_external:
|
||||||
|
mock_external.return_value = None
|
||||||
|
|
||||||
|
result = await core.generate_agent_patch("Add a node", {"nodes": []})
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestJsonToGraph:
|
||||||
|
"""Test json_to_graph function."""
|
||||||
|
|
||||||
|
def test_converts_agent_json_to_graph(self):
|
||||||
|
"""Test conversion of agent JSON to Graph model."""
|
||||||
|
agent_json = {
|
||||||
|
"id": "test-id",
|
||||||
|
"version": 2,
|
||||||
|
"is_active": True,
|
||||||
|
"name": "Test Agent",
|
||||||
|
"description": "A test agent",
|
||||||
|
"nodes": [
|
||||||
|
{
|
||||||
|
"id": "node1",
|
||||||
|
"block_id": "block1",
|
||||||
|
"input_default": {"key": "value"},
|
||||||
|
"metadata": {"x": 100},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"links": [
|
||||||
|
{
|
||||||
|
"id": "link1",
|
||||||
|
"source_id": "node1",
|
||||||
|
"sink_id": "output",
|
||||||
|
"source_name": "result",
|
||||||
|
"sink_name": "input",
|
||||||
|
"is_static": False,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
graph = core.json_to_graph(agent_json)
|
||||||
|
|
||||||
|
assert graph.id == "test-id"
|
||||||
|
assert graph.version == 2
|
||||||
|
assert graph.is_active is True
|
||||||
|
assert graph.name == "Test Agent"
|
||||||
|
assert graph.description == "A test agent"
|
||||||
|
assert len(graph.nodes) == 1
|
||||||
|
assert graph.nodes[0].id == "node1"
|
||||||
|
assert graph.nodes[0].block_id == "block1"
|
||||||
|
assert len(graph.links) == 1
|
||||||
|
assert graph.links[0].source_id == "node1"
|
||||||
|
|
||||||
|
def test_generates_ids_if_missing(self):
|
||||||
|
"""Test that missing IDs are generated."""
|
||||||
|
agent_json = {
|
||||||
|
"name": "Test Agent",
|
||||||
|
"nodes": [{"block_id": "block1"}],
|
||||||
|
"links": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
graph = core.json_to_graph(agent_json)
|
||||||
|
|
||||||
|
assert graph.id is not None
|
||||||
|
assert graph.nodes[0].id is not None
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
pytest.main([__file__, "-v"])
|
||||||
422
autogpt_platform/backend/test/agent_generator/test_service.py
Normal file
@@ -0,0 +1,422 @@
|
|||||||
|
"""
|
||||||
|
Tests for the Agent Generator external service client.
|
||||||
|
|
||||||
|
This test suite verifies the external Agent Generator service integration,
|
||||||
|
including service detection, API calls, and error handling.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from backend.api.features.chat.tools.agent_generator import service
|
||||||
|
|
||||||
|
|
||||||
|
class TestServiceConfiguration:
|
||||||
|
"""Test service configuration detection."""
|
||||||
|
|
||||||
|
def setup_method(self):
|
||||||
|
"""Reset settings singleton before each test."""
|
||||||
|
service._settings = None
|
||||||
|
service._client = None
|
||||||
|
|
||||||
|
def test_external_service_not_configured_when_host_empty(self):
|
||||||
|
"""Test that external service is not configured when host is empty."""
|
||||||
|
mock_settings = MagicMock()
|
||||||
|
mock_settings.config.agentgenerator_host = ""
|
||||||
|
|
||||||
|
with patch.object(service, "_get_settings", return_value=mock_settings):
|
||||||
|
assert service.is_external_service_configured() is False
|
||||||
|
|
||||||
|
def test_external_service_configured_when_host_set(self):
|
||||||
|
"""Test that external service is configured when host is set."""
|
||||||
|
mock_settings = MagicMock()
|
||||||
|
mock_settings.config.agentgenerator_host = "agent-generator.local"
|
||||||
|
|
||||||
|
with patch.object(service, "_get_settings", return_value=mock_settings):
|
||||||
|
assert service.is_external_service_configured() is True
|
||||||
|
|
||||||
|
def test_get_base_url(self):
|
||||||
|
"""Test base URL construction."""
|
||||||
|
mock_settings = MagicMock()
|
||||||
|
mock_settings.config.agentgenerator_host = "agent-generator.local"
|
||||||
|
mock_settings.config.agentgenerator_port = 8000
|
||||||
|
|
||||||
|
with patch.object(service, "_get_settings", return_value=mock_settings):
|
||||||
|
url = service._get_base_url()
|
||||||
|
assert url == "http://agent-generator.local:8000"
|
||||||
|
|
||||||
|
|
||||||
|
class TestDecomposeGoalExternal:
|
||||||
|
"""Test decompose_goal_external function."""
|
||||||
|
|
||||||
|
def setup_method(self):
|
||||||
|
"""Reset client singleton before each test."""
|
||||||
|
service._settings = None
|
||||||
|
service._client = None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_decompose_goal_returns_instructions(self):
|
||||||
|
"""Test successful decomposition returning instructions."""
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"success": True,
|
||||||
|
"type": "instructions",
|
||||||
|
"steps": ["Step 1", "Step 2"],
|
||||||
|
}
|
||||||
|
mock_response.raise_for_status = MagicMock()
|
||||||
|
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.post.return_value = mock_response
|
||||||
|
|
||||||
|
with patch.object(service, "_get_client", return_value=mock_client):
|
||||||
|
result = await service.decompose_goal_external("Build a chatbot")
|
||||||
|
|
||||||
|
assert result == {"type": "instructions", "steps": ["Step 1", "Step 2"]}
|
||||||
|
mock_client.post.assert_called_once_with(
|
||||||
|
"/api/decompose-description", json={"description": "Build a chatbot"}
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_decompose_goal_returns_clarifying_questions(self):
|
||||||
|
"""Test decomposition returning clarifying questions."""
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"success": True,
|
||||||
|
"type": "clarifying_questions",
|
||||||
|
"questions": ["What platform?", "What language?"],
|
||||||
|
}
|
||||||
|
mock_response.raise_for_status = MagicMock()
|
||||||
|
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.post.return_value = mock_response
|
||||||
|
|
||||||
|
with patch.object(service, "_get_client", return_value=mock_client):
|
||||||
|
result = await service.decompose_goal_external("Build something")
|
||||||
|
|
||||||
|
assert result == {
|
||||||
|
"type": "clarifying_questions",
|
||||||
|
"questions": ["What platform?", "What language?"],
|
||||||
|
}
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_decompose_goal_with_context(self):
|
||||||
|
"""Test decomposition with additional context."""
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"success": True,
|
||||||
|
"type": "instructions",
|
||||||
|
"steps": ["Step 1"],
|
||||||
|
}
|
||||||
|
mock_response.raise_for_status = MagicMock()
|
||||||
|
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.post.return_value = mock_response
|
||||||
|
|
||||||
|
with patch.object(service, "_get_client", return_value=mock_client):
|
||||||
|
await service.decompose_goal_external(
|
||||||
|
"Build a chatbot", context="Use Python"
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_client.post.assert_called_once_with(
|
||||||
|
"/api/decompose-description",
|
||||||
|
json={"description": "Build a chatbot", "user_instruction": "Use Python"},
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_decompose_goal_returns_unachievable_goal(self):
|
||||||
|
"""Test decomposition returning unachievable goal response."""
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"success": True,
|
||||||
|
"type": "unachievable_goal",
|
||||||
|
"reason": "Cannot do X",
|
||||||
|
"suggested_goal": "Try Y instead",
|
||||||
|
}
|
||||||
|
mock_response.raise_for_status = MagicMock()
|
||||||
|
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.post.return_value = mock_response
|
||||||
|
|
||||||
|
with patch.object(service, "_get_client", return_value=mock_client):
|
||||||
|
result = await service.decompose_goal_external("Do something impossible")
|
||||||
|
|
||||||
|
assert result == {
|
||||||
|
"type": "unachievable_goal",
|
||||||
|
"reason": "Cannot do X",
|
||||||
|
"suggested_goal": "Try Y instead",
|
||||||
|
}
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_decompose_goal_handles_http_error(self):
|
||||||
|
"""Test decomposition handles HTTP errors gracefully."""
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.post.side_effect = httpx.HTTPStatusError(
|
||||||
|
"Server error", request=MagicMock(), response=MagicMock()
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch.object(service, "_get_client", return_value=mock_client):
|
||||||
|
result = await service.decompose_goal_external("Build a chatbot")
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_decompose_goal_handles_request_error(self):
|
||||||
|
"""Test decomposition handles request errors gracefully."""
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.post.side_effect = httpx.RequestError("Connection failed")
|
||||||
|
|
||||||
|
with patch.object(service, "_get_client", return_value=mock_client):
|
||||||
|
result = await service.decompose_goal_external("Build a chatbot")
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_decompose_goal_handles_service_error(self):
|
||||||
|
"""Test decomposition handles service returning error."""
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"success": False,
|
||||||
|
"error": "Internal error",
|
||||||
|
}
|
||||||
|
mock_response.raise_for_status = MagicMock()
|
||||||
|
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.post.return_value = mock_response
|
||||||
|
|
||||||
|
with patch.object(service, "_get_client", return_value=mock_client):
|
||||||
|
result = await service.decompose_goal_external("Build a chatbot")
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestGenerateAgentExternal:
|
||||||
|
"""Test generate_agent_external function."""
|
||||||
|
|
||||||
|
def setup_method(self):
|
||||||
|
"""Reset client singleton before each test."""
|
||||||
|
service._settings = None
|
||||||
|
service._client = None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_generate_agent_success(self):
|
||||||
|
"""Test successful agent generation."""
|
||||||
|
agent_json = {
|
||||||
|
"name": "Test Agent",
|
||||||
|
"nodes": [],
|
||||||
|
"links": [],
|
||||||
|
}
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"success": True,
|
||||||
|
"agent_json": agent_json,
|
||||||
|
}
|
||||||
|
mock_response.raise_for_status = MagicMock()
|
||||||
|
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.post.return_value = mock_response
|
||||||
|
|
||||||
|
instructions = {"type": "instructions", "steps": ["Step 1"]}
|
||||||
|
|
||||||
|
with patch.object(service, "_get_client", return_value=mock_client):
|
||||||
|
result = await service.generate_agent_external(instructions)
|
||||||
|
|
||||||
|
assert result == agent_json
|
||||||
|
mock_client.post.assert_called_once_with(
|
||||||
|
"/api/generate-agent", json={"instructions": instructions}
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_generate_agent_handles_error(self):
|
||||||
|
"""Test agent generation handles errors gracefully."""
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.post.side_effect = httpx.RequestError("Connection failed")
|
||||||
|
|
||||||
|
with patch.object(service, "_get_client", return_value=mock_client):
|
||||||
|
result = await service.generate_agent_external({"steps": []})
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestGenerateAgentPatchExternal:
|
||||||
|
"""Test generate_agent_patch_external function."""
|
||||||
|
|
||||||
|
def setup_method(self):
|
||||||
|
"""Reset client singleton before each test."""
|
||||||
|
service._settings = None
|
||||||
|
service._client = None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_generate_patch_returns_updated_agent(self):
|
||||||
|
"""Test successful patch generation returning updated agent."""
|
||||||
|
updated_agent = {
|
||||||
|
"name": "Updated Agent",
|
||||||
|
"nodes": [{"id": "1", "block_id": "test"}],
|
||||||
|
"links": [],
|
||||||
|
}
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"success": True,
|
||||||
|
"agent_json": updated_agent,
|
||||||
|
}
|
||||||
|
mock_response.raise_for_status = MagicMock()
|
||||||
|
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.post.return_value = mock_response
|
||||||
|
|
||||||
|
current_agent = {"name": "Old Agent", "nodes": [], "links": []}
|
||||||
|
|
||||||
|
with patch.object(service, "_get_client", return_value=mock_client):
|
||||||
|
result = await service.generate_agent_patch_external(
|
||||||
|
"Add a new node", current_agent
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result == updated_agent
|
||||||
|
mock_client.post.assert_called_once_with(
|
||||||
|
"/api/update-agent",
|
||||||
|
json={
|
||||||
|
"update_request": "Add a new node",
|
||||||
|
"current_agent_json": current_agent,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_generate_patch_returns_clarifying_questions(self):
|
||||||
|
"""Test patch generation returning clarifying questions."""
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"success": True,
|
||||||
|
"type": "clarifying_questions",
|
||||||
|
"questions": ["What type of node?"],
|
||||||
|
}
|
||||||
|
mock_response.raise_for_status = MagicMock()
|
||||||
|
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.post.return_value = mock_response
|
||||||
|
|
||||||
|
with patch.object(service, "_get_client", return_value=mock_client):
|
||||||
|
result = await service.generate_agent_patch_external(
|
||||||
|
"Add something", {"nodes": []}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result == {
|
||||||
|
"type": "clarifying_questions",
|
||||||
|
"questions": ["What type of node?"],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class TestHealthCheck:
|
||||||
|
"""Test health_check function."""
|
||||||
|
|
||||||
|
def setup_method(self):
|
||||||
|
"""Reset singletons before each test."""
|
||||||
|
service._settings = None
|
||||||
|
service._client = None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_health_check_returns_false_when_not_configured(self):
|
||||||
|
"""Test health check returns False when service not configured."""
|
||||||
|
with patch.object(
|
||||||
|
service, "is_external_service_configured", return_value=False
|
||||||
|
):
|
||||||
|
result = await service.health_check()
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_health_check_returns_true_when_healthy(self):
|
||||||
|
"""Test health check returns True when service is healthy."""
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"status": "healthy",
|
||||||
|
"blocks_loaded": True,
|
||||||
|
}
|
||||||
|
mock_response.raise_for_status = MagicMock()
|
||||||
|
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.get.return_value = mock_response
|
||||||
|
|
||||||
|
with patch.object(service, "is_external_service_configured", return_value=True):
|
||||||
|
with patch.object(service, "_get_client", return_value=mock_client):
|
||||||
|
result = await service.health_check()
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
mock_client.get.assert_called_once_with("/health")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_health_check_returns_false_when_not_healthy(self):
|
||||||
|
"""Test health check returns False when service is not healthy."""
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"status": "unhealthy",
|
||||||
|
"blocks_loaded": False,
|
||||||
|
}
|
||||||
|
mock_response.raise_for_status = MagicMock()
|
||||||
|
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.get.return_value = mock_response
|
||||||
|
|
||||||
|
with patch.object(service, "is_external_service_configured", return_value=True):
|
||||||
|
with patch.object(service, "_get_client", return_value=mock_client):
|
||||||
|
result = await service.health_check()
|
||||||
|
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_health_check_returns_false_on_error(self):
|
||||||
|
"""Test health check returns False on connection error."""
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.get.side_effect = httpx.RequestError("Connection failed")
|
||||||
|
|
||||||
|
with patch.object(service, "is_external_service_configured", return_value=True):
|
||||||
|
with patch.object(service, "_get_client", return_value=mock_client):
|
||||||
|
result = await service.health_check()
|
||||||
|
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetBlocksExternal:
|
||||||
|
"""Test get_blocks_external function."""
|
||||||
|
|
||||||
|
def setup_method(self):
|
||||||
|
"""Reset client singleton before each test."""
|
||||||
|
service._settings = None
|
||||||
|
service._client = None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_blocks_success(self):
|
||||||
|
"""Test successful blocks retrieval."""
|
||||||
|
blocks = [
|
||||||
|
{"id": "block1", "name": "Block 1"},
|
||||||
|
{"id": "block2", "name": "Block 2"},
|
||||||
|
]
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"success": True,
|
||||||
|
"blocks": blocks,
|
||||||
|
}
|
||||||
|
mock_response.raise_for_status = MagicMock()
|
||||||
|
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.get.return_value = mock_response
|
||||||
|
|
||||||
|
with patch.object(service, "_get_client", return_value=mock_client):
|
||||||
|
result = await service.get_blocks_external()
|
||||||
|
|
||||||
|
assert result == blocks
|
||||||
|
mock_client.get.assert_called_once_with("/api/blocks")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_blocks_handles_error(self):
|
||||||
|
"""Test blocks retrieval handles errors gracefully."""
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.get.side_effect = httpx.RequestError("Connection failed")
|
||||||
|
|
||||||
|
with patch.object(service, "_get_client", return_value=mock_client):
|
||||||
|
result = await service.get_blocks_external()
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
pytest.main([__file__, "-v"])
|
||||||
@@ -86,7 +86,6 @@ export function FloatingSafeModeToggle({
|
|||||||
const {
|
const {
|
||||||
currentHITLSafeMode,
|
currentHITLSafeMode,
|
||||||
showHITLToggle,
|
showHITLToggle,
|
||||||
isHITLStateUndetermined,
|
|
||||||
handleHITLToggle,
|
handleHITLToggle,
|
||||||
currentSensitiveActionSafeMode,
|
currentSensitiveActionSafeMode,
|
||||||
showSensitiveActionToggle,
|
showSensitiveActionToggle,
|
||||||
@@ -99,16 +98,9 @@ export function FloatingSafeModeToggle({
|
|||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
const showHITL = showHITLToggle && !isHITLStateUndetermined;
|
|
||||||
const showSensitive = showSensitiveActionToggle;
|
|
||||||
|
|
||||||
if (!showHITL && !showSensitive) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className={cn("fixed z-50 flex flex-col gap-2", className)}>
|
<div className={cn("fixed z-50 flex flex-col gap-2", className)}>
|
||||||
{showHITL && (
|
{showHITLToggle && (
|
||||||
<SafeModeButton
|
<SafeModeButton
|
||||||
isEnabled={currentHITLSafeMode}
|
isEnabled={currentHITLSafeMode}
|
||||||
label="Human in the loop block approval"
|
label="Human in the loop block approval"
|
||||||
@@ -119,7 +111,7 @@ export function FloatingSafeModeToggle({
|
|||||||
fullWidth={fullWidth}
|
fullWidth={fullWidth}
|
||||||
/>
|
/>
|
||||||
)}
|
)}
|
||||||
{showSensitive && (
|
{showSensitiveActionToggle && (
|
||||||
<SafeModeButton
|
<SafeModeButton
|
||||||
isEnabled={currentSensitiveActionSafeMode}
|
isEnabled={currentSensitiveActionSafeMode}
|
||||||
label="Sensitive actions blocks approval"
|
label="Sensitive actions blocks approval"
|
||||||
|
|||||||
@@ -0,0 +1,41 @@
|
|||||||
|
"use client";
|
||||||
|
|
||||||
|
import { createContext, useContext, useRef, type ReactNode } from "react";
|
||||||
|
|
||||||
|
interface NewChatContextValue {
|
||||||
|
onNewChatClick: () => void;
|
||||||
|
setOnNewChatClick: (handler?: () => void) => void;
|
||||||
|
performNewChat?: () => void;
|
||||||
|
setPerformNewChat: (handler?: () => void) => void;
|
||||||
|
}
|
||||||
|
|
||||||
|
const NewChatContext = createContext<NewChatContextValue | null>(null);
|
||||||
|
|
||||||
|
export function NewChatProvider({ children }: { children: ReactNode }) {
|
||||||
|
const onNewChatRef = useRef<(() => void) | undefined>();
|
||||||
|
const performNewChatRef = useRef<(() => void) | undefined>();
|
||||||
|
const contextValueRef = useRef<NewChatContextValue>({
|
||||||
|
onNewChatClick() {
|
||||||
|
onNewChatRef.current?.();
|
||||||
|
},
|
||||||
|
setOnNewChatClick(handler?: () => void) {
|
||||||
|
onNewChatRef.current = handler;
|
||||||
|
},
|
||||||
|
performNewChat() {
|
||||||
|
performNewChatRef.current?.();
|
||||||
|
},
|
||||||
|
setPerformNewChat(handler?: () => void) {
|
||||||
|
performNewChatRef.current = handler;
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
return (
|
||||||
|
<NewChatContext.Provider value={contextValueRef.current}>
|
||||||
|
{children}
|
||||||
|
</NewChatContext.Provider>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
export function useNewChat() {
|
||||||
|
return useContext(NewChatContext);
|
||||||
|
}
|
||||||
@@ -1,8 +1,10 @@
|
|||||||
"use client";
|
"use client";
|
||||||
|
|
||||||
import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner";
|
import { ChatLoader } from "@/components/contextual/Chat/components/ChatLoader/ChatLoader";
|
||||||
import { NAVBAR_HEIGHT_PX } from "@/lib/constants";
|
import { NAVBAR_HEIGHT_PX } from "@/lib/constants";
|
||||||
import type { ReactNode } from "react";
|
import type { ReactNode } from "react";
|
||||||
|
import { useEffect } from "react";
|
||||||
|
import { useNewChat } from "../../NewChatContext";
|
||||||
import { DesktopSidebar } from "./components/DesktopSidebar/DesktopSidebar";
|
import { DesktopSidebar } from "./components/DesktopSidebar/DesktopSidebar";
|
||||||
import { LoadingState } from "./components/LoadingState/LoadingState";
|
import { LoadingState } from "./components/LoadingState/LoadingState";
|
||||||
import { MobileDrawer } from "./components/MobileDrawer/MobileDrawer";
|
import { MobileDrawer } from "./components/MobileDrawer/MobileDrawer";
|
||||||
@@ -33,10 +35,25 @@ export function CopilotShell({ children }: Props) {
|
|||||||
isReadyToShowContent,
|
isReadyToShowContent,
|
||||||
} = useCopilotShell();
|
} = useCopilotShell();
|
||||||
|
|
||||||
|
const newChatContext = useNewChat();
|
||||||
|
const handleNewChatClickWrapper =
|
||||||
|
newChatContext?.onNewChatClick || handleNewChat;
|
||||||
|
|
||||||
|
useEffect(
|
||||||
|
function registerNewChatHandler() {
|
||||||
|
if (!newChatContext) return;
|
||||||
|
newChatContext.setPerformNewChat(handleNewChat);
|
||||||
|
return function cleanup() {
|
||||||
|
newChatContext.setPerformNewChat(undefined);
|
||||||
|
};
|
||||||
|
},
|
||||||
|
[newChatContext, handleNewChat],
|
||||||
|
);
|
||||||
|
|
||||||
if (!isLoggedIn) {
|
if (!isLoggedIn) {
|
||||||
return (
|
return (
|
||||||
<div className="flex h-full items-center justify-center">
|
<div className="flex h-full items-center justify-center">
|
||||||
<LoadingSpinner size="large" />
|
<ChatLoader />
|
||||||
</div>
|
</div>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
@@ -55,7 +72,7 @@ export function CopilotShell({ children }: Props) {
|
|||||||
isFetchingNextPage={isFetchingNextPage}
|
isFetchingNextPage={isFetchingNextPage}
|
||||||
onSelectSession={handleSelectSession}
|
onSelectSession={handleSelectSession}
|
||||||
onFetchNextPage={fetchNextPage}
|
onFetchNextPage={fetchNextPage}
|
||||||
onNewChat={handleNewChat}
|
onNewChat={handleNewChatClickWrapper}
|
||||||
hasActiveSession={Boolean(hasActiveSession)}
|
hasActiveSession={Boolean(hasActiveSession)}
|
||||||
/>
|
/>
|
||||||
)}
|
)}
|
||||||
@@ -77,7 +94,7 @@ export function CopilotShell({ children }: Props) {
|
|||||||
isFetchingNextPage={isFetchingNextPage}
|
isFetchingNextPage={isFetchingNextPage}
|
||||||
onSelectSession={handleSelectSession}
|
onSelectSession={handleSelectSession}
|
||||||
onFetchNextPage={fetchNextPage}
|
onFetchNextPage={fetchNextPage}
|
||||||
onNewChat={handleNewChat}
|
onNewChat={handleNewChatClickWrapper}
|
||||||
onClose={handleCloseDrawer}
|
onClose={handleCloseDrawer}
|
||||||
onOpenChange={handleDrawerOpenChange}
|
onOpenChange={handleDrawerOpenChange}
|
||||||
hasActiveSession={Boolean(hasActiveSession)}
|
hasActiveSession={Boolean(hasActiveSession)}
|
||||||
|
|||||||
@@ -148,13 +148,15 @@ export function useCopilotShell() {
|
|||||||
setHasAutoSelectedSession(false);
|
setHasAutoSelectedSession(false);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const isLoading = isSessionsLoading && accumulatedSessions.length === 0;
|
||||||
|
|
||||||
return {
|
return {
|
||||||
isMobile,
|
isMobile,
|
||||||
isDrawerOpen,
|
isDrawerOpen,
|
||||||
isLoggedIn,
|
isLoggedIn,
|
||||||
hasActiveSession:
|
hasActiveSession:
|
||||||
Boolean(currentSessionId) && (!isOnHomepage || Boolean(paramSessionId)),
|
Boolean(currentSessionId) && (!isOnHomepage || Boolean(paramSessionId)),
|
||||||
isLoading: isSessionsLoading || !areAllSessionsLoaded,
|
isLoading,
|
||||||
sessions: visibleSessions,
|
sessions: visibleSessions,
|
||||||
currentSessionId: sidebarSelectedSessionId,
|
currentSessionId: sidebarSelectedSessionId,
|
||||||
handleSelectSession,
|
handleSelectSession,
|
||||||
|
|||||||
@@ -1,5 +1,28 @@
|
|||||||
import type { User } from "@supabase/supabase-js";
|
import type { User } from "@supabase/supabase-js";
|
||||||
|
|
||||||
|
export type PageState =
|
||||||
|
| { type: "welcome" }
|
||||||
|
| { type: "newChat" }
|
||||||
|
| { type: "creating"; prompt: string }
|
||||||
|
| { type: "chat"; sessionId: string; initialPrompt?: string };
|
||||||
|
|
||||||
|
export function getInitialPromptFromState(
|
||||||
|
pageState: PageState,
|
||||||
|
storedInitialPrompt: string | undefined,
|
||||||
|
) {
|
||||||
|
if (storedInitialPrompt) return storedInitialPrompt;
|
||||||
|
if (pageState.type === "creating") return pageState.prompt;
|
||||||
|
if (pageState.type === "chat") return pageState.initialPrompt;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function shouldResetToWelcome(pageState: PageState) {
|
||||||
|
return (
|
||||||
|
pageState.type !== "newChat" &&
|
||||||
|
pageState.type !== "creating" &&
|
||||||
|
pageState.type !== "welcome"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
export function getGreetingName(user?: User | null): string {
|
export function getGreetingName(user?: User | null): string {
|
||||||
if (!user) return "there";
|
if (!user) return "there";
|
||||||
const metadata = user.user_metadata as Record<string, unknown> | undefined;
|
const metadata = user.user_metadata as Record<string, unknown> | undefined;
|
||||||
|
|||||||
@@ -1,6 +1,11 @@
|
|||||||
import type { ReactNode } from "react";
|
import type { ReactNode } from "react";
|
||||||
|
import { NewChatProvider } from "./NewChatContext";
|
||||||
import { CopilotShell } from "./components/CopilotShell/CopilotShell";
|
import { CopilotShell } from "./components/CopilotShell/CopilotShell";
|
||||||
|
|
||||||
export default function CopilotLayout({ children }: { children: ReactNode }) {
|
export default function CopilotLayout({ children }: { children: ReactNode }) {
|
||||||
return <CopilotShell>{children}</CopilotShell>;
|
return (
|
||||||
|
<NewChatProvider>
|
||||||
|
<CopilotShell>{children}</CopilotShell>
|
||||||
|
</NewChatProvider>
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,142 +1,35 @@
|
|||||||
"use client";
|
"use client";
|
||||||
|
|
||||||
import { postV2CreateSession } from "@/app/api/__generated__/endpoints/chat/chat";
|
|
||||||
import { Skeleton } from "@/components/__legacy__/ui/skeleton";
|
import { Skeleton } from "@/components/__legacy__/ui/skeleton";
|
||||||
import { Button } from "@/components/atoms/Button/Button";
|
import { Button } from "@/components/atoms/Button/Button";
|
||||||
import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner";
|
|
||||||
import { Text } from "@/components/atoms/Text/Text";
|
import { Text } from "@/components/atoms/Text/Text";
|
||||||
import { Chat } from "@/components/contextual/Chat/Chat";
|
import { Chat } from "@/components/contextual/Chat/Chat";
|
||||||
import { ChatInput } from "@/components/contextual/Chat/components/ChatInput/ChatInput";
|
import { ChatInput } from "@/components/contextual/Chat/components/ChatInput/ChatInput";
|
||||||
import { getHomepageRoute } from "@/lib/constants";
|
import { ChatLoader } from "@/components/contextual/Chat/components/ChatLoader/ChatLoader";
|
||||||
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
import { Dialog } from "@/components/molecules/Dialog/Dialog";
|
||||||
import {
|
import { useCopilotPage } from "./useCopilotPage";
|
||||||
Flag,
|
|
||||||
type FlagValues,
|
|
||||||
useGetFlag,
|
|
||||||
} from "@/services/feature-flags/use-get-flag";
|
|
||||||
import { useFlags } from "launchdarkly-react-client-sdk";
|
|
||||||
import { useRouter, useSearchParams } from "next/navigation";
|
|
||||||
import { useEffect, useMemo, useRef, useState } from "react";
|
|
||||||
import { getGreetingName, getQuickActions } from "./helpers";
|
|
||||||
|
|
||||||
type PageState =
|
|
||||||
| { type: "welcome" }
|
|
||||||
| { type: "creating"; prompt: string }
|
|
||||||
| { type: "chat"; sessionId: string; initialPrompt?: string };
|
|
||||||
|
|
||||||
export default function CopilotPage() {
|
export default function CopilotPage() {
|
||||||
const router = useRouter();
|
const { state, handlers } = useCopilotPage();
|
||||||
const searchParams = useSearchParams();
|
const {
|
||||||
const { user, isLoggedIn, isUserLoading } = useSupabase();
|
greetingName,
|
||||||
|
quickActions,
|
||||||
|
isLoading,
|
||||||
|
pageState,
|
||||||
|
isNewChatModalOpen,
|
||||||
|
isReady,
|
||||||
|
} = state;
|
||||||
|
const {
|
||||||
|
handleQuickAction,
|
||||||
|
startChatWithPrompt,
|
||||||
|
handleSessionNotFound,
|
||||||
|
handleStreamingChange,
|
||||||
|
handleCancelNewChat,
|
||||||
|
proceedWithNewChat,
|
||||||
|
handleNewChatModalOpen,
|
||||||
|
} = handlers;
|
||||||
|
|
||||||
const isChatEnabled = useGetFlag(Flag.CHAT);
|
if (!isReady) {
|
||||||
const flags = useFlags<FlagValues>();
|
|
||||||
const homepageRoute = getHomepageRoute(isChatEnabled);
|
|
||||||
const envEnabled = process.env.NEXT_PUBLIC_LAUNCHDARKLY_ENABLED === "true";
|
|
||||||
const clientId = process.env.NEXT_PUBLIC_LAUNCHDARKLY_CLIENT_ID;
|
|
||||||
const isLaunchDarklyConfigured = envEnabled && Boolean(clientId);
|
|
||||||
const isFlagReady =
|
|
||||||
!isLaunchDarklyConfigured || flags[Flag.CHAT] !== undefined;
|
|
||||||
|
|
||||||
const [pageState, setPageState] = useState<PageState>({ type: "welcome" });
|
|
||||||
const initialPromptRef = useRef<Map<string, string>>(new Map());
|
|
||||||
|
|
||||||
const urlSessionId = searchParams.get("sessionId");
|
|
||||||
|
|
||||||
// Sync with URL sessionId (preserve initialPrompt from ref)
|
|
||||||
useEffect(
|
|
||||||
function syncSessionFromUrl() {
|
|
||||||
if (urlSessionId) {
|
|
||||||
// If we're already in chat state with this sessionId, don't overwrite
|
|
||||||
if (pageState.type === "chat" && pageState.sessionId === urlSessionId) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
// Get initialPrompt from ref or current state
|
|
||||||
const storedInitialPrompt = initialPromptRef.current.get(urlSessionId);
|
|
||||||
const currentInitialPrompt =
|
|
||||||
storedInitialPrompt ||
|
|
||||||
(pageState.type === "creating"
|
|
||||||
? pageState.prompt
|
|
||||||
: pageState.type === "chat"
|
|
||||||
? pageState.initialPrompt
|
|
||||||
: undefined);
|
|
||||||
if (currentInitialPrompt) {
|
|
||||||
initialPromptRef.current.set(urlSessionId, currentInitialPrompt);
|
|
||||||
}
|
|
||||||
setPageState({
|
|
||||||
type: "chat",
|
|
||||||
sessionId: urlSessionId,
|
|
||||||
initialPrompt: currentInitialPrompt,
|
|
||||||
});
|
|
||||||
} else if (pageState.type === "chat") {
|
|
||||||
setPageState({ type: "welcome" });
|
|
||||||
}
|
|
||||||
},
|
|
||||||
[urlSessionId],
|
|
||||||
);
|
|
||||||
|
|
||||||
useEffect(
|
|
||||||
function ensureAccess() {
|
|
||||||
if (!isFlagReady) return;
|
|
||||||
if (isChatEnabled === false) {
|
|
||||||
router.replace(homepageRoute);
|
|
||||||
}
|
|
||||||
},
|
|
||||||
[homepageRoute, isChatEnabled, isFlagReady, router],
|
|
||||||
);
|
|
||||||
|
|
||||||
const greetingName = useMemo(
|
|
||||||
function getName() {
|
|
||||||
return getGreetingName(user);
|
|
||||||
},
|
|
||||||
[user],
|
|
||||||
);
|
|
||||||
|
|
||||||
const quickActions = useMemo(function getActions() {
|
|
||||||
return getQuickActions();
|
|
||||||
}, []);
|
|
||||||
|
|
||||||
async function startChatWithPrompt(prompt: string) {
|
|
||||||
if (!prompt?.trim()) return;
|
|
||||||
if (pageState.type === "creating") return;
|
|
||||||
|
|
||||||
const trimmedPrompt = prompt.trim();
|
|
||||||
setPageState({ type: "creating", prompt: trimmedPrompt });
|
|
||||||
|
|
||||||
try {
|
|
||||||
// Create session
|
|
||||||
const sessionResponse = await postV2CreateSession({
|
|
||||||
body: JSON.stringify({}),
|
|
||||||
});
|
|
||||||
|
|
||||||
if (sessionResponse.status !== 200 || !sessionResponse.data?.id) {
|
|
||||||
throw new Error("Failed to create session");
|
|
||||||
}
|
|
||||||
|
|
||||||
const sessionId = sessionResponse.data.id;
|
|
||||||
|
|
||||||
// Store initialPrompt in ref so it persists across re-renders
|
|
||||||
initialPromptRef.current.set(sessionId, trimmedPrompt);
|
|
||||||
|
|
||||||
// Update URL and show Chat with initial prompt
|
|
||||||
// Chat will handle sending the message and streaming
|
|
||||||
window.history.replaceState(null, "", `/copilot?sessionId=${sessionId}`);
|
|
||||||
setPageState({ type: "chat", sessionId, initialPrompt: trimmedPrompt });
|
|
||||||
} catch (error) {
|
|
||||||
console.error("[CopilotPage] Failed to start chat:", error);
|
|
||||||
setPageState({ type: "welcome" });
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
function handleQuickAction(action: string) {
|
|
||||||
startChatWithPrompt(action);
|
|
||||||
}
|
|
||||||
|
|
||||||
function handleSessionNotFound() {
|
|
||||||
router.replace("/copilot");
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!isFlagReady || isChatEnabled === false || !isLoggedIn) {
|
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -150,7 +43,55 @@ export default function CopilotPage() {
|
|||||||
urlSessionId={pageState.sessionId}
|
urlSessionId={pageState.sessionId}
|
||||||
initialPrompt={pageState.initialPrompt}
|
initialPrompt={pageState.initialPrompt}
|
||||||
onSessionNotFound={handleSessionNotFound}
|
onSessionNotFound={handleSessionNotFound}
|
||||||
|
onStreamingChange={handleStreamingChange}
|
||||||
/>
|
/>
|
||||||
|
<Dialog
|
||||||
|
title="Interrupt current chat?"
|
||||||
|
styling={{ maxWidth: 300, width: "100%" }}
|
||||||
|
controlled={{
|
||||||
|
isOpen: isNewChatModalOpen,
|
||||||
|
set: handleNewChatModalOpen,
|
||||||
|
}}
|
||||||
|
onClose={handleCancelNewChat}
|
||||||
|
>
|
||||||
|
<Dialog.Content>
|
||||||
|
<div className="flex flex-col gap-4">
|
||||||
|
<Text variant="body">
|
||||||
|
The current chat response will be interrupted. Are you sure you
|
||||||
|
want to start a new chat?
|
||||||
|
</Text>
|
||||||
|
<Dialog.Footer>
|
||||||
|
<Button
|
||||||
|
type="button"
|
||||||
|
variant="outline"
|
||||||
|
onClick={handleCancelNewChat}
|
||||||
|
>
|
||||||
|
Cancel
|
||||||
|
</Button>
|
||||||
|
<Button
|
||||||
|
type="button"
|
||||||
|
variant="primary"
|
||||||
|
onClick={proceedWithNewChat}
|
||||||
|
>
|
||||||
|
Start new chat
|
||||||
|
</Button>
|
||||||
|
</Dialog.Footer>
|
||||||
|
</div>
|
||||||
|
</Dialog.Content>
|
||||||
|
</Dialog>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (pageState.type === "newChat") {
|
||||||
|
return (
|
||||||
|
<div className="flex h-full flex-1 flex-col items-center justify-center bg-[#f8f8f9]">
|
||||||
|
<div className="flex flex-col items-center gap-4">
|
||||||
|
<ChatLoader />
|
||||||
|
<Text variant="body" className="text-zinc-500">
|
||||||
|
Loading your chats...
|
||||||
|
</Text>
|
||||||
|
</div>
|
||||||
</div>
|
</div>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
@@ -158,18 +99,18 @@ export default function CopilotPage() {
|
|||||||
// Show loading state while creating session and sending first message
|
// Show loading state while creating session and sending first message
|
||||||
if (pageState.type === "creating") {
|
if (pageState.type === "creating") {
|
||||||
return (
|
return (
|
||||||
<div className="flex h-full flex-1 flex-col items-center justify-center bg-[#f8f8f9] px-6 py-10">
|
<div className="flex h-full flex-1 flex-col items-center justify-center bg-[#f8f8f9]">
|
||||||
<LoadingSpinner size="large" />
|
<div className="flex flex-col items-center gap-4">
|
||||||
<Text variant="body" className="mt-4 text-zinc-500">
|
<ChatLoader />
|
||||||
Starting your chat...
|
<Text variant="body" className="text-zinc-500">
|
||||||
</Text>
|
Loading your chats...
|
||||||
|
</Text>
|
||||||
|
</div>
|
||||||
</div>
|
</div>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Show Welcome screen
|
// Show Welcome screen
|
||||||
const isLoading = isUserLoading;
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className="flex h-full flex-1 items-center justify-center overflow-y-auto bg-[#f8f8f9] px-6 py-10">
|
<div className="flex h-full flex-1 items-center justify-center overflow-y-auto bg-[#f8f8f9] px-6 py-10">
|
||||||
<div className="w-full text-center">
|
<div className="w-full text-center">
|
||||||
|
|||||||
@@ -0,0 +1,266 @@
|
|||||||
|
import { postV2CreateSession } from "@/app/api/__generated__/endpoints/chat/chat";
|
||||||
|
import { useToast } from "@/components/molecules/Toast/use-toast";
|
||||||
|
import { getHomepageRoute } from "@/lib/constants";
|
||||||
|
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
||||||
|
import {
|
||||||
|
Flag,
|
||||||
|
type FlagValues,
|
||||||
|
useGetFlag,
|
||||||
|
} from "@/services/feature-flags/use-get-flag";
|
||||||
|
import * as Sentry from "@sentry/nextjs";
|
||||||
|
import { useFlags } from "launchdarkly-react-client-sdk";
|
||||||
|
import { useRouter } from "next/navigation";
|
||||||
|
import { useEffect, useReducer } from "react";
|
||||||
|
import { useNewChat } from "./NewChatContext";
|
||||||
|
import { getGreetingName, getQuickActions, type PageState } from "./helpers";
|
||||||
|
import { useCopilotURLState } from "./useCopilotURLState";
|
||||||
|
|
||||||
|
type CopilotState = {
|
||||||
|
pageState: PageState;
|
||||||
|
isStreaming: boolean;
|
||||||
|
isNewChatModalOpen: boolean;
|
||||||
|
initialPrompts: Record<string, string>;
|
||||||
|
previousSessionId: string | null;
|
||||||
|
};
|
||||||
|
|
||||||
|
type CopilotAction =
|
||||||
|
| { type: "setPageState"; pageState: PageState }
|
||||||
|
| { type: "setStreaming"; isStreaming: boolean }
|
||||||
|
| { type: "setNewChatModalOpen"; isOpen: boolean }
|
||||||
|
| { type: "setInitialPrompt"; sessionId: string; prompt: string }
|
||||||
|
| { type: "setPreviousSessionId"; sessionId: string | null };
|
||||||
|
|
||||||
|
function isSamePageState(next: PageState, current: PageState) {
|
||||||
|
if (next.type !== current.type) return false;
|
||||||
|
if (next.type === "creating" && current.type === "creating") {
|
||||||
|
return next.prompt === current.prompt;
|
||||||
|
}
|
||||||
|
if (next.type === "chat" && current.type === "chat") {
|
||||||
|
return (
|
||||||
|
next.sessionId === current.sessionId &&
|
||||||
|
next.initialPrompt === current.initialPrompt
|
||||||
|
);
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
function copilotReducer(
|
||||||
|
state: CopilotState,
|
||||||
|
action: CopilotAction,
|
||||||
|
): CopilotState {
|
||||||
|
if (action.type === "setPageState") {
|
||||||
|
if (isSamePageState(action.pageState, state.pageState)) return state;
|
||||||
|
return { ...state, pageState: action.pageState };
|
||||||
|
}
|
||||||
|
if (action.type === "setStreaming") {
|
||||||
|
if (action.isStreaming === state.isStreaming) return state;
|
||||||
|
return { ...state, isStreaming: action.isStreaming };
|
||||||
|
}
|
||||||
|
if (action.type === "setNewChatModalOpen") {
|
||||||
|
if (action.isOpen === state.isNewChatModalOpen) return state;
|
||||||
|
return { ...state, isNewChatModalOpen: action.isOpen };
|
||||||
|
}
|
||||||
|
if (action.type === "setInitialPrompt") {
|
||||||
|
if (state.initialPrompts[action.sessionId] === action.prompt) return state;
|
||||||
|
return {
|
||||||
|
...state,
|
||||||
|
initialPrompts: {
|
||||||
|
...state.initialPrompts,
|
||||||
|
[action.sessionId]: action.prompt,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
}
|
||||||
|
if (action.type === "setPreviousSessionId") {
|
||||||
|
if (state.previousSessionId === action.sessionId) return state;
|
||||||
|
return { ...state, previousSessionId: action.sessionId };
|
||||||
|
}
|
||||||
|
return state;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function useCopilotPage() {
|
||||||
|
const router = useRouter();
|
||||||
|
const { user, isLoggedIn, isUserLoading } = useSupabase();
|
||||||
|
const { toast } = useToast();
|
||||||
|
|
||||||
|
const isChatEnabled = useGetFlag(Flag.CHAT);
|
||||||
|
const flags = useFlags<FlagValues>();
|
||||||
|
const homepageRoute = getHomepageRoute(isChatEnabled);
|
||||||
|
const envEnabled = process.env.NEXT_PUBLIC_LAUNCHDARKLY_ENABLED === "true";
|
||||||
|
const clientId = process.env.NEXT_PUBLIC_LAUNCHDARKLY_CLIENT_ID;
|
||||||
|
const isLaunchDarklyConfigured = envEnabled && Boolean(clientId);
|
||||||
|
const isFlagReady =
|
||||||
|
!isLaunchDarklyConfigured || flags[Flag.CHAT] !== undefined;
|
||||||
|
|
||||||
|
const [state, dispatch] = useReducer(copilotReducer, {
|
||||||
|
pageState: { type: "welcome" },
|
||||||
|
isStreaming: false,
|
||||||
|
isNewChatModalOpen: false,
|
||||||
|
initialPrompts: {},
|
||||||
|
previousSessionId: null,
|
||||||
|
});
|
||||||
|
|
||||||
|
const newChatContext = useNewChat();
|
||||||
|
const greetingName = getGreetingName(user);
|
||||||
|
const quickActions = getQuickActions();
|
||||||
|
|
||||||
|
function setPageState(pageState: PageState) {
|
||||||
|
dispatch({ type: "setPageState", pageState });
|
||||||
|
}
|
||||||
|
|
||||||
|
function setInitialPrompt(sessionId: string, prompt: string) {
|
||||||
|
dispatch({ type: "setInitialPrompt", sessionId, prompt });
|
||||||
|
}
|
||||||
|
|
||||||
|
function setPreviousSessionId(sessionId: string | null) {
|
||||||
|
dispatch({ type: "setPreviousSessionId", sessionId });
|
||||||
|
}
|
||||||
|
|
||||||
|
const { setUrlSessionId } = useCopilotURLState({
|
||||||
|
pageState: state.pageState,
|
||||||
|
initialPrompts: state.initialPrompts,
|
||||||
|
previousSessionId: state.previousSessionId,
|
||||||
|
setPageState,
|
||||||
|
setInitialPrompt,
|
||||||
|
setPreviousSessionId,
|
||||||
|
});
|
||||||
|
|
||||||
|
useEffect(
|
||||||
|
function registerNewChatHandler() {
|
||||||
|
if (!newChatContext) return;
|
||||||
|
newChatContext.setOnNewChatClick(handleNewChatClick);
|
||||||
|
return function cleanup() {
|
||||||
|
newChatContext.setOnNewChatClick(undefined);
|
||||||
|
};
|
||||||
|
},
|
||||||
|
[newChatContext, handleNewChatClick],
|
||||||
|
);
|
||||||
|
|
||||||
|
useEffect(
|
||||||
|
function transitionNewChatToWelcome() {
|
||||||
|
if (state.pageState.type === "newChat") {
|
||||||
|
function setWelcomeState() {
|
||||||
|
dispatch({ type: "setPageState", pageState: { type: "welcome" } });
|
||||||
|
}
|
||||||
|
|
||||||
|
const timer = setTimeout(setWelcomeState, 300);
|
||||||
|
|
||||||
|
return function cleanup() {
|
||||||
|
clearTimeout(timer);
|
||||||
|
};
|
||||||
|
}
|
||||||
|
},
|
||||||
|
[state.pageState.type],
|
||||||
|
);
|
||||||
|
|
||||||
|
useEffect(
|
||||||
|
function ensureAccess() {
|
||||||
|
if (!isFlagReady) return;
|
||||||
|
if (isChatEnabled === false) {
|
||||||
|
router.replace(homepageRoute);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
[homepageRoute, isChatEnabled, isFlagReady, router],
|
||||||
|
);
|
||||||
|
|
||||||
|
async function startChatWithPrompt(prompt: string) {
|
||||||
|
if (!prompt?.trim()) return;
|
||||||
|
if (state.pageState.type === "creating") return;
|
||||||
|
|
||||||
|
const trimmedPrompt = prompt.trim();
|
||||||
|
dispatch({
|
||||||
|
type: "setPageState",
|
||||||
|
pageState: { type: "creating", prompt: trimmedPrompt },
|
||||||
|
});
|
||||||
|
|
||||||
|
try {
|
||||||
|
const sessionResponse = await postV2CreateSession({
|
||||||
|
body: JSON.stringify({}),
|
||||||
|
});
|
||||||
|
|
||||||
|
if (sessionResponse.status !== 200 || !sessionResponse.data?.id) {
|
||||||
|
throw new Error("Failed to create session");
|
||||||
|
}
|
||||||
|
|
||||||
|
const sessionId = sessionResponse.data.id;
|
||||||
|
|
||||||
|
dispatch({
|
||||||
|
type: "setInitialPrompt",
|
||||||
|
sessionId,
|
||||||
|
prompt: trimmedPrompt,
|
||||||
|
});
|
||||||
|
|
||||||
|
await setUrlSessionId(sessionId, { shallow: false });
|
||||||
|
dispatch({
|
||||||
|
type: "setPageState",
|
||||||
|
pageState: { type: "chat", sessionId, initialPrompt: trimmedPrompt },
|
||||||
|
});
|
||||||
|
} catch (error) {
|
||||||
|
console.error("[CopilotPage] Failed to start chat:", error);
|
||||||
|
toast({ title: "Failed to start chat", variant: "destructive" });
|
||||||
|
Sentry.captureException(error);
|
||||||
|
dispatch({ type: "setPageState", pageState: { type: "welcome" } });
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function handleQuickAction(action: string) {
|
||||||
|
startChatWithPrompt(action);
|
||||||
|
}
|
||||||
|
|
||||||
|
function handleSessionNotFound() {
|
||||||
|
router.replace("/copilot");
|
||||||
|
}
|
||||||
|
|
||||||
|
function handleStreamingChange(isStreamingValue: boolean) {
|
||||||
|
dispatch({ type: "setStreaming", isStreaming: isStreamingValue });
|
||||||
|
}
|
||||||
|
|
||||||
|
async function proceedWithNewChat() {
|
||||||
|
dispatch({ type: "setNewChatModalOpen", isOpen: false });
|
||||||
|
if (newChatContext?.performNewChat) {
|
||||||
|
newChatContext.performNewChat();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
try {
|
||||||
|
await setUrlSessionId(null, { shallow: false });
|
||||||
|
} catch (error) {
|
||||||
|
console.error("[CopilotPage] Failed to clear session:", error);
|
||||||
|
}
|
||||||
|
router.replace("/copilot");
|
||||||
|
}
|
||||||
|
|
||||||
|
function handleCancelNewChat() {
|
||||||
|
dispatch({ type: "setNewChatModalOpen", isOpen: false });
|
||||||
|
}
|
||||||
|
|
||||||
|
function handleNewChatModalOpen(isOpen: boolean) {
|
||||||
|
dispatch({ type: "setNewChatModalOpen", isOpen });
|
||||||
|
}
|
||||||
|
|
||||||
|
function handleNewChatClick() {
|
||||||
|
if (state.isStreaming) {
|
||||||
|
dispatch({ type: "setNewChatModalOpen", isOpen: true });
|
||||||
|
} else {
|
||||||
|
proceedWithNewChat();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
state: {
|
||||||
|
greetingName,
|
||||||
|
quickActions,
|
||||||
|
isLoading: isUserLoading,
|
||||||
|
pageState: state.pageState,
|
||||||
|
isNewChatModalOpen: state.isNewChatModalOpen,
|
||||||
|
isReady: isFlagReady && isChatEnabled !== false && isLoggedIn,
|
||||||
|
},
|
||||||
|
handlers: {
|
||||||
|
handleQuickAction,
|
||||||
|
startChatWithPrompt,
|
||||||
|
handleSessionNotFound,
|
||||||
|
handleStreamingChange,
|
||||||
|
handleCancelNewChat,
|
||||||
|
proceedWithNewChat,
|
||||||
|
handleNewChatModalOpen,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
}
|
||||||
@@ -0,0 +1,80 @@
|
|||||||
|
import { parseAsString, useQueryState } from "nuqs";
|
||||||
|
import { useLayoutEffect } from "react";
|
||||||
|
import {
|
||||||
|
getInitialPromptFromState,
|
||||||
|
type PageState,
|
||||||
|
shouldResetToWelcome,
|
||||||
|
} from "./helpers";
|
||||||
|
|
||||||
|
interface UseCopilotUrlStateArgs {
|
||||||
|
pageState: PageState;
|
||||||
|
initialPrompts: Record<string, string>;
|
||||||
|
previousSessionId: string | null;
|
||||||
|
setPageState: (pageState: PageState) => void;
|
||||||
|
setInitialPrompt: (sessionId: string, prompt: string) => void;
|
||||||
|
setPreviousSessionId: (sessionId: string | null) => void;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function useCopilotURLState({
|
||||||
|
pageState,
|
||||||
|
initialPrompts,
|
||||||
|
previousSessionId,
|
||||||
|
setPageState,
|
||||||
|
setInitialPrompt,
|
||||||
|
setPreviousSessionId,
|
||||||
|
}: UseCopilotUrlStateArgs) {
|
||||||
|
const [urlSessionId, setUrlSessionId] = useQueryState(
|
||||||
|
"sessionId",
|
||||||
|
parseAsString,
|
||||||
|
);
|
||||||
|
|
||||||
|
function syncSessionFromUrl() {
|
||||||
|
if (urlSessionId) {
|
||||||
|
if (pageState.type === "chat" && pageState.sessionId === urlSessionId) {
|
||||||
|
setPreviousSessionId(urlSessionId);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const storedInitialPrompt = initialPrompts[urlSessionId];
|
||||||
|
const currentInitialPrompt = getInitialPromptFromState(
|
||||||
|
pageState,
|
||||||
|
storedInitialPrompt,
|
||||||
|
);
|
||||||
|
|
||||||
|
if (currentInitialPrompt) {
|
||||||
|
setInitialPrompt(urlSessionId, currentInitialPrompt);
|
||||||
|
}
|
||||||
|
|
||||||
|
setPageState({
|
||||||
|
type: "chat",
|
||||||
|
sessionId: urlSessionId,
|
||||||
|
initialPrompt: currentInitialPrompt,
|
||||||
|
});
|
||||||
|
setPreviousSessionId(urlSessionId);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const wasInChat = previousSessionId !== null && pageState.type === "chat";
|
||||||
|
setPreviousSessionId(null);
|
||||||
|
if (wasInChat) {
|
||||||
|
setPageState({ type: "newChat" });
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (shouldResetToWelcome(pageState)) {
|
||||||
|
setPageState({ type: "welcome" });
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
useLayoutEffect(syncSessionFromUrl, [
|
||||||
|
urlSessionId,
|
||||||
|
pageState.type,
|
||||||
|
previousSessionId,
|
||||||
|
initialPrompts,
|
||||||
|
]);
|
||||||
|
|
||||||
|
return {
|
||||||
|
urlSessionId,
|
||||||
|
setUrlSessionId,
|
||||||
|
};
|
||||||
|
}
|
||||||
@@ -14,6 +14,10 @@ import {
|
|||||||
import { Dialog } from "@/components/molecules/Dialog/Dialog";
|
import { Dialog } from "@/components/molecules/Dialog/Dialog";
|
||||||
import { useEffect, useRef, useState } from "react";
|
import { useEffect, useRef, useState } from "react";
|
||||||
import { ScheduleAgentModal } from "../ScheduleAgentModal/ScheduleAgentModal";
|
import { ScheduleAgentModal } from "../ScheduleAgentModal/ScheduleAgentModal";
|
||||||
|
import {
|
||||||
|
AIAgentSafetyPopup,
|
||||||
|
useAIAgentSafetyPopup,
|
||||||
|
} from "./components/AIAgentSafetyPopup/AIAgentSafetyPopup";
|
||||||
import { ModalHeader } from "./components/ModalHeader/ModalHeader";
|
import { ModalHeader } from "./components/ModalHeader/ModalHeader";
|
||||||
import { ModalRunSection } from "./components/ModalRunSection/ModalRunSection";
|
import { ModalRunSection } from "./components/ModalRunSection/ModalRunSection";
|
||||||
import { RunActions } from "./components/RunActions/RunActions";
|
import { RunActions } from "./components/RunActions/RunActions";
|
||||||
@@ -83,8 +87,18 @@ export function RunAgentModal({
|
|||||||
|
|
||||||
const [isScheduleModalOpen, setIsScheduleModalOpen] = useState(false);
|
const [isScheduleModalOpen, setIsScheduleModalOpen] = useState(false);
|
||||||
const [hasOverflow, setHasOverflow] = useState(false);
|
const [hasOverflow, setHasOverflow] = useState(false);
|
||||||
|
const [isSafetyPopupOpen, setIsSafetyPopupOpen] = useState(false);
|
||||||
|
const [pendingRunAction, setPendingRunAction] = useState<(() => void) | null>(
|
||||||
|
null,
|
||||||
|
);
|
||||||
const contentRef = useRef<HTMLDivElement>(null);
|
const contentRef = useRef<HTMLDivElement>(null);
|
||||||
|
|
||||||
|
const { shouldShowPopup, dismissPopup } = useAIAgentSafetyPopup(
|
||||||
|
agent.id,
|
||||||
|
agent.has_sensitive_action,
|
||||||
|
agent.has_human_in_the_loop,
|
||||||
|
);
|
||||||
|
|
||||||
const hasAnySetupFields =
|
const hasAnySetupFields =
|
||||||
Object.keys(agentInputFields || {}).length > 0 ||
|
Object.keys(agentInputFields || {}).length > 0 ||
|
||||||
Object.keys(agentCredentialsInputFields || {}).length > 0;
|
Object.keys(agentCredentialsInputFields || {}).length > 0;
|
||||||
@@ -165,6 +179,24 @@ export function RunAgentModal({
|
|||||||
onScheduleCreated?.(schedule);
|
onScheduleCreated?.(schedule);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function handleRunWithSafetyCheck() {
|
||||||
|
if (shouldShowPopup) {
|
||||||
|
setPendingRunAction(() => handleRun);
|
||||||
|
setIsSafetyPopupOpen(true);
|
||||||
|
} else {
|
||||||
|
handleRun();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function handleSafetyPopupAcknowledge() {
|
||||||
|
setIsSafetyPopupOpen(false);
|
||||||
|
dismissPopup();
|
||||||
|
if (pendingRunAction) {
|
||||||
|
pendingRunAction();
|
||||||
|
setPendingRunAction(null);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<>
|
<>
|
||||||
<Dialog
|
<Dialog
|
||||||
@@ -248,7 +280,7 @@ export function RunAgentModal({
|
|||||||
)}
|
)}
|
||||||
<RunActions
|
<RunActions
|
||||||
defaultRunType={defaultRunType}
|
defaultRunType={defaultRunType}
|
||||||
onRun={handleRun}
|
onRun={handleRunWithSafetyCheck}
|
||||||
isExecuting={isExecuting}
|
isExecuting={isExecuting}
|
||||||
isSettingUpTrigger={isSettingUpTrigger}
|
isSettingUpTrigger={isSettingUpTrigger}
|
||||||
isRunReady={allRequiredInputsAreSet}
|
isRunReady={allRequiredInputsAreSet}
|
||||||
@@ -266,6 +298,12 @@ export function RunAgentModal({
|
|||||||
</div>
|
</div>
|
||||||
</Dialog.Content>
|
</Dialog.Content>
|
||||||
</Dialog>
|
</Dialog>
|
||||||
|
|
||||||
|
<AIAgentSafetyPopup
|
||||||
|
agentId={agent.id}
|
||||||
|
isOpen={isSafetyPopupOpen}
|
||||||
|
onAcknowledge={handleSafetyPopupAcknowledge}
|
||||||
|
/>
|
||||||
</>
|
</>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,108 @@
|
|||||||
|
"use client";
|
||||||
|
|
||||||
|
import { Button } from "@/components/atoms/Button/Button";
|
||||||
|
import { Text } from "@/components/atoms/Text/Text";
|
||||||
|
import { Dialog } from "@/components/molecules/Dialog/Dialog";
|
||||||
|
import { Key, storage } from "@/services/storage/local-storage";
|
||||||
|
import { ShieldCheckIcon } from "@phosphor-icons/react";
|
||||||
|
import { useCallback, useEffect, useState } from "react";
|
||||||
|
|
||||||
|
interface Props {
|
||||||
|
agentId: string;
|
||||||
|
onAcknowledge: () => void;
|
||||||
|
isOpen: boolean;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function AIAgentSafetyPopup({ agentId, onAcknowledge, isOpen }: Props) {
|
||||||
|
function handleAcknowledge() {
|
||||||
|
// Add this agent to the list of agents for which popup has been shown
|
||||||
|
const seenAgentsJson = storage.get(Key.AI_AGENT_SAFETY_POPUP_SHOWN);
|
||||||
|
const seenAgents: string[] = seenAgentsJson
|
||||||
|
? JSON.parse(seenAgentsJson)
|
||||||
|
: [];
|
||||||
|
|
||||||
|
if (!seenAgents.includes(agentId)) {
|
||||||
|
seenAgents.push(agentId);
|
||||||
|
storage.set(Key.AI_AGENT_SAFETY_POPUP_SHOWN, JSON.stringify(seenAgents));
|
||||||
|
}
|
||||||
|
|
||||||
|
onAcknowledge();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!isOpen) return null;
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Dialog
|
||||||
|
controlled={{ isOpen, set: () => {} }}
|
||||||
|
styling={{ maxWidth: "480px" }}
|
||||||
|
>
|
||||||
|
<Dialog.Content>
|
||||||
|
<div className="flex flex-col items-center p-6 text-center">
|
||||||
|
<div className="mb-6 flex h-16 w-16 items-center justify-center rounded-full bg-blue-50">
|
||||||
|
<ShieldCheckIcon
|
||||||
|
weight="fill"
|
||||||
|
size={32}
|
||||||
|
className="text-blue-600"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<Text variant="h3" className="mb-4">
|
||||||
|
Safety Checks Enabled
|
||||||
|
</Text>
|
||||||
|
|
||||||
|
<Text variant="body" className="mb-2 text-zinc-700">
|
||||||
|
AI-generated agents may take actions that affect your data or
|
||||||
|
external systems.
|
||||||
|
</Text>
|
||||||
|
|
||||||
|
<Text variant="body" className="mb-8 text-zinc-700">
|
||||||
|
AutoGPT includes safety checks so you'll always have the
|
||||||
|
opportunity to review and approve sensitive actions before they
|
||||||
|
happen.
|
||||||
|
</Text>
|
||||||
|
|
||||||
|
<Button
|
||||||
|
variant="primary"
|
||||||
|
size="large"
|
||||||
|
className="w-full"
|
||||||
|
onClick={handleAcknowledge}
|
||||||
|
>
|
||||||
|
Got it
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
</Dialog.Content>
|
||||||
|
</Dialog>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
export function useAIAgentSafetyPopup(
|
||||||
|
agentId: string,
|
||||||
|
hasSensitiveAction: boolean,
|
||||||
|
hasHumanInTheLoop: boolean,
|
||||||
|
) {
|
||||||
|
const [shouldShowPopup, setShouldShowPopup] = useState(false);
|
||||||
|
const [hasChecked, setHasChecked] = useState(false);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (hasChecked) return;
|
||||||
|
|
||||||
|
const seenAgentsJson = storage.get(Key.AI_AGENT_SAFETY_POPUP_SHOWN);
|
||||||
|
const seenAgents: string[] = seenAgentsJson
|
||||||
|
? JSON.parse(seenAgentsJson)
|
||||||
|
: [];
|
||||||
|
const hasSeenPopupForThisAgent = seenAgents.includes(agentId);
|
||||||
|
const isRelevantAgent = hasSensitiveAction || hasHumanInTheLoop;
|
||||||
|
|
||||||
|
setShouldShowPopup(!hasSeenPopupForThisAgent && isRelevantAgent);
|
||||||
|
setHasChecked(true);
|
||||||
|
}, [agentId, hasSensitiveAction, hasHumanInTheLoop, hasChecked]);
|
||||||
|
|
||||||
|
const dismissPopup = useCallback(() => {
|
||||||
|
setShouldShowPopup(false);
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
return {
|
||||||
|
shouldShowPopup,
|
||||||
|
dismissPopup,
|
||||||
|
};
|
||||||
|
}
|
||||||
@@ -69,7 +69,6 @@ export function SafeModeToggle({ graph, className }: Props) {
|
|||||||
const {
|
const {
|
||||||
currentHITLSafeMode,
|
currentHITLSafeMode,
|
||||||
showHITLToggle,
|
showHITLToggle,
|
||||||
isHITLStateUndetermined,
|
|
||||||
handleHITLToggle,
|
handleHITLToggle,
|
||||||
currentSensitiveActionSafeMode,
|
currentSensitiveActionSafeMode,
|
||||||
showSensitiveActionToggle,
|
showSensitiveActionToggle,
|
||||||
@@ -78,20 +77,13 @@ export function SafeModeToggle({ graph, className }: Props) {
|
|||||||
shouldShowToggle,
|
shouldShowToggle,
|
||||||
} = useAgentSafeMode(graph);
|
} = useAgentSafeMode(graph);
|
||||||
|
|
||||||
if (!shouldShowToggle || isHITLStateUndetermined) {
|
if (!shouldShowToggle) {
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
const showHITL = showHITLToggle && !isHITLStateUndetermined;
|
|
||||||
const showSensitive = showSensitiveActionToggle;
|
|
||||||
|
|
||||||
if (!showHITL && !showSensitive) {
|
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className={cn("flex gap-1", className)}>
|
<div className={cn("flex gap-1", className)}>
|
||||||
{showHITL && (
|
{showHITLToggle && (
|
||||||
<SafeModeIconButton
|
<SafeModeIconButton
|
||||||
isEnabled={currentHITLSafeMode}
|
isEnabled={currentHITLSafeMode}
|
||||||
label="Human-in-the-loop"
|
label="Human-in-the-loop"
|
||||||
@@ -101,7 +93,7 @@ export function SafeModeToggle({ graph, className }: Props) {
|
|||||||
isPending={isPending}
|
isPending={isPending}
|
||||||
/>
|
/>
|
||||||
)}
|
)}
|
||||||
{showSensitive && (
|
{showSensitiveActionToggle && (
|
||||||
<SafeModeIconButton
|
<SafeModeIconButton
|
||||||
isEnabled={currentSensitiveActionSafeMode}
|
isEnabled={currentSensitiveActionSafeMode}
|
||||||
label="Sensitive actions"
|
label="Sensitive actions"
|
||||||
|
|||||||
@@ -0,0 +1,15 @@
|
|||||||
|
import { expect, test } from "vitest";
|
||||||
|
import { render, screen } from "@/tests/integrations/test-utils";
|
||||||
|
import { MainMarkeplacePage } from "../MainMarketplacePage";
|
||||||
|
import { server } from "@/mocks/mock-server";
|
||||||
|
import { getDeleteV2DeleteStoreSubmissionMockHandler422 } from "@/app/api/__generated__/endpoints/store/store.msw";
|
||||||
|
|
||||||
|
// Only for CI testing purpose, will remove it in future PR
|
||||||
|
test("MainMarketplacePage", async () => {
|
||||||
|
server.use(getDeleteV2DeleteStoreSubmissionMockHandler422());
|
||||||
|
|
||||||
|
render(<MainMarkeplacePage />);
|
||||||
|
expect(
|
||||||
|
await screen.findByText("Featured agents", { exact: false }),
|
||||||
|
).toBeDefined();
|
||||||
|
});
|
||||||
@@ -8809,6 +8809,12 @@
|
|||||||
"title": "Node Exec Id",
|
"title": "Node Exec Id",
|
||||||
"description": "Node execution ID (primary key)"
|
"description": "Node execution ID (primary key)"
|
||||||
},
|
},
|
||||||
|
"node_id": {
|
||||||
|
"type": "string",
|
||||||
|
"title": "Node Id",
|
||||||
|
"description": "Node definition ID (for grouping)",
|
||||||
|
"default": ""
|
||||||
|
},
|
||||||
"user_id": {
|
"user_id": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"title": "User Id",
|
"title": "User Id",
|
||||||
@@ -8908,7 +8914,7 @@
|
|||||||
"created_at"
|
"created_at"
|
||||||
],
|
],
|
||||||
"title": "PendingHumanReviewModel",
|
"title": "PendingHumanReviewModel",
|
||||||
"description": "Response model for pending human review data.\n\nRepresents a human review request that is awaiting user action.\nContains all necessary information for a user to review and approve\nor reject data from a Human-in-the-Loop block execution.\n\nAttributes:\n id: Unique identifier for the review record\n user_id: ID of the user who must perform the review\n node_exec_id: ID of the node execution that created this review\n graph_exec_id: ID of the graph execution containing the node\n graph_id: ID of the graph template being executed\n graph_version: Version number of the graph template\n payload: The actual data payload awaiting review\n instructions: Instructions or message for the reviewer\n editable: Whether the reviewer can edit the data\n status: Current review status (WAITING, APPROVED, or REJECTED)\n review_message: Optional message from the reviewer\n created_at: Timestamp when review was created\n updated_at: Timestamp when review was last modified\n reviewed_at: Timestamp when review was completed (if applicable)"
|
"description": "Response model for pending human review data.\n\nRepresents a human review request that is awaiting user action.\nContains all necessary information for a user to review and approve\nor reject data from a Human-in-the-Loop block execution.\n\nAttributes:\n id: Unique identifier for the review record\n user_id: ID of the user who must perform the review\n node_exec_id: ID of the node execution that created this review\n node_id: ID of the node definition (for grouping reviews from same node)\n graph_exec_id: ID of the graph execution containing the node\n graph_id: ID of the graph template being executed\n graph_version: Version number of the graph template\n payload: The actual data payload awaiting review\n instructions: Instructions or message for the reviewer\n editable: Whether the reviewer can edit the data\n status: Current review status (WAITING, APPROVED, or REJECTED)\n review_message: Optional message from the reviewer\n created_at: Timestamp when review was created\n updated_at: Timestamp when review was last modified\n reviewed_at: Timestamp when review was completed (if applicable)"
|
||||||
},
|
},
|
||||||
"PostmarkBounceEnum": {
|
"PostmarkBounceEnum": {
|
||||||
"type": "integer",
|
"type": "integer",
|
||||||
@@ -9411,6 +9417,12 @@
|
|||||||
],
|
],
|
||||||
"title": "Reviewed Data",
|
"title": "Reviewed Data",
|
||||||
"description": "Optional edited data (ignored if approved=False)"
|
"description": "Optional edited data (ignored if approved=False)"
|
||||||
|
},
|
||||||
|
"auto_approve_future": {
|
||||||
|
"type": "boolean",
|
||||||
|
"title": "Auto Approve Future",
|
||||||
|
"description": "If true and this review is approved, future executions of this same block (node) will be automatically approved. This only affects approved reviews.",
|
||||||
|
"default": false
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"type": "object",
|
"type": "object",
|
||||||
@@ -9430,7 +9442,7 @@
|
|||||||
"type": "object",
|
"type": "object",
|
||||||
"required": ["reviews"],
|
"required": ["reviews"],
|
||||||
"title": "ReviewRequest",
|
"title": "ReviewRequest",
|
||||||
"description": "Request model for processing ALL pending reviews for an execution.\n\nThis request must include ALL pending reviews for a graph execution.\nEach review will be either approved (with optional data modifications)\nor rejected (data ignored). The execution will resume only after ALL reviews are processed."
|
"description": "Request model for processing ALL pending reviews for an execution.\n\nThis request must include ALL pending reviews for a graph execution.\nEach review will be either approved (with optional data modifications)\nor rejected (data ignored). The execution will resume only after ALL reviews are processed.\n\nEach review item can individually specify whether to auto-approve future executions\nof the same block via the `auto_approve_future` field on ReviewItem."
|
||||||
},
|
},
|
||||||
"ReviewResponse": {
|
"ReviewResponse": {
|
||||||
"properties": {
|
"properties": {
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ export interface ChatProps {
|
|||||||
urlSessionId?: string | null;
|
urlSessionId?: string | null;
|
||||||
initialPrompt?: string;
|
initialPrompt?: string;
|
||||||
onSessionNotFound?: () => void;
|
onSessionNotFound?: () => void;
|
||||||
|
onStreamingChange?: (isStreaming: boolean) => void;
|
||||||
}
|
}
|
||||||
|
|
||||||
export function Chat({
|
export function Chat({
|
||||||
@@ -20,6 +21,7 @@ export function Chat({
|
|||||||
urlSessionId,
|
urlSessionId,
|
||||||
initialPrompt,
|
initialPrompt,
|
||||||
onSessionNotFound,
|
onSessionNotFound,
|
||||||
|
onStreamingChange,
|
||||||
}: ChatProps) {
|
}: ChatProps) {
|
||||||
const hasHandledNotFoundRef = useRef(false);
|
const hasHandledNotFoundRef = useRef(false);
|
||||||
const {
|
const {
|
||||||
@@ -73,6 +75,7 @@ export function Chat({
|
|||||||
initialMessages={messages}
|
initialMessages={messages}
|
||||||
initialPrompt={initialPrompt}
|
initialPrompt={initialPrompt}
|
||||||
className="flex-1"
|
className="flex-1"
|
||||||
|
onStreamingChange={onStreamingChange}
|
||||||
/>
|
/>
|
||||||
)}
|
)}
|
||||||
</main>
|
</main>
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import { Text } from "@/components/atoms/Text/Text";
|
|||||||
import { Dialog } from "@/components/molecules/Dialog/Dialog";
|
import { Dialog } from "@/components/molecules/Dialog/Dialog";
|
||||||
import { useBreakpoint } from "@/lib/hooks/useBreakpoint";
|
import { useBreakpoint } from "@/lib/hooks/useBreakpoint";
|
||||||
import { cn } from "@/lib/utils";
|
import { cn } from "@/lib/utils";
|
||||||
|
import { useEffect } from "react";
|
||||||
import { ChatInput } from "../ChatInput/ChatInput";
|
import { ChatInput } from "../ChatInput/ChatInput";
|
||||||
import { MessageList } from "../MessageList/MessageList";
|
import { MessageList } from "../MessageList/MessageList";
|
||||||
import { useChatContainer } from "./useChatContainer";
|
import { useChatContainer } from "./useChatContainer";
|
||||||
@@ -13,6 +14,7 @@ export interface ChatContainerProps {
|
|||||||
initialMessages: SessionDetailResponse["messages"];
|
initialMessages: SessionDetailResponse["messages"];
|
||||||
initialPrompt?: string;
|
initialPrompt?: string;
|
||||||
className?: string;
|
className?: string;
|
||||||
|
onStreamingChange?: (isStreaming: boolean) => void;
|
||||||
}
|
}
|
||||||
|
|
||||||
export function ChatContainer({
|
export function ChatContainer({
|
||||||
@@ -20,6 +22,7 @@ export function ChatContainer({
|
|||||||
initialMessages,
|
initialMessages,
|
||||||
initialPrompt,
|
initialPrompt,
|
||||||
className,
|
className,
|
||||||
|
onStreamingChange,
|
||||||
}: ChatContainerProps) {
|
}: ChatContainerProps) {
|
||||||
const {
|
const {
|
||||||
messages,
|
messages,
|
||||||
@@ -36,6 +39,10 @@ export function ChatContainer({
|
|||||||
initialPrompt,
|
initialPrompt,
|
||||||
});
|
});
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
onStreamingChange?.(isStreaming);
|
||||||
|
}, [isStreaming, onStreamingChange]);
|
||||||
|
|
||||||
const breakpoint = useBreakpoint();
|
const breakpoint = useBreakpoint();
|
||||||
const isMobile =
|
const isMobile =
|
||||||
breakpoint === "base" || breakpoint === "sm" || breakpoint === "md";
|
breakpoint === "base" || breakpoint === "sm" || breakpoint === "md";
|
||||||
|
|||||||
@@ -1,12 +1,7 @@
|
|||||||
import { Text } from "@/components/atoms/Text/Text";
|
|
||||||
|
|
||||||
export function ChatLoader() {
|
export function ChatLoader() {
|
||||||
return (
|
return (
|
||||||
<Text
|
<div className="flex items-center gap-2">
|
||||||
variant="small"
|
<div className="h-5 w-5 animate-loader rounded-full bg-black" />
|
||||||
className="bg-gradient-to-r from-neutral-600 via-neutral-500 to-neutral-600 bg-[length:200%_100%] bg-clip-text text-xs text-transparent [animation:shimmer_2s_ease-in-out_infinite]"
|
</div>
|
||||||
>
|
|
||||||
Taking a bit more time...
|
|
||||||
</Text>
|
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ import {
|
|||||||
ArrowsClockwiseIcon,
|
ArrowsClockwiseIcon,
|
||||||
CheckCircleIcon,
|
CheckCircleIcon,
|
||||||
CheckIcon,
|
CheckIcon,
|
||||||
CopyIcon,
|
|
||||||
} from "@phosphor-icons/react";
|
} from "@phosphor-icons/react";
|
||||||
import { useRouter } from "next/navigation";
|
import { useRouter } from "next/navigation";
|
||||||
import { useCallback, useState } from "react";
|
import { useCallback, useState } from "react";
|
||||||
@@ -340,11 +339,26 @@ export function ChatMessage({
|
|||||||
size="icon"
|
size="icon"
|
||||||
onClick={handleCopy}
|
onClick={handleCopy}
|
||||||
aria-label="Copy message"
|
aria-label="Copy message"
|
||||||
|
className="p-1"
|
||||||
>
|
>
|
||||||
{copied ? (
|
{copied ? (
|
||||||
<CheckIcon className="size-4 text-green-600" />
|
<CheckIcon className="size-4 text-green-600" />
|
||||||
) : (
|
) : (
|
||||||
<CopyIcon className="size-4 text-zinc-600" />
|
<svg
|
||||||
|
xmlns="http://www.w3.org/2000/svg"
|
||||||
|
width="24"
|
||||||
|
height="24"
|
||||||
|
viewBox="0 0 24 24"
|
||||||
|
fill="none"
|
||||||
|
stroke="currentColor"
|
||||||
|
strokeWidth="2"
|
||||||
|
strokeLinecap="round"
|
||||||
|
strokeLinejoin="round"
|
||||||
|
className="size-3 text-zinc-600"
|
||||||
|
>
|
||||||
|
<rect width="14" height="14" x="8" y="8" rx="2" ry="2" />
|
||||||
|
<path d="M4 16c-1.1 0-2-.9-2-2V4c0-1.1.9-2 2-2h10c1.1 0 2 .9 2 2" />
|
||||||
|
</svg>
|
||||||
)}
|
)}
|
||||||
</Button>
|
</Button>
|
||||||
)}
|
)}
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
import { cn } from "@/lib/utils";
|
import { cn } from "@/lib/utils";
|
||||||
import { useEffect, useRef, useState } from "react";
|
import { useEffect, useRef, useState } from "react";
|
||||||
import { AIChatBubble } from "../AIChatBubble/AIChatBubble";
|
import { AIChatBubble } from "../AIChatBubble/AIChatBubble";
|
||||||
import { ChatLoader } from "../ChatLoader/ChatLoader";
|
|
||||||
|
|
||||||
export interface ThinkingMessageProps {
|
export interface ThinkingMessageProps {
|
||||||
className?: string;
|
className?: string;
|
||||||
@@ -9,7 +8,9 @@ export interface ThinkingMessageProps {
|
|||||||
|
|
||||||
export function ThinkingMessage({ className }: ThinkingMessageProps) {
|
export function ThinkingMessage({ className }: ThinkingMessageProps) {
|
||||||
const [showSlowLoader, setShowSlowLoader] = useState(false);
|
const [showSlowLoader, setShowSlowLoader] = useState(false);
|
||||||
|
const [showCoffeeMessage, setShowCoffeeMessage] = useState(false);
|
||||||
const timerRef = useRef<NodeJS.Timeout | null>(null);
|
const timerRef = useRef<NodeJS.Timeout | null>(null);
|
||||||
|
const coffeeTimerRef = useRef<NodeJS.Timeout | null>(null);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (timerRef.current === null) {
|
if (timerRef.current === null) {
|
||||||
@@ -18,11 +19,21 @@ export function ThinkingMessage({ className }: ThinkingMessageProps) {
|
|||||||
}, 8000);
|
}, 8000);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (coffeeTimerRef.current === null) {
|
||||||
|
coffeeTimerRef.current = setTimeout(() => {
|
||||||
|
setShowCoffeeMessage(true);
|
||||||
|
}, 10000);
|
||||||
|
}
|
||||||
|
|
||||||
return () => {
|
return () => {
|
||||||
if (timerRef.current) {
|
if (timerRef.current) {
|
||||||
clearTimeout(timerRef.current);
|
clearTimeout(timerRef.current);
|
||||||
timerRef.current = null;
|
timerRef.current = null;
|
||||||
}
|
}
|
||||||
|
if (coffeeTimerRef.current) {
|
||||||
|
clearTimeout(coffeeTimerRef.current);
|
||||||
|
coffeeTimerRef.current = null;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
}, []);
|
}, []);
|
||||||
|
|
||||||
@@ -37,16 +48,16 @@ export function ThinkingMessage({ className }: ThinkingMessageProps) {
|
|||||||
<div className="flex min-w-0 flex-1 flex-col">
|
<div className="flex min-w-0 flex-1 flex-col">
|
||||||
<AIChatBubble>
|
<AIChatBubble>
|
||||||
<div className="transition-all duration-500 ease-in-out">
|
<div className="transition-all duration-500 ease-in-out">
|
||||||
{showSlowLoader ? (
|
{showCoffeeMessage ? (
|
||||||
<ChatLoader />
|
<span className="inline-block animate-shimmer bg-gradient-to-r from-neutral-400 via-neutral-600 to-neutral-400 bg-[length:200%_100%] bg-clip-text text-transparent">
|
||||||
|
This could take a few minutes, grab a coffee ☕️
|
||||||
|
</span>
|
||||||
|
) : showSlowLoader ? (
|
||||||
|
<span className="inline-block animate-shimmer bg-gradient-to-r from-neutral-400 via-neutral-600 to-neutral-400 bg-[length:200%_100%] bg-clip-text text-transparent">
|
||||||
|
Taking a bit more time...
|
||||||
|
</span>
|
||||||
) : (
|
) : (
|
||||||
<span
|
<span className="inline-block animate-shimmer bg-gradient-to-r from-neutral-400 via-neutral-600 to-neutral-400 bg-[length:200%_100%] bg-clip-text text-transparent">
|
||||||
className="inline-block bg-gradient-to-r from-neutral-400 via-neutral-600 to-neutral-400 bg-clip-text text-transparent"
|
|
||||||
style={{
|
|
||||||
backgroundSize: "200% 100%",
|
|
||||||
animation: "shimmer 2s ease-in-out infinite",
|
|
||||||
}}
|
|
||||||
>
|
|
||||||
Thinking...
|
Thinking...
|
||||||
</span>
|
</span>
|
||||||
)}
|
)}
|
||||||
|
|||||||
@@ -31,6 +31,29 @@ export function FloatingReviewsPanel({
|
|||||||
query: {
|
query: {
|
||||||
enabled: !!(graphId && executionId),
|
enabled: !!(graphId && executionId),
|
||||||
select: okData,
|
select: okData,
|
||||||
|
// Poll while execution is in progress to detect status changes
|
||||||
|
refetchInterval: (q) => {
|
||||||
|
// Note: refetchInterval callback receives raw data before select transform
|
||||||
|
const rawData = q.state.data as
|
||||||
|
| { status: number; data?: { status?: string } }
|
||||||
|
| undefined;
|
||||||
|
if (rawData?.status !== 200) return false;
|
||||||
|
|
||||||
|
const status = rawData?.data?.status;
|
||||||
|
if (!status) return false;
|
||||||
|
|
||||||
|
// Poll every 2 seconds while running or in review
|
||||||
|
if (
|
||||||
|
status === AgentExecutionStatus.RUNNING ||
|
||||||
|
status === AgentExecutionStatus.QUEUED ||
|
||||||
|
status === AgentExecutionStatus.INCOMPLETE ||
|
||||||
|
status === AgentExecutionStatus.REVIEW
|
||||||
|
) {
|
||||||
|
return 2000;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
},
|
||||||
|
refetchIntervalInBackground: true,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
@@ -40,28 +63,47 @@ export function FloatingReviewsPanel({
|
|||||||
useShallow((state) => state.graphExecutionStatus),
|
useShallow((state) => state.graphExecutionStatus),
|
||||||
);
|
);
|
||||||
|
|
||||||
|
// Determine if we should poll for pending reviews
|
||||||
|
const isInReviewStatus =
|
||||||
|
executionDetails?.status === AgentExecutionStatus.REVIEW ||
|
||||||
|
graphExecutionStatus === AgentExecutionStatus.REVIEW;
|
||||||
|
|
||||||
const { pendingReviews, isLoading, refetch } = usePendingReviewsForExecution(
|
const { pendingReviews, isLoading, refetch } = usePendingReviewsForExecution(
|
||||||
executionId || "",
|
executionId || "",
|
||||||
|
{
|
||||||
|
enabled: !!executionId,
|
||||||
|
// Poll every 2 seconds when in REVIEW status to catch new reviews
|
||||||
|
refetchInterval: isInReviewStatus ? 2000 : false,
|
||||||
|
},
|
||||||
);
|
);
|
||||||
|
|
||||||
|
// Refetch pending reviews when execution status changes
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (executionId) {
|
if (executionId && executionDetails?.status) {
|
||||||
refetch();
|
refetch();
|
||||||
}
|
}
|
||||||
}, [executionDetails?.status, executionId, refetch]);
|
}, [executionDetails?.status, executionId, refetch]);
|
||||||
|
|
||||||
// Refetch when graph execution status changes to REVIEW
|
// Hide panel if:
|
||||||
useEffect(() => {
|
// 1. No execution ID
|
||||||
if (graphExecutionStatus === AgentExecutionStatus.REVIEW && executionId) {
|
// 2. No pending reviews and not in REVIEW status
|
||||||
refetch();
|
// 3. Execution is RUNNING or QUEUED (hasn't paused for review yet)
|
||||||
}
|
if (!executionId) {
|
||||||
}, [graphExecutionStatus, executionId, refetch]);
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
if (
|
if (
|
||||||
!executionId ||
|
!isLoading &&
|
||||||
(!isLoading &&
|
pendingReviews.length === 0 &&
|
||||||
pendingReviews.length === 0 &&
|
executionDetails?.status !== AgentExecutionStatus.REVIEW
|
||||||
executionDetails?.status !== AgentExecutionStatus.REVIEW)
|
) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Don't show panel while execution is still running/queued (not paused for review)
|
||||||
|
if (
|
||||||
|
executionDetails?.status === AgentExecutionStatus.RUNNING ||
|
||||||
|
executionDetails?.status === AgentExecutionStatus.QUEUED
|
||||||
) {
|
) {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,10 +1,8 @@
|
|||||||
import { PendingHumanReviewModel } from "@/app/api/__generated__/models/pendingHumanReviewModel";
|
import { PendingHumanReviewModel } from "@/app/api/__generated__/models/pendingHumanReviewModel";
|
||||||
import { Text } from "@/components/atoms/Text/Text";
|
import { Text } from "@/components/atoms/Text/Text";
|
||||||
import { Button } from "@/components/atoms/Button/Button";
|
|
||||||
import { Input } from "@/components/atoms/Input/Input";
|
import { Input } from "@/components/atoms/Input/Input";
|
||||||
import { Switch } from "@/components/atoms/Switch/Switch";
|
import { Switch } from "@/components/atoms/Switch/Switch";
|
||||||
import { TrashIcon, EyeSlashIcon } from "@phosphor-icons/react";
|
import { useEffect, useState } from "react";
|
||||||
import { useState } from "react";
|
|
||||||
|
|
||||||
interface StructuredReviewPayload {
|
interface StructuredReviewPayload {
|
||||||
data: unknown;
|
data: unknown;
|
||||||
@@ -40,37 +38,49 @@ function extractReviewData(payload: unknown): {
|
|||||||
interface PendingReviewCardProps {
|
interface PendingReviewCardProps {
|
||||||
review: PendingHumanReviewModel;
|
review: PendingHumanReviewModel;
|
||||||
onReviewDataChange: (nodeExecId: string, data: string) => void;
|
onReviewDataChange: (nodeExecId: string, data: string) => void;
|
||||||
reviewMessage?: string;
|
autoApproveFuture?: boolean;
|
||||||
onReviewMessageChange?: (nodeExecId: string, message: string) => void;
|
onAutoApproveFutureChange?: (nodeExecId: string, enabled: boolean) => void;
|
||||||
isDisabled?: boolean;
|
externalDataValue?: string;
|
||||||
onToggleDisabled?: (nodeExecId: string) => void;
|
showAutoApprove?: boolean;
|
||||||
|
nodeId?: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
export function PendingReviewCard({
|
export function PendingReviewCard({
|
||||||
review,
|
review,
|
||||||
onReviewDataChange,
|
onReviewDataChange,
|
||||||
reviewMessage = "",
|
autoApproveFuture = false,
|
||||||
onReviewMessageChange,
|
onAutoApproveFutureChange,
|
||||||
isDisabled = false,
|
externalDataValue,
|
||||||
onToggleDisabled,
|
showAutoApprove = true,
|
||||||
|
nodeId,
|
||||||
}: PendingReviewCardProps) {
|
}: PendingReviewCardProps) {
|
||||||
const extractedData = extractReviewData(review.payload);
|
const extractedData = extractReviewData(review.payload);
|
||||||
const isDataEditable = review.editable;
|
const isDataEditable = review.editable;
|
||||||
const instructions = extractedData.instructions || review.instructions;
|
|
||||||
|
let instructions = review.instructions;
|
||||||
|
|
||||||
|
const isHITLBlock = instructions && !instructions.includes("Block");
|
||||||
|
|
||||||
|
if (instructions && !isHITLBlock) {
|
||||||
|
instructions = undefined;
|
||||||
|
}
|
||||||
|
|
||||||
const [currentData, setCurrentData] = useState(extractedData.data);
|
const [currentData, setCurrentData] = useState(extractedData.data);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (externalDataValue !== undefined) {
|
||||||
|
try {
|
||||||
|
const parsedData = JSON.parse(externalDataValue);
|
||||||
|
setCurrentData(parsedData);
|
||||||
|
} catch {}
|
||||||
|
}
|
||||||
|
}, [externalDataValue]);
|
||||||
|
|
||||||
const handleDataChange = (newValue: unknown) => {
|
const handleDataChange = (newValue: unknown) => {
|
||||||
setCurrentData(newValue);
|
setCurrentData(newValue);
|
||||||
onReviewDataChange(review.node_exec_id, JSON.stringify(newValue, null, 2));
|
onReviewDataChange(review.node_exec_id, JSON.stringify(newValue, null, 2));
|
||||||
};
|
};
|
||||||
|
|
||||||
const handleMessageChange = (newMessage: string) => {
|
|
||||||
onReviewMessageChange?.(review.node_exec_id, newMessage);
|
|
||||||
};
|
|
||||||
|
|
||||||
// Show simplified view when no toggle functionality is provided (Screenshot 1 mode)
|
|
||||||
const showSimplified = !onToggleDisabled;
|
|
||||||
|
|
||||||
const renderDataInput = () => {
|
const renderDataInput = () => {
|
||||||
const data = currentData;
|
const data = currentData;
|
||||||
|
|
||||||
@@ -137,97 +147,59 @@ export function PendingReviewCard({
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Helper function to get proper field label
|
const getShortenedNodeId = (id: string) => {
|
||||||
const getFieldLabel = (instructions?: string) => {
|
if (id.length <= 8) return id;
|
||||||
if (instructions)
|
return `${id.slice(0, 4)}...${id.slice(-4)}`;
|
||||||
return instructions.charAt(0).toUpperCase() + instructions.slice(1);
|
|
||||||
return "Data to Review";
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// Use the existing HITL review interface
|
|
||||||
return (
|
return (
|
||||||
<div className="space-y-4">
|
<div className="space-y-4">
|
||||||
{!showSimplified && (
|
{nodeId && (
|
||||||
<div className="flex items-start justify-between">
|
<Text variant="small" className="text-gray-500">
|
||||||
<div className="flex-1">
|
Node #{getShortenedNodeId(nodeId)}
|
||||||
{isDisabled && (
|
</Text>
|
||||||
<Text variant="small" className="text-muted-foreground">
|
)}
|
||||||
This item will be rejected
|
|
||||||
</Text>
|
<div className="space-y-3">
|
||||||
)}
|
{instructions && (
|
||||||
|
<Text variant="body" className="font-semibold text-gray-900">
|
||||||
|
{instructions}
|
||||||
|
</Text>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{isDataEditable && !autoApproveFuture ? (
|
||||||
|
renderDataInput()
|
||||||
|
) : (
|
||||||
|
<div className="rounded-lg border border-gray-200 bg-white p-3">
|
||||||
|
<Text variant="small" className="text-gray-600">
|
||||||
|
{JSON.stringify(currentData, null, 2)}
|
||||||
|
</Text>
|
||||||
</div>
|
</div>
|
||||||
<Button
|
)}
|
||||||
onClick={() => onToggleDisabled!(review.node_exec_id)}
|
</div>
|
||||||
variant={isDisabled ? "primary" : "secondary"}
|
|
||||||
size="small"
|
|
||||||
leftIcon={
|
|
||||||
isDisabled ? <EyeSlashIcon size={14} /> : <TrashIcon size={14} />
|
|
||||||
}
|
|
||||||
>
|
|
||||||
{isDisabled ? "Include" : "Exclude"}
|
|
||||||
</Button>
|
|
||||||
</div>
|
|
||||||
)}
|
|
||||||
|
|
||||||
{/* Show instructions as field label */}
|
{/* Auto-approve toggle for this review */}
|
||||||
{instructions && (
|
{showAutoApprove && onAutoApproveFutureChange && (
|
||||||
<div className="space-y-3">
|
<div className="space-y-2 pt-2">
|
||||||
<Text variant="body" className="font-semibold text-gray-900">
|
<div className="flex items-center gap-3">
|
||||||
{getFieldLabel(instructions)}
|
<Switch
|
||||||
</Text>
|
checked={autoApproveFuture}
|
||||||
{isDataEditable && !isDisabled ? (
|
onCheckedChange={(enabled: boolean) =>
|
||||||
renderDataInput()
|
onAutoApproveFutureChange(review.node_exec_id, enabled)
|
||||||
) : (
|
}
|
||||||
<div className="rounded-lg border border-gray-200 bg-white p-3">
|
/>
|
||||||
<Text variant="small" className="text-gray-600">
|
<Text variant="small" className="text-gray-700">
|
||||||
{JSON.stringify(currentData, null, 2)}
|
Auto-approve future executions of this block
|
||||||
</Text>
|
</Text>
|
||||||
</div>
|
</div>
|
||||||
|
{autoApproveFuture && (
|
||||||
|
<Text variant="small" className="pl-11 text-gray-500">
|
||||||
|
Original data will be used for this and all future reviews from
|
||||||
|
this block.
|
||||||
|
</Text>
|
||||||
)}
|
)}
|
||||||
</div>
|
</div>
|
||||||
)}
|
)}
|
||||||
|
|
||||||
{/* If no instructions, show data directly */}
|
|
||||||
{!instructions && (
|
|
||||||
<div className="space-y-3">
|
|
||||||
<Text variant="body" className="font-semibold text-gray-900">
|
|
||||||
Data to Review
|
|
||||||
{!isDataEditable && (
|
|
||||||
<span className="ml-2 text-xs text-muted-foreground">
|
|
||||||
(Read-only)
|
|
||||||
</span>
|
|
||||||
)}
|
|
||||||
</Text>
|
|
||||||
{isDataEditable && !isDisabled ? (
|
|
||||||
renderDataInput()
|
|
||||||
) : (
|
|
||||||
<div className="rounded-lg border border-gray-200 bg-white p-3">
|
|
||||||
<Text variant="small" className="text-gray-600">
|
|
||||||
{JSON.stringify(currentData, null, 2)}
|
|
||||||
</Text>
|
|
||||||
</div>
|
|
||||||
)}
|
|
||||||
</div>
|
|
||||||
)}
|
|
||||||
|
|
||||||
{!showSimplified && isDisabled && (
|
|
||||||
<div>
|
|
||||||
<Text variant="body" className="mb-2 font-semibold">
|
|
||||||
Rejection Reason (Optional):
|
|
||||||
</Text>
|
|
||||||
<Input
|
|
||||||
id="rejection-reason"
|
|
||||||
label="Rejection Reason"
|
|
||||||
hideLabel
|
|
||||||
size="small"
|
|
||||||
type="textarea"
|
|
||||||
rows={3}
|
|
||||||
value={reviewMessage}
|
|
||||||
onChange={(e) => handleMessageChange(e.target.value)}
|
|
||||||
placeholder="Add any notes about why you're rejecting this..."
|
|
||||||
/>
|
|
||||||
</div>
|
|
||||||
)}
|
|
||||||
</div>
|
</div>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,10 +1,16 @@
|
|||||||
import { useState } from "react";
|
import { useMemo, useState } from "react";
|
||||||
import { PendingHumanReviewModel } from "@/app/api/__generated__/models/pendingHumanReviewModel";
|
import { PendingHumanReviewModel } from "@/app/api/__generated__/models/pendingHumanReviewModel";
|
||||||
import { PendingReviewCard } from "@/components/organisms/PendingReviewCard/PendingReviewCard";
|
import { PendingReviewCard } from "@/components/organisms/PendingReviewCard/PendingReviewCard";
|
||||||
import { Text } from "@/components/atoms/Text/Text";
|
import { Text } from "@/components/atoms/Text/Text";
|
||||||
import { Button } from "@/components/atoms/Button/Button";
|
import { Button } from "@/components/atoms/Button/Button";
|
||||||
|
import { Switch } from "@/components/atoms/Switch/Switch";
|
||||||
import { useToast } from "@/components/molecules/Toast/use-toast";
|
import { useToast } from "@/components/molecules/Toast/use-toast";
|
||||||
import { ClockIcon, WarningIcon } from "@phosphor-icons/react";
|
import {
|
||||||
|
ClockIcon,
|
||||||
|
WarningIcon,
|
||||||
|
CaretDownIcon,
|
||||||
|
CaretRightIcon,
|
||||||
|
} from "@phosphor-icons/react";
|
||||||
import { usePostV2ProcessReviewAction } from "@/app/api/__generated__/endpoints/executions/executions";
|
import { usePostV2ProcessReviewAction } from "@/app/api/__generated__/endpoints/executions/executions";
|
||||||
|
|
||||||
interface PendingReviewsListProps {
|
interface PendingReviewsListProps {
|
||||||
@@ -32,16 +38,34 @@ export function PendingReviewsList({
|
|||||||
},
|
},
|
||||||
);
|
);
|
||||||
|
|
||||||
const [reviewMessageMap, setReviewMessageMap] = useState<
|
|
||||||
Record<string, string>
|
|
||||||
>({});
|
|
||||||
|
|
||||||
const [pendingAction, setPendingAction] = useState<
|
const [pendingAction, setPendingAction] = useState<
|
||||||
"approve" | "reject" | null
|
"approve" | "reject" | null
|
||||||
>(null);
|
>(null);
|
||||||
|
|
||||||
|
const [autoApproveFutureMap, setAutoApproveFutureMap] = useState<
|
||||||
|
Record<string, boolean>
|
||||||
|
>({});
|
||||||
|
|
||||||
|
const [collapsedGroups, setCollapsedGroups] = useState<
|
||||||
|
Record<string, boolean>
|
||||||
|
>({});
|
||||||
|
|
||||||
const { toast } = useToast();
|
const { toast } = useToast();
|
||||||
|
|
||||||
|
const groupedReviews = useMemo(() => {
|
||||||
|
return reviews.reduce(
|
||||||
|
(acc, review) => {
|
||||||
|
const nodeId = review.node_id || "unknown";
|
||||||
|
if (!acc[nodeId]) {
|
||||||
|
acc[nodeId] = [];
|
||||||
|
}
|
||||||
|
acc[nodeId].push(review);
|
||||||
|
return acc;
|
||||||
|
},
|
||||||
|
{} as Record<string, PendingHumanReviewModel[]>,
|
||||||
|
);
|
||||||
|
}, [reviews]);
|
||||||
|
|
||||||
const reviewActionMutation = usePostV2ProcessReviewAction({
|
const reviewActionMutation = usePostV2ProcessReviewAction({
|
||||||
mutation: {
|
mutation: {
|
||||||
onSuccess: (res) => {
|
onSuccess: (res) => {
|
||||||
@@ -88,8 +112,33 @@ export function PendingReviewsList({
|
|||||||
setReviewDataMap((prev) => ({ ...prev, [nodeExecId]: data }));
|
setReviewDataMap((prev) => ({ ...prev, [nodeExecId]: data }));
|
||||||
}
|
}
|
||||||
|
|
||||||
function handleReviewMessageChange(nodeExecId: string, message: string) {
|
function handleAutoApproveFutureToggle(nodeId: string, enabled: boolean) {
|
||||||
setReviewMessageMap((prev) => ({ ...prev, [nodeExecId]: message }));
|
setAutoApproveFutureMap((prev) => ({
|
||||||
|
...prev,
|
||||||
|
[nodeId]: enabled,
|
||||||
|
}));
|
||||||
|
|
||||||
|
if (enabled) {
|
||||||
|
const nodeReviews = groupedReviews[nodeId] || [];
|
||||||
|
setReviewDataMap((prev) => {
|
||||||
|
const updated = { ...prev };
|
||||||
|
nodeReviews.forEach((review) => {
|
||||||
|
updated[review.node_exec_id] = JSON.stringify(
|
||||||
|
review.payload,
|
||||||
|
null,
|
||||||
|
2,
|
||||||
|
);
|
||||||
|
});
|
||||||
|
return updated;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function toggleGroupCollapse(nodeId: string) {
|
||||||
|
setCollapsedGroups((prev) => ({
|
||||||
|
...prev,
|
||||||
|
[nodeId]: !prev[nodeId],
|
||||||
|
}));
|
||||||
}
|
}
|
||||||
|
|
||||||
function processReviews(approved: boolean) {
|
function processReviews(approved: boolean) {
|
||||||
@@ -107,22 +156,25 @@ export function PendingReviewsList({
|
|||||||
|
|
||||||
for (const review of reviews) {
|
for (const review of reviews) {
|
||||||
const reviewData = reviewDataMap[review.node_exec_id];
|
const reviewData = reviewDataMap[review.node_exec_id];
|
||||||
const reviewMessage = reviewMessageMap[review.node_exec_id];
|
const autoApproveThisNode = autoApproveFutureMap[review.node_id || ""];
|
||||||
|
|
||||||
let parsedData: any = review.payload; // Default to original payload
|
let parsedData: any = undefined;
|
||||||
|
|
||||||
// Parse edited data if available and editable
|
if (!autoApproveThisNode) {
|
||||||
if (review.editable && reviewData) {
|
if (review.editable && reviewData) {
|
||||||
try {
|
try {
|
||||||
parsedData = JSON.parse(reviewData);
|
parsedData = JSON.parse(reviewData);
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
toast({
|
toast({
|
||||||
title: "Invalid JSON",
|
title: "Invalid JSON",
|
||||||
description: `Please fix the JSON format in review for node ${review.node_exec_id}: ${error instanceof Error ? error.message : "Invalid syntax"}`,
|
description: `Please fix the JSON format in review for node ${review.node_exec_id}: ${error instanceof Error ? error.message : "Invalid syntax"}`,
|
||||||
variant: "destructive",
|
variant: "destructive",
|
||||||
});
|
});
|
||||||
setPendingAction(null);
|
setPendingAction(null);
|
||||||
return;
|
return;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
parsedData = review.payload;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -130,7 +182,7 @@ export function PendingReviewsList({
|
|||||||
node_exec_id: review.node_exec_id,
|
node_exec_id: review.node_exec_id,
|
||||||
approved,
|
approved,
|
||||||
reviewed_data: parsedData,
|
reviewed_data: parsedData,
|
||||||
message: reviewMessage || undefined,
|
auto_approve_future: autoApproveThisNode && approved,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -158,7 +210,6 @@ export function PendingReviewsList({
|
|||||||
|
|
||||||
return (
|
return (
|
||||||
<div className="space-y-7 rounded-xl border border-yellow-150 bg-yellow-25 p-6">
|
<div className="space-y-7 rounded-xl border border-yellow-150 bg-yellow-25 p-6">
|
||||||
{/* Warning Box Header */}
|
|
||||||
<div className="space-y-6">
|
<div className="space-y-6">
|
||||||
<div className="flex items-start gap-2">
|
<div className="flex items-start gap-2">
|
||||||
<WarningIcon
|
<WarningIcon
|
||||||
@@ -180,23 +231,76 @@ export function PendingReviewsList({
|
|||||||
</div>
|
</div>
|
||||||
|
|
||||||
<div className="space-y-7">
|
<div className="space-y-7">
|
||||||
{reviews.map((review) => (
|
{Object.entries(groupedReviews).map(([nodeId, nodeReviews]) => {
|
||||||
<PendingReviewCard
|
const isCollapsed = collapsedGroups[nodeId] ?? nodeReviews.length > 1;
|
||||||
key={review.node_exec_id}
|
const reviewCount = nodeReviews.length;
|
||||||
review={review}
|
|
||||||
onReviewDataChange={handleReviewDataChange}
|
const firstReview = nodeReviews[0];
|
||||||
onReviewMessageChange={handleReviewMessageChange}
|
const blockName = firstReview?.instructions;
|
||||||
reviewMessage={reviewMessageMap[review.node_exec_id] || ""}
|
const reviewTitle = `Review required for ${blockName}`;
|
||||||
/>
|
|
||||||
))}
|
const getShortenedNodeId = (id: string) => {
|
||||||
|
if (id.length <= 8) return id;
|
||||||
|
return `${id.slice(0, 4)}...${id.slice(-4)}`;
|
||||||
|
};
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div key={nodeId} className="space-y-4">
|
||||||
|
<button
|
||||||
|
onClick={() => toggleGroupCollapse(nodeId)}
|
||||||
|
className="flex w-full items-center gap-2 text-left"
|
||||||
|
>
|
||||||
|
{isCollapsed ? (
|
||||||
|
<CaretRightIcon size={20} className="text-gray-600" />
|
||||||
|
) : (
|
||||||
|
<CaretDownIcon size={20} className="text-gray-600" />
|
||||||
|
)}
|
||||||
|
<div className="flex-1">
|
||||||
|
<Text variant="body" className="font-semibold text-gray-900">
|
||||||
|
{reviewTitle}
|
||||||
|
</Text>
|
||||||
|
<Text variant="small" className="text-gray-500">
|
||||||
|
Node #{getShortenedNodeId(nodeId)}
|
||||||
|
</Text>
|
||||||
|
</div>
|
||||||
|
<span className="text-xs text-gray-600">
|
||||||
|
{reviewCount} {reviewCount === 1 ? "review" : "reviews"}
|
||||||
|
</span>
|
||||||
|
</button>
|
||||||
|
|
||||||
|
{!isCollapsed && (
|
||||||
|
<div className="space-y-4">
|
||||||
|
{nodeReviews.map((review) => (
|
||||||
|
<PendingReviewCard
|
||||||
|
key={review.node_exec_id}
|
||||||
|
review={review}
|
||||||
|
onReviewDataChange={handleReviewDataChange}
|
||||||
|
autoApproveFuture={autoApproveFutureMap[nodeId] || false}
|
||||||
|
externalDataValue={reviewDataMap[review.node_exec_id]}
|
||||||
|
showAutoApprove={false}
|
||||||
|
/>
|
||||||
|
))}
|
||||||
|
|
||||||
|
<div className="flex items-center gap-3 pt-2">
|
||||||
|
<Switch
|
||||||
|
checked={autoApproveFutureMap[nodeId] || false}
|
||||||
|
onCheckedChange={(enabled: boolean) =>
|
||||||
|
handleAutoApproveFutureToggle(nodeId, enabled)
|
||||||
|
}
|
||||||
|
/>
|
||||||
|
<Text variant="small" className="text-gray-700">
|
||||||
|
Auto-approve future executions of this node
|
||||||
|
</Text>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
})}
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<div className="space-y-7">
|
<div className="space-y-4">
|
||||||
<Text variant="body" className="text-textGrey">
|
<div className="flex flex-wrap gap-2">
|
||||||
Note: Changes you make here apply only to this task
|
|
||||||
</Text>
|
|
||||||
|
|
||||||
<div className="flex gap-2">
|
|
||||||
<Button
|
<Button
|
||||||
onClick={() => processReviews(true)}
|
onClick={() => processReviews(true)}
|
||||||
disabled={reviewActionMutation.isPending || reviews.length === 0}
|
disabled={reviewActionMutation.isPending || reviews.length === 0}
|
||||||
@@ -220,6 +324,11 @@ export function PendingReviewsList({
|
|||||||
Reject
|
Reject
|
||||||
</Button>
|
</Button>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
<Text variant="small" className="text-textGrey">
|
||||||
|
You can turn auto-approval on or off using the toggle above for each
|
||||||
|
node.
|
||||||
|
</Text>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
);
|
);
|
||||||
|
|||||||
@@ -15,8 +15,22 @@ export function usePendingReviews() {
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
export function usePendingReviewsForExecution(graphExecId: string) {
|
interface UsePendingReviewsForExecutionOptions {
|
||||||
const query = useGetV2GetPendingReviewsForExecution(graphExecId);
|
enabled?: boolean;
|
||||||
|
refetchInterval?: number | false;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function usePendingReviewsForExecution(
|
||||||
|
graphExecId: string,
|
||||||
|
options?: UsePendingReviewsForExecutionOptions,
|
||||||
|
) {
|
||||||
|
const query = useGetV2GetPendingReviewsForExecution(graphExecId, {
|
||||||
|
query: {
|
||||||
|
enabled: options?.enabled ?? !!graphExecId,
|
||||||
|
refetchInterval: options?.refetchInterval,
|
||||||
|
refetchIntervalInBackground: !!options?.refetchInterval,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
return {
|
return {
|
||||||
pendingReviews: okData(query.data) || [],
|
pendingReviews: okData(query.data) || [],
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ export enum Key {
|
|||||||
LIBRARY_AGENTS_CACHE = "library-agents-cache",
|
LIBRARY_AGENTS_CACHE = "library-agents-cache",
|
||||||
CHAT_SESSION_ID = "chat_session_id",
|
CHAT_SESSION_ID = "chat_session_id",
|
||||||
COOKIE_CONSENT = "autogpt_cookie_consent",
|
COOKIE_CONSENT = "autogpt_cookie_consent",
|
||||||
|
AI_AGENT_SAFETY_POPUP_SHOWN = "ai-agent-safety-popup-shown",
|
||||||
}
|
}
|
||||||
|
|
||||||
function get(key: Key) {
|
function get(key: Key) {
|
||||||
|
|||||||
@@ -157,12 +157,21 @@ const config = {
|
|||||||
backgroundPosition: "-200% 0",
|
backgroundPosition: "-200% 0",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
loader: {
|
||||||
|
"0%": {
|
||||||
|
boxShadow: "0 0 0 0 rgba(0, 0, 0, 0.25)",
|
||||||
|
},
|
||||||
|
"100%": {
|
||||||
|
boxShadow: "0 0 0 30px rgba(0, 0, 0, 0)",
|
||||||
|
},
|
||||||
|
},
|
||||||
},
|
},
|
||||||
animation: {
|
animation: {
|
||||||
"accordion-down": "accordion-down 0.2s ease-out",
|
"accordion-down": "accordion-down 0.2s ease-out",
|
||||||
"accordion-up": "accordion-up 0.2s ease-out",
|
"accordion-up": "accordion-up 0.2s ease-out",
|
||||||
"fade-in": "fade-in 0.2s ease-out",
|
"fade-in": "fade-in 0.2s ease-out",
|
||||||
shimmer: "shimmer 2s ease-in-out infinite",
|
shimmer: "shimmer 2s ease-in-out infinite",
|
||||||
|
loader: "loader 1s infinite",
|
||||||
},
|
},
|
||||||
transitionDuration: {
|
transitionDuration: {
|
||||||
"2000": "2000ms",
|
"2000": "2000ms",
|
||||||
|
|||||||
BIN
docs/integrations/.gitbook/assets/Ollama-Add-Prompts.png
Normal file
|
After Width: | Height: | Size: 115 KiB |
BIN
docs/integrations/.gitbook/assets/Ollama-Output.png
Normal file
|
After Width: | Height: | Size: 29 KiB |
BIN
docs/integrations/.gitbook/assets/Ollama-Remote-Host.png
Normal file
|
After Width: | Height: | Size: 6.0 KiB |
BIN
docs/integrations/.gitbook/assets/Ollama-Select-Llama32.png
Normal file
|
After Width: | Height: | Size: 81 KiB |
BIN
docs/integrations/.gitbook/assets/Select-AI-block.png
Normal file
|
After Width: | Height: | Size: 116 KiB |
BIN
docs/integrations/.gitbook/assets/e2b-dashboard.png
Normal file
|
After Width: | Height: | Size: 504 KiB |
BIN
docs/integrations/.gitbook/assets/e2b-log-url.png
Normal file
|
After Width: | Height: | Size: 43 KiB |
BIN
docs/integrations/.gitbook/assets/e2b-new-tag.png
Normal file
|
After Width: | Height: | Size: 47 KiB |
BIN
docs/integrations/.gitbook/assets/e2b-tag-button.png
Normal file
|
After Width: | Height: | Size: 20 KiB |
BIN
docs/integrations/.gitbook/assets/get-repo-dialog.png
Normal file
|
After Width: | Height: | Size: 68 KiB |
133
docs/integrations/SUMMARY.md
Normal file
@@ -0,0 +1,133 @@
|
|||||||
|
# Table of contents
|
||||||
|
|
||||||
|
* [AutoGPT Blocks Overview](README.md)
|
||||||
|
|
||||||
|
## Guides
|
||||||
|
|
||||||
|
* [LLM Providers](guides/llm-providers.md)
|
||||||
|
* [Voice Providers](guides/voice-providers.md)
|
||||||
|
|
||||||
|
## Block Integrations
|
||||||
|
|
||||||
|
* [Airtable Bases](block-integrations/airtable/bases.md)
|
||||||
|
* [Airtable Records](block-integrations/airtable/records.md)
|
||||||
|
* [Airtable Schema](block-integrations/airtable/schema.md)
|
||||||
|
* [Airtable Triggers](block-integrations/airtable/triggers.md)
|
||||||
|
* [Apollo Organization](block-integrations/apollo/organization.md)
|
||||||
|
* [Apollo People](block-integrations/apollo/people.md)
|
||||||
|
* [Apollo Person](block-integrations/apollo/person.md)
|
||||||
|
* [Ayrshare Post To Bluesky](block-integrations/ayrshare/post_to_bluesky.md)
|
||||||
|
* [Ayrshare Post To Facebook](block-integrations/ayrshare/post_to_facebook.md)
|
||||||
|
* [Ayrshare Post To GMB](block-integrations/ayrshare/post_to_gmb.md)
|
||||||
|
* [Ayrshare Post To Instagram](block-integrations/ayrshare/post_to_instagram.md)
|
||||||
|
* [Ayrshare Post To LinkedIn](block-integrations/ayrshare/post_to_linkedin.md)
|
||||||
|
* [Ayrshare Post To Pinterest](block-integrations/ayrshare/post_to_pinterest.md)
|
||||||
|
* [Ayrshare Post To Reddit](block-integrations/ayrshare/post_to_reddit.md)
|
||||||
|
* [Ayrshare Post To Snapchat](block-integrations/ayrshare/post_to_snapchat.md)
|
||||||
|
* [Ayrshare Post To Telegram](block-integrations/ayrshare/post_to_telegram.md)
|
||||||
|
* [Ayrshare Post To Threads](block-integrations/ayrshare/post_to_threads.md)
|
||||||
|
* [Ayrshare Post To TikTok](block-integrations/ayrshare/post_to_tiktok.md)
|
||||||
|
* [Ayrshare Post To X](block-integrations/ayrshare/post_to_x.md)
|
||||||
|
* [Ayrshare Post To YouTube](block-integrations/ayrshare/post_to_youtube.md)
|
||||||
|
* [Baas Bots](block-integrations/baas/bots.md)
|
||||||
|
* [Bannerbear Text Overlay](block-integrations/bannerbear/text_overlay.md)
|
||||||
|
* [Basic](block-integrations/basic.md)
|
||||||
|
* [Compass Triggers](block-integrations/compass/triggers.md)
|
||||||
|
* [Data](block-integrations/data.md)
|
||||||
|
* [Dataforseo Keyword Suggestions](block-integrations/dataforseo/keyword_suggestions.md)
|
||||||
|
* [Dataforseo Related Keywords](block-integrations/dataforseo/related_keywords.md)
|
||||||
|
* [Discord Bot Blocks](block-integrations/discord/bot_blocks.md)
|
||||||
|
* [Discord OAuth Blocks](block-integrations/discord/oauth_blocks.md)
|
||||||
|
* [Enrichlayer LinkedIn](block-integrations/enrichlayer/linkedin.md)
|
||||||
|
* [Exa Answers](block-integrations/exa/answers.md)
|
||||||
|
* [Exa Code Context](block-integrations/exa/code_context.md)
|
||||||
|
* [Exa Contents](block-integrations/exa/contents.md)
|
||||||
|
* [Exa Research](block-integrations/exa/research.md)
|
||||||
|
* [Exa Search](block-integrations/exa/search.md)
|
||||||
|
* [Exa Similar](block-integrations/exa/similar.md)
|
||||||
|
* [Exa Webhook Blocks](block-integrations/exa/webhook_blocks.md)
|
||||||
|
* [Exa Websets](block-integrations/exa/websets.md)
|
||||||
|
* [Exa Websets Enrichment](block-integrations/exa/websets_enrichment.md)
|
||||||
|
* [Exa Websets Import Export](block-integrations/exa/websets_import_export.md)
|
||||||
|
* [Exa Websets Items](block-integrations/exa/websets_items.md)
|
||||||
|
* [Exa Websets Monitor](block-integrations/exa/websets_monitor.md)
|
||||||
|
* [Exa Websets Polling](block-integrations/exa/websets_polling.md)
|
||||||
|
* [Exa Websets Search](block-integrations/exa/websets_search.md)
|
||||||
|
* [Fal AI Video Generator](block-integrations/fal/ai_video_generator.md)
|
||||||
|
* [Firecrawl Crawl](block-integrations/firecrawl/crawl.md)
|
||||||
|
* [Firecrawl Extract](block-integrations/firecrawl/extract.md)
|
||||||
|
* [Firecrawl Map](block-integrations/firecrawl/map.md)
|
||||||
|
* [Firecrawl Scrape](block-integrations/firecrawl/scrape.md)
|
||||||
|
* [Firecrawl Search](block-integrations/firecrawl/search.md)
|
||||||
|
* [Generic Webhook Triggers](block-integrations/generic_webhook/triggers.md)
|
||||||
|
* [GitHub Checks](block-integrations/github/checks.md)
|
||||||
|
* [GitHub CI](block-integrations/github/ci.md)
|
||||||
|
* [GitHub Issues](block-integrations/github/issues.md)
|
||||||
|
* [GitHub Pull Requests](block-integrations/github/pull_requests.md)
|
||||||
|
* [GitHub Repo](block-integrations/github/repo.md)
|
||||||
|
* [GitHub Reviews](block-integrations/github/reviews.md)
|
||||||
|
* [GitHub Statuses](block-integrations/github/statuses.md)
|
||||||
|
* [GitHub Triggers](block-integrations/github/triggers.md)
|
||||||
|
* [Google Calendar](block-integrations/google/calendar.md)
|
||||||
|
* [Google Docs](block-integrations/google/docs.md)
|
||||||
|
* [Google Gmail](block-integrations/google/gmail.md)
|
||||||
|
* [Google Sheets](block-integrations/google/sheets.md)
|
||||||
|
* [HubSpot Company](block-integrations/hubspot/company.md)
|
||||||
|
* [HubSpot Contact](block-integrations/hubspot/contact.md)
|
||||||
|
* [HubSpot Engagement](block-integrations/hubspot/engagement.md)
|
||||||
|
* [Jina Chunking](block-integrations/jina/chunking.md)
|
||||||
|
* [Jina Embeddings](block-integrations/jina/embeddings.md)
|
||||||
|
* [Jina Fact Checker](block-integrations/jina/fact_checker.md)
|
||||||
|
* [Jina Search](block-integrations/jina/search.md)
|
||||||
|
* [Linear Comment](block-integrations/linear/comment.md)
|
||||||
|
* [Linear Issues](block-integrations/linear/issues.md)
|
||||||
|
* [Linear Projects](block-integrations/linear/projects.md)
|
||||||
|
* [LLM](block-integrations/llm.md)
|
||||||
|
* [Logic](block-integrations/logic.md)
|
||||||
|
* [Misc](block-integrations/misc.md)
|
||||||
|
* [Multimedia](block-integrations/multimedia.md)
|
||||||
|
* [Notion Create Page](block-integrations/notion/create_page.md)
|
||||||
|
* [Notion Read Database](block-integrations/notion/read_database.md)
|
||||||
|
* [Notion Read Page](block-integrations/notion/read_page.md)
|
||||||
|
* [Notion Read Page Markdown](block-integrations/notion/read_page_markdown.md)
|
||||||
|
* [Notion Search](block-integrations/notion/search.md)
|
||||||
|
* [Nvidia Deepfake](block-integrations/nvidia/deepfake.md)
|
||||||
|
* [Replicate Flux Advanced](block-integrations/replicate/flux_advanced.md)
|
||||||
|
* [Replicate Replicate Block](block-integrations/replicate/replicate_block.md)
|
||||||
|
* [Search](block-integrations/search.md)
|
||||||
|
* [Slant3D Filament](block-integrations/slant3d/filament.md)
|
||||||
|
* [Slant3D Order](block-integrations/slant3d/order.md)
|
||||||
|
* [Slant3D Slicing](block-integrations/slant3d/slicing.md)
|
||||||
|
* [Slant3D Webhook](block-integrations/slant3d/webhook.md)
|
||||||
|
* [Smartlead Campaign](block-integrations/smartlead/campaign.md)
|
||||||
|
* [Stagehand Blocks](block-integrations/stagehand/blocks.md)
|
||||||
|
* [System Library Operations](block-integrations/system/library_operations.md)
|
||||||
|
* [System Store Operations](block-integrations/system/store_operations.md)
|
||||||
|
* [Text](block-integrations/text.md)
|
||||||
|
* [Todoist Comments](block-integrations/todoist/comments.md)
|
||||||
|
* [Todoist Labels](block-integrations/todoist/labels.md)
|
||||||
|
* [Todoist Projects](block-integrations/todoist/projects.md)
|
||||||
|
* [Todoist Sections](block-integrations/todoist/sections.md)
|
||||||
|
* [Todoist Tasks](block-integrations/todoist/tasks.md)
|
||||||
|
* [Twitter Blocks](block-integrations/twitter/blocks.md)
|
||||||
|
* [Twitter Bookmark](block-integrations/twitter/bookmark.md)
|
||||||
|
* [Twitter Follows](block-integrations/twitter/follows.md)
|
||||||
|
* [Twitter Hide](block-integrations/twitter/hide.md)
|
||||||
|
* [Twitter Like](block-integrations/twitter/like.md)
|
||||||
|
* [Twitter List Follows](block-integrations/twitter/list_follows.md)
|
||||||
|
* [Twitter List Lookup](block-integrations/twitter/list_lookup.md)
|
||||||
|
* [Twitter List Members](block-integrations/twitter/list_members.md)
|
||||||
|
* [Twitter List Tweets Lookup](block-integrations/twitter/list_tweets_lookup.md)
|
||||||
|
* [Twitter Manage](block-integrations/twitter/manage.md)
|
||||||
|
* [Twitter Manage Lists](block-integrations/twitter/manage_lists.md)
|
||||||
|
* [Twitter Mutes](block-integrations/twitter/mutes.md)
|
||||||
|
* [Twitter Pinned Lists](block-integrations/twitter/pinned_lists.md)
|
||||||
|
* [Twitter Quote](block-integrations/twitter/quote.md)
|
||||||
|
* [Twitter Retweet](block-integrations/twitter/retweet.md)
|
||||||
|
* [Twitter Search Spaces](block-integrations/twitter/search_spaces.md)
|
||||||
|
* [Twitter Spaces Lookup](block-integrations/twitter/spaces_lookup.md)
|
||||||
|
* [Twitter Timeline](block-integrations/twitter/timeline.md)
|
||||||
|
* [Twitter Tweet Lookup](block-integrations/twitter/tweet_lookup.md)
|
||||||
|
* [Twitter User Lookup](block-integrations/twitter/user_lookup.md)
|
||||||
|
* [Wolfram LLM API](block-integrations/wolfram/llm_api.md)
|
||||||
|
* [Zerobounce Validate Emails](block-integrations/zerobounce/validate_emails.md)
|
||||||