From 083cceca0fd85b3f24f30a8b942355e1ac188d88 Mon Sep 17 00:00:00 2001 From: Swifty Date: Thu, 29 Jan 2026 18:02:21 +0100 Subject: [PATCH] fixing edge cases --- .../backend/api/features/chat/routes.py | 47 ++- .../backend/api/features/chat/service.py | 29 +- .../api/features/chat/stream_registry.py | 284 ++++++++++++++---- .../chat/tools/agent_generator/core.py | 30 +- .../chat/tools/agent_generator/service.py | 62 +++- .../api/features/chat/tools/create_agent.py | 24 +- .../api/features/chat/tools/edit_agent.py | 25 +- .../backend/api/features/chat/tools/models.py | 17 ++ 8 files changed, 426 insertions(+), 92 deletions(-) diff --git a/autogpt_platform/backend/backend/api/features/chat/routes.py b/autogpt_platform/backend/backend/api/features/chat/routes.py index bc9611021f..3b98d7e542 100644 --- a/autogpt_platform/backend/backend/api/features/chat/routes.py +++ b/autogpt_platform/backend/backend/api/features/chat/routes.py @@ -15,7 +15,7 @@ from . import service as chat_service from . import stream_registry from .config import ChatConfig from .model import ChatSession, create_chat_session, get_chat_session, get_user_sessions -from .response_model import StreamFinish +from .response_model import StreamFinish, StreamHeartbeat config = ChatConfig() @@ -385,7 +385,10 @@ async def session_assign_user( async def stream_task( task_id: str, user_id: str | None = Depends(auth.get_user_id), - last_idx: int = Query(default=0, ge=0, description="Last message index received"), + last_message_id: str = Query( + default="0-0", + description="Last Redis Stream message ID received (e.g., '1706540123456-0'). Use '0-0' for full replay.", + ), ): """ Reconnect to a long-running task's SSE stream. @@ -397,10 +400,10 @@ async def stream_task( Args: task_id: The task ID from the operation_started response. user_id: Authenticated user ID for ownership validation. - last_idx: Last message index received (0 for full replay). + last_message_id: Last Redis Stream message ID received ("0-0" for full replay). Returns: - StreamingResponse: SSE-formatted response chunks starting from last_idx. + StreamingResponse: SSE-formatted response chunks starting after last_message_id. Raises: NotFoundError: If task_id is not found or user doesn't have access. @@ -409,30 +412,42 @@ async def stream_task( subscriber_queue = await stream_registry.subscribe_to_task( task_id=task_id, user_id=user_id, - last_idx=last_idx, + last_message_id=last_message_id, ) if subscriber_queue is None: raise NotFoundError(f"Task {task_id} not found or access denied.") async def event_generator() -> AsyncGenerator[str, None]: + import asyncio + chunk_count = 0 + heartbeat_interval = 15.0 # Send heartbeat every 15 seconds try: while True: - # Wait for next chunk from the queue - chunk = await subscriber_queue.get() - chunk_count += 1 - yield chunk.to_sse() - - # Check for finish signal - if isinstance(chunk, StreamFinish): - logger.info( - f"Task stream completed for task {task_id}, " - f"chunk_count={chunk_count}" + try: + # Wait for next chunk with timeout for heartbeats + chunk = await asyncio.wait_for( + subscriber_queue.get(), timeout=heartbeat_interval ) - break + chunk_count += 1 + yield chunk.to_sse() + + # Check for finish signal + if isinstance(chunk, StreamFinish): + logger.info( + f"Task stream completed for task {task_id}, " + f"chunk_count={chunk_count}" + ) + break + except asyncio.TimeoutError: + # Send heartbeat to keep connection alive + yield StreamHeartbeat().to_sse() except Exception as e: logger.error(f"Error in task stream {task_id}: {e}", exc_info=True) + finally: + # Unsubscribe when client disconnects or stream ends + await stream_registry.unsubscribe_from_task(task_id, subscriber_queue) # AI SDK protocol termination yield "data: [DONE]\n\n" diff --git a/autogpt_platform/backend/backend/api/features/chat/service.py b/autogpt_platform/backend/backend/api/features/chat/service.py index 09eb8e6093..9b5085038a 100644 --- a/autogpt_platform/backend/backend/api/features/chat/service.py +++ b/autogpt_platform/backend/backend/api/features/chat/service.py @@ -1877,6 +1877,9 @@ async def _execute_long_running_tool_with_streaming( This function runs independently of the SSE connection, publishes progress to the stream registry, and survives if the user closes their browser tab. Clients can reconnect via GET /chat/tasks/{task_id}/stream to resume streaming. + + If the external service returns a 202 Accepted (async), this function exits + early and lets the RabbitMQ completion consumer handle the rest. """ try: # Load fresh session (not stale reference) @@ -1886,15 +1889,39 @@ async def _execute_long_running_tool_with_streaming( await stream_registry.mark_task_completed(task_id, status="failed") return + # Pass operation_id and task_id to the tool for async processing + enriched_parameters = { + **parameters, + "_operation_id": operation_id, + "_task_id": task_id, + } + # Execute the actual tool result = await execute_tool( tool_name=tool_name, - parameters=parameters, + parameters=enriched_parameters, tool_call_id=tool_call_id, user_id=user_id, session=session, ) + # Check if the tool result indicates async processing + # (e.g., Agent Generator returned 202 Accepted) + try: + result_data = orjson.loads(result.output) if result.output else {} + if result_data.get("status") == "accepted": + logger.info( + f"Tool {tool_name} delegated to async processing " + f"(operation_id={operation_id}, task_id={task_id}). " + f"RabbitMQ completion consumer will handle the rest." + ) + # Don't publish result, don't continue with LLM + # The RabbitMQ consumer will handle everything when the external + # service completes and publishes to the queue + return + except (orjson.JSONDecodeError, TypeError): + pass # Not JSON or not async - continue normally + # Publish tool result to stream registry await stream_registry.publish_chunk(task_id, result) diff --git a/autogpt_platform/backend/backend/api/features/chat/stream_registry.py b/autogpt_platform/backend/backend/api/features/chat/stream_registry.py index 72d4488c0e..b39861bb7f 100644 --- a/autogpt_platform/backend/backend/api/features/chat/stream_registry.py +++ b/autogpt_platform/backend/backend/api/features/chat/stream_registry.py @@ -6,6 +6,7 @@ messages. It supports: - Publishing stream messages to both Redis Streams and in-memory queues - Subscribing to tasks with replay of missed messages - Looking up tasks by operation_id for webhook callbacks +- Cross-pod real-time delivery via Redis pub/sub """ import asyncio @@ -24,6 +25,9 @@ from .response_model import StreamBaseResponse, StreamFinish logger = logging.getLogger(__name__) config = ChatConfig() +# Track active pub/sub listeners for cross-pod delivery +_pubsub_listeners: dict[str, asyncio.Task] = {} + @dataclass class ActiveTask: @@ -39,6 +43,10 @@ class ActiveTask: created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) queue: asyncio.Queue[StreamBaseResponse] = field(default_factory=asyncio.Queue) asyncio_task: asyncio.Task | None = None + # Lock for atomic status checks and subscriber management + lock: asyncio.Lock = field(default_factory=asyncio.Lock) + # Set of subscriber queues for fan-out + subscribers: set[asyncio.Queue[StreamBaseResponse]] = field(default_factory=set) # Module-level registry for active tasks @@ -48,6 +56,7 @@ _active_tasks: dict[str, ActiveTask] = {} TASK_META_PREFIX = "chat:task:meta:" # Hash for task metadata TASK_STREAM_PREFIX = "chat:stream:" # Redis Stream for messages TASK_OP_PREFIX = "chat:task:op:" # Operation ID -> task_id mapping +TASK_PUBSUB_PREFIX = "chat:task:pubsub:" # Pub/sub channel for cross-pod delivery def _get_task_meta_key(task_id: str) -> str: @@ -65,6 +74,11 @@ def _get_operation_mapping_key(operation_id: str) -> str: return f"{TASK_OP_PREFIX}{operation_id}" +def _get_task_pubsub_channel(task_id: str) -> str: + """Get Redis pub/sub channel for task cross-pod delivery.""" + return f"{TASK_PUBSUB_PREFIX}{task_id}" + + async def create_task( task_id: str, session_id: str, @@ -132,58 +146,74 @@ async def create_task( async def publish_chunk( task_id: str, chunk: StreamBaseResponse, -) -> int: +) -> str: """Publish a chunk to the task's stream. - Writes to both Redis Stream (for replay) and in-memory queue (for live subscribers). + Delivers to in-memory subscribers first (for real-time), then persists to + Redis Stream (for replay). This order ensures live subscribers get messages + even if Redis temporarily fails. Args: task_id: Task ID to publish to chunk: The stream response chunk to publish Returns: - The message index in the Redis Stream + The Redis Stream message ID (format: "timestamp-sequence"), or "0-0" if + Redis persistence failed """ - redis = await get_redis_async() - stream_key = _get_task_stream_key(task_id) - - # Serialize chunk to JSON - chunk_json = chunk.model_dump_json() - - # Add to Redis Stream with auto-generated ID - # The ID format is "timestamp-sequence" which gives us ordering - message_id = await redis.xadd( - stream_key, - {"data": chunk_json}, - maxlen=config.stream_max_length, - ) - - # Publish to in-memory queue if task exists + # Deliver to in-memory subscribers FIRST for real-time updates task = _active_tasks.get(task_id) if task: - try: - task.queue.put_nowait(chunk) - except asyncio.QueueFull: - logger.warning(f"Queue full for task {task_id}, dropping chunk") + async with task.lock: + for subscriber_queue in task.subscribers: + try: + subscriber_queue.put_nowait(chunk) + except asyncio.QueueFull: + logger.warning( + f"Subscriber queue full for task {task_id}, dropping chunk" + ) - logger.debug(f"Published chunk to task {task_id}, message_id={message_id}") + # Then persist to Redis Stream for replay (with error handling) + message_id = "0-0" + chunk_json = chunk.model_dump_json() + try: + redis = await get_redis_async() + stream_key = _get_task_stream_key(task_id) - # Parse the message_id to extract the index - # Redis Stream IDs are "timestamp-sequence", we return the raw ID - return int(message_id.split("-")[1]) if "-" in message_id else 0 + # Add to Redis Stream with auto-generated ID + # The ID format is "timestamp-sequence" which gives us ordering + raw_id = await redis.xadd( + stream_key, + {"data": chunk_json}, + maxlen=config.stream_max_length, + ) + message_id = raw_id if isinstance(raw_id, str) else raw_id.decode() + + # Publish to pub/sub for cross-pod real-time delivery + pubsub_channel = _get_task_pubsub_channel(task_id) + await redis.publish(pubsub_channel, chunk_json) + + logger.debug(f"Published chunk to task {task_id}, message_id={message_id}") + except Exception as e: + logger.error( + f"Failed to persist chunk to Redis for task {task_id}: {e}", + exc_info=True, + ) + + return message_id async def subscribe_to_task( task_id: str, user_id: str | None, - last_idx: int = 0, + last_message_id: str = "0-0", ) -> asyncio.Queue[StreamBaseResponse] | None: """Subscribe to a task's stream with replay of missed messages. Args: task_id: Task ID to subscribe to user_id: User ID for ownership validation - last_idx: Last message index received (0 for full replay) + last_message_id: Last Redis Stream message ID received ("0-0" for full replay) Returns: An asyncio Queue that will receive stream chunks, or None if task not found @@ -208,15 +238,21 @@ async def subscribe_to_task( redis = await get_redis_async() stream_key = _get_task_stream_key(task_id) - # Read all messages from stream - # Use "0-0" to get all messages or construct ID from last_idx - start_id = "0-0" if last_idx == 0 else f"0-{last_idx}" - messages = await redis.xread({stream_key: start_id}, block=0, count=1000) + # Track the last message ID we've seen for gap detection + replay_last_id = last_message_id + + # Read all messages from stream starting after last_message_id + # xread returns messages with ID > last_message_id + messages = await redis.xread({stream_key: last_message_id}, block=0, count=1000) if messages: # messages format: [[stream_name, [(id, {data: json}), ...]]] for _stream_name, stream_messages in messages: - for _msg_id, msg_data in stream_messages: + for msg_id, msg_data in stream_messages: + # Track the last message ID we've processed + replay_last_id = ( + msg_id if isinstance(msg_id, str) else msg_id.decode() + ) if b"data" in msg_data: try: chunk_data = orjson.loads(msg_data[b"data"]) @@ -227,23 +263,44 @@ async def subscribe_to_task( except Exception as e: logger.warning(f"Failed to replay message: {e}") - # If task is still running, set up live subscription - if task.status == "running": - # Forward messages from task queue to subscriber queue - async def _forward_messages(): - try: - while True: - chunk = await task.queue.get() - await subscriber_queue.put(chunk) - if isinstance(chunk, StreamFinish): - break - except asyncio.CancelledError: - pass + # Atomically check status and register subscriber under lock + # This prevents race condition where task completes between check and subscribe + should_start_pubsub = False + async with task.lock: + if task.status == "running": + # Register this subscriber for live updates + task.subscribers.add(subscriber_queue) + # Start pub/sub listener if this is the first subscriber + should_start_pubsub = len(task.subscribers) == 1 + logger.debug( + f"Registered subscriber for task {task_id}, " + f"total subscribers: {len(task.subscribers)}" + ) + else: + # Task is done, add finish marker + await subscriber_queue.put(StreamFinish()) - asyncio.create_task(_forward_messages()) - else: - # Task is done, add finish marker - await subscriber_queue.put(StreamFinish()) + # After registering, do a second read to catch any messages published + # between the first read and registration (closes the race window) + if task.status == "running": + gap_messages = await redis.xread( + {stream_key: replay_last_id}, block=0, count=1000 + ) + if gap_messages: + for _stream_name, stream_messages in gap_messages: + for _msg_id, msg_data in stream_messages: + if b"data" in msg_data: + try: + chunk_data = orjson.loads(msg_data[b"data"]) + chunk = _reconstruct_chunk(chunk_data) + if chunk: + await subscriber_queue.put(chunk) + except Exception as e: + logger.warning(f"Failed to replay gap message: {e}") + + # Start pub/sub listener outside the lock to avoid deadlocks + if should_start_pubsub: + await start_pubsub_listener(task_id) return subscriber_queue @@ -269,8 +326,8 @@ async def subscribe_to_task( subscriber_queue = asyncio.Queue() stream_key = _get_task_stream_key(task_id) - start_id = "0-0" if last_idx == 0 else f"0-{last_idx}" - messages = await redis.xread({stream_key: start_id}, block=0, count=1000) + # Read all messages starting after last_message_id + messages = await redis.xread({stream_key: last_message_id}, block=0, count=1000) if messages: for _stream_name, stream_messages in messages: @@ -303,8 +360,25 @@ async def mark_task_completed( task = _active_tasks.get(task_id) if task: - task.status = status - # Publish finish event to all subscribers + # Acquire lock to prevent new subscribers during completion + async with task.lock: + task.status = status + # Send finish event directly to all current subscribers + finish_event = StreamFinish() + for subscriber_queue in task.subscribers: + try: + subscriber_queue.put_nowait(finish_event) + except asyncio.QueueFull: + logger.warning( + f"Subscriber queue full for task {task_id} during completion" + ) + # Clear subscribers since task is done + task.subscribers.clear() + + # Stop pub/sub listener since task is done + await stop_pubsub_listener(task_id) + + # Also publish to Redis Stream for replay (and pub/sub for cross-pod) await publish_chunk(task_id, StreamFinish()) # Remove from active tasks after a short delay to allow subscribers to finish @@ -468,3 +542,107 @@ async def set_task_asyncio_task(task_id: str, asyncio_task: asyncio.Task) -> Non task = _active_tasks.get(task_id) if task: task.asyncio_task = asyncio_task + + +async def unsubscribe_from_task( + task_id: str, + subscriber_queue: asyncio.Queue[StreamBaseResponse], +) -> None: + """Unsubscribe a queue from a task's stream. + + Should be called when a client disconnects to clean up resources. + Also stops the pub/sub listener if there are no more local subscribers. + + Args: + task_id: Task ID to unsubscribe from + subscriber_queue: The queue to remove from subscribers + """ + task = _active_tasks.get(task_id) + if task: + async with task.lock: + task.subscribers.discard(subscriber_queue) + remaining = len(task.subscribers) + logger.debug( + f"Unsubscribed from task {task_id}, " + f"remaining subscribers: {remaining}" + ) + # Stop pub/sub listener if no more local subscribers + if remaining == 0: + await stop_pubsub_listener(task_id) + + +async def start_pubsub_listener(task_id: str) -> None: + """Start listening to Redis pub/sub for cross-pod delivery. + + This enables real-time updates when another pod publishes chunks for a task + that has local subscribers on this pod. + + Args: + task_id: Task ID to listen for + """ + if task_id in _pubsub_listeners: + return # Already listening + + task = _active_tasks.get(task_id) + if not task: + return + + async def _listener(): + try: + redis = await get_redis_async() + pubsub = redis.pubsub() + channel = _get_task_pubsub_channel(task_id) + await pubsub.subscribe(channel) + logger.debug(f"Started pub/sub listener for task {task_id}") + + async for message in pubsub.listen(): + if message["type"] != "message": + continue + + try: + chunk_data = orjson.loads(message["data"]) + chunk = _reconstruct_chunk(chunk_data) + if chunk: + # Deliver to local subscribers + local_task = _active_tasks.get(task_id) + if local_task: + async with local_task.lock: + for queue in local_task.subscribers: + try: + queue.put_nowait(chunk) + except asyncio.QueueFull: + pass + # Stop listening if this was a finish event + if isinstance(chunk, StreamFinish): + break + except Exception as e: + logger.warning(f"Error processing pub/sub message: {e}") + + await pubsub.unsubscribe(channel) + await pubsub.close() + except asyncio.CancelledError: + pass + except Exception as e: + logger.error(f"Pub/sub listener error for task {task_id}: {e}") + finally: + _pubsub_listeners.pop(task_id, None) + logger.debug(f"Stopped pub/sub listener for task {task_id}") + + listener_task = asyncio.create_task(_listener()) + _pubsub_listeners[task_id] = listener_task + + +async def stop_pubsub_listener(task_id: str) -> None: + """Stop the pub/sub listener for a task. + + Args: + task_id: Task ID to stop listening for + """ + listener = _pubsub_listeners.pop(task_id, None) + if listener and not listener.done(): + listener.cancel() + try: + await listener + except asyncio.CancelledError: + pass + logger.debug(f"Cancelled pub/sub listener for task {task_id}") diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/core.py b/autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/core.py index fc15587110..653426cbac 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/core.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/core.py @@ -57,21 +57,32 @@ async def decompose_goal(description: str, context: str = "") -> dict[str, Any] return await decompose_goal_external(description, context) -async def generate_agent(instructions: dict[str, Any]) -> dict[str, Any] | None: +async def generate_agent( + instructions: dict[str, Any], + operation_id: str | None = None, + task_id: str | None = None, +) -> dict[str, Any] | None: """Generate agent JSON from instructions. Args: instructions: Structured instructions from decompose_goal + operation_id: Operation ID for async processing (enables RabbitMQ callback) + task_id: Task ID for async processing (enables RabbitMQ callback) Returns: - Agent JSON dict or None on error + Agent JSON dict, {"status": "accepted"} for async, or None on error Raises: AgentGeneratorNotConfiguredError: If the external service is not configured. """ _check_service_configured() logger.info("Calling external Agent Generator service for generate_agent") - result = await generate_agent_external(instructions) + result = await generate_agent_external(instructions, operation_id, task_id) + + # Don't modify async response + if result and result.get("status") == "accepted": + return result + if result: # Ensure required fields if "id" not in result: @@ -253,7 +264,10 @@ async def get_agent_as_json( async def generate_agent_patch( - update_request: str, current_agent: dict[str, Any] + update_request: str, + current_agent: dict[str, Any], + operation_id: str | None = None, + task_id: str | None = None, ) -> dict[str, Any] | None: """Update an existing agent using natural language. @@ -265,13 +279,17 @@ async def generate_agent_patch( Args: update_request: Natural language description of changes current_agent: Current agent JSON + operation_id: Operation ID for async processing (enables RabbitMQ callback) + task_id: Task ID for async processing (enables RabbitMQ callback) Returns: - Updated agent JSON, clarifying questions dict, or None on error + Updated agent JSON, clarifying questions dict, {"status": "accepted"} for async, or None on error Raises: AgentGeneratorNotConfiguredError: If the external service is not configured. """ _check_service_configured() logger.info("Calling external Agent Generator service for generate_agent_patch") - return await generate_agent_patch_external(update_request, current_agent) + return await generate_agent_patch_external( + update_request, current_agent, operation_id, task_id + ) diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/service.py b/autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/service.py index a4d2f1af15..edad1e8e4e 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/service.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/service.py @@ -124,22 +124,39 @@ async def decompose_goal_external( async def generate_agent_external( - instructions: dict[str, Any] + instructions: dict[str, Any], + operation_id: str | None = None, + task_id: str | None = None, ) -> dict[str, Any] | None: """Call the external service to generate an agent from instructions. Args: instructions: Structured instructions from decompose_goal + operation_id: Operation ID for async processing (enables RabbitMQ callback) + task_id: Task ID for async processing (enables RabbitMQ callback) Returns: - Agent JSON dict or None on error + Agent JSON dict, or {"status": "accepted"} for async, or None on error """ client = _get_client() + # Build request payload + payload: dict[str, Any] = {"instructions": instructions} + if operation_id and task_id: + payload["operation_id"] = operation_id + payload["task_id"] = task_id + try: - response = await client.post( - "/api/generate-agent", json={"instructions": instructions} - ) + response = await client.post("/api/generate-agent", json=payload) + + # Handle 202 Accepted for async processing + if response.status_code == 202: + logger.info( + f"Agent Generator accepted async request " + f"(operation_id={operation_id}, task_id={task_id})" + ) + return {"status": "accepted", "operation_id": operation_id, "task_id": task_id} + response.raise_for_status() data = response.json() @@ -161,27 +178,44 @@ async def generate_agent_external( async def generate_agent_patch_external( - update_request: str, current_agent: dict[str, Any] + update_request: str, + current_agent: dict[str, Any], + operation_id: str | None = None, + task_id: str | None = None, ) -> dict[str, Any] | None: """Call the external service to generate a patch for an existing agent. Args: update_request: Natural language description of changes current_agent: Current agent JSON + operation_id: Operation ID for async processing (enables RabbitMQ callback) + task_id: Task ID for async processing (enables RabbitMQ callback) Returns: - Updated agent JSON, clarifying questions dict, or None on error + Updated agent JSON, clarifying questions dict, {"status": "accepted"} for async, or None on error """ client = _get_client() + # Build request payload + payload: dict[str, Any] = { + "update_request": update_request, + "current_agent_json": current_agent, + } + if operation_id and task_id: + payload["operation_id"] = operation_id + payload["task_id"] = task_id + try: - response = await client.post( - "/api/update-agent", - json={ - "update_request": update_request, - "current_agent_json": current_agent, - }, - ) + response = await client.post("/api/update-agent", json=payload) + + # Handle 202 Accepted for async processing + if response.status_code == 202: + logger.info( + f"Agent Generator accepted async update request " + f"(operation_id={operation_id}, task_id={task_id})" + ) + return {"status": "accepted", "operation_id": operation_id, "task_id": task_id} + response.raise_for_status() data = response.json() diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/create_agent.py b/autogpt_platform/backend/backend/api/features/chat/tools/create_agent.py index 6b3784e323..c19ff83ecd 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/create_agent.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/create_agent.py @@ -15,6 +15,7 @@ from .base import BaseTool from .models import ( AgentPreviewResponse, AgentSavedResponse, + AsyncProcessingResponse, ClarificationNeededResponse, ClarifyingQuestion, ErrorResponse, @@ -95,6 +96,10 @@ class CreateAgentTool(BaseTool): save = kwargs.get("save", True) session_id = session.session_id if session else None + # Extract async processing params (passed by long-running tool handler) + operation_id = kwargs.get("_operation_id") + task_id = kwargs.get("_task_id") + if not description: return ErrorResponse( message="Please provide a description of what the agent should do.", @@ -173,7 +178,11 @@ class CreateAgentTool(BaseTool): # Step 2: Generate agent JSON (external service handles fixing and validation) try: - agent_json = await generate_agent(decomposition_result) + agent_json = await generate_agent( + decomposition_result, + operation_id=operation_id, + task_id=task_id, + ) except AgentGeneratorNotConfiguredError: return ErrorResponse( message=( @@ -194,6 +203,19 @@ class CreateAgentTool(BaseTool): session_id=session_id, ) + # Check if Agent Generator accepted for async processing + if agent_json.get("status") == "accepted": + logger.info( + f"Agent generation delegated to async processing " + f"(operation_id={operation_id}, task_id={task_id})" + ) + return AsyncProcessingResponse( + message="Agent generation started. You'll be notified when it's complete.", + operation_id=operation_id, + task_id=task_id, + session_id=session_id, + ) + agent_name = agent_json.get("name", "Generated Agent") agent_description = agent_json.get("description", "") node_count = len(agent_json.get("nodes", [])) diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/edit_agent.py b/autogpt_platform/backend/backend/api/features/chat/tools/edit_agent.py index 7c4da8ad43..b006e04b41 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/edit_agent.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/edit_agent.py @@ -15,6 +15,7 @@ from .base import BaseTool from .models import ( AgentPreviewResponse, AgentSavedResponse, + AsyncProcessingResponse, ClarificationNeededResponse, ClarifyingQuestion, ErrorResponse, @@ -102,6 +103,10 @@ class EditAgentTool(BaseTool): save = kwargs.get("save", True) session_id = session.session_id if session else None + # Extract async processing params (passed by long-running tool handler) + operation_id = kwargs.get("_operation_id") + task_id = kwargs.get("_task_id") + if not agent_id: return ErrorResponse( message="Please provide the agent ID to edit.", @@ -133,7 +138,12 @@ class EditAgentTool(BaseTool): # Step 2: Generate updated agent (external service handles fixing and validation) try: - result = await generate_agent_patch(update_request, current_agent) + result = await generate_agent_patch( + update_request, + current_agent, + operation_id=operation_id, + task_id=task_id, + ) except AgentGeneratorNotConfiguredError: return ErrorResponse( message=( @@ -152,6 +162,19 @@ class EditAgentTool(BaseTool): session_id=session_id, ) + # Check if Agent Generator accepted for async processing + if result.get("status") == "accepted": + logger.info( + f"Agent edit delegated to async processing " + f"(operation_id={operation_id}, task_id={task_id})" + ) + return AsyncProcessingResponse( + message="Agent edit started. You'll be notified when it's complete.", + operation_id=operation_id, + task_id=task_id, + session_id=session_id, + ) + # Check if LLM returned clarifying questions if result.get("type") == "clarifying_questions": questions = result.get("questions", []) diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/models.py b/autogpt_platform/backend/backend/api/features/chat/tools/models.py index e907042ba3..d14bf5a2c1 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/models.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/models.py @@ -384,3 +384,20 @@ class OperationInProgressResponse(ToolResponseBase): type: ResponseType = ResponseType.OPERATION_IN_PROGRESS tool_call_id: str + + +class AsyncProcessingResponse(ToolResponseBase): + """Response when an operation has been delegated to async processing. + + This is returned by tools when the external service accepts the request + for async processing (HTTP 202 Accepted). The RabbitMQ completion consumer + will handle the result when the external service completes. + + The status field is specifically "accepted" to allow the long-running tool + handler to detect this response and skip LLM continuation. + """ + + type: ResponseType = ResponseType.OPERATION_STARTED + status: str = "accepted" # Must be "accepted" for detection + operation_id: str | None = None + task_id: str | None = None