diff --git a/autogpt_platform/backend/backend/api/features/chat/completion_consumer.py b/autogpt_platform/backend/backend/api/features/chat/completion_consumer.py index fd295c3494..4ad7afb9e9 100644 --- a/autogpt_platform/backend/backend/api/features/chat/completion_consumer.py +++ b/autogpt_platform/backend/backend/api/features/chat/completion_consumer.py @@ -7,12 +7,17 @@ stream registry and chat service updates. import asyncio import logging -from typing import Any import orjson from pydantic import BaseModel -from backend.data.rabbitmq import AsyncRabbitMQ, Exchange, ExchangeType, Queue, RabbitMQConfig +from backend.data.rabbitmq import ( + AsyncRabbitMQ, + Exchange, + ExchangeType, + Queue, + RabbitMQConfig, +) from . import service as chat_service from . import stream_registry @@ -182,9 +187,11 @@ class ChatCompletionConsumer: result_str = ( message.result if isinstance(message.result, str) - else orjson.dumps(message.result).decode("utf-8") - if message.result - else '{"status": "completed"}' + else ( + orjson.dumps(message.result).decode("utf-8") + if message.result + else '{"status": "completed"}' + ) ) await chat_service._update_pending_operation( session_id=task.session_id, diff --git a/autogpt_platform/backend/backend/api/features/chat/routes.py b/autogpt_platform/backend/backend/api/features/chat/routes.py index b085b6d83d..bc9611021f 100644 --- a/autogpt_platform/backend/backend/api/features/chat/routes.py +++ b/autogpt_platform/backend/backend/api/features/chat/routes.py @@ -5,7 +5,7 @@ from collections.abc import AsyncGenerator from typing import Annotated from autogpt_libs import auth -from fastapi import APIRouter, Depends, Query, Security +from fastapi import APIRouter, Depends, Header, HTTPException, Query, Security from fastapi.responses import StreamingResponse from pydantic import BaseModel @@ -491,9 +491,6 @@ async def get_task_status( # ========== External Completion Webhook ========== -from fastapi import Header, HTTPException - - @router.post( "/operations/{operation_id}/complete", status_code=200, @@ -527,8 +524,8 @@ async def complete_operation( else: # If no internal API key is configured, log a warning logger.warning( - f"Operation complete webhook called without API key validation " - f"(CHAT_INTERNAL_API_KEY not configured)" + "Operation complete webhook called without API key validation " + "(CHAT_INTERNAL_API_KEY not configured)" ) # Find task by operation_id @@ -554,7 +551,11 @@ async def complete_operation( StreamToolOutputAvailable( toolCallId=task.tool_call_id, toolName=task.tool_name, - output=result_output if isinstance(result_output, str) else str(result_output), + output=( + result_output + if isinstance(result_output, str) + else str(result_output) + ), success=True, ), ) diff --git a/autogpt_platform/backend/backend/api/features/chat/service.py b/autogpt_platform/backend/backend/api/features/chat/service.py index 53f7f863af..09eb8e6093 100644 --- a/autogpt_platform/backend/backend/api/features/chat/service.py +++ b/autogpt_platform/backend/backend/api/features/chat/service.py @@ -26,6 +26,7 @@ from backend.util.exceptions import NotFoundError from backend.util.settings import Settings from . import db as chat_db +from . import stream_registry from .config import ChatConfig from .model import ( ChatMessage, @@ -59,7 +60,6 @@ from .tools.models import ( OperationStartedResponse, ) from .tracking import track_user_message -from . import stream_registry logger = logging.getLogger(__name__) @@ -2134,12 +2134,8 @@ async def _generate_llm_continuation_with_streaming( text_block_id = str(uuid_module.uuid4()) # Publish start event - await stream_registry.publish_chunk( - task_id, StreamStart(messageId=message_id) - ) - await stream_registry.publish_chunk( - task_id, StreamTextStart(id=text_block_id) - ) + await stream_registry.publish_chunk(task_id, StreamStart(messageId=message_id)) + await stream_registry.publish_chunk(task_id, StreamTextStart(id=text_block_id)) # Stream the response stream = await client.chat.completions.create( @@ -2161,9 +2157,7 @@ async def _generate_llm_continuation_with_streaming( ) # Publish end events - await stream_registry.publish_chunk( - task_id, StreamTextEnd(id=text_block_id) - ) + await stream_registry.publish_chunk(task_id, StreamTextEnd(id=text_block_id)) if assistant_content: # Reload session from DB to avoid race condition with user messages diff --git a/autogpt_platform/backend/backend/api/features/chat/stream_registry.py b/autogpt_platform/backend/backend/api/features/chat/stream_registry.py index 9839c4ff48..72d4488c0e 100644 --- a/autogpt_platform/backend/backend/api/features/chat/stream_registry.py +++ b/autogpt_platform/backend/backend/api/features/chat/stream_registry.py @@ -12,7 +12,7 @@ import asyncio import logging from dataclasses import dataclass, field from datetime import datetime, timezone -from typing import Literal +from typing import Any, Literal import orjson @@ -103,7 +103,7 @@ async def create_task( meta_key = _get_task_meta_key(task_id) op_key = _get_operation_mapping_key(operation_id) - await redis.hset( + await redis.hset( # type: ignore[misc] meta_key, mapping={ "task_id": task_id, @@ -250,7 +250,7 @@ async def subscribe_to_task( # Try to load from Redis if not in memory redis = await get_redis_async() meta_key = _get_task_meta_key(task_id) - meta = await redis.hgetall(meta_key) + meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc] if not meta: logger.warning(f"Task {task_id} not found in memory or Redis") @@ -318,7 +318,7 @@ async def mark_task_completed( # Update Redis metadata redis = await get_redis_async() meta_key = _get_task_meta_key(task_id) - await redis.hset(meta_key, "status", status) + await redis.hset(meta_key, "status", status) # type: ignore[misc] logger.info(f"Marked task {task_id} as {status}") @@ -352,7 +352,7 @@ async def find_task_by_operation_id(operation_id: str) -> ActiveTask | None: # Load metadata from Redis meta_key = _get_task_meta_key(task_id_str) - meta = await redis.hgetall(meta_key) + meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc] if meta: # Reconstruct task object (not fully active, but has metadata) @@ -385,7 +385,7 @@ async def get_task(task_id: str) -> ActiveTask | None: # Try Redis lookup redis = await get_redis_async() meta_key = _get_task_meta_key(task_id) - meta = await redis.hgetall(meta_key) + meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc] if meta: return ActiveTask( @@ -395,7 +395,7 @@ async def get_task(task_id: str) -> ActiveTask | None: tool_call_id=meta.get(b"tool_call_id", b"").decode(), tool_name=meta.get(b"tool_name", b"").decode(), operation_id=meta.get(b"operation_id", b"").decode(), - status=meta.get(b"status", b"running").decode(), # type: ignore + status=meta.get(b"status", b"running").decode(), # type: ignore[arg-type] ) return None