Add client close (#5871)

Fixes #4821 by adding a `close()` method to all clients.

Additionally:
* The m1 CLI is updated to close the client before exiting.
* The playwrightcontroller is updated to suppress some other unrelated
chatty warnings (e.g,, produced by markitdown when encountering
conversions that require external utilities)
This commit is contained in:
afourney
2025-03-07 14:10:06 -08:00
committed by GitHub
parent dd82883a90
commit 8f737de0e1
12 changed files with 37 additions and 7 deletions

View File

@@ -152,6 +152,9 @@ class ChatCompletionClient(ComponentBase[BaseModel], ABC):
cancellation_token: Optional[CancellationToken] = None,
) -> AsyncGenerator[Union[str, CreateResult], None]: ...
@abstractmethod
async def close(self) -> None: ...
@abstractmethod
def actual_usage(self) -> RequestUsage: ...

View File

@@ -126,6 +126,9 @@ async def test_caller_loop() -> None:
) -> AsyncGenerator[Union[str, CreateResult], None]:
raise NotImplementedError()
async def close(self) -> None:
pass
def actual_usage(self) -> RequestUsage:
return RequestUsage(prompt_tokens=0, completion_tokens=0)

View File

@@ -3,10 +3,13 @@ import base64
import io
import os
import random
import warnings
from typing import Any, Callable, Dict, Optional, Tuple, Union, cast
# TODO: Fix unfollowed import
try:
# Suppress warnings from markitdown -- which is pretty chatty
warnings.filterwarnings(action="ignore", module="markitdown")
from markitdown import MarkItDown # type: ignore
except ImportError:
MarkItDown = None

View File

@@ -166,6 +166,9 @@ class ChatCompletionClientRecorder(ChatCompletionClient):
cancellation_token=cancellation_token,
)
async def close(self) -> None:
await self.base_client.close()
def actual_usage(self) -> RequestUsage:
# Calls base_client.actual_usage() and returns the result.
return self.base_client.actual_usage()

View File

@@ -775,6 +775,9 @@ class BaseAnthropicChatCompletionClient(ChatCompletionClient):
yield result
async def close(self) -> None:
await self._client.close()
def count_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = []) -> int:
"""
Estimate the number of tokens used by messages and tools.

View File

@@ -490,6 +490,9 @@ class AzureAIChatCompletionClient(ChatCompletionClient):
yield result
async def close(self) -> None:
await self._client.close()
def actual_usage(self) -> RequestUsage:
return self._actual_usage

View File

@@ -206,6 +206,9 @@ class ChatCompletionCache(ChatCompletionClient, Component[ChatCompletionCacheCon
return _generator()
async def close(self) -> None:
await self.client.close()
def actual_usage(self) -> RequestUsage:
return self.client.actual_usage()

View File

@@ -772,6 +772,9 @@ class BaseOllamaChatCompletionClient(ChatCompletionClient):
yield result
async def close(self) -> None:
pass # ollama has no close method?
def actual_usage(self) -> RequestUsage:
return self._actual_usage

View File

@@ -944,6 +944,9 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
except StopAsyncIteration:
break
async def close(self) -> None:
await self._client.close()
def actual_usage(self) -> RequestUsage:
return self._actual_usage

View File

@@ -3,7 +3,6 @@ from __future__ import annotations
import logging
import warnings
from typing import Any, AsyncGenerator, Dict, List, Mapping, Optional, Sequence, Union
from typing_extensions import Self
from autogen_core import EVENT_LOGGER_NAME, CancellationToken, Component
from autogen_core.models import (
@@ -18,6 +17,7 @@ from autogen_core.models import (
)
from autogen_core.tools import Tool, ToolSchema
from pydantic import BaseModel
from typing_extensions import Self
logger = logging.getLogger(EVENT_LOGGER_NAME)
@@ -229,6 +229,9 @@ class ReplayChatCompletionClient(ChatCompletionClient, Component[ReplayChatCompl
self._current_index += 1
async def close(self) -> None:
pass
def actual_usage(self) -> RequestUsage:
return self._cur_usage

View File

@@ -654,6 +654,9 @@ class SKChatCompletionAdapter(ChatCompletionClient):
thought=thought,
)
async def close(self) -> None:
pass # No explicit close method in SK client?
def actual_usage(self) -> RequestUsage:
return RequestUsage(prompt_tokens=self._total_prompt_tokens, completion_tokens=self._total_completion_tokens)

View File

@@ -2,7 +2,6 @@ import argparse
import asyncio
import os
import sys
import warnings
from typing import Any, Dict, Optional
import yaml
@@ -13,9 +12,6 @@ from autogen_ext.code_executors.docker import DockerCommandLineCodeExecutor
from autogen_ext.teams.magentic_one import MagenticOne
from autogen_ext.ui import RichConsole
# Suppress warnings about the requests.Session() not being closed
warnings.filterwarnings(action="ignore", message="unclosed", category=ResourceWarning)
DEFAULT_CONFIG_FILE = "config.yaml"
DEFAULT_CONFIG_CONTENTS = """# config.yaml
#
@@ -109,10 +105,9 @@ def main() -> None:
with open(args.config if isinstance(args.config, str) else args.config[0], "r") as f:
config = yaml.safe_load(f)
client = ChatCompletionClient.load_component(config["client"])
# Run the task
async def run_task(task: str, hil_mode: bool, use_rich_console: bool) -> None:
client = ChatCompletionClient.load_component(config["client"])
input_manager = UserInputManager(callback=cancellable_input)
async with DockerCommandLineCodeExecutor(work_dir=os.getcwd()) as code_executor:
@@ -128,6 +123,8 @@ def main() -> None:
else:
await Console(m1.run_stream(task=task), output_stats=False, user_input_manager=input_manager)
await client.close()
task = args.task if isinstance(args.task, str) else args.task[0]
asyncio.run(run_task(task, not args.no_hil, args.rich))