mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-31 09:58:19 -05:00
Compare commits
1 Commits
hotfix/cop
...
hotfix/ope
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8572b12d0a |
@@ -1883,127 +1883,84 @@ async def _generate_llm_continuation(
|
||||
This is called by background tasks to continue the conversation
|
||||
after a tool result is saved. The response is saved to the database
|
||||
so users see it when they refresh or poll.
|
||||
|
||||
Includes retry logic with exponential backoff for transient API errors.
|
||||
"""
|
||||
retry_count = 0
|
||||
max_retries = MAX_RETRIES
|
||||
last_error: Exception | None = None
|
||||
try:
|
||||
# Load fresh session from DB (bypass cache to get the updated tool result)
|
||||
await invalidate_session_cache(session_id)
|
||||
session = await get_chat_session(session_id, user_id)
|
||||
if not session:
|
||||
logger.error(f"Session {session_id} not found for LLM continuation")
|
||||
return
|
||||
|
||||
while retry_count <= max_retries:
|
||||
try:
|
||||
# Load fresh session from DB (bypass cache to get the updated tool result)
|
||||
await invalidate_session_cache(session_id)
|
||||
session = await get_chat_session(session_id, user_id)
|
||||
if not session:
|
||||
logger.error(f"Session {session_id} not found for LLM continuation")
|
||||
# Build system prompt
|
||||
system_prompt, _ = await _build_system_prompt(user_id)
|
||||
|
||||
# Build messages in OpenAI format
|
||||
messages = session.to_openai_messages()
|
||||
if system_prompt:
|
||||
from openai.types.chat import ChatCompletionSystemMessageParam
|
||||
|
||||
system_message = ChatCompletionSystemMessageParam(
|
||||
role="system",
|
||||
content=system_prompt,
|
||||
)
|
||||
messages = [system_message] + messages
|
||||
|
||||
# Build extra_body for tracing
|
||||
extra_body: dict[str, Any] = {
|
||||
"posthogProperties": {
|
||||
"environment": settings.config.app_env.value,
|
||||
},
|
||||
}
|
||||
if user_id:
|
||||
extra_body["user"] = user_id[:128]
|
||||
extra_body["posthogDistinctId"] = user_id
|
||||
if session_id:
|
||||
extra_body["session_id"] = session_id[:128]
|
||||
|
||||
# Make non-streaming LLM call (no tools - just text response)
|
||||
from typing import cast
|
||||
|
||||
from openai.types.chat import ChatCompletionMessageParam
|
||||
|
||||
# No tools parameter = text-only response (no tool calls)
|
||||
response = await client.chat.completions.create(
|
||||
model=config.model,
|
||||
messages=cast(list[ChatCompletionMessageParam], messages),
|
||||
extra_body=extra_body,
|
||||
)
|
||||
|
||||
if response.choices and response.choices[0].message.content:
|
||||
assistant_content = response.choices[0].message.content
|
||||
|
||||
# Reload session from DB to avoid race condition with user messages
|
||||
# that may have been sent while we were generating the LLM response
|
||||
fresh_session = await get_chat_session(session_id, user_id)
|
||||
if not fresh_session:
|
||||
logger.error(
|
||||
f"Session {session_id} disappeared during LLM continuation"
|
||||
)
|
||||
return
|
||||
|
||||
# Build system prompt
|
||||
system_prompt, _ = await _build_system_prompt(user_id)
|
||||
|
||||
# Build messages in OpenAI format
|
||||
messages = session.to_openai_messages()
|
||||
if system_prompt:
|
||||
from openai.types.chat import ChatCompletionSystemMessageParam
|
||||
|
||||
system_message = ChatCompletionSystemMessageParam(
|
||||
role="system",
|
||||
content=system_prompt,
|
||||
)
|
||||
messages = [system_message] + messages
|
||||
|
||||
# Build extra_body for tracing
|
||||
extra_body: dict[str, Any] = {
|
||||
"posthogProperties": {
|
||||
"environment": settings.config.app_env.value,
|
||||
},
|
||||
}
|
||||
if user_id:
|
||||
extra_body["user"] = user_id[:128]
|
||||
extra_body["posthogDistinctId"] = user_id
|
||||
if session_id:
|
||||
extra_body["session_id"] = session_id[:128]
|
||||
|
||||
from typing import cast
|
||||
|
||||
from openai.types.chat import ChatCompletionMessageParam
|
||||
|
||||
# Include tools with tool_choice="none" to allow the provider to validate
|
||||
# tool interactions in the message history without allowing new tool calls.
|
||||
# Some providers (especially Anthropic via OpenRouter) require the tools
|
||||
# schema to be present when the conversation contains tool_use/tool_result.
|
||||
response = await client.chat.completions.create(
|
||||
model=config.model,
|
||||
messages=cast(list[ChatCompletionMessageParam], messages),
|
||||
tools=tools,
|
||||
tool_choice="none",
|
||||
extra_body=extra_body,
|
||||
# Save assistant message to database
|
||||
assistant_message = ChatMessage(
|
||||
role="assistant",
|
||||
content=assistant_content,
|
||||
)
|
||||
fresh_session.messages.append(assistant_message)
|
||||
|
||||
if response.choices and response.choices[0].message.content:
|
||||
assistant_content = response.choices[0].message.content
|
||||
# Save to database (not cache) to persist the response
|
||||
await upsert_chat_session(fresh_session)
|
||||
|
||||
# Reload session from DB to avoid race condition with user messages
|
||||
# that may have been sent while we were generating the LLM response.
|
||||
# Invalidate cache first to ensure we get fresh data from the database.
|
||||
await invalidate_session_cache(session_id)
|
||||
fresh_session = await get_chat_session(session_id, user_id)
|
||||
if not fresh_session:
|
||||
logger.error(
|
||||
f"Session {session_id} disappeared during LLM continuation"
|
||||
)
|
||||
return
|
||||
# Invalidate cache so next poll/refresh gets fresh data
|
||||
await invalidate_session_cache(session_id)
|
||||
|
||||
# Save assistant message to database
|
||||
assistant_message = ChatMessage(
|
||||
role="assistant",
|
||||
content=assistant_content,
|
||||
)
|
||||
fresh_session.messages.append(assistant_message)
|
||||
logger.info(
|
||||
f"Generated LLM continuation for session {session_id}, "
|
||||
f"response length: {len(assistant_content)}"
|
||||
)
|
||||
else:
|
||||
logger.warning(f"LLM continuation returned empty response for {session_id}")
|
||||
|
||||
# Save to database (not cache) to persist the response
|
||||
await upsert_chat_session(fresh_session)
|
||||
|
||||
# Invalidate cache so next poll/refresh gets fresh data
|
||||
await invalidate_session_cache(session_id)
|
||||
|
||||
logger.info(
|
||||
f"Generated LLM continuation for session {session_id}, "
|
||||
f"response length: {len(assistant_content)}"
|
||||
)
|
||||
return # Success - exit the retry loop
|
||||
else:
|
||||
logger.warning(
|
||||
f"LLM continuation returned empty response for {session_id}"
|
||||
)
|
||||
return # Empty response is not retryable
|
||||
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
if _is_retryable_error(e) and retry_count < max_retries:
|
||||
retry_count += 1
|
||||
delay = min(
|
||||
BASE_DELAY_SECONDS * (2 ** (retry_count - 1)),
|
||||
MAX_DELAY_SECONDS,
|
||||
)
|
||||
logger.warning(
|
||||
f"Retryable error in LLM continuation for {session_id}: {e!s}. "
|
||||
f"Retrying in {delay:.1f}s (attempt {retry_count}/{max_retries})"
|
||||
)
|
||||
await asyncio.sleep(delay)
|
||||
continue
|
||||
else:
|
||||
# Non-retryable error or max retries exceeded
|
||||
logger.error(
|
||||
f"Failed to generate LLM continuation for {session_id}: {e!s}",
|
||||
exc_info=True,
|
||||
)
|
||||
break
|
||||
|
||||
# If we get here, all retries failed
|
||||
if last_error:
|
||||
logger.error(
|
||||
f"LLM continuation failed after {retry_count} retries for {session_id}. "
|
||||
f"Last error: {last_error!s}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate LLM continuation: {e}", exc_info=True)
|
||||
|
||||
@@ -1,13 +1,11 @@
|
||||
"""Shared agent search functionality for find_agent and find_library_agent tools."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Literal
|
||||
from typing import Any, Literal
|
||||
|
||||
from backend.api.features.library import db as library_db
|
||||
from backend.api.features.store import db as store_db
|
||||
from backend.data import graph as graph_db
|
||||
from backend.data.graph import GraphModel
|
||||
from backend.util.exceptions import DatabaseError, NotFoundError
|
||||
|
||||
from .models import (
|
||||
@@ -17,13 +15,50 @@ from .models import (
|
||||
NoResultsResponse,
|
||||
ToolResponseBase,
|
||||
)
|
||||
from .utils import fetch_graph_from_store_slug
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SearchSource = Literal["marketplace", "library"]
|
||||
|
||||
|
||||
async def _fetch_input_schema_for_store_agent(
|
||||
creator: str, slug: str
|
||||
) -> dict[str, Any] | None:
|
||||
"""Fetch input schema for a marketplace agent."""
|
||||
try:
|
||||
store_agent = await store_db.get_store_agent_details(creator, slug)
|
||||
graph_meta = await store_db.get_available_graph(
|
||||
store_agent.store_listing_version_id
|
||||
)
|
||||
graph = await graph_db.get_graph(
|
||||
graph_id=graph_meta.id,
|
||||
version=graph_meta.version,
|
||||
user_id=None,
|
||||
include_subgraphs=False,
|
||||
)
|
||||
return graph.input_schema if graph else None
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch input schema for {creator}/{slug}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def _fetch_input_schema_for_library_agent(
|
||||
graph_id: str, user_id: str
|
||||
) -> dict[str, Any] | None:
|
||||
"""Fetch input schema for a library agent."""
|
||||
try:
|
||||
graph = await graph_db.get_graph(
|
||||
graph_id=graph_id,
|
||||
version=None, # Get latest version
|
||||
user_id=user_id,
|
||||
include_subgraphs=False,
|
||||
)
|
||||
return graph.input_schema if graph else None
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch input schema for graph {graph_id}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def search_agents(
|
||||
query: str,
|
||||
source: SearchSource,
|
||||
@@ -58,28 +93,11 @@ async def search_agents(
|
||||
if source == "marketplace":
|
||||
logger.info(f"Searching marketplace for: {query}")
|
||||
results = await store_db.get_store_agents(search_query=query, page_size=5)
|
||||
|
||||
# Fetch all graphs in parallel for better performance
|
||||
async def fetch_marketplace_graph(
|
||||
creator: str, slug: str
|
||||
) -> GraphModel | None:
|
||||
try:
|
||||
graph, _ = await fetch_graph_from_store_slug(creator, slug)
|
||||
return graph
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to fetch input schema for {creator}/{slug}: {e}"
|
||||
)
|
||||
return None
|
||||
|
||||
graphs = await asyncio.gather(
|
||||
*(
|
||||
fetch_marketplace_graph(agent.creator, agent.slug)
|
||||
for agent in results.agents
|
||||
for agent in results.agents:
|
||||
# Fetch input schema for the agent
|
||||
input_schema = await _fetch_input_schema_for_store_agent(
|
||||
agent.creator, agent.slug
|
||||
)
|
||||
)
|
||||
|
||||
for agent, graph in zip(results.agents, graphs):
|
||||
agents.append(
|
||||
AgentInfo(
|
||||
id=f"{agent.creator}/{agent.slug}",
|
||||
@@ -92,7 +110,7 @@ async def search_agents(
|
||||
rating=agent.rating,
|
||||
runs=agent.runs,
|
||||
is_featured=False,
|
||||
inputs=graph.input_schema if graph else None,
|
||||
inputs=input_schema,
|
||||
)
|
||||
)
|
||||
else: # library
|
||||
@@ -102,32 +120,11 @@ async def search_agents(
|
||||
search_term=query,
|
||||
page_size=10,
|
||||
)
|
||||
|
||||
# Fetch all graphs in parallel for better performance
|
||||
# (list_library_agents doesn't include nodes for performance)
|
||||
async def fetch_library_graph(
|
||||
graph_id: str, graph_version: int
|
||||
) -> GraphModel | None:
|
||||
try:
|
||||
return await graph_db.get_graph(
|
||||
graph_id=graph_id,
|
||||
version=graph_version,
|
||||
user_id=user_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to fetch input schema for graph {graph_id}: {e}"
|
||||
)
|
||||
return None
|
||||
|
||||
graphs = await asyncio.gather(
|
||||
*(
|
||||
fetch_library_graph(agent.graph_id, agent.graph_version)
|
||||
for agent in results.agents
|
||||
for agent in results.agents:
|
||||
# Fetch input schema for the agent
|
||||
input_schema = await _fetch_input_schema_for_library_agent(
|
||||
agent.graph_id, user_id # type: ignore[arg-type]
|
||||
)
|
||||
)
|
||||
|
||||
for agent, graph in zip(results.agents, graphs):
|
||||
agents.append(
|
||||
AgentInfo(
|
||||
id=agent.id,
|
||||
@@ -141,7 +138,7 @@ async def search_agents(
|
||||
has_external_trigger=agent.has_external_trigger,
|
||||
new_output=agent.new_output,
|
||||
graph_id=agent.graph_id,
|
||||
inputs=graph.input_schema if graph else None,
|
||||
inputs=input_schema,
|
||||
)
|
||||
)
|
||||
logger.info(f"Found {len(agents)} agents in {source}")
|
||||
|
||||
@@ -32,8 +32,6 @@ class ResponseType(str, Enum):
|
||||
OPERATION_STARTED = "operation_started"
|
||||
OPERATION_PENDING = "operation_pending"
|
||||
OPERATION_IN_PROGRESS = "operation_in_progress"
|
||||
# Input validation
|
||||
INPUT_VALIDATION_ERROR = "input_validation_error"
|
||||
|
||||
|
||||
# Base response model
|
||||
@@ -194,20 +192,6 @@ class ErrorResponse(ToolResponseBase):
|
||||
details: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class InputValidationErrorResponse(ToolResponseBase):
|
||||
"""Response when run_agent receives unknown input fields."""
|
||||
|
||||
type: ResponseType = ResponseType.INPUT_VALIDATION_ERROR
|
||||
unrecognized_fields: list[str] = Field(
|
||||
description="List of input field names that were not recognized"
|
||||
)
|
||||
inputs: dict[str, Any] = Field(
|
||||
description="The agent's valid input schema for reference"
|
||||
)
|
||||
graph_id: str | None = None
|
||||
graph_version: int | None = None
|
||||
|
||||
|
||||
# Agent output models
|
||||
class ExecutionOutputInfo(BaseModel):
|
||||
"""Summary of a single execution's outputs."""
|
||||
|
||||
@@ -30,7 +30,6 @@ from .models import (
|
||||
ErrorResponse,
|
||||
ExecutionOptions,
|
||||
ExecutionStartedResponse,
|
||||
InputValidationErrorResponse,
|
||||
SetupInfo,
|
||||
SetupRequirementsResponse,
|
||||
ToolResponseBase,
|
||||
@@ -274,22 +273,6 @@ class RunAgentTool(BaseTool):
|
||||
input_properties = graph.input_schema.get("properties", {})
|
||||
required_fields = set(graph.input_schema.get("required", []))
|
||||
provided_inputs = set(params.inputs.keys())
|
||||
valid_fields = set(input_properties.keys())
|
||||
|
||||
# Check for unknown input fields
|
||||
unrecognized_fields = provided_inputs - valid_fields
|
||||
if unrecognized_fields:
|
||||
return InputValidationErrorResponse(
|
||||
message=(
|
||||
f"Unknown input field(s) provided: {', '.join(sorted(unrecognized_fields))}. "
|
||||
f"Agent was not executed. Please use the correct field names from the schema."
|
||||
),
|
||||
session_id=session_id,
|
||||
unrecognized_fields=sorted(unrecognized_fields),
|
||||
inputs=graph.input_schema,
|
||||
graph_id=graph.id,
|
||||
graph_version=graph.version,
|
||||
)
|
||||
|
||||
# If agent has inputs but none were provided AND use_defaults is not set,
|
||||
# always show what's available first so user can decide
|
||||
@@ -327,6 +310,26 @@ class RunAgentTool(BaseTool):
|
||||
graph_version=graph.version,
|
||||
)
|
||||
|
||||
# Check for unknown input fields - reject to prevent silent failures
|
||||
valid_fields = set(input_properties.keys())
|
||||
unknown_fields = provided_inputs - valid_fields
|
||||
if unknown_fields:
|
||||
credentials = extract_credentials_from_schema(
|
||||
graph.credentials_input_schema
|
||||
)
|
||||
return AgentDetailsResponse(
|
||||
message=(
|
||||
f"Unknown input field(s) provided: {', '.join(sorted(unknown_fields))}. "
|
||||
f"Agent was not executed. "
|
||||
f"Valid input fields are: {', '.join(sorted(valid_fields)) or 'none'}."
|
||||
),
|
||||
session_id=session_id,
|
||||
agent=self._build_agent_details(graph, credentials),
|
||||
user_authenticated=True,
|
||||
graph_id=graph.id,
|
||||
graph_version=graph.version,
|
||||
)
|
||||
|
||||
# Step 4: Execute or Schedule
|
||||
if is_schedule:
|
||||
return await self._schedule_agent(
|
||||
|
||||
@@ -402,42 +402,3 @@ async def test_run_agent_schedule_without_name(setup_test_data):
|
||||
# Should return error about missing schedule_name
|
||||
assert result_data.get("type") == "error"
|
||||
assert "schedule_name" in result_data["message"].lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_run_agent_rejects_unknown_input_fields(setup_test_data):
|
||||
"""Test that run_agent returns input_validation_error for unknown input fields."""
|
||||
user = setup_test_data["user"]
|
||||
store_submission = setup_test_data["store_submission"]
|
||||
|
||||
tool = RunAgentTool()
|
||||
agent_marketplace_id = f"{user.email.split('@')[0]}/{store_submission.slug}"
|
||||
session = make_session(user_id=user.id)
|
||||
|
||||
# Execute with unknown input field names
|
||||
response = await tool.execute(
|
||||
user_id=user.id,
|
||||
session_id=str(uuid.uuid4()),
|
||||
tool_call_id=str(uuid.uuid4()),
|
||||
username_agent_slug=agent_marketplace_id,
|
||||
inputs={
|
||||
"unknown_field": "some value",
|
||||
"another_unknown": "another value",
|
||||
},
|
||||
session=session,
|
||||
)
|
||||
|
||||
assert response is not None
|
||||
assert hasattr(response, "output")
|
||||
assert isinstance(response.output, str)
|
||||
result_data = orjson.loads(response.output)
|
||||
|
||||
# Should return input_validation_error type with unrecognized fields
|
||||
assert result_data.get("type") == "input_validation_error"
|
||||
assert "unrecognized_fields" in result_data
|
||||
assert set(result_data["unrecognized_fields"]) == {
|
||||
"another_unknown",
|
||||
"unknown_field",
|
||||
}
|
||||
assert "inputs" in result_data # Contains the valid schema
|
||||
assert "Agent was not executed" in result_data["message"]
|
||||
|
||||
@@ -4,8 +4,6 @@ import logging
|
||||
from collections import defaultdict
|
||||
from typing import Any
|
||||
|
||||
from pydantic_core import PydanticUndefined
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.data.block import get_block
|
||||
from backend.data.execution import ExecutionContext
|
||||
@@ -75,22 +73,15 @@ class RunBlockTool(BaseTool):
|
||||
self,
|
||||
user_id: str,
|
||||
block: Any,
|
||||
input_data: dict[str, Any] | None = None,
|
||||
) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]:
|
||||
"""
|
||||
Check if user has required credentials for a block.
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
block: Block to check credentials for
|
||||
input_data: Input data for the block (used to determine provider via discriminator)
|
||||
|
||||
Returns:
|
||||
tuple[matched_credentials, missing_credentials]
|
||||
"""
|
||||
matched_credentials: dict[str, CredentialsMetaInput] = {}
|
||||
missing_credentials: list[CredentialsMetaInput] = []
|
||||
input_data = input_data or {}
|
||||
|
||||
# Get credential field info from block's input schema
|
||||
credentials_fields_info = block.input_schema.get_credentials_fields_info()
|
||||
@@ -103,33 +94,14 @@ class RunBlockTool(BaseTool):
|
||||
available_creds = await creds_manager.store.get_all_creds(user_id)
|
||||
|
||||
for field_name, field_info in credentials_fields_info.items():
|
||||
effective_field_info = field_info
|
||||
if field_info.discriminator and field_info.discriminator_mapping:
|
||||
# Get discriminator from input, falling back to schema default
|
||||
discriminator_value = input_data.get(field_info.discriminator)
|
||||
if discriminator_value is None:
|
||||
field = block.input_schema.model_fields.get(
|
||||
field_info.discriminator
|
||||
)
|
||||
if field and field.default is not PydanticUndefined:
|
||||
discriminator_value = field.default
|
||||
|
||||
if (
|
||||
discriminator_value
|
||||
and discriminator_value in field_info.discriminator_mapping
|
||||
):
|
||||
effective_field_info = field_info.discriminate(discriminator_value)
|
||||
logger.debug(
|
||||
f"Discriminated provider for {field_name}: "
|
||||
f"{discriminator_value} -> {effective_field_info.provider}"
|
||||
)
|
||||
|
||||
# field_info.provider is a frozenset of acceptable providers
|
||||
# field_info.supported_types is a frozenset of acceptable types
|
||||
matching_cred = next(
|
||||
(
|
||||
cred
|
||||
for cred in available_creds
|
||||
if cred.provider in effective_field_info.provider
|
||||
and cred.type in effective_field_info.supported_types
|
||||
if cred.provider in field_info.provider
|
||||
and cred.type in field_info.supported_types
|
||||
),
|
||||
None,
|
||||
)
|
||||
@@ -143,8 +115,8 @@ class RunBlockTool(BaseTool):
|
||||
)
|
||||
else:
|
||||
# Create a placeholder for the missing credential
|
||||
provider = next(iter(effective_field_info.provider), "unknown")
|
||||
cred_type = next(iter(effective_field_info.supported_types), "api_key")
|
||||
provider = next(iter(field_info.provider), "unknown")
|
||||
cred_type = next(iter(field_info.supported_types), "api_key")
|
||||
missing_credentials.append(
|
||||
CredentialsMetaInput(
|
||||
id=field_name,
|
||||
@@ -212,9 +184,10 @@ class RunBlockTool(BaseTool):
|
||||
|
||||
logger.info(f"Executing block {block.name} ({block_id}) for user {user_id}")
|
||||
|
||||
# Check credentials
|
||||
creds_manager = IntegrationCredentialsManager()
|
||||
matched_credentials, missing_credentials = await self._check_block_credentials(
|
||||
user_id, block, input_data
|
||||
user_id, block
|
||||
)
|
||||
|
||||
if missing_credentials:
|
||||
|
||||
Reference in New Issue
Block a user