mirror of
https://github.com/microsoft/autogen.git
synced 2026-04-20 03:02:16 -04:00
add cost calculation to client (#769)
* add cost calculation * Update autogen/oai/client.py Co-authored-by: Joshua Kim <joshkyh@users.noreply.github.com> * Update autogen/oai/client.py Co-authored-by: Joshua Kim <joshkyh@users.noreply.github.com> * update * add doc --------- Co-authored-by: Joshua Kim <joshkyh@users.noreply.github.com>
This commit is contained in:
@@ -2,12 +2,12 @@ from __future__ import annotations
|
||||
|
||||
import os
|
||||
import sys
|
||||
from typing import List, Optional, Dict, Callable
|
||||
from typing import List, Optional, Dict, Callable, Union
|
||||
import logging
|
||||
import inspect
|
||||
from flaml.automl.logger import logger_formatter
|
||||
|
||||
from autogen.oai.openai_utils import get_key
|
||||
from autogen.oai.openai_utils import get_key, oai_price1k
|
||||
from autogen.token_count_utils import count_token
|
||||
|
||||
try:
|
||||
@@ -240,7 +240,7 @@ class OpenAIWrapper:
|
||||
# Return the response if it passes the filter or it is the last client
|
||||
response.config_id = i
|
||||
response.pass_filter = pass_filter
|
||||
# TODO: add response.cost
|
||||
response.cost = self.cost(response)
|
||||
return response
|
||||
continue # filter is not passed; try the next config
|
||||
try:
|
||||
@@ -261,10 +261,25 @@ class OpenAIWrapper:
|
||||
# Return the response if it passes the filter or it is the last client
|
||||
response.config_id = i
|
||||
response.pass_filter = pass_filter
|
||||
# TODO: add response.cost
|
||||
response.cost = self.cost(response)
|
||||
return response
|
||||
continue # filter is not passed; try the next config
|
||||
|
||||
def cost(self, response: Union[ChatCompletion, Completion]) -> float:
|
||||
"""Calculate the cost of the response."""
|
||||
model = response.model
|
||||
if model not in oai_price1k:
|
||||
# TODO: add logging to warn that the model is not found
|
||||
return 0
|
||||
|
||||
n_input_tokens = response.usage.prompt_tokens
|
||||
n_output_tokens = response.usage.completion_tokens
|
||||
tmp_price1K = oai_price1k[model]
|
||||
# First value is input token rate, second value is output token rate
|
||||
if isinstance(tmp_price1K, tuple):
|
||||
return (tmp_price1K[0] * n_input_tokens + tmp_price1K[1] * n_output_tokens) / 1000
|
||||
return tmp_price1K * (n_input_tokens + n_output_tokens) / 1000
|
||||
|
||||
def _completions_create(self, client, params):
|
||||
completions = client.chat.completions if "messages" in params else client.completions
|
||||
# If streaming is enabled, has messages, and does not have functions, then
|
||||
|
||||
@@ -9,6 +9,36 @@ from dotenv import find_dotenv, load_dotenv
|
||||
|
||||
NON_CACHE_KEY = ["api_key", "base_url", "api_type", "api_version"]
|
||||
|
||||
oai_price1k = {
|
||||
"text-ada-001": 0.0004,
|
||||
"text-babbage-001": 0.0005,
|
||||
"text-curie-001": 0.002,
|
||||
"code-cushman-001": 0.024,
|
||||
"code-davinci-002": 0.1,
|
||||
"text-davinci-002": 0.02,
|
||||
"text-davinci-003": 0.02,
|
||||
"gpt-3.5-turbo-instruct": (0.0015, 0.002),
|
||||
"gpt-3.5-turbo-0301": (0.0015, 0.002), # deprecate in Sep
|
||||
"gpt-3.5-turbo-0613": (0.0015, 0.002),
|
||||
"gpt-3.5-turbo-16k": (0.003, 0.004),
|
||||
"gpt-3.5-turbo-16k-0613": (0.003, 0.004),
|
||||
"gpt-35-turbo": (0.0015, 0.002),
|
||||
"gpt-35-turbo-16k": (0.003, 0.004),
|
||||
"gpt-35-turbo-instruct": (0.0015, 0.002),
|
||||
"gpt-4": (0.03, 0.06),
|
||||
"gpt-4-32k": (0.06, 0.12),
|
||||
"gpt-4-0314": (0.03, 0.06), # deprecate in Sep
|
||||
"gpt-4-32k-0314": (0.06, 0.12), # deprecate in Sep
|
||||
"gpt-4-0613": (0.03, 0.06),
|
||||
"gpt-4-32k-0613": (0.06, 0.12),
|
||||
# 11-06
|
||||
"gpt-3.5-turbo": (0.001, 0.002),
|
||||
"gpt-3.5-turbo-1106": (0.001, 0.002),
|
||||
"gpt-35-turbo-1106": (0.001, 0.002),
|
||||
"gpt-4-1106-preview": (0.01, 0.03),
|
||||
"gpt-4-1106-vision-preview": (0.01, 0.03), # TODO: support vision pricing of images
|
||||
}
|
||||
|
||||
|
||||
def get_key(config):
|
||||
"""Get a unique identifier of a configuration.
|
||||
|
||||
@@ -48,7 +48,24 @@ def test_completion():
|
||||
print(client.extract_text_or_function_call(response))
|
||||
|
||||
|
||||
@pytest.mark.skipif(skip, reason="openai>=1 not installed")
|
||||
@pytest.mark.parametrize(
|
||||
"cache_seed, model",
|
||||
[
|
||||
(None, "gpt-3.5-turbo-instruct"),
|
||||
(42, "gpt-3.5-turbo-instruct"),
|
||||
(None, "text-ada-001"),
|
||||
],
|
||||
)
|
||||
def test_cost(cache_seed, model):
|
||||
config_list = config_list_openai_aoai(KEY_LOC)
|
||||
client = OpenAIWrapper(config_list=config_list, cache_seed=cache_seed)
|
||||
response = client.create(prompt="1+3=", model=model)
|
||||
print(response.cost)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_aoai_chat_completion()
|
||||
test_chat_completion()
|
||||
test_completion()
|
||||
test_cost()
|
||||
|
||||
@@ -122,6 +122,8 @@ client = OpenAIWrapper()
|
||||
response = client.create(messages=[{"role": "user", "content": "2+2="}], model="gpt-3.5-turbo")
|
||||
# extract the response text
|
||||
print(client.extract_text_or_function_call(response))
|
||||
# get cost of this completion
|
||||
print(response.cost)
|
||||
# Azure OpenAI endpoint
|
||||
client = OpenAIWrapper(api_key=..., base_url=..., api_version=..., api_type="azure")
|
||||
# Completion
|
||||
|
||||
Reference in New Issue
Block a user