fmt issues

This commit is contained in:
Swifty
2026-01-29 13:28:01 +01:00
parent d738059da8
commit c01c29a059
4 changed files with 31 additions and 29 deletions

View File

@@ -7,12 +7,17 @@ stream registry and chat service updates.
import asyncio
import logging
from typing import Any
import orjson
from pydantic import BaseModel
from backend.data.rabbitmq import AsyncRabbitMQ, Exchange, ExchangeType, Queue, RabbitMQConfig
from backend.data.rabbitmq import (
AsyncRabbitMQ,
Exchange,
ExchangeType,
Queue,
RabbitMQConfig,
)
from . import service as chat_service
from . import stream_registry
@@ -182,9 +187,11 @@ class ChatCompletionConsumer:
result_str = (
message.result
if isinstance(message.result, str)
else orjson.dumps(message.result).decode("utf-8")
if message.result
else '{"status": "completed"}'
else (
orjson.dumps(message.result).decode("utf-8")
if message.result
else '{"status": "completed"}'
)
)
await chat_service._update_pending_operation(
session_id=task.session_id,

View File

@@ -5,7 +5,7 @@ from collections.abc import AsyncGenerator
from typing import Annotated
from autogpt_libs import auth
from fastapi import APIRouter, Depends, Query, Security
from fastapi import APIRouter, Depends, Header, HTTPException, Query, Security
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
@@ -491,9 +491,6 @@ async def get_task_status(
# ========== External Completion Webhook ==========
from fastapi import Header, HTTPException
@router.post(
"/operations/{operation_id}/complete",
status_code=200,
@@ -527,8 +524,8 @@ async def complete_operation(
else:
# If no internal API key is configured, log a warning
logger.warning(
f"Operation complete webhook called without API key validation "
f"(CHAT_INTERNAL_API_KEY not configured)"
"Operation complete webhook called without API key validation "
"(CHAT_INTERNAL_API_KEY not configured)"
)
# Find task by operation_id
@@ -554,7 +551,11 @@ async def complete_operation(
StreamToolOutputAvailable(
toolCallId=task.tool_call_id,
toolName=task.tool_name,
output=result_output if isinstance(result_output, str) else str(result_output),
output=(
result_output
if isinstance(result_output, str)
else str(result_output)
),
success=True,
),
)

View File

@@ -26,6 +26,7 @@ from backend.util.exceptions import NotFoundError
from backend.util.settings import Settings
from . import db as chat_db
from . import stream_registry
from .config import ChatConfig
from .model import (
ChatMessage,
@@ -59,7 +60,6 @@ from .tools.models import (
OperationStartedResponse,
)
from .tracking import track_user_message
from . import stream_registry
logger = logging.getLogger(__name__)
@@ -2134,12 +2134,8 @@ async def _generate_llm_continuation_with_streaming(
text_block_id = str(uuid_module.uuid4())
# Publish start event
await stream_registry.publish_chunk(
task_id, StreamStart(messageId=message_id)
)
await stream_registry.publish_chunk(
task_id, StreamTextStart(id=text_block_id)
)
await stream_registry.publish_chunk(task_id, StreamStart(messageId=message_id))
await stream_registry.publish_chunk(task_id, StreamTextStart(id=text_block_id))
# Stream the response
stream = await client.chat.completions.create(
@@ -2161,9 +2157,7 @@ async def _generate_llm_continuation_with_streaming(
)
# Publish end events
await stream_registry.publish_chunk(
task_id, StreamTextEnd(id=text_block_id)
)
await stream_registry.publish_chunk(task_id, StreamTextEnd(id=text_block_id))
if assistant_content:
# Reload session from DB to avoid race condition with user messages

View File

@@ -12,7 +12,7 @@ import asyncio
import logging
from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import Literal
from typing import Any, Literal
import orjson
@@ -103,7 +103,7 @@ async def create_task(
meta_key = _get_task_meta_key(task_id)
op_key = _get_operation_mapping_key(operation_id)
await redis.hset(
await redis.hset( # type: ignore[misc]
meta_key,
mapping={
"task_id": task_id,
@@ -250,7 +250,7 @@ async def subscribe_to_task(
# Try to load from Redis if not in memory
redis = await get_redis_async()
meta_key = _get_task_meta_key(task_id)
meta = await redis.hgetall(meta_key)
meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc]
if not meta:
logger.warning(f"Task {task_id} not found in memory or Redis")
@@ -318,7 +318,7 @@ async def mark_task_completed(
# Update Redis metadata
redis = await get_redis_async()
meta_key = _get_task_meta_key(task_id)
await redis.hset(meta_key, "status", status)
await redis.hset(meta_key, "status", status) # type: ignore[misc]
logger.info(f"Marked task {task_id} as {status}")
@@ -352,7 +352,7 @@ async def find_task_by_operation_id(operation_id: str) -> ActiveTask | None:
# Load metadata from Redis
meta_key = _get_task_meta_key(task_id_str)
meta = await redis.hgetall(meta_key)
meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc]
if meta:
# Reconstruct task object (not fully active, but has metadata)
@@ -385,7 +385,7 @@ async def get_task(task_id: str) -> ActiveTask | None:
# Try Redis lookup
redis = await get_redis_async()
meta_key = _get_task_meta_key(task_id)
meta = await redis.hgetall(meta_key)
meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc]
if meta:
return ActiveTask(
@@ -395,7 +395,7 @@ async def get_task(task_id: str) -> ActiveTask | None:
tool_call_id=meta.get(b"tool_call_id", b"").decode(),
tool_name=meta.get(b"tool_name", b"").decode(),
operation_id=meta.get(b"operation_id", b"").decode(),
status=meta.get(b"status", b"running").decode(), # type: ignore
status=meta.get(b"status", b"running").decode(), # type: ignore[arg-type]
)
return None