Add token usage termination (#4035)

* Add token usage termination

* fix test
This commit is contained in:
Eric Zhu
2024-11-01 15:01:43 -07:00
committed by GitHub
parent ca7caa779d
commit 27ea99a485
3 changed files with 110 additions and 2 deletions

View File

@@ -1,7 +1,8 @@
from ._terminations import MaxMessageTermination, StopMessageTermination, TextMentionTermination
from ._terminations import MaxMessageTermination, StopMessageTermination, TextMentionTermination, TokenUsageTermination
__all__ = [
"MaxMessageTermination",
"TextMentionTermination",
"StopMessageTermination",
"TokenUsageTermination",
]

View File

@@ -88,3 +88,59 @@ class TextMentionTermination(TerminationCondition):
async def reset(self) -> None:
self._terminated = False
class TokenUsageTermination(TerminationCondition):
"""Terminate the conversation if a token usage limit is reached.
Args:
max_total_token: The maximum total number of tokens allowed in the conversation.
max_prompt_token: The maximum number of prompt tokens allowed in the conversation.
max_completion_token: The maximum number of completion tokens allowed in the conversation.
Raises:
ValueError: If none of max_total_token, max_prompt_token, or max_completion_token is provided.
"""
def __init__(
self,
max_total_token: int | None = None,
max_prompt_token: int | None = None,
max_completion_token: int | None = None,
) -> None:
if max_total_token is None and max_prompt_token is None and max_completion_token is None:
raise ValueError(
"At least one of max_total_token, max_prompt_token, or max_completion_token must be provided"
)
self._max_total_token = max_total_token
self._max_prompt_token = max_prompt_token
self._max_completion_token = max_completion_token
self._total_token_count = 0
self._prompt_token_count = 0
self._completion_token_count = 0
@property
def terminated(self) -> bool:
return (
(self._max_total_token is not None and self._total_token_count >= self._max_total_token)
or (self._max_prompt_token is not None and self._prompt_token_count >= self._max_prompt_token)
or (self._max_completion_token is not None and self._completion_token_count >= self._max_completion_token)
)
async def __call__(self, messages: Sequence[ChatMessage]) -> StopMessage | None:
if self.terminated:
raise TerminatedException("Termination condition has already been reached")
for message in messages:
if message.model_usage is not None:
self._prompt_token_count += message.model_usage.prompt_tokens
self._completion_token_count += message.model_usage.completion_tokens
self._total_token_count += message.model_usage.prompt_tokens + message.model_usage.completion_tokens
if self.terminated:
content = f"Token usage limit reached, total token count: {self._total_token_count}, prompt token count: {self._prompt_token_count}, completion token count: {self._completion_token_count}."
return StopMessage(content=content, source="TokenUsageTermination")
return None
async def reset(self) -> None:
self._total_token_count = 0
self._prompt_token_count = 0
self._completion_token_count = 0

View File

@@ -1,6 +1,12 @@
import pytest
from autogen_agentchat.messages import StopMessage, TextMessage
from autogen_agentchat.task import MaxMessageTermination, StopMessageTermination, TextMentionTermination
from autogen_agentchat.task import (
MaxMessageTermination,
StopMessageTermination,
TextMentionTermination,
TokenUsageTermination,
)
from autogen_core.components.models import RequestUsage
@pytest.mark.asyncio
@@ -51,6 +57,51 @@ async def test_mention_termination() -> None:
)
@pytest.mark.asyncio
async def test_token_usage_termination() -> None:
termination = TokenUsageTermination(max_total_token=10)
assert await termination([]) is None
await termination.reset()
assert (
await termination(
[
TextMessage(
content="Hello", source="user", model_usage=RequestUsage(prompt_tokens=10, completion_tokens=10)
)
]
)
is not None
)
await termination.reset()
assert (
await termination(
[
TextMessage(
content="Hello", source="user", model_usage=RequestUsage(prompt_tokens=1, completion_tokens=1)
),
TextMessage(
content="World", source="agent", model_usage=RequestUsage(prompt_tokens=1, completion_tokens=1)
),
]
)
is None
)
await termination.reset()
assert (
await termination(
[
TextMessage(
content="Hello", source="user", model_usage=RequestUsage(prompt_tokens=5, completion_tokens=0)
),
TextMessage(
content="stop", source="user", model_usage=RequestUsage(prompt_tokens=0, completion_tokens=5)
),
]
)
is not None
)
@pytest.mark.asyncio
async def test_and_termination() -> None:
termination = MaxMessageTermination(2) & TextMentionTermination("stop")