mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-02-04 20:05:11 -05:00
fixing edge cases
This commit is contained in:
@@ -15,7 +15,7 @@ from . import service as chat_service
|
||||
from . import stream_registry
|
||||
from .config import ChatConfig
|
||||
from .model import ChatSession, create_chat_session, get_chat_session, get_user_sessions
|
||||
from .response_model import StreamFinish
|
||||
from .response_model import StreamFinish, StreamHeartbeat
|
||||
|
||||
config = ChatConfig()
|
||||
|
||||
@@ -385,7 +385,10 @@ async def session_assign_user(
|
||||
async def stream_task(
|
||||
task_id: str,
|
||||
user_id: str | None = Depends(auth.get_user_id),
|
||||
last_idx: int = Query(default=0, ge=0, description="Last message index received"),
|
||||
last_message_id: str = Query(
|
||||
default="0-0",
|
||||
description="Last Redis Stream message ID received (e.g., '1706540123456-0'). Use '0-0' for full replay.",
|
||||
),
|
||||
):
|
||||
"""
|
||||
Reconnect to a long-running task's SSE stream.
|
||||
@@ -397,10 +400,10 @@ async def stream_task(
|
||||
Args:
|
||||
task_id: The task ID from the operation_started response.
|
||||
user_id: Authenticated user ID for ownership validation.
|
||||
last_idx: Last message index received (0 for full replay).
|
||||
last_message_id: Last Redis Stream message ID received ("0-0" for full replay).
|
||||
|
||||
Returns:
|
||||
StreamingResponse: SSE-formatted response chunks starting from last_idx.
|
||||
StreamingResponse: SSE-formatted response chunks starting after last_message_id.
|
||||
|
||||
Raises:
|
||||
NotFoundError: If task_id is not found or user doesn't have access.
|
||||
@@ -409,30 +412,42 @@ async def stream_task(
|
||||
subscriber_queue = await stream_registry.subscribe_to_task(
|
||||
task_id=task_id,
|
||||
user_id=user_id,
|
||||
last_idx=last_idx,
|
||||
last_message_id=last_message_id,
|
||||
)
|
||||
|
||||
if subscriber_queue is None:
|
||||
raise NotFoundError(f"Task {task_id} not found or access denied.")
|
||||
|
||||
async def event_generator() -> AsyncGenerator[str, None]:
|
||||
import asyncio
|
||||
|
||||
chunk_count = 0
|
||||
heartbeat_interval = 15.0 # Send heartbeat every 15 seconds
|
||||
try:
|
||||
while True:
|
||||
# Wait for next chunk from the queue
|
||||
chunk = await subscriber_queue.get()
|
||||
chunk_count += 1
|
||||
yield chunk.to_sse()
|
||||
|
||||
# Check for finish signal
|
||||
if isinstance(chunk, StreamFinish):
|
||||
logger.info(
|
||||
f"Task stream completed for task {task_id}, "
|
||||
f"chunk_count={chunk_count}"
|
||||
try:
|
||||
# Wait for next chunk with timeout for heartbeats
|
||||
chunk = await asyncio.wait_for(
|
||||
subscriber_queue.get(), timeout=heartbeat_interval
|
||||
)
|
||||
break
|
||||
chunk_count += 1
|
||||
yield chunk.to_sse()
|
||||
|
||||
# Check for finish signal
|
||||
if isinstance(chunk, StreamFinish):
|
||||
logger.info(
|
||||
f"Task stream completed for task {task_id}, "
|
||||
f"chunk_count={chunk_count}"
|
||||
)
|
||||
break
|
||||
except asyncio.TimeoutError:
|
||||
# Send heartbeat to keep connection alive
|
||||
yield StreamHeartbeat().to_sse()
|
||||
except Exception as e:
|
||||
logger.error(f"Error in task stream {task_id}: {e}", exc_info=True)
|
||||
finally:
|
||||
# Unsubscribe when client disconnects or stream ends
|
||||
await stream_registry.unsubscribe_from_task(task_id, subscriber_queue)
|
||||
|
||||
# AI SDK protocol termination
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
@@ -1877,6 +1877,9 @@ async def _execute_long_running_tool_with_streaming(
|
||||
This function runs independently of the SSE connection, publishes progress
|
||||
to the stream registry, and survives if the user closes their browser tab.
|
||||
Clients can reconnect via GET /chat/tasks/{task_id}/stream to resume streaming.
|
||||
|
||||
If the external service returns a 202 Accepted (async), this function exits
|
||||
early and lets the RabbitMQ completion consumer handle the rest.
|
||||
"""
|
||||
try:
|
||||
# Load fresh session (not stale reference)
|
||||
@@ -1886,15 +1889,39 @@ async def _execute_long_running_tool_with_streaming(
|
||||
await stream_registry.mark_task_completed(task_id, status="failed")
|
||||
return
|
||||
|
||||
# Pass operation_id and task_id to the tool for async processing
|
||||
enriched_parameters = {
|
||||
**parameters,
|
||||
"_operation_id": operation_id,
|
||||
"_task_id": task_id,
|
||||
}
|
||||
|
||||
# Execute the actual tool
|
||||
result = await execute_tool(
|
||||
tool_name=tool_name,
|
||||
parameters=parameters,
|
||||
parameters=enriched_parameters,
|
||||
tool_call_id=tool_call_id,
|
||||
user_id=user_id,
|
||||
session=session,
|
||||
)
|
||||
|
||||
# Check if the tool result indicates async processing
|
||||
# (e.g., Agent Generator returned 202 Accepted)
|
||||
try:
|
||||
result_data = orjson.loads(result.output) if result.output else {}
|
||||
if result_data.get("status") == "accepted":
|
||||
logger.info(
|
||||
f"Tool {tool_name} delegated to async processing "
|
||||
f"(operation_id={operation_id}, task_id={task_id}). "
|
||||
f"RabbitMQ completion consumer will handle the rest."
|
||||
)
|
||||
# Don't publish result, don't continue with LLM
|
||||
# The RabbitMQ consumer will handle everything when the external
|
||||
# service completes and publishes to the queue
|
||||
return
|
||||
except (orjson.JSONDecodeError, TypeError):
|
||||
pass # Not JSON or not async - continue normally
|
||||
|
||||
# Publish tool result to stream registry
|
||||
await stream_registry.publish_chunk(task_id, result)
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ messages. It supports:
|
||||
- Publishing stream messages to both Redis Streams and in-memory queues
|
||||
- Subscribing to tasks with replay of missed messages
|
||||
- Looking up tasks by operation_id for webhook callbacks
|
||||
- Cross-pod real-time delivery via Redis pub/sub
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
@@ -24,6 +25,9 @@ from .response_model import StreamBaseResponse, StreamFinish
|
||||
logger = logging.getLogger(__name__)
|
||||
config = ChatConfig()
|
||||
|
||||
# Track active pub/sub listeners for cross-pod delivery
|
||||
_pubsub_listeners: dict[str, asyncio.Task] = {}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ActiveTask:
|
||||
@@ -39,6 +43,10 @@ class ActiveTask:
|
||||
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
queue: asyncio.Queue[StreamBaseResponse] = field(default_factory=asyncio.Queue)
|
||||
asyncio_task: asyncio.Task | None = None
|
||||
# Lock for atomic status checks and subscriber management
|
||||
lock: asyncio.Lock = field(default_factory=asyncio.Lock)
|
||||
# Set of subscriber queues for fan-out
|
||||
subscribers: set[asyncio.Queue[StreamBaseResponse]] = field(default_factory=set)
|
||||
|
||||
|
||||
# Module-level registry for active tasks
|
||||
@@ -48,6 +56,7 @@ _active_tasks: dict[str, ActiveTask] = {}
|
||||
TASK_META_PREFIX = "chat:task:meta:" # Hash for task metadata
|
||||
TASK_STREAM_PREFIX = "chat:stream:" # Redis Stream for messages
|
||||
TASK_OP_PREFIX = "chat:task:op:" # Operation ID -> task_id mapping
|
||||
TASK_PUBSUB_PREFIX = "chat:task:pubsub:" # Pub/sub channel for cross-pod delivery
|
||||
|
||||
|
||||
def _get_task_meta_key(task_id: str) -> str:
|
||||
@@ -65,6 +74,11 @@ def _get_operation_mapping_key(operation_id: str) -> str:
|
||||
return f"{TASK_OP_PREFIX}{operation_id}"
|
||||
|
||||
|
||||
def _get_task_pubsub_channel(task_id: str) -> str:
|
||||
"""Get Redis pub/sub channel for task cross-pod delivery."""
|
||||
return f"{TASK_PUBSUB_PREFIX}{task_id}"
|
||||
|
||||
|
||||
async def create_task(
|
||||
task_id: str,
|
||||
session_id: str,
|
||||
@@ -132,58 +146,74 @@ async def create_task(
|
||||
async def publish_chunk(
|
||||
task_id: str,
|
||||
chunk: StreamBaseResponse,
|
||||
) -> int:
|
||||
) -> str:
|
||||
"""Publish a chunk to the task's stream.
|
||||
|
||||
Writes to both Redis Stream (for replay) and in-memory queue (for live subscribers).
|
||||
Delivers to in-memory subscribers first (for real-time), then persists to
|
||||
Redis Stream (for replay). This order ensures live subscribers get messages
|
||||
even if Redis temporarily fails.
|
||||
|
||||
Args:
|
||||
task_id: Task ID to publish to
|
||||
chunk: The stream response chunk to publish
|
||||
|
||||
Returns:
|
||||
The message index in the Redis Stream
|
||||
The Redis Stream message ID (format: "timestamp-sequence"), or "0-0" if
|
||||
Redis persistence failed
|
||||
"""
|
||||
redis = await get_redis_async()
|
||||
stream_key = _get_task_stream_key(task_id)
|
||||
|
||||
# Serialize chunk to JSON
|
||||
chunk_json = chunk.model_dump_json()
|
||||
|
||||
# Add to Redis Stream with auto-generated ID
|
||||
# The ID format is "timestamp-sequence" which gives us ordering
|
||||
message_id = await redis.xadd(
|
||||
stream_key,
|
||||
{"data": chunk_json},
|
||||
maxlen=config.stream_max_length,
|
||||
)
|
||||
|
||||
# Publish to in-memory queue if task exists
|
||||
# Deliver to in-memory subscribers FIRST for real-time updates
|
||||
task = _active_tasks.get(task_id)
|
||||
if task:
|
||||
try:
|
||||
task.queue.put_nowait(chunk)
|
||||
except asyncio.QueueFull:
|
||||
logger.warning(f"Queue full for task {task_id}, dropping chunk")
|
||||
async with task.lock:
|
||||
for subscriber_queue in task.subscribers:
|
||||
try:
|
||||
subscriber_queue.put_nowait(chunk)
|
||||
except asyncio.QueueFull:
|
||||
logger.warning(
|
||||
f"Subscriber queue full for task {task_id}, dropping chunk"
|
||||
)
|
||||
|
||||
logger.debug(f"Published chunk to task {task_id}, message_id={message_id}")
|
||||
# Then persist to Redis Stream for replay (with error handling)
|
||||
message_id = "0-0"
|
||||
chunk_json = chunk.model_dump_json()
|
||||
try:
|
||||
redis = await get_redis_async()
|
||||
stream_key = _get_task_stream_key(task_id)
|
||||
|
||||
# Parse the message_id to extract the index
|
||||
# Redis Stream IDs are "timestamp-sequence", we return the raw ID
|
||||
return int(message_id.split("-")[1]) if "-" in message_id else 0
|
||||
# Add to Redis Stream with auto-generated ID
|
||||
# The ID format is "timestamp-sequence" which gives us ordering
|
||||
raw_id = await redis.xadd(
|
||||
stream_key,
|
||||
{"data": chunk_json},
|
||||
maxlen=config.stream_max_length,
|
||||
)
|
||||
message_id = raw_id if isinstance(raw_id, str) else raw_id.decode()
|
||||
|
||||
# Publish to pub/sub for cross-pod real-time delivery
|
||||
pubsub_channel = _get_task_pubsub_channel(task_id)
|
||||
await redis.publish(pubsub_channel, chunk_json)
|
||||
|
||||
logger.debug(f"Published chunk to task {task_id}, message_id={message_id}")
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to persist chunk to Redis for task {task_id}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
return message_id
|
||||
|
||||
|
||||
async def subscribe_to_task(
|
||||
task_id: str,
|
||||
user_id: str | None,
|
||||
last_idx: int = 0,
|
||||
last_message_id: str = "0-0",
|
||||
) -> asyncio.Queue[StreamBaseResponse] | None:
|
||||
"""Subscribe to a task's stream with replay of missed messages.
|
||||
|
||||
Args:
|
||||
task_id: Task ID to subscribe to
|
||||
user_id: User ID for ownership validation
|
||||
last_idx: Last message index received (0 for full replay)
|
||||
last_message_id: Last Redis Stream message ID received ("0-0" for full replay)
|
||||
|
||||
Returns:
|
||||
An asyncio Queue that will receive stream chunks, or None if task not found
|
||||
@@ -208,15 +238,21 @@ async def subscribe_to_task(
|
||||
redis = await get_redis_async()
|
||||
stream_key = _get_task_stream_key(task_id)
|
||||
|
||||
# Read all messages from stream
|
||||
# Use "0-0" to get all messages or construct ID from last_idx
|
||||
start_id = "0-0" if last_idx == 0 else f"0-{last_idx}"
|
||||
messages = await redis.xread({stream_key: start_id}, block=0, count=1000)
|
||||
# Track the last message ID we've seen for gap detection
|
||||
replay_last_id = last_message_id
|
||||
|
||||
# Read all messages from stream starting after last_message_id
|
||||
# xread returns messages with ID > last_message_id
|
||||
messages = await redis.xread({stream_key: last_message_id}, block=0, count=1000)
|
||||
|
||||
if messages:
|
||||
# messages format: [[stream_name, [(id, {data: json}), ...]]]
|
||||
for _stream_name, stream_messages in messages:
|
||||
for _msg_id, msg_data in stream_messages:
|
||||
for msg_id, msg_data in stream_messages:
|
||||
# Track the last message ID we've processed
|
||||
replay_last_id = (
|
||||
msg_id if isinstance(msg_id, str) else msg_id.decode()
|
||||
)
|
||||
if b"data" in msg_data:
|
||||
try:
|
||||
chunk_data = orjson.loads(msg_data[b"data"])
|
||||
@@ -227,23 +263,44 @@ async def subscribe_to_task(
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to replay message: {e}")
|
||||
|
||||
# If task is still running, set up live subscription
|
||||
if task.status == "running":
|
||||
# Forward messages from task queue to subscriber queue
|
||||
async def _forward_messages():
|
||||
try:
|
||||
while True:
|
||||
chunk = await task.queue.get()
|
||||
await subscriber_queue.put(chunk)
|
||||
if isinstance(chunk, StreamFinish):
|
||||
break
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
# Atomically check status and register subscriber under lock
|
||||
# This prevents race condition where task completes between check and subscribe
|
||||
should_start_pubsub = False
|
||||
async with task.lock:
|
||||
if task.status == "running":
|
||||
# Register this subscriber for live updates
|
||||
task.subscribers.add(subscriber_queue)
|
||||
# Start pub/sub listener if this is the first subscriber
|
||||
should_start_pubsub = len(task.subscribers) == 1
|
||||
logger.debug(
|
||||
f"Registered subscriber for task {task_id}, "
|
||||
f"total subscribers: {len(task.subscribers)}"
|
||||
)
|
||||
else:
|
||||
# Task is done, add finish marker
|
||||
await subscriber_queue.put(StreamFinish())
|
||||
|
||||
asyncio.create_task(_forward_messages())
|
||||
else:
|
||||
# Task is done, add finish marker
|
||||
await subscriber_queue.put(StreamFinish())
|
||||
# After registering, do a second read to catch any messages published
|
||||
# between the first read and registration (closes the race window)
|
||||
if task.status == "running":
|
||||
gap_messages = await redis.xread(
|
||||
{stream_key: replay_last_id}, block=0, count=1000
|
||||
)
|
||||
if gap_messages:
|
||||
for _stream_name, stream_messages in gap_messages:
|
||||
for _msg_id, msg_data in stream_messages:
|
||||
if b"data" in msg_data:
|
||||
try:
|
||||
chunk_data = orjson.loads(msg_data[b"data"])
|
||||
chunk = _reconstruct_chunk(chunk_data)
|
||||
if chunk:
|
||||
await subscriber_queue.put(chunk)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to replay gap message: {e}")
|
||||
|
||||
# Start pub/sub listener outside the lock to avoid deadlocks
|
||||
if should_start_pubsub:
|
||||
await start_pubsub_listener(task_id)
|
||||
|
||||
return subscriber_queue
|
||||
|
||||
@@ -269,8 +326,8 @@ async def subscribe_to_task(
|
||||
subscriber_queue = asyncio.Queue()
|
||||
stream_key = _get_task_stream_key(task_id)
|
||||
|
||||
start_id = "0-0" if last_idx == 0 else f"0-{last_idx}"
|
||||
messages = await redis.xread({stream_key: start_id}, block=0, count=1000)
|
||||
# Read all messages starting after last_message_id
|
||||
messages = await redis.xread({stream_key: last_message_id}, block=0, count=1000)
|
||||
|
||||
if messages:
|
||||
for _stream_name, stream_messages in messages:
|
||||
@@ -303,8 +360,25 @@ async def mark_task_completed(
|
||||
task = _active_tasks.get(task_id)
|
||||
|
||||
if task:
|
||||
task.status = status
|
||||
# Publish finish event to all subscribers
|
||||
# Acquire lock to prevent new subscribers during completion
|
||||
async with task.lock:
|
||||
task.status = status
|
||||
# Send finish event directly to all current subscribers
|
||||
finish_event = StreamFinish()
|
||||
for subscriber_queue in task.subscribers:
|
||||
try:
|
||||
subscriber_queue.put_nowait(finish_event)
|
||||
except asyncio.QueueFull:
|
||||
logger.warning(
|
||||
f"Subscriber queue full for task {task_id} during completion"
|
||||
)
|
||||
# Clear subscribers since task is done
|
||||
task.subscribers.clear()
|
||||
|
||||
# Stop pub/sub listener since task is done
|
||||
await stop_pubsub_listener(task_id)
|
||||
|
||||
# Also publish to Redis Stream for replay (and pub/sub for cross-pod)
|
||||
await publish_chunk(task_id, StreamFinish())
|
||||
|
||||
# Remove from active tasks after a short delay to allow subscribers to finish
|
||||
@@ -468,3 +542,107 @@ async def set_task_asyncio_task(task_id: str, asyncio_task: asyncio.Task) -> Non
|
||||
task = _active_tasks.get(task_id)
|
||||
if task:
|
||||
task.asyncio_task = asyncio_task
|
||||
|
||||
|
||||
async def unsubscribe_from_task(
|
||||
task_id: str,
|
||||
subscriber_queue: asyncio.Queue[StreamBaseResponse],
|
||||
) -> None:
|
||||
"""Unsubscribe a queue from a task's stream.
|
||||
|
||||
Should be called when a client disconnects to clean up resources.
|
||||
Also stops the pub/sub listener if there are no more local subscribers.
|
||||
|
||||
Args:
|
||||
task_id: Task ID to unsubscribe from
|
||||
subscriber_queue: The queue to remove from subscribers
|
||||
"""
|
||||
task = _active_tasks.get(task_id)
|
||||
if task:
|
||||
async with task.lock:
|
||||
task.subscribers.discard(subscriber_queue)
|
||||
remaining = len(task.subscribers)
|
||||
logger.debug(
|
||||
f"Unsubscribed from task {task_id}, "
|
||||
f"remaining subscribers: {remaining}"
|
||||
)
|
||||
# Stop pub/sub listener if no more local subscribers
|
||||
if remaining == 0:
|
||||
await stop_pubsub_listener(task_id)
|
||||
|
||||
|
||||
async def start_pubsub_listener(task_id: str) -> None:
|
||||
"""Start listening to Redis pub/sub for cross-pod delivery.
|
||||
|
||||
This enables real-time updates when another pod publishes chunks for a task
|
||||
that has local subscribers on this pod.
|
||||
|
||||
Args:
|
||||
task_id: Task ID to listen for
|
||||
"""
|
||||
if task_id in _pubsub_listeners:
|
||||
return # Already listening
|
||||
|
||||
task = _active_tasks.get(task_id)
|
||||
if not task:
|
||||
return
|
||||
|
||||
async def _listener():
|
||||
try:
|
||||
redis = await get_redis_async()
|
||||
pubsub = redis.pubsub()
|
||||
channel = _get_task_pubsub_channel(task_id)
|
||||
await pubsub.subscribe(channel)
|
||||
logger.debug(f"Started pub/sub listener for task {task_id}")
|
||||
|
||||
async for message in pubsub.listen():
|
||||
if message["type"] != "message":
|
||||
continue
|
||||
|
||||
try:
|
||||
chunk_data = orjson.loads(message["data"])
|
||||
chunk = _reconstruct_chunk(chunk_data)
|
||||
if chunk:
|
||||
# Deliver to local subscribers
|
||||
local_task = _active_tasks.get(task_id)
|
||||
if local_task:
|
||||
async with local_task.lock:
|
||||
for queue in local_task.subscribers:
|
||||
try:
|
||||
queue.put_nowait(chunk)
|
||||
except asyncio.QueueFull:
|
||||
pass
|
||||
# Stop listening if this was a finish event
|
||||
if isinstance(chunk, StreamFinish):
|
||||
break
|
||||
except Exception as e:
|
||||
logger.warning(f"Error processing pub/sub message: {e}")
|
||||
|
||||
await pubsub.unsubscribe(channel)
|
||||
await pubsub.close()
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"Pub/sub listener error for task {task_id}: {e}")
|
||||
finally:
|
||||
_pubsub_listeners.pop(task_id, None)
|
||||
logger.debug(f"Stopped pub/sub listener for task {task_id}")
|
||||
|
||||
listener_task = asyncio.create_task(_listener())
|
||||
_pubsub_listeners[task_id] = listener_task
|
||||
|
||||
|
||||
async def stop_pubsub_listener(task_id: str) -> None:
|
||||
"""Stop the pub/sub listener for a task.
|
||||
|
||||
Args:
|
||||
task_id: Task ID to stop listening for
|
||||
"""
|
||||
listener = _pubsub_listeners.pop(task_id, None)
|
||||
if listener and not listener.done():
|
||||
listener.cancel()
|
||||
try:
|
||||
await listener
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
logger.debug(f"Cancelled pub/sub listener for task {task_id}")
|
||||
|
||||
@@ -57,21 +57,32 @@ async def decompose_goal(description: str, context: str = "") -> dict[str, Any]
|
||||
return await decompose_goal_external(description, context)
|
||||
|
||||
|
||||
async def generate_agent(instructions: dict[str, Any]) -> dict[str, Any] | None:
|
||||
async def generate_agent(
|
||||
instructions: dict[str, Any],
|
||||
operation_id: str | None = None,
|
||||
task_id: str | None = None,
|
||||
) -> dict[str, Any] | None:
|
||||
"""Generate agent JSON from instructions.
|
||||
|
||||
Args:
|
||||
instructions: Structured instructions from decompose_goal
|
||||
operation_id: Operation ID for async processing (enables RabbitMQ callback)
|
||||
task_id: Task ID for async processing (enables RabbitMQ callback)
|
||||
|
||||
Returns:
|
||||
Agent JSON dict or None on error
|
||||
Agent JSON dict, {"status": "accepted"} for async, or None on error
|
||||
|
||||
Raises:
|
||||
AgentGeneratorNotConfiguredError: If the external service is not configured.
|
||||
"""
|
||||
_check_service_configured()
|
||||
logger.info("Calling external Agent Generator service for generate_agent")
|
||||
result = await generate_agent_external(instructions)
|
||||
result = await generate_agent_external(instructions, operation_id, task_id)
|
||||
|
||||
# Don't modify async response
|
||||
if result and result.get("status") == "accepted":
|
||||
return result
|
||||
|
||||
if result:
|
||||
# Ensure required fields
|
||||
if "id" not in result:
|
||||
@@ -253,7 +264,10 @@ async def get_agent_as_json(
|
||||
|
||||
|
||||
async def generate_agent_patch(
|
||||
update_request: str, current_agent: dict[str, Any]
|
||||
update_request: str,
|
||||
current_agent: dict[str, Any],
|
||||
operation_id: str | None = None,
|
||||
task_id: str | None = None,
|
||||
) -> dict[str, Any] | None:
|
||||
"""Update an existing agent using natural language.
|
||||
|
||||
@@ -265,13 +279,17 @@ async def generate_agent_patch(
|
||||
Args:
|
||||
update_request: Natural language description of changes
|
||||
current_agent: Current agent JSON
|
||||
operation_id: Operation ID for async processing (enables RabbitMQ callback)
|
||||
task_id: Task ID for async processing (enables RabbitMQ callback)
|
||||
|
||||
Returns:
|
||||
Updated agent JSON, clarifying questions dict, or None on error
|
||||
Updated agent JSON, clarifying questions dict, {"status": "accepted"} for async, or None on error
|
||||
|
||||
Raises:
|
||||
AgentGeneratorNotConfiguredError: If the external service is not configured.
|
||||
"""
|
||||
_check_service_configured()
|
||||
logger.info("Calling external Agent Generator service for generate_agent_patch")
|
||||
return await generate_agent_patch_external(update_request, current_agent)
|
||||
return await generate_agent_patch_external(
|
||||
update_request, current_agent, operation_id, task_id
|
||||
)
|
||||
|
||||
@@ -124,22 +124,39 @@ async def decompose_goal_external(
|
||||
|
||||
|
||||
async def generate_agent_external(
|
||||
instructions: dict[str, Any]
|
||||
instructions: dict[str, Any],
|
||||
operation_id: str | None = None,
|
||||
task_id: str | None = None,
|
||||
) -> dict[str, Any] | None:
|
||||
"""Call the external service to generate an agent from instructions.
|
||||
|
||||
Args:
|
||||
instructions: Structured instructions from decompose_goal
|
||||
operation_id: Operation ID for async processing (enables RabbitMQ callback)
|
||||
task_id: Task ID for async processing (enables RabbitMQ callback)
|
||||
|
||||
Returns:
|
||||
Agent JSON dict or None on error
|
||||
Agent JSON dict, or {"status": "accepted"} for async, or None on error
|
||||
"""
|
||||
client = _get_client()
|
||||
|
||||
# Build request payload
|
||||
payload: dict[str, Any] = {"instructions": instructions}
|
||||
if operation_id and task_id:
|
||||
payload["operation_id"] = operation_id
|
||||
payload["task_id"] = task_id
|
||||
|
||||
try:
|
||||
response = await client.post(
|
||||
"/api/generate-agent", json={"instructions": instructions}
|
||||
)
|
||||
response = await client.post("/api/generate-agent", json=payload)
|
||||
|
||||
# Handle 202 Accepted for async processing
|
||||
if response.status_code == 202:
|
||||
logger.info(
|
||||
f"Agent Generator accepted async request "
|
||||
f"(operation_id={operation_id}, task_id={task_id})"
|
||||
)
|
||||
return {"status": "accepted", "operation_id": operation_id, "task_id": task_id}
|
||||
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
@@ -161,27 +178,44 @@ async def generate_agent_external(
|
||||
|
||||
|
||||
async def generate_agent_patch_external(
|
||||
update_request: str, current_agent: dict[str, Any]
|
||||
update_request: str,
|
||||
current_agent: dict[str, Any],
|
||||
operation_id: str | None = None,
|
||||
task_id: str | None = None,
|
||||
) -> dict[str, Any] | None:
|
||||
"""Call the external service to generate a patch for an existing agent.
|
||||
|
||||
Args:
|
||||
update_request: Natural language description of changes
|
||||
current_agent: Current agent JSON
|
||||
operation_id: Operation ID for async processing (enables RabbitMQ callback)
|
||||
task_id: Task ID for async processing (enables RabbitMQ callback)
|
||||
|
||||
Returns:
|
||||
Updated agent JSON, clarifying questions dict, or None on error
|
||||
Updated agent JSON, clarifying questions dict, {"status": "accepted"} for async, or None on error
|
||||
"""
|
||||
client = _get_client()
|
||||
|
||||
# Build request payload
|
||||
payload: dict[str, Any] = {
|
||||
"update_request": update_request,
|
||||
"current_agent_json": current_agent,
|
||||
}
|
||||
if operation_id and task_id:
|
||||
payload["operation_id"] = operation_id
|
||||
payload["task_id"] = task_id
|
||||
|
||||
try:
|
||||
response = await client.post(
|
||||
"/api/update-agent",
|
||||
json={
|
||||
"update_request": update_request,
|
||||
"current_agent_json": current_agent,
|
||||
},
|
||||
)
|
||||
response = await client.post("/api/update-agent", json=payload)
|
||||
|
||||
# Handle 202 Accepted for async processing
|
||||
if response.status_code == 202:
|
||||
logger.info(
|
||||
f"Agent Generator accepted async update request "
|
||||
f"(operation_id={operation_id}, task_id={task_id})"
|
||||
)
|
||||
return {"status": "accepted", "operation_id": operation_id, "task_id": task_id}
|
||||
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@ from .base import BaseTool
|
||||
from .models import (
|
||||
AgentPreviewResponse,
|
||||
AgentSavedResponse,
|
||||
AsyncProcessingResponse,
|
||||
ClarificationNeededResponse,
|
||||
ClarifyingQuestion,
|
||||
ErrorResponse,
|
||||
@@ -95,6 +96,10 @@ class CreateAgentTool(BaseTool):
|
||||
save = kwargs.get("save", True)
|
||||
session_id = session.session_id if session else None
|
||||
|
||||
# Extract async processing params (passed by long-running tool handler)
|
||||
operation_id = kwargs.get("_operation_id")
|
||||
task_id = kwargs.get("_task_id")
|
||||
|
||||
if not description:
|
||||
return ErrorResponse(
|
||||
message="Please provide a description of what the agent should do.",
|
||||
@@ -173,7 +178,11 @@ class CreateAgentTool(BaseTool):
|
||||
|
||||
# Step 2: Generate agent JSON (external service handles fixing and validation)
|
||||
try:
|
||||
agent_json = await generate_agent(decomposition_result)
|
||||
agent_json = await generate_agent(
|
||||
decomposition_result,
|
||||
operation_id=operation_id,
|
||||
task_id=task_id,
|
||||
)
|
||||
except AgentGeneratorNotConfiguredError:
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
@@ -194,6 +203,19 @@ class CreateAgentTool(BaseTool):
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Check if Agent Generator accepted for async processing
|
||||
if agent_json.get("status") == "accepted":
|
||||
logger.info(
|
||||
f"Agent generation delegated to async processing "
|
||||
f"(operation_id={operation_id}, task_id={task_id})"
|
||||
)
|
||||
return AsyncProcessingResponse(
|
||||
message="Agent generation started. You'll be notified when it's complete.",
|
||||
operation_id=operation_id,
|
||||
task_id=task_id,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
agent_name = agent_json.get("name", "Generated Agent")
|
||||
agent_description = agent_json.get("description", "")
|
||||
node_count = len(agent_json.get("nodes", []))
|
||||
|
||||
@@ -15,6 +15,7 @@ from .base import BaseTool
|
||||
from .models import (
|
||||
AgentPreviewResponse,
|
||||
AgentSavedResponse,
|
||||
AsyncProcessingResponse,
|
||||
ClarificationNeededResponse,
|
||||
ClarifyingQuestion,
|
||||
ErrorResponse,
|
||||
@@ -102,6 +103,10 @@ class EditAgentTool(BaseTool):
|
||||
save = kwargs.get("save", True)
|
||||
session_id = session.session_id if session else None
|
||||
|
||||
# Extract async processing params (passed by long-running tool handler)
|
||||
operation_id = kwargs.get("_operation_id")
|
||||
task_id = kwargs.get("_task_id")
|
||||
|
||||
if not agent_id:
|
||||
return ErrorResponse(
|
||||
message="Please provide the agent ID to edit.",
|
||||
@@ -133,7 +138,12 @@ class EditAgentTool(BaseTool):
|
||||
|
||||
# Step 2: Generate updated agent (external service handles fixing and validation)
|
||||
try:
|
||||
result = await generate_agent_patch(update_request, current_agent)
|
||||
result = await generate_agent_patch(
|
||||
update_request,
|
||||
current_agent,
|
||||
operation_id=operation_id,
|
||||
task_id=task_id,
|
||||
)
|
||||
except AgentGeneratorNotConfiguredError:
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
@@ -152,6 +162,19 @@ class EditAgentTool(BaseTool):
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Check if Agent Generator accepted for async processing
|
||||
if result.get("status") == "accepted":
|
||||
logger.info(
|
||||
f"Agent edit delegated to async processing "
|
||||
f"(operation_id={operation_id}, task_id={task_id})"
|
||||
)
|
||||
return AsyncProcessingResponse(
|
||||
message="Agent edit started. You'll be notified when it's complete.",
|
||||
operation_id=operation_id,
|
||||
task_id=task_id,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Check if LLM returned clarifying questions
|
||||
if result.get("type") == "clarifying_questions":
|
||||
questions = result.get("questions", [])
|
||||
|
||||
@@ -384,3 +384,20 @@ class OperationInProgressResponse(ToolResponseBase):
|
||||
|
||||
type: ResponseType = ResponseType.OPERATION_IN_PROGRESS
|
||||
tool_call_id: str
|
||||
|
||||
|
||||
class AsyncProcessingResponse(ToolResponseBase):
|
||||
"""Response when an operation has been delegated to async processing.
|
||||
|
||||
This is returned by tools when the external service accepts the request
|
||||
for async processing (HTTP 202 Accepted). The RabbitMQ completion consumer
|
||||
will handle the result when the external service completes.
|
||||
|
||||
The status field is specifically "accepted" to allow the long-running tool
|
||||
handler to detect this response and skip LLM continuation.
|
||||
"""
|
||||
|
||||
type: ResponseType = ResponseType.OPERATION_STARTED
|
||||
status: str = "accepted" # Must be "accepted" for detection
|
||||
operation_id: str | None = None
|
||||
task_id: str | None = None
|
||||
|
||||
Reference in New Issue
Block a user