mirror of
https://github.com/microsoft/autogen.git
synced 2026-04-20 03:02:16 -04:00
Add token usage termination (#4035)
* Add token usage termination * fix test
This commit is contained in:
@@ -1,7 +1,8 @@
|
||||
from ._terminations import MaxMessageTermination, StopMessageTermination, TextMentionTermination
|
||||
from ._terminations import MaxMessageTermination, StopMessageTermination, TextMentionTermination, TokenUsageTermination
|
||||
|
||||
__all__ = [
|
||||
"MaxMessageTermination",
|
||||
"TextMentionTermination",
|
||||
"StopMessageTermination",
|
||||
"TokenUsageTermination",
|
||||
]
|
||||
|
||||
@@ -88,3 +88,59 @@ class TextMentionTermination(TerminationCondition):
|
||||
|
||||
async def reset(self) -> None:
|
||||
self._terminated = False
|
||||
|
||||
|
||||
class TokenUsageTermination(TerminationCondition):
|
||||
"""Terminate the conversation if a token usage limit is reached.
|
||||
|
||||
Args:
|
||||
max_total_token: The maximum total number of tokens allowed in the conversation.
|
||||
max_prompt_token: The maximum number of prompt tokens allowed in the conversation.
|
||||
max_completion_token: The maximum number of completion tokens allowed in the conversation.
|
||||
|
||||
Raises:
|
||||
ValueError: If none of max_total_token, max_prompt_token, or max_completion_token is provided.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_total_token: int | None = None,
|
||||
max_prompt_token: int | None = None,
|
||||
max_completion_token: int | None = None,
|
||||
) -> None:
|
||||
if max_total_token is None and max_prompt_token is None and max_completion_token is None:
|
||||
raise ValueError(
|
||||
"At least one of max_total_token, max_prompt_token, or max_completion_token must be provided"
|
||||
)
|
||||
self._max_total_token = max_total_token
|
||||
self._max_prompt_token = max_prompt_token
|
||||
self._max_completion_token = max_completion_token
|
||||
self._total_token_count = 0
|
||||
self._prompt_token_count = 0
|
||||
self._completion_token_count = 0
|
||||
|
||||
@property
|
||||
def terminated(self) -> bool:
|
||||
return (
|
||||
(self._max_total_token is not None and self._total_token_count >= self._max_total_token)
|
||||
or (self._max_prompt_token is not None and self._prompt_token_count >= self._max_prompt_token)
|
||||
or (self._max_completion_token is not None and self._completion_token_count >= self._max_completion_token)
|
||||
)
|
||||
|
||||
async def __call__(self, messages: Sequence[ChatMessage]) -> StopMessage | None:
|
||||
if self.terminated:
|
||||
raise TerminatedException("Termination condition has already been reached")
|
||||
for message in messages:
|
||||
if message.model_usage is not None:
|
||||
self._prompt_token_count += message.model_usage.prompt_tokens
|
||||
self._completion_token_count += message.model_usage.completion_tokens
|
||||
self._total_token_count += message.model_usage.prompt_tokens + message.model_usage.completion_tokens
|
||||
if self.terminated:
|
||||
content = f"Token usage limit reached, total token count: {self._total_token_count}, prompt token count: {self._prompt_token_count}, completion token count: {self._completion_token_count}."
|
||||
return StopMessage(content=content, source="TokenUsageTermination")
|
||||
return None
|
||||
|
||||
async def reset(self) -> None:
|
||||
self._total_token_count = 0
|
||||
self._prompt_token_count = 0
|
||||
self._completion_token_count = 0
|
||||
|
||||
@@ -1,6 +1,12 @@
|
||||
import pytest
|
||||
from autogen_agentchat.messages import StopMessage, TextMessage
|
||||
from autogen_agentchat.task import MaxMessageTermination, StopMessageTermination, TextMentionTermination
|
||||
from autogen_agentchat.task import (
|
||||
MaxMessageTermination,
|
||||
StopMessageTermination,
|
||||
TextMentionTermination,
|
||||
TokenUsageTermination,
|
||||
)
|
||||
from autogen_core.components.models import RequestUsage
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -51,6 +57,51 @@ async def test_mention_termination() -> None:
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_token_usage_termination() -> None:
|
||||
termination = TokenUsageTermination(max_total_token=10)
|
||||
assert await termination([]) is None
|
||||
await termination.reset()
|
||||
assert (
|
||||
await termination(
|
||||
[
|
||||
TextMessage(
|
||||
content="Hello", source="user", model_usage=RequestUsage(prompt_tokens=10, completion_tokens=10)
|
||||
)
|
||||
]
|
||||
)
|
||||
is not None
|
||||
)
|
||||
await termination.reset()
|
||||
assert (
|
||||
await termination(
|
||||
[
|
||||
TextMessage(
|
||||
content="Hello", source="user", model_usage=RequestUsage(prompt_tokens=1, completion_tokens=1)
|
||||
),
|
||||
TextMessage(
|
||||
content="World", source="agent", model_usage=RequestUsage(prompt_tokens=1, completion_tokens=1)
|
||||
),
|
||||
]
|
||||
)
|
||||
is None
|
||||
)
|
||||
await termination.reset()
|
||||
assert (
|
||||
await termination(
|
||||
[
|
||||
TextMessage(
|
||||
content="Hello", source="user", model_usage=RequestUsage(prompt_tokens=5, completion_tokens=0)
|
||||
),
|
||||
TextMessage(
|
||||
content="stop", source="user", model_usage=RequestUsage(prompt_tokens=0, completion_tokens=5)
|
||||
),
|
||||
]
|
||||
)
|
||||
is not None
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_and_termination() -> None:
|
||||
termination = MaxMessageTermination(2) & TextMentionTermination("stop")
|
||||
|
||||
Reference in New Issue
Block a user