mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-03-17 03:00:27 -04:00
Compare commits
5 Commits
fix/transc
...
feat/async
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ec7c7ebea2 | ||
|
|
8ef8bec14f | ||
|
|
9b3e25d98e | ||
|
|
0bc098acb1 | ||
|
|
d78e0ee122 |
@@ -1,11 +1,13 @@
|
||||
"""External Agent Generator service client.
|
||||
|
||||
This module provides a client for communicating with the external Agent Generator
|
||||
microservice. When AGENTGENERATOR_HOST is configured, the agent generation functions
|
||||
will delegate to the external service instead of using the built-in LLM-based implementation.
|
||||
microservice. All generation endpoints use async polling: submit a job (202),
|
||||
then poll GET /api/jobs/{job_id} every few seconds until the result is ready.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
@@ -25,22 +27,21 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
_dummy_mode_warned = False
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Shared helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
POLL_INTERVAL_SECONDS = 10.0
|
||||
MAX_POLL_TIME_SECONDS = 1800.0 # 30 minutes
|
||||
MAX_CONSECUTIVE_POLL_ERRORS = 5
|
||||
|
||||
|
||||
def _create_error_response(
|
||||
error_message: str,
|
||||
error_type: str = "unknown",
|
||||
details: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Create a standardized error response dict.
|
||||
|
||||
Args:
|
||||
error_message: Human-readable error message
|
||||
error_type: Machine-readable error type
|
||||
details: Optional additional error details
|
||||
|
||||
Returns:
|
||||
Error dict with type="error" and error details
|
||||
"""
|
||||
"""Create a standardized error response dict."""
|
||||
response: dict[str, Any] = {
|
||||
"type": "error",
|
||||
"error": error_message,
|
||||
@@ -52,14 +53,7 @@ def _create_error_response(
|
||||
|
||||
|
||||
def _classify_http_error(e: httpx.HTTPStatusError) -> tuple[str, str]:
|
||||
"""Classify an HTTP error into error_type and message.
|
||||
|
||||
Args:
|
||||
e: The HTTP status error
|
||||
|
||||
Returns:
|
||||
Tuple of (error_type, error_message)
|
||||
"""
|
||||
"""Classify an HTTP error into error_type and message."""
|
||||
status = e.response.status_code
|
||||
if status == 429:
|
||||
return "rate_limit", f"Agent Generator rate limited: {e}"
|
||||
@@ -72,14 +66,7 @@ def _classify_http_error(e: httpx.HTTPStatusError) -> tuple[str, str]:
|
||||
|
||||
|
||||
def _classify_request_error(e: httpx.RequestError) -> tuple[str, str]:
|
||||
"""Classify a request error into error_type and message.
|
||||
|
||||
Args:
|
||||
e: The request error
|
||||
|
||||
Returns:
|
||||
Tuple of (error_type, error_message)
|
||||
"""
|
||||
"""Classify a request error into error_type and message."""
|
||||
error_str = str(e).lower()
|
||||
if "timeout" in error_str or "timed out" in error_str:
|
||||
return "timeout", f"Agent Generator request timed out: {e}"
|
||||
@@ -89,6 +76,10 @@ def _classify_request_error(e: httpx.RequestError) -> tuple[str, str]:
|
||||
return "request_error", f"Request error calling Agent Generator: {e}"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Client / settings singletons
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_client: httpx.AsyncClient | None = None
|
||||
_settings: Settings | None = None
|
||||
|
||||
@@ -136,13 +127,149 @@ def _get_client() -> httpx.AsyncClient:
|
||||
global _client
|
||||
if _client is None:
|
||||
settings = _get_settings()
|
||||
timeout = httpx.Timeout(float(settings.config.agentgenerator_timeout))
|
||||
_client = httpx.AsyncClient(
|
||||
base_url=_get_base_url(),
|
||||
timeout=httpx.Timeout(settings.config.agentgenerator_timeout),
|
||||
timeout=timeout,
|
||||
)
|
||||
return _client
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Core polling helper
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _submit_and_poll(
|
||||
endpoint: str,
|
||||
payload: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
"""Submit a job to the agent-generator and poll until the result is ready.
|
||||
|
||||
The endpoint is expected to return 202 with ``{"job_id": "..."}`` on success.
|
||||
We then poll ``GET /api/jobs/{job_id}`` every ``POLL_INTERVAL_SECONDS``
|
||||
until the job completes or fails.
|
||||
|
||||
Returns:
|
||||
The *result* dict from a completed job, or an error dict.
|
||||
"""
|
||||
client = _get_client()
|
||||
|
||||
# 1. Submit ----------------------------------------------------------------
|
||||
try:
|
||||
response = await client.post(endpoint, json=payload)
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as e:
|
||||
error_type, error_msg = _classify_http_error(e)
|
||||
logger.error(error_msg)
|
||||
return _create_error_response(error_msg, error_type)
|
||||
except httpx.RequestError as e:
|
||||
error_type, error_msg = _classify_request_error(e)
|
||||
logger.error(error_msg)
|
||||
return _create_error_response(error_msg, error_type)
|
||||
|
||||
data = response.json()
|
||||
job_id = data.get("job_id")
|
||||
if not job_id:
|
||||
return _create_error_response(
|
||||
"Agent Generator did not return a job_id", "invalid_response"
|
||||
)
|
||||
|
||||
logger.info(f"Agent Generator job submitted: {job_id} via {endpoint}")
|
||||
|
||||
# 2. Poll ------------------------------------------------------------------
|
||||
start = time.monotonic()
|
||||
consecutive_errors = 0
|
||||
while (time.monotonic() - start) < MAX_POLL_TIME_SECONDS:
|
||||
await asyncio.sleep(POLL_INTERVAL_SECONDS)
|
||||
|
||||
try:
|
||||
poll_resp = await client.get(f"/api/jobs/{job_id}")
|
||||
poll_resp.raise_for_status()
|
||||
except httpx.HTTPStatusError as e:
|
||||
if e.response.status_code == 404:
|
||||
return _create_error_response(
|
||||
"Agent Generator job not found or expired", "job_not_found"
|
||||
)
|
||||
status_code = e.response.status_code
|
||||
if status_code in {429, 503, 504, 408}:
|
||||
consecutive_errors += 1
|
||||
logger.warning(
|
||||
f"Transient HTTP {status_code} polling job {job_id} "
|
||||
f"({consecutive_errors}/{MAX_CONSECUTIVE_POLL_ERRORS}): {e}"
|
||||
)
|
||||
if consecutive_errors >= MAX_CONSECUTIVE_POLL_ERRORS:
|
||||
error_type, error_msg = _classify_http_error(e)
|
||||
logger.error(
|
||||
f"Giving up on job {job_id} after "
|
||||
f"{MAX_CONSECUTIVE_POLL_ERRORS} consecutive poll errors: {error_msg}"
|
||||
)
|
||||
return _create_error_response(error_msg, error_type)
|
||||
continue
|
||||
error_type, error_msg = _classify_http_error(e)
|
||||
logger.error(f"Poll error for job {job_id}: {error_msg}")
|
||||
return _create_error_response(error_msg, error_type)
|
||||
except httpx.RequestError as e:
|
||||
consecutive_errors += 1
|
||||
logger.warning(
|
||||
f"Transient poll error for job {job_id} "
|
||||
f"({consecutive_errors}/{MAX_CONSECUTIVE_POLL_ERRORS}): {e}"
|
||||
)
|
||||
if consecutive_errors >= MAX_CONSECUTIVE_POLL_ERRORS:
|
||||
error_msg = (
|
||||
f"Giving up on job {job_id} after "
|
||||
f"{MAX_CONSECUTIVE_POLL_ERRORS} consecutive poll errors: {e}"
|
||||
)
|
||||
logger.error(error_msg)
|
||||
return _create_error_response(error_msg, "poll_error")
|
||||
continue
|
||||
|
||||
consecutive_errors = 0
|
||||
poll_data = poll_resp.json()
|
||||
status = poll_data.get("status")
|
||||
|
||||
if status == "completed":
|
||||
logger.info(f"Agent Generator job {job_id} completed")
|
||||
result = poll_data.get("result", {})
|
||||
if not isinstance(result, dict):
|
||||
return _create_error_response(
|
||||
"Agent Generator returned invalid result payload",
|
||||
"invalid_response",
|
||||
)
|
||||
return result
|
||||
elif status == "failed":
|
||||
error_msg = poll_data.get("error", "Job failed")
|
||||
logger.error(f"Agent Generator job {job_id} failed: {error_msg}")
|
||||
return _create_error_response(error_msg, "job_failed")
|
||||
elif status in {"running", "pending", "queued"}:
|
||||
continue
|
||||
else:
|
||||
return _create_error_response(
|
||||
f"Agent Generator returned unexpected job status: {status}",
|
||||
"invalid_response",
|
||||
)
|
||||
|
||||
return _create_error_response("Agent generation timed out after polling", "timeout")
|
||||
|
||||
|
||||
def _extract_agent_json(result: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Extract and validate agent_json from a job result.
|
||||
|
||||
Returns the agent_json dict, or an error response if missing/invalid.
|
||||
"""
|
||||
agent_json = result.get("agent_json")
|
||||
if not isinstance(agent_json, dict):
|
||||
return _create_error_response(
|
||||
"Agent Generator returned no agent_json in result", "invalid_response"
|
||||
)
|
||||
return agent_json
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public functions — same signatures as before, now using polling
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def decompose_goal_external(
|
||||
description: str,
|
||||
context: str = "",
|
||||
@@ -150,25 +277,17 @@ async def decompose_goal_external(
|
||||
) -> dict[str, Any] | None:
|
||||
"""Call the external service to decompose a goal.
|
||||
|
||||
Args:
|
||||
description: Natural language goal description
|
||||
context: Additional context (e.g., answers to previous questions)
|
||||
library_agents: User's library agents available for sub-agent composition
|
||||
Returns one of the following dicts (keyed by ``"type"``):
|
||||
|
||||
Returns:
|
||||
Dict with either:
|
||||
- {"type": "clarifying_questions", "questions": [...]}
|
||||
- {"type": "instructions", "steps": [...]}
|
||||
- {"type": "unachievable_goal", ...}
|
||||
- {"type": "vague_goal", ...}
|
||||
- {"type": "error", "error": "...", "error_type": "..."} on error
|
||||
Or None on unexpected error
|
||||
* ``{"type": "instructions", "steps": [...]}``
|
||||
* ``{"type": "clarifying_questions", "questions": [...]}``
|
||||
* ``{"type": "unachievable_goal", "reason": ..., "suggested_goal": ...}``
|
||||
* ``{"type": "vague_goal", "suggested_goal": ...}``
|
||||
* ``{"type": "error", "error": ..., "error_type": ...}``
|
||||
"""
|
||||
if _is_dummy_mode():
|
||||
return await decompose_goal_dummy(description, context, library_agents)
|
||||
|
||||
client = _get_client()
|
||||
|
||||
if context:
|
||||
description = f"{description}\n\nAdditional context from user:\n{context}"
|
||||
|
||||
@@ -177,67 +296,43 @@ async def decompose_goal_external(
|
||||
payload["library_agents"] = library_agents
|
||||
|
||||
try:
|
||||
response = await client.post("/api/decompose-description", json=payload)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
if not data.get("success"):
|
||||
error_msg = data.get("error", "Unknown error from Agent Generator")
|
||||
error_type = data.get("error_type", "unknown")
|
||||
logger.error(
|
||||
f"Agent Generator decomposition failed: {error_msg} "
|
||||
f"(type: {error_type})"
|
||||
)
|
||||
return _create_error_response(error_msg, error_type)
|
||||
|
||||
# Map the response to the expected format
|
||||
response_type = data.get("type")
|
||||
if response_type == "instructions":
|
||||
return {"type": "instructions", "steps": data.get("steps", [])}
|
||||
elif response_type == "clarifying_questions":
|
||||
return {
|
||||
"type": "clarifying_questions",
|
||||
"questions": data.get("questions", []),
|
||||
}
|
||||
elif response_type == "unachievable_goal":
|
||||
return {
|
||||
"type": "unachievable_goal",
|
||||
"reason": data.get("reason"),
|
||||
"suggested_goal": data.get("suggested_goal"),
|
||||
}
|
||||
elif response_type == "vague_goal":
|
||||
return {
|
||||
"type": "vague_goal",
|
||||
"suggested_goal": data.get("suggested_goal"),
|
||||
}
|
||||
elif response_type == "error":
|
||||
# Pass through error from the service
|
||||
return _create_error_response(
|
||||
data.get("error", "Unknown error"),
|
||||
data.get("error_type", "unknown"),
|
||||
)
|
||||
else:
|
||||
logger.error(
|
||||
f"Unknown response type from external service: {response_type}"
|
||||
)
|
||||
return _create_error_response(
|
||||
f"Unknown response type from Agent Generator: {response_type}",
|
||||
"invalid_response",
|
||||
)
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
error_type, error_msg = _classify_http_error(e)
|
||||
logger.error(error_msg)
|
||||
return _create_error_response(error_msg, error_type)
|
||||
except httpx.RequestError as e:
|
||||
error_type, error_msg = _classify_request_error(e)
|
||||
logger.error(error_msg)
|
||||
return _create_error_response(error_msg, error_type)
|
||||
result = await _submit_and_poll("/api/decompose-description", payload)
|
||||
except Exception as e:
|
||||
error_msg = f"Unexpected error calling Agent Generator: {e}"
|
||||
logger.error(error_msg)
|
||||
return _create_error_response(error_msg, "unexpected_error")
|
||||
|
||||
# The result dict from the job is already in the expected format
|
||||
# (type, steps, questions, etc.) — just return it as-is.
|
||||
if result.get("type") == "error":
|
||||
return result
|
||||
|
||||
response_type = result.get("type")
|
||||
if response_type == "instructions":
|
||||
return {"type": "instructions", "steps": result.get("steps", [])}
|
||||
elif response_type == "clarifying_questions":
|
||||
return {
|
||||
"type": "clarifying_questions",
|
||||
"questions": result.get("questions", []),
|
||||
}
|
||||
elif response_type == "unachievable_goal":
|
||||
return {
|
||||
"type": "unachievable_goal",
|
||||
"reason": result.get("reason"),
|
||||
"suggested_goal": result.get("suggested_goal"),
|
||||
}
|
||||
elif response_type == "vague_goal":
|
||||
return {
|
||||
"type": "vague_goal",
|
||||
"suggested_goal": result.get("suggested_goal"),
|
||||
}
|
||||
else:
|
||||
logger.error(f"Unknown response type from Agent Generator job: {response_type}")
|
||||
return _create_error_response(
|
||||
f"Unknown response type: {response_type}",
|
||||
"invalid_response",
|
||||
)
|
||||
|
||||
|
||||
async def generate_agent_external(
|
||||
instructions: dict[str, Any],
|
||||
@@ -245,51 +340,28 @@ async def generate_agent_external(
|
||||
) -> dict[str, Any] | None:
|
||||
"""Call the external service to generate an agent from instructions.
|
||||
|
||||
Args:
|
||||
instructions: Structured instructions from decompose_goal
|
||||
library_agents: User's library agents available for sub-agent composition
|
||||
|
||||
Returns:
|
||||
Agent JSON dict or error dict {"type": "error", ...} on error
|
||||
Agent JSON dict or error dict {"type": "error", ...} on error.
|
||||
"""
|
||||
if _is_dummy_mode():
|
||||
return await generate_agent_dummy(instructions, library_agents)
|
||||
|
||||
client = _get_client()
|
||||
|
||||
# Build request payload
|
||||
payload: dict[str, Any] = {"instructions": instructions}
|
||||
if library_agents:
|
||||
payload["library_agents"] = library_agents
|
||||
|
||||
try:
|
||||
response = await client.post("/api/generate-agent", json=payload)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
if not data.get("success"):
|
||||
error_msg = data.get("error", "Unknown error from Agent Generator")
|
||||
error_type = data.get("error_type", "unknown")
|
||||
logger.error(
|
||||
f"Agent Generator generation failed: {error_msg} (type: {error_type})"
|
||||
)
|
||||
return _create_error_response(error_msg, error_type)
|
||||
|
||||
return data.get("agent_json")
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
error_type, error_msg = _classify_http_error(e)
|
||||
logger.error(error_msg)
|
||||
return _create_error_response(error_msg, error_type)
|
||||
except httpx.RequestError as e:
|
||||
error_type, error_msg = _classify_request_error(e)
|
||||
logger.error(error_msg)
|
||||
return _create_error_response(error_msg, error_type)
|
||||
result = await _submit_and_poll("/api/generate-agent", payload)
|
||||
except Exception as e:
|
||||
error_msg = f"Unexpected error calling Agent Generator: {e}"
|
||||
logger.error(error_msg)
|
||||
return _create_error_response(error_msg, "unexpected_error")
|
||||
|
||||
if result.get("type") == "error":
|
||||
return result
|
||||
|
||||
return _extract_agent_json(result)
|
||||
|
||||
|
||||
async def generate_agent_patch_external(
|
||||
update_request: str,
|
||||
@@ -298,24 +370,14 @@ async def generate_agent_patch_external(
|
||||
) -> dict[str, Any] | None:
|
||||
"""Call the external service to generate a patch for an existing agent.
|
||||
|
||||
Args:
|
||||
update_request: Natural language description of changes
|
||||
current_agent: Current agent JSON
|
||||
library_agents: User's library agents available for sub-agent composition
|
||||
operation_id: Operation ID for async processing (enables Redis Streams callback)
|
||||
session_id: Session ID for async processing (enables Redis Streams callback)
|
||||
|
||||
Returns:
|
||||
Updated agent JSON, clarifying questions dict, {"status": "accepted"} for async, or error dict on error
|
||||
Updated agent JSON, clarifying questions dict, or error dict.
|
||||
"""
|
||||
if _is_dummy_mode():
|
||||
return await generate_agent_patch_dummy(
|
||||
update_request, current_agent, library_agents
|
||||
)
|
||||
|
||||
client = _get_client()
|
||||
|
||||
# Build request payload
|
||||
payload: dict[str, Any] = {
|
||||
"update_request": update_request,
|
||||
"current_agent_json": current_agent,
|
||||
@@ -324,49 +386,23 @@ async def generate_agent_patch_external(
|
||||
payload["library_agents"] = library_agents
|
||||
|
||||
try:
|
||||
response = await client.post("/api/update-agent", json=payload)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
if not data.get("success"):
|
||||
error_msg = data.get("error", "Unknown error from Agent Generator")
|
||||
error_type = data.get("error_type", "unknown")
|
||||
logger.error(
|
||||
f"Agent Generator patch generation failed: {error_msg} "
|
||||
f"(type: {error_type})"
|
||||
)
|
||||
return _create_error_response(error_msg, error_type)
|
||||
|
||||
# Check if it's clarifying questions
|
||||
if data.get("type") == "clarifying_questions":
|
||||
return {
|
||||
"type": "clarifying_questions",
|
||||
"questions": data.get("questions", []),
|
||||
}
|
||||
|
||||
# Check if it's an error passed through
|
||||
if data.get("type") == "error":
|
||||
return _create_error_response(
|
||||
data.get("error", "Unknown error"),
|
||||
data.get("error_type", "unknown"),
|
||||
)
|
||||
|
||||
# Otherwise return the updated agent JSON
|
||||
return data.get("agent_json")
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
error_type, error_msg = _classify_http_error(e)
|
||||
logger.error(error_msg)
|
||||
return _create_error_response(error_msg, error_type)
|
||||
except httpx.RequestError as e:
|
||||
error_type, error_msg = _classify_request_error(e)
|
||||
logger.error(error_msg)
|
||||
return _create_error_response(error_msg, error_type)
|
||||
result = await _submit_and_poll("/api/update-agent", payload)
|
||||
except Exception as e:
|
||||
error_msg = f"Unexpected error calling Agent Generator: {e}"
|
||||
logger.error(error_msg)
|
||||
return _create_error_response(error_msg, "unexpected_error")
|
||||
|
||||
if result.get("type") == "error":
|
||||
return result
|
||||
|
||||
if result.get("type") == "clarifying_questions":
|
||||
return {
|
||||
"type": "clarifying_questions",
|
||||
"questions": result.get("questions", []),
|
||||
}
|
||||
|
||||
return _extract_agent_json(result)
|
||||
|
||||
|
||||
async def customize_template_external(
|
||||
template_agent: dict[str, Any],
|
||||
@@ -375,83 +411,51 @@ async def customize_template_external(
|
||||
) -> dict[str, Any] | None:
|
||||
"""Call the external service to customize a template/marketplace agent.
|
||||
|
||||
Args:
|
||||
template_agent: The template agent JSON to customize
|
||||
modification_request: Natural language description of customizations
|
||||
context: Additional context (e.g., answers to previous questions)
|
||||
operation_id: Operation ID for async processing (enables Redis Streams callback)
|
||||
session_id: Session ID for async processing (enables Redis Streams callback)
|
||||
|
||||
Returns:
|
||||
Customized agent JSON, clarifying questions dict, or error dict on error
|
||||
Customized agent JSON, clarifying questions dict, or error dict.
|
||||
"""
|
||||
if _is_dummy_mode():
|
||||
return await customize_template_dummy(
|
||||
template_agent, modification_request, context
|
||||
)
|
||||
|
||||
client = _get_client()
|
||||
|
||||
request = modification_request
|
||||
request_text = modification_request
|
||||
if context:
|
||||
request = f"{modification_request}\n\nAdditional context from user:\n{context}"
|
||||
request_text = (
|
||||
f"{modification_request}\n\nAdditional context from user:\n{context}"
|
||||
)
|
||||
|
||||
payload: dict[str, Any] = {
|
||||
"template_agent_json": template_agent,
|
||||
"modification_request": request,
|
||||
"modification_request": request_text,
|
||||
}
|
||||
|
||||
try:
|
||||
response = await client.post("/api/template-modification", json=payload)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
if not data.get("success"):
|
||||
error_msg = data.get("error", "Unknown error from Agent Generator")
|
||||
error_type = data.get("error_type", "unknown")
|
||||
logger.error(
|
||||
f"Agent Generator template customization failed: {error_msg} "
|
||||
f"(type: {error_type})"
|
||||
)
|
||||
return _create_error_response(error_msg, error_type)
|
||||
|
||||
# Check if it's clarifying questions
|
||||
if data.get("type") == "clarifying_questions":
|
||||
return {
|
||||
"type": "clarifying_questions",
|
||||
"questions": data.get("questions", []),
|
||||
}
|
||||
|
||||
# Check if it's an error passed through
|
||||
if data.get("type") == "error":
|
||||
return _create_error_response(
|
||||
data.get("error", "Unknown error"),
|
||||
data.get("error_type", "unknown"),
|
||||
)
|
||||
|
||||
# Otherwise return the customized agent JSON
|
||||
return data.get("agent_json")
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
error_type, error_msg = _classify_http_error(e)
|
||||
logger.error(error_msg)
|
||||
return _create_error_response(error_msg, error_type)
|
||||
except httpx.RequestError as e:
|
||||
error_type, error_msg = _classify_request_error(e)
|
||||
logger.error(error_msg)
|
||||
return _create_error_response(error_msg, error_type)
|
||||
result = await _submit_and_poll("/api/template-modification", payload)
|
||||
except Exception as e:
|
||||
error_msg = f"Unexpected error calling Agent Generator: {e}"
|
||||
logger.error(error_msg)
|
||||
return _create_error_response(error_msg, "unexpected_error")
|
||||
|
||||
if result.get("type") == "error":
|
||||
return result
|
||||
|
||||
if result.get("type") == "clarifying_questions":
|
||||
return {
|
||||
"type": "clarifying_questions",
|
||||
"questions": result.get("questions", []),
|
||||
}
|
||||
|
||||
return _extract_agent_json(result)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Non-generation endpoints (still synchronous — quick responses)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def get_blocks_external() -> list[dict[str, Any]] | None:
|
||||
"""Get available blocks from the external service.
|
||||
|
||||
Returns:
|
||||
List of block info dicts or None on error
|
||||
"""
|
||||
"""Get available blocks from the external service."""
|
||||
if _is_dummy_mode():
|
||||
return await get_blocks_dummy()
|
||||
|
||||
@@ -480,11 +484,7 @@ async def get_blocks_external() -> list[dict[str, Any]] | None:
|
||||
|
||||
|
||||
async def health_check() -> bool:
|
||||
"""Check if the external service is healthy.
|
||||
|
||||
Returns:
|
||||
True if healthy, False otherwise
|
||||
"""
|
||||
"""Check if the external service is healthy."""
|
||||
if not is_external_service_configured():
|
||||
return False
|
||||
|
||||
|
||||
@@ -372,8 +372,8 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
|
||||
description="The port for the Agent Generator service",
|
||||
)
|
||||
agentgenerator_timeout: int = Field(
|
||||
default=1800,
|
||||
description="The timeout in seconds for Agent Generator service requests (includes retries for rate limits)",
|
||||
default=30,
|
||||
description="The timeout in seconds for individual Agent Generator HTTP requests (submit and poll)",
|
||||
)
|
||||
agentgenerator_use_dummy: bool = Field(
|
||||
default=False,
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
Tests for the Agent Generator external service client.
|
||||
|
||||
This test suite verifies the external Agent Generator service integration,
|
||||
including service detection, API calls, and error handling.
|
||||
including service detection, async polling, and error handling.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
@@ -49,6 +49,292 @@ class TestServiceConfiguration:
|
||||
assert url == "http://agent-generator.local:8000"
|
||||
|
||||
|
||||
class TestSubmitAndPoll:
|
||||
"""Test the _submit_and_poll helper that handles async job polling."""
|
||||
|
||||
def setup_method(self):
|
||||
service._settings = None
|
||||
service._client = None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_successful_submit_and_poll(self):
|
||||
"""Test normal submit -> poll -> completed flow."""
|
||||
submit_resp = MagicMock()
|
||||
submit_resp.json.return_value = {"job_id": "job-123", "status": "accepted"}
|
||||
submit_resp.raise_for_status = MagicMock()
|
||||
|
||||
poll_resp = MagicMock()
|
||||
poll_resp.json.return_value = {
|
||||
"job_id": "job-123",
|
||||
"status": "completed",
|
||||
"result": {"type": "instructions", "steps": ["Step 1"]},
|
||||
}
|
||||
poll_resp.raise_for_status = MagicMock()
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post.return_value = submit_resp
|
||||
mock_client.get.return_value = poll_resp
|
||||
|
||||
with (
|
||||
patch.object(service, "_get_client", return_value=mock_client),
|
||||
patch("asyncio.sleep", new_callable=AsyncMock),
|
||||
):
|
||||
result = await service._submit_and_poll("/api/test", {"key": "value"})
|
||||
|
||||
assert result == {"type": "instructions", "steps": ["Step 1"]}
|
||||
mock_client.post.assert_called_once_with("/api/test", json={"key": "value"})
|
||||
mock_client.get.assert_called_once_with("/api/jobs/job-123")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_poll_returns_failed_job(self):
|
||||
"""Test submit -> poll -> failed flow."""
|
||||
submit_resp = MagicMock()
|
||||
submit_resp.json.return_value = {"job_id": "job-456", "status": "accepted"}
|
||||
submit_resp.raise_for_status = MagicMock()
|
||||
|
||||
poll_resp = MagicMock()
|
||||
poll_resp.json.return_value = {
|
||||
"job_id": "job-456",
|
||||
"status": "failed",
|
||||
"error": "Generation failed",
|
||||
}
|
||||
poll_resp.raise_for_status = MagicMock()
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post.return_value = submit_resp
|
||||
mock_client.get.return_value = poll_resp
|
||||
|
||||
with (
|
||||
patch.object(service, "_get_client", return_value=mock_client),
|
||||
patch("asyncio.sleep", new_callable=AsyncMock),
|
||||
):
|
||||
result = await service._submit_and_poll("/api/test", {})
|
||||
|
||||
assert result["type"] == "error"
|
||||
assert result["error_type"] == "job_failed"
|
||||
assert "Generation failed" in result["error"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_submit_http_error(self):
|
||||
"""Test HTTP error during job submission."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 500
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post.side_effect = httpx.HTTPStatusError(
|
||||
"Server error", request=MagicMock(), response=mock_response
|
||||
)
|
||||
|
||||
with patch.object(service, "_get_client", return_value=mock_client):
|
||||
result = await service._submit_and_poll("/api/test", {})
|
||||
|
||||
assert result["type"] == "error"
|
||||
assert result["error_type"] == "http_error"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_submit_connection_error(self):
|
||||
"""Test connection error during job submission."""
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post.side_effect = httpx.RequestError("Connection failed")
|
||||
|
||||
with patch.object(service, "_get_client", return_value=mock_client):
|
||||
result = await service._submit_and_poll("/api/test", {})
|
||||
|
||||
assert result["type"] == "error"
|
||||
assert result["error_type"] == "connection_error"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_job_id_in_submit_response(self):
|
||||
"""Test submit response missing job_id."""
|
||||
submit_resp = MagicMock()
|
||||
submit_resp.json.return_value = {"status": "accepted"} # no job_id
|
||||
submit_resp.raise_for_status = MagicMock()
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post.return_value = submit_resp
|
||||
|
||||
with patch.object(service, "_get_client", return_value=mock_client):
|
||||
result = await service._submit_and_poll("/api/test", {})
|
||||
|
||||
assert result["type"] == "error"
|
||||
assert result["error_type"] == "invalid_response"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_poll_retries_on_transient_network_error(self):
|
||||
"""Test that transient network errors during polling are retried."""
|
||||
submit_resp = MagicMock()
|
||||
submit_resp.json.return_value = {"job_id": "job-789"}
|
||||
submit_resp.raise_for_status = MagicMock()
|
||||
|
||||
ok_poll_resp = MagicMock()
|
||||
ok_poll_resp.json.return_value = {
|
||||
"job_id": "job-789",
|
||||
"status": "completed",
|
||||
"result": {"data": "ok"},
|
||||
}
|
||||
ok_poll_resp.raise_for_status = MagicMock()
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post.return_value = submit_resp
|
||||
# First poll fails with transient error, second succeeds
|
||||
mock_client.get.side_effect = [
|
||||
httpx.RequestError("transient"),
|
||||
ok_poll_resp,
|
||||
]
|
||||
|
||||
with (
|
||||
patch.object(service, "_get_client", return_value=mock_client),
|
||||
patch("asyncio.sleep", new_callable=AsyncMock),
|
||||
):
|
||||
result = await service._submit_and_poll("/api/test", {})
|
||||
|
||||
assert result == {"data": "ok"}
|
||||
assert mock_client.get.call_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_poll_returns_404_for_expired_job(self):
|
||||
"""Test that 404 during polling returns job_not_found error."""
|
||||
submit_resp = MagicMock()
|
||||
submit_resp.json.return_value = {"job_id": "job-expired"}
|
||||
submit_resp.raise_for_status = MagicMock()
|
||||
|
||||
mock_404_response = MagicMock()
|
||||
mock_404_response.status_code = 404
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post.return_value = submit_resp
|
||||
mock_client.get.side_effect = httpx.HTTPStatusError(
|
||||
"Not Found", request=MagicMock(), response=mock_404_response
|
||||
)
|
||||
|
||||
with (
|
||||
patch.object(service, "_get_client", return_value=mock_client),
|
||||
patch("asyncio.sleep", new_callable=AsyncMock),
|
||||
):
|
||||
result = await service._submit_and_poll("/api/test", {})
|
||||
|
||||
assert result["type"] == "error"
|
||||
assert result["error_type"] == "job_not_found"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_poll_retries_on_transient_http_status(self):
|
||||
"""Test that transient HTTP status codes (429, 503, etc.) are retried."""
|
||||
submit_resp = MagicMock()
|
||||
submit_resp.json.return_value = {"job_id": "job-transient"}
|
||||
submit_resp.raise_for_status = MagicMock()
|
||||
|
||||
mock_429_response = MagicMock()
|
||||
mock_429_response.status_code = 429
|
||||
|
||||
ok_poll_resp = MagicMock()
|
||||
ok_poll_resp.json.return_value = {
|
||||
"job_id": "job-transient",
|
||||
"status": "completed",
|
||||
"result": {"data": "recovered"},
|
||||
}
|
||||
ok_poll_resp.raise_for_status = MagicMock()
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post.return_value = submit_resp
|
||||
mock_client.get.side_effect = [
|
||||
httpx.HTTPStatusError(
|
||||
"Too Many Requests", request=MagicMock(), response=mock_429_response
|
||||
),
|
||||
ok_poll_resp,
|
||||
]
|
||||
|
||||
with (
|
||||
patch.object(service, "_get_client", return_value=mock_client),
|
||||
patch("asyncio.sleep", new_callable=AsyncMock),
|
||||
):
|
||||
result = await service._submit_and_poll("/api/test", {})
|
||||
|
||||
assert result == {"data": "recovered"}
|
||||
assert mock_client.get.call_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_poll_does_not_retry_non_transient_http_status(self):
|
||||
"""Test that non-transient HTTP status codes (e.g. 500) fail immediately."""
|
||||
submit_resp = MagicMock()
|
||||
submit_resp.json.return_value = {"job_id": "job-500"}
|
||||
submit_resp.raise_for_status = MagicMock()
|
||||
|
||||
mock_500_response = MagicMock()
|
||||
mock_500_response.status_code = 500
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post.return_value = submit_resp
|
||||
mock_client.get.side_effect = httpx.HTTPStatusError(
|
||||
"Internal Server Error", request=MagicMock(), response=mock_500_response
|
||||
)
|
||||
|
||||
with (
|
||||
patch.object(service, "_get_client", return_value=mock_client),
|
||||
patch("asyncio.sleep", new_callable=AsyncMock),
|
||||
):
|
||||
result = await service._submit_and_poll("/api/test", {})
|
||||
|
||||
assert result["type"] == "error"
|
||||
assert result["error_type"] == "http_error"
|
||||
assert mock_client.get.call_count == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_poll_timeout(self):
|
||||
"""Test that polling times out after MAX_POLL_TIME_SECONDS."""
|
||||
submit_resp = MagicMock()
|
||||
submit_resp.json.return_value = {"job_id": "job-slow"}
|
||||
submit_resp.raise_for_status = MagicMock()
|
||||
|
||||
running_resp = MagicMock()
|
||||
running_resp.json.return_value = {"job_id": "job-slow", "status": "running"}
|
||||
running_resp.raise_for_status = MagicMock()
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post.return_value = submit_resp
|
||||
mock_client.get.return_value = running_resp
|
||||
|
||||
# Simulate time passing: first call returns 0.0 (start), then jumps past limit
|
||||
monotonic_values = iter([0.0, 0.0, 100.0])
|
||||
|
||||
with (
|
||||
patch.object(service, "_get_client", return_value=mock_client),
|
||||
patch.object(service, "MAX_POLL_TIME_SECONDS", 50.0),
|
||||
patch.object(service, "POLL_INTERVAL_SECONDS", 0.01),
|
||||
patch("asyncio.sleep", new_callable=AsyncMock),
|
||||
patch("backend.copilot.tools.agent_generator.service.time") as mock_time,
|
||||
):
|
||||
mock_time.monotonic.side_effect = monotonic_values
|
||||
result = await service._submit_and_poll("/api/test", {})
|
||||
|
||||
assert result["type"] == "error"
|
||||
assert result["error_type"] == "timeout"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_poll_gives_up_after_consecutive_transient_errors(self):
|
||||
"""Test that polling gives up after MAX_CONSECUTIVE_POLL_ERRORS."""
|
||||
submit_resp = MagicMock()
|
||||
submit_resp.json.return_value = {"job_id": "job-flaky"}
|
||||
submit_resp.raise_for_status = MagicMock()
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post.return_value = submit_resp
|
||||
mock_client.get.side_effect = httpx.RequestError("network down")
|
||||
|
||||
# Ensure monotonic always returns 0 so timeout doesn't kick in
|
||||
with (
|
||||
patch.object(service, "_get_client", return_value=mock_client),
|
||||
patch.object(service, "MAX_POLL_TIME_SECONDS", 9999.0),
|
||||
patch.object(service, "POLL_INTERVAL_SECONDS", 0.01),
|
||||
patch("asyncio.sleep", new_callable=AsyncMock),
|
||||
patch("backend.copilot.tools.agent_generator.service.time") as mock_time,
|
||||
):
|
||||
mock_time.monotonic.return_value = 0.0
|
||||
result = await service._submit_and_poll("/api/test", {})
|
||||
|
||||
assert result["type"] == "error"
|
||||
assert result["error_type"] == "poll_error"
|
||||
assert mock_client.get.call_count == service.MAX_CONSECUTIVE_POLL_ERRORS
|
||||
|
||||
|
||||
class TestDecomposeGoalExternal:
|
||||
"""Test decompose_goal_external function."""
|
||||
|
||||
@@ -60,40 +346,37 @@ class TestDecomposeGoalExternal:
|
||||
@pytest.mark.asyncio
|
||||
async def test_decompose_goal_returns_instructions(self):
|
||||
"""Test successful decomposition returning instructions."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"success": True,
|
||||
"type": "instructions",
|
||||
"steps": ["Step 1", "Step 2"],
|
||||
}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post.return_value = mock_response
|
||||
|
||||
with patch.object(service, "_get_client", return_value=mock_client):
|
||||
with (
|
||||
patch.object(service, "_is_dummy_mode", return_value=False),
|
||||
patch.object(
|
||||
service, "_submit_and_poll", new_callable=AsyncMock
|
||||
) as mock_poll,
|
||||
):
|
||||
mock_poll.return_value = {
|
||||
"type": "instructions",
|
||||
"steps": ["Step 1", "Step 2"],
|
||||
}
|
||||
result = await service.decompose_goal_external("Build a chatbot")
|
||||
|
||||
assert result == {"type": "instructions", "steps": ["Step 1", "Step 2"]}
|
||||
mock_client.post.assert_called_once_with(
|
||||
"/api/decompose-description", json={"description": "Build a chatbot"}
|
||||
mock_poll.assert_called_once_with(
|
||||
"/api/decompose-description",
|
||||
{"description": "Build a chatbot"},
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decompose_goal_returns_clarifying_questions(self):
|
||||
"""Test decomposition returning clarifying questions."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"success": True,
|
||||
"type": "clarifying_questions",
|
||||
"questions": ["What platform?", "What language?"],
|
||||
}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post.return_value = mock_response
|
||||
|
||||
with patch.object(service, "_get_client", return_value=mock_client):
|
||||
with (
|
||||
patch.object(service, "_is_dummy_mode", return_value=False),
|
||||
patch.object(
|
||||
service, "_submit_and_poll", new_callable=AsyncMock
|
||||
) as mock_poll,
|
||||
):
|
||||
mock_poll.return_value = {
|
||||
"type": "clarifying_questions",
|
||||
"questions": ["What platform?", "What language?"],
|
||||
}
|
||||
result = await service.decompose_goal_external("Build something")
|
||||
|
||||
assert result == {
|
||||
@@ -104,18 +387,13 @@ class TestDecomposeGoalExternal:
|
||||
@pytest.mark.asyncio
|
||||
async def test_decompose_goal_with_context(self):
|
||||
"""Test decomposition with additional context enriched into description."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"success": True,
|
||||
"type": "instructions",
|
||||
"steps": ["Step 1"],
|
||||
}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post.return_value = mock_response
|
||||
|
||||
with patch.object(service, "_get_client", return_value=mock_client):
|
||||
with (
|
||||
patch.object(service, "_is_dummy_mode", return_value=False),
|
||||
patch.object(
|
||||
service, "_submit_and_poll", new_callable=AsyncMock
|
||||
) as mock_poll,
|
||||
):
|
||||
mock_poll.return_value = {"type": "instructions", "steps": ["Step 1"]}
|
||||
await service.decompose_goal_external(
|
||||
"Build a chatbot", context="Use Python"
|
||||
)
|
||||
@@ -123,27 +401,25 @@ class TestDecomposeGoalExternal:
|
||||
expected_description = (
|
||||
"Build a chatbot\n\nAdditional context from user:\nUse Python"
|
||||
)
|
||||
mock_client.post.assert_called_once_with(
|
||||
mock_poll.assert_called_once_with(
|
||||
"/api/decompose-description",
|
||||
json={"description": expected_description},
|
||||
{"description": expected_description},
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decompose_goal_returns_unachievable_goal(self):
|
||||
"""Test decomposition returning unachievable goal response."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"success": True,
|
||||
"type": "unachievable_goal",
|
||||
"reason": "Cannot do X",
|
||||
"suggested_goal": "Try Y instead",
|
||||
}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post.return_value = mock_response
|
||||
|
||||
with patch.object(service, "_get_client", return_value=mock_client):
|
||||
with (
|
||||
patch.object(service, "_is_dummy_mode", return_value=False),
|
||||
patch.object(
|
||||
service, "_submit_and_poll", new_callable=AsyncMock
|
||||
) as mock_poll,
|
||||
):
|
||||
mock_poll.return_value = {
|
||||
"type": "unachievable_goal",
|
||||
"reason": "Cannot do X",
|
||||
"suggested_goal": "Try Y instead",
|
||||
}
|
||||
result = await service.decompose_goal_external("Do something impossible")
|
||||
|
||||
assert result == {
|
||||
@@ -153,58 +429,40 @@ class TestDecomposeGoalExternal:
|
||||
}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decompose_goal_handles_http_error(self):
|
||||
"""Test decomposition handles HTTP errors gracefully."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 500
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post.side_effect = httpx.HTTPStatusError(
|
||||
"Server error", request=MagicMock(), response=mock_response
|
||||
)
|
||||
|
||||
with patch.object(service, "_get_client", return_value=mock_client):
|
||||
async def test_decompose_goal_handles_poll_error(self):
|
||||
"""Test that errors from _submit_and_poll are passed through."""
|
||||
with (
|
||||
patch.object(service, "_is_dummy_mode", return_value=False),
|
||||
patch.object(
|
||||
service, "_submit_and_poll", new_callable=AsyncMock
|
||||
) as mock_poll,
|
||||
):
|
||||
mock_poll.return_value = {
|
||||
"type": "error",
|
||||
"error": "HTTP error calling Agent Generator: Server error",
|
||||
"error_type": "http_error",
|
||||
}
|
||||
result = await service.decompose_goal_external("Build a chatbot")
|
||||
|
||||
assert result is not None
|
||||
assert result.get("type") == "error"
|
||||
assert result.get("error_type") == "http_error"
|
||||
assert "Server error" in result.get("error", "")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decompose_goal_handles_request_error(self):
|
||||
"""Test decomposition handles request errors gracefully."""
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post.side_effect = httpx.RequestError("Connection failed")
|
||||
|
||||
with patch.object(service, "_get_client", return_value=mock_client):
|
||||
async def test_decompose_goal_handles_unexpected_exception(self):
|
||||
"""Test that unexpected exceptions are caught and returned as errors."""
|
||||
with (
|
||||
patch.object(service, "_is_dummy_mode", return_value=False),
|
||||
patch.object(
|
||||
service, "_submit_and_poll", new_callable=AsyncMock
|
||||
) as mock_poll,
|
||||
):
|
||||
mock_poll.side_effect = RuntimeError("unexpected")
|
||||
result = await service.decompose_goal_external("Build a chatbot")
|
||||
|
||||
assert result is not None
|
||||
assert result.get("type") == "error"
|
||||
assert result.get("error_type") == "connection_error"
|
||||
assert "Connection failed" in result.get("error", "")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decompose_goal_handles_service_error(self):
|
||||
"""Test decomposition handles service returning error."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"success": False,
|
||||
"error": "Internal error",
|
||||
"error_type": "internal_error",
|
||||
}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post.return_value = mock_response
|
||||
|
||||
with patch.object(service, "_get_client", return_value=mock_client):
|
||||
result = await service.decompose_goal_external("Build a chatbot")
|
||||
|
||||
assert result is not None
|
||||
assert result.get("type") == "error"
|
||||
assert result.get("error") == "Internal error"
|
||||
assert result.get("error_type") == "internal_error"
|
||||
assert result.get("error_type") == "unexpected_error"
|
||||
|
||||
|
||||
class TestGenerateAgentExternal:
|
||||
@@ -223,39 +481,59 @@ class TestGenerateAgentExternal:
|
||||
"nodes": [],
|
||||
"links": [],
|
||||
}
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"success": True,
|
||||
"agent_json": agent_json,
|
||||
}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post.return_value = mock_response
|
||||
with (
|
||||
patch.object(service, "_is_dummy_mode", return_value=False),
|
||||
patch.object(
|
||||
service, "_submit_and_poll", new_callable=AsyncMock
|
||||
) as mock_poll,
|
||||
):
|
||||
mock_poll.return_value = {"success": True, "agent_json": agent_json}
|
||||
|
||||
instructions = {"type": "instructions", "steps": ["Step 1"]}
|
||||
|
||||
with patch.object(service, "_get_client", return_value=mock_client):
|
||||
instructions = {"type": "instructions", "steps": ["Step 1"]}
|
||||
result = await service.generate_agent_external(instructions)
|
||||
|
||||
assert result == agent_json
|
||||
mock_client.post.assert_called_once_with(
|
||||
"/api/generate-agent", json={"instructions": instructions}
|
||||
mock_poll.assert_called_once_with(
|
||||
"/api/generate-agent",
|
||||
{"instructions": instructions},
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_agent_handles_error(self):
|
||||
"""Test agent generation handles errors gracefully."""
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post.side_effect = httpx.RequestError("Connection failed")
|
||||
|
||||
with patch.object(service, "_get_client", return_value=mock_client):
|
||||
with (
|
||||
patch.object(service, "_is_dummy_mode", return_value=False),
|
||||
patch.object(
|
||||
service, "_submit_and_poll", new_callable=AsyncMock
|
||||
) as mock_poll,
|
||||
):
|
||||
mock_poll.return_value = {
|
||||
"type": "error",
|
||||
"error": "Connection failed",
|
||||
"error_type": "connection_error",
|
||||
}
|
||||
result = await service.generate_agent_external({"steps": []})
|
||||
|
||||
assert result is not None
|
||||
assert result.get("type") == "error"
|
||||
assert result.get("error_type") == "connection_error"
|
||||
assert "Connection failed" in result.get("error", "")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_agent_missing_agent_json(self):
|
||||
"""Test that missing agent_json in result returns an error."""
|
||||
with (
|
||||
patch.object(service, "_is_dummy_mode", return_value=False),
|
||||
patch.object(
|
||||
service, "_submit_and_poll", new_callable=AsyncMock
|
||||
) as mock_poll,
|
||||
):
|
||||
mock_poll.return_value = {"success": True}
|
||||
result = await service.generate_agent_external({"steps": ["Step 1"]})
|
||||
|
||||
assert result is not None
|
||||
assert result.get("type") == "error"
|
||||
assert result.get("error_type") == "invalid_response"
|
||||
|
||||
|
||||
class TestGenerateAgentPatchExternal:
|
||||
@@ -274,27 +552,24 @@ class TestGenerateAgentPatchExternal:
|
||||
"nodes": [{"id": "1", "block_id": "test"}],
|
||||
"links": [],
|
||||
}
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"success": True,
|
||||
"agent_json": updated_agent,
|
||||
}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post.return_value = mock_response
|
||||
with (
|
||||
patch.object(service, "_is_dummy_mode", return_value=False),
|
||||
patch.object(
|
||||
service, "_submit_and_poll", new_callable=AsyncMock
|
||||
) as mock_poll,
|
||||
):
|
||||
mock_poll.return_value = {"success": True, "agent_json": updated_agent}
|
||||
|
||||
current_agent = {"name": "Old Agent", "nodes": [], "links": []}
|
||||
|
||||
with patch.object(service, "_get_client", return_value=mock_client):
|
||||
current_agent = {"name": "Old Agent", "nodes": [], "links": []}
|
||||
result = await service.generate_agent_patch_external(
|
||||
"Add a new node", current_agent
|
||||
)
|
||||
|
||||
assert result == updated_agent
|
||||
mock_client.post.assert_called_once_with(
|
||||
mock_poll.assert_called_once_with(
|
||||
"/api/update-agent",
|
||||
json={
|
||||
{
|
||||
"update_request": "Add a new node",
|
||||
"current_agent_json": current_agent,
|
||||
},
|
||||
@@ -303,18 +578,16 @@ class TestGenerateAgentPatchExternal:
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_patch_returns_clarifying_questions(self):
|
||||
"""Test patch generation returning clarifying questions."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"success": True,
|
||||
"type": "clarifying_questions",
|
||||
"questions": ["What type of node?"],
|
||||
}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post.return_value = mock_response
|
||||
|
||||
with patch.object(service, "_get_client", return_value=mock_client):
|
||||
with (
|
||||
patch.object(service, "_is_dummy_mode", return_value=False),
|
||||
patch.object(
|
||||
service, "_submit_and_poll", new_callable=AsyncMock
|
||||
) as mock_poll,
|
||||
):
|
||||
mock_poll.return_value = {
|
||||
"type": "clarifying_questions",
|
||||
"questions": ["What type of node?"],
|
||||
}
|
||||
result = await service.generate_agent_patch_external(
|
||||
"Add something", {"nodes": []}
|
||||
)
|
||||
@@ -355,9 +628,12 @@ class TestHealthCheck:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get.return_value = mock_response
|
||||
|
||||
with patch.object(service, "is_external_service_configured", return_value=True):
|
||||
with patch.object(service, "_get_client", return_value=mock_client):
|
||||
result = await service.health_check()
|
||||
with (
|
||||
patch.object(service, "is_external_service_configured", return_value=True),
|
||||
patch.object(service, "_is_dummy_mode", return_value=False),
|
||||
patch.object(service, "_get_client", return_value=mock_client),
|
||||
):
|
||||
result = await service.health_check()
|
||||
|
||||
assert result is True
|
||||
mock_client.get.assert_called_once_with("/health")
|
||||
@@ -375,9 +651,12 @@ class TestHealthCheck:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get.return_value = mock_response
|
||||
|
||||
with patch.object(service, "is_external_service_configured", return_value=True):
|
||||
with patch.object(service, "_get_client", return_value=mock_client):
|
||||
result = await service.health_check()
|
||||
with (
|
||||
patch.object(service, "is_external_service_configured", return_value=True),
|
||||
patch.object(service, "_is_dummy_mode", return_value=False),
|
||||
patch.object(service, "_get_client", return_value=mock_client),
|
||||
):
|
||||
result = await service.health_check()
|
||||
|
||||
assert result is False
|
||||
|
||||
@@ -387,9 +666,12 @@ class TestHealthCheck:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get.side_effect = httpx.RequestError("Connection failed")
|
||||
|
||||
with patch.object(service, "is_external_service_configured", return_value=True):
|
||||
with patch.object(service, "_get_client", return_value=mock_client):
|
||||
result = await service.health_check()
|
||||
with (
|
||||
patch.object(service, "is_external_service_configured", return_value=True),
|
||||
patch.object(service, "_is_dummy_mode", return_value=False),
|
||||
patch.object(service, "_get_client", return_value=mock_client),
|
||||
):
|
||||
result = await service.health_check()
|
||||
|
||||
assert result is False
|
||||
|
||||
@@ -419,7 +701,10 @@ class TestGetBlocksExternal:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get.return_value = mock_response
|
||||
|
||||
with patch.object(service, "_get_client", return_value=mock_client):
|
||||
with (
|
||||
patch.object(service, "_is_dummy_mode", return_value=False),
|
||||
patch.object(service, "_get_client", return_value=mock_client),
|
||||
):
|
||||
result = await service.get_blocks_external()
|
||||
|
||||
assert result == blocks
|
||||
@@ -431,7 +716,10 @@ class TestGetBlocksExternal:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get.side_effect = httpx.RequestError("Connection failed")
|
||||
|
||||
with patch.object(service, "_get_client", return_value=mock_client):
|
||||
with (
|
||||
patch.object(service, "_is_dummy_mode", return_value=False),
|
||||
patch.object(service, "_get_client", return_value=mock_client),
|
||||
):
|
||||
result = await service.get_blocks_external()
|
||||
|
||||
assert result is None
|
||||
@@ -459,26 +747,22 @@ class TestLibraryAgentsPassthrough:
|
||||
},
|
||||
]
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"success": True,
|
||||
"type": "instructions",
|
||||
"steps": ["Step 1"],
|
||||
}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post.return_value = mock_response
|
||||
|
||||
with patch.object(service, "_get_client", return_value=mock_client):
|
||||
with (
|
||||
patch.object(service, "_is_dummy_mode", return_value=False),
|
||||
patch.object(
|
||||
service, "_submit_and_poll", new_callable=AsyncMock
|
||||
) as mock_poll,
|
||||
):
|
||||
mock_poll.return_value = {"type": "instructions", "steps": ["Step 1"]}
|
||||
await service.decompose_goal_external(
|
||||
"Send an email",
|
||||
library_agents=library_agents,
|
||||
)
|
||||
|
||||
# Verify library_agents was passed in the payload
|
||||
call_args = mock_client.post.call_args
|
||||
assert call_args[1]["json"]["library_agents"] == library_agents
|
||||
call_args = mock_poll.call_args
|
||||
payload = call_args[0][1]
|
||||
assert payload["library_agents"] == library_agents
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_agent_passes_library_agents(self):
|
||||
@@ -494,25 +778,24 @@ class TestLibraryAgentsPassthrough:
|
||||
},
|
||||
]
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"success": True,
|
||||
"agent_json": {"name": "Test Agent", "nodes": []},
|
||||
}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post.return_value = mock_response
|
||||
|
||||
with patch.object(service, "_get_client", return_value=mock_client):
|
||||
with (
|
||||
patch.object(service, "_is_dummy_mode", return_value=False),
|
||||
patch.object(
|
||||
service, "_submit_and_poll", new_callable=AsyncMock
|
||||
) as mock_poll,
|
||||
):
|
||||
mock_poll.return_value = {
|
||||
"agent_json": {"name": "Test Agent", "nodes": []},
|
||||
}
|
||||
await service.generate_agent_external(
|
||||
{"steps": ["Step 1"]},
|
||||
library_agents=library_agents,
|
||||
)
|
||||
|
||||
# Verify library_agents was passed in the payload
|
||||
call_args = mock_client.post.call_args
|
||||
assert call_args[1]["json"]["library_agents"] == library_agents
|
||||
call_args = mock_poll.call_args
|
||||
payload = call_args[0][1]
|
||||
assert payload["library_agents"] == library_agents
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_agent_patch_passes_library_agents(self):
|
||||
@@ -528,17 +811,15 @@ class TestLibraryAgentsPassthrough:
|
||||
},
|
||||
]
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"success": True,
|
||||
"agent_json": {"name": "Updated Agent", "nodes": []},
|
||||
}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post.return_value = mock_response
|
||||
|
||||
with patch.object(service, "_get_client", return_value=mock_client):
|
||||
with (
|
||||
patch.object(service, "_is_dummy_mode", return_value=False),
|
||||
patch.object(
|
||||
service, "_submit_and_poll", new_callable=AsyncMock
|
||||
) as mock_poll,
|
||||
):
|
||||
mock_poll.return_value = {
|
||||
"agent_json": {"name": "Updated Agent", "nodes": []},
|
||||
}
|
||||
await service.generate_agent_patch_external(
|
||||
"Add error handling",
|
||||
{"name": "Original Agent", "nodes": []},
|
||||
@@ -546,29 +827,26 @@ class TestLibraryAgentsPassthrough:
|
||||
)
|
||||
|
||||
# Verify library_agents was passed in the payload
|
||||
call_args = mock_client.post.call_args
|
||||
assert call_args[1]["json"]["library_agents"] == library_agents
|
||||
call_args = mock_poll.call_args
|
||||
payload = call_args[0][1]
|
||||
assert payload["library_agents"] == library_agents
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decompose_goal_without_library_agents(self):
|
||||
"""Test that decompose goal works without library_agents."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"success": True,
|
||||
"type": "instructions",
|
||||
"steps": ["Step 1"],
|
||||
}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post.return_value = mock_response
|
||||
|
||||
with patch.object(service, "_get_client", return_value=mock_client):
|
||||
with (
|
||||
patch.object(service, "_is_dummy_mode", return_value=False),
|
||||
patch.object(
|
||||
service, "_submit_and_poll", new_callable=AsyncMock
|
||||
) as mock_poll,
|
||||
):
|
||||
mock_poll.return_value = {"type": "instructions", "steps": ["Step 1"]}
|
||||
await service.decompose_goal_external("Build a workflow")
|
||||
|
||||
# Verify library_agents was NOT passed when not provided
|
||||
call_args = mock_client.post.call_args
|
||||
assert "library_agents" not in call_args[1]["json"]
|
||||
call_args = mock_poll.call_args
|
||||
payload = call_args[0][1]
|
||||
assert "library_agents" not in payload
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -243,6 +243,13 @@ export function useCopilotPage() {
|
||||
// next hydration fetches fresh messages from the backend. Without this,
|
||||
// staleTime: Infinity means the cache keeps the pre-stream data forever,
|
||||
// and any messages added during streaming are lost on remount/navigation.
|
||||
// Track status transitions for cache invalidation and auto-reconnect.
|
||||
// Auto-reconnect: GCP's L7 load balancer kills SSE connections at ~5 min.
|
||||
// When that happens the AI SDK goes "streaming" → "error". If the backend
|
||||
// executor is still running (hasActiveStream), we call resumeStream() to
|
||||
// reconnect via GET and replay from Redis.
|
||||
const MAX_RECONNECT_ATTEMPTS = 3;
|
||||
const reconnectAttemptsRef = useRef(0);
|
||||
const prevStatusRef = useRef(status);
|
||||
useEffect(() => {
|
||||
const prev = prevStatusRef.current;
|
||||
@@ -250,12 +257,44 @@ export function useCopilotPage() {
|
||||
|
||||
const wasActive = prev === "streaming" || prev === "submitted";
|
||||
const isIdle = status === "ready" || status === "error";
|
||||
|
||||
// Invalidate session cache when stream ends so hydration fetches fresh data
|
||||
if (wasActive && isIdle && sessionId) {
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: getGetV2GetSessionQueryKey(sessionId),
|
||||
});
|
||||
}
|
||||
}, [status, sessionId, queryClient]);
|
||||
|
||||
// Auto-reconnect on mid-stream SSE drop
|
||||
if (
|
||||
prev === "streaming" &&
|
||||
status === "error" &&
|
||||
sessionId &&
|
||||
hasActiveStream
|
||||
) {
|
||||
if (reconnectAttemptsRef.current < MAX_RECONNECT_ATTEMPTS) {
|
||||
reconnectAttemptsRef.current += 1;
|
||||
const attempt = reconnectAttemptsRef.current;
|
||||
console.info(
|
||||
`[copilot] SSE dropped mid-stream, reconnecting (attempt ${attempt}/${MAX_RECONNECT_ATTEMPTS})...`,
|
||||
);
|
||||
const timer = setTimeout(() => resumeStream(), 1_000);
|
||||
return () => clearTimeout(timer);
|
||||
} else {
|
||||
toast({
|
||||
title: "Connection lost",
|
||||
description:
|
||||
"Could not reconnect to the stream. Please refresh the page.",
|
||||
variant: "destructive",
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Reset reconnect counter when stream completes normally or resumes
|
||||
if (status === "ready" || status === "streaming") {
|
||||
reconnectAttemptsRef.current = 0;
|
||||
}
|
||||
}, [status, sessionId, hasActiveStream, queryClient, resumeStream]);
|
||||
|
||||
// Resume an active stream AFTER hydration completes.
|
||||
// IMPORTANT: Only runs when page loads with existing active stream (reconnection).
|
||||
|
||||
Reference in New Issue
Block a user