mirror of
https://github.com/microsoft/autogen.git
synced 2026-04-20 03:02:16 -04:00
Add client close (#5871)
Fixes #4821 by adding a `close()` method to all clients. Additionally: * The m1 CLI is updated to close the client before exiting. * The playwrightcontroller is updated to suppress some other unrelated chatty warnings (e.g,, produced by markitdown when encountering conversions that require external utilities)
This commit is contained in:
@@ -152,6 +152,9 @@ class ChatCompletionClient(ComponentBase[BaseModel], ABC):
|
||||
cancellation_token: Optional[CancellationToken] = None,
|
||||
) -> AsyncGenerator[Union[str, CreateResult], None]: ...
|
||||
|
||||
@abstractmethod
|
||||
async def close(self) -> None: ...
|
||||
|
||||
@abstractmethod
|
||||
def actual_usage(self) -> RequestUsage: ...
|
||||
|
||||
|
||||
@@ -126,6 +126,9 @@ async def test_caller_loop() -> None:
|
||||
) -> AsyncGenerator[Union[str, CreateResult], None]:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def close(self) -> None:
|
||||
pass
|
||||
|
||||
def actual_usage(self) -> RequestUsage:
|
||||
return RequestUsage(prompt_tokens=0, completion_tokens=0)
|
||||
|
||||
|
||||
@@ -3,10 +3,13 @@ import base64
|
||||
import io
|
||||
import os
|
||||
import random
|
||||
import warnings
|
||||
from typing import Any, Callable, Dict, Optional, Tuple, Union, cast
|
||||
|
||||
# TODO: Fix unfollowed import
|
||||
try:
|
||||
# Suppress warnings from markitdown -- which is pretty chatty
|
||||
warnings.filterwarnings(action="ignore", module="markitdown")
|
||||
from markitdown import MarkItDown # type: ignore
|
||||
except ImportError:
|
||||
MarkItDown = None
|
||||
|
||||
@@ -166,6 +166,9 @@ class ChatCompletionClientRecorder(ChatCompletionClient):
|
||||
cancellation_token=cancellation_token,
|
||||
)
|
||||
|
||||
async def close(self) -> None:
|
||||
await self.base_client.close()
|
||||
|
||||
def actual_usage(self) -> RequestUsage:
|
||||
# Calls base_client.actual_usage() and returns the result.
|
||||
return self.base_client.actual_usage()
|
||||
|
||||
@@ -775,6 +775,9 @@ class BaseAnthropicChatCompletionClient(ChatCompletionClient):
|
||||
|
||||
yield result
|
||||
|
||||
async def close(self) -> None:
|
||||
await self._client.close()
|
||||
|
||||
def count_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = []) -> int:
|
||||
"""
|
||||
Estimate the number of tokens used by messages and tools.
|
||||
|
||||
@@ -490,6 +490,9 @@ class AzureAIChatCompletionClient(ChatCompletionClient):
|
||||
|
||||
yield result
|
||||
|
||||
async def close(self) -> None:
|
||||
await self._client.close()
|
||||
|
||||
def actual_usage(self) -> RequestUsage:
|
||||
return self._actual_usage
|
||||
|
||||
|
||||
@@ -206,6 +206,9 @@ class ChatCompletionCache(ChatCompletionClient, Component[ChatCompletionCacheCon
|
||||
|
||||
return _generator()
|
||||
|
||||
async def close(self) -> None:
|
||||
await self.client.close()
|
||||
|
||||
def actual_usage(self) -> RequestUsage:
|
||||
return self.client.actual_usage()
|
||||
|
||||
|
||||
@@ -772,6 +772,9 @@ class BaseOllamaChatCompletionClient(ChatCompletionClient):
|
||||
|
||||
yield result
|
||||
|
||||
async def close(self) -> None:
|
||||
pass # ollama has no close method?
|
||||
|
||||
def actual_usage(self) -> RequestUsage:
|
||||
return self._actual_usage
|
||||
|
||||
|
||||
@@ -944,6 +944,9 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
|
||||
async def close(self) -> None:
|
||||
await self._client.close()
|
||||
|
||||
def actual_usage(self) -> RequestUsage:
|
||||
return self._actual_usage
|
||||
|
||||
|
||||
@@ -3,7 +3,6 @@ from __future__ import annotations
|
||||
import logging
|
||||
import warnings
|
||||
from typing import Any, AsyncGenerator, Dict, List, Mapping, Optional, Sequence, Union
|
||||
from typing_extensions import Self
|
||||
|
||||
from autogen_core import EVENT_LOGGER_NAME, CancellationToken, Component
|
||||
from autogen_core.models import (
|
||||
@@ -18,6 +17,7 @@ from autogen_core.models import (
|
||||
)
|
||||
from autogen_core.tools import Tool, ToolSchema
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import Self
|
||||
|
||||
logger = logging.getLogger(EVENT_LOGGER_NAME)
|
||||
|
||||
@@ -229,6 +229,9 @@ class ReplayChatCompletionClient(ChatCompletionClient, Component[ReplayChatCompl
|
||||
|
||||
self._current_index += 1
|
||||
|
||||
async def close(self) -> None:
|
||||
pass
|
||||
|
||||
def actual_usage(self) -> RequestUsage:
|
||||
return self._cur_usage
|
||||
|
||||
|
||||
@@ -654,6 +654,9 @@ class SKChatCompletionAdapter(ChatCompletionClient):
|
||||
thought=thought,
|
||||
)
|
||||
|
||||
async def close(self) -> None:
|
||||
pass # No explicit close method in SK client?
|
||||
|
||||
def actual_usage(self) -> RequestUsage:
|
||||
return RequestUsage(prompt_tokens=self._total_prompt_tokens, completion_tokens=self._total_completion_tokens)
|
||||
|
||||
|
||||
@@ -2,7 +2,6 @@ import argparse
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
import warnings
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import yaml
|
||||
@@ -13,9 +12,6 @@ from autogen_ext.code_executors.docker import DockerCommandLineCodeExecutor
|
||||
from autogen_ext.teams.magentic_one import MagenticOne
|
||||
from autogen_ext.ui import RichConsole
|
||||
|
||||
# Suppress warnings about the requests.Session() not being closed
|
||||
warnings.filterwarnings(action="ignore", message="unclosed", category=ResourceWarning)
|
||||
|
||||
DEFAULT_CONFIG_FILE = "config.yaml"
|
||||
DEFAULT_CONFIG_CONTENTS = """# config.yaml
|
||||
#
|
||||
@@ -109,10 +105,9 @@ def main() -> None:
|
||||
with open(args.config if isinstance(args.config, str) else args.config[0], "r") as f:
|
||||
config = yaml.safe_load(f)
|
||||
|
||||
client = ChatCompletionClient.load_component(config["client"])
|
||||
|
||||
# Run the task
|
||||
async def run_task(task: str, hil_mode: bool, use_rich_console: bool) -> None:
|
||||
client = ChatCompletionClient.load_component(config["client"])
|
||||
input_manager = UserInputManager(callback=cancellable_input)
|
||||
|
||||
async with DockerCommandLineCodeExecutor(work_dir=os.getcwd()) as code_executor:
|
||||
@@ -128,6 +123,8 @@ def main() -> None:
|
||||
else:
|
||||
await Console(m1.run_stream(task=task), output_stats=False, user_input_manager=input_manager)
|
||||
|
||||
await client.close()
|
||||
|
||||
task = args.task if isinstance(args.task, str) else args.task[0]
|
||||
asyncio.run(run_task(task, not args.no_hil, args.rich))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user