mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
fix(backend): Handle ChatCompletionMessage objects in token counting functions
The actual issue was in the _msg_tokens function in prompt.py which expected dict objects but was receiving ChatCompletionMessage objects. This fix: - Updates _msg_tokens to convert non-dict messages to dicts using json.to_dict - Updates compress_prompt to handle mixed message types - Updates estimate_token_count to accept Any types in addition to dicts This ensures that all message objects are properly converted to dictionaries before accessing their properties, preventing AttributeError exceptions. Co-authored-by: Nicholas Tindle <ntindle@users.noreply.github.com>
This commit is contained in:
@@ -15,11 +15,15 @@ def _tok_len(text: str, enc) -> int:
|
||||
return len(enc.encode(str(text)))
|
||||
|
||||
|
||||
def _msg_tokens(msg: dict, enc) -> int:
|
||||
def _msg_tokens(msg: dict | Any, enc) -> int:
|
||||
"""
|
||||
OpenAI counts ≈3 wrapper tokens per chat message, plus 1 if "name"
|
||||
is present, plus the tokenised content length.
|
||||
"""
|
||||
# Handle ChatCompletionMessage objects by converting to dict
|
||||
if not isinstance(msg, dict):
|
||||
msg = json.to_dict(msg)
|
||||
|
||||
WRAPPER = 3 + (1 if "name" in msg else 0)
|
||||
return WRAPPER + _tok_len(msg.get("content") or "", enc)
|
||||
|
||||
@@ -46,7 +50,7 @@ def _truncate_middle_tokens(text: str, enc, max_tok: int) -> str:
|
||||
|
||||
|
||||
def compress_prompt(
|
||||
messages: list[dict],
|
||||
messages: list[dict | Any],
|
||||
target_tokens: int,
|
||||
*,
|
||||
model: str = "gpt-4o",
|
||||
@@ -94,7 +98,11 @@ def compress_prompt(
|
||||
list[dict] – A *new* messages list that abides by the rules above.
|
||||
"""
|
||||
enc = encoding_for_model(model) # best-match tokenizer
|
||||
msgs = deepcopy(messages) # never mutate caller
|
||||
# Convert any ChatCompletionMessage objects to dicts first
|
||||
messages_as_dicts = [
|
||||
json.to_dict(m) if not isinstance(m, dict) else m for m in messages
|
||||
]
|
||||
msgs = deepcopy(messages_as_dicts) # never mutate caller
|
||||
|
||||
def total_tokens() -> int:
|
||||
"""Current size of *msgs* in tokens."""
|
||||
@@ -162,7 +170,7 @@ def compress_prompt(
|
||||
|
||||
|
||||
def estimate_token_count(
|
||||
messages: list[dict],
|
||||
messages: list[dict | Any],
|
||||
*,
|
||||
model: str = "gpt-4o",
|
||||
) -> int:
|
||||
|
||||
Reference in New Issue
Block a user