Extract OpenAI API retry handler and unify ADA embeddings calls. (#3191)

* Extract retry logic, unify embedding functions

* Add some docstrings

* Remove embedding creation from API manager

* Add test suite for retry handler

* Make api manager fixture

* Fix typing

* Streamline tests
This commit is contained in:
James Collins
2023-04-25 11:12:24 -07:00
committed by GitHub
parent 940b115f0a
commit 2619740daa
9 changed files with 242 additions and 93 deletions

View File

@@ -3,6 +3,8 @@ from pathlib import Path
import pytest
from dotenv import load_dotenv
from autogpt.api_manager import ApiManager
from autogpt.api_manager import api_manager as api_manager_
from autogpt.config import Config
from autogpt.workspace import Workspace
@@ -29,3 +31,11 @@ def config(workspace: Workspace) -> Config:
config.workspace_path = workspace.root
yield config
config.workspace_path = old_ws_path
@pytest.fixture()
def api_manager() -> ApiManager:
old_attrs = api_manager_.__dict__.copy()
api_manager_.reset()
yield api_manager_
api_manager_.__dict__.update(old_attrs)

View File

@@ -86,37 +86,6 @@ class TestApiManager:
assert api_manager.get_total_completion_tokens() == 20
assert api_manager.get_total_cost() == (10 * 0.002 + 20 * 0.002) / 1000
@staticmethod
def test_embedding_create_invalid_model():
"""Test if an invalid model for embedding raises a KeyError."""
text_list = ["Hello, how are you?"]
model = "invalid-model"
with patch("openai.Embedding.create") as mock_create:
mock_response = MagicMock()
mock_response.usage.prompt_tokens = 5
mock_create.side_effect = KeyError("Invalid model")
with pytest.raises(KeyError):
api_manager.embedding_create(text_list, model=model)
@staticmethod
def test_embedding_create_valid_inputs():
"""Test if valid inputs for embedding result in correct tokens and cost."""
text_list = ["Hello, how are you?"]
model = "text-embedding-ada-002"
with patch("openai.Embedding.create") as mock_create:
mock_response = MagicMock()
mock_response.usage.prompt_tokens = 5
mock_response["data"] = [{"embedding": [0.1, 0.2, 0.3]}]
mock_create.return_value = mock_response
api_manager.embedding_create(text_list, model=model)
assert api_manager.get_total_prompt_tokens() == 5
assert api_manager.get_total_completion_tokens() == 0
assert api_manager.get_total_cost() == (5 * 0.0004) / 1000
def test_getter_methods(self):
"""Test the getter methods for total tokens, cost, and budget."""
api_manager.update_cost(60, 120, "gpt-3.5-turbo")

129
tests/test_llm_utils.py Normal file
View File

@@ -0,0 +1,129 @@
import pytest
from openai.error import APIError, RateLimitError
from autogpt.llm_utils import get_ada_embedding, retry_openai_api
from autogpt.modelsinfo import COSTS
@pytest.fixture(params=[RateLimitError, APIError])
def error(request):
if request.param == APIError:
return request.param("Error", http_status=502)
else:
return request.param("Error")
@pytest.fixture
def mock_create_embedding(mocker):
mock_response = mocker.MagicMock()
mock_response.usage.prompt_tokens = 5
mock_response.__getitem__.side_effect = lambda key: [{"embedding": [0.1, 0.2, 0.3]}]
return mocker.patch(
"autogpt.llm_utils.create_embedding", return_value=mock_response
)
def error_factory(error_instance, error_count, retry_count, warn_user=True):
class RaisesError:
def __init__(self):
self.count = 0
@retry_openai_api(
num_retries=retry_count, backoff_base=0.001, warn_user=warn_user
)
def __call__(self):
self.count += 1
if self.count <= error_count:
raise error_instance
return self.count
return RaisesError()
def test_retry_open_api_no_error(capsys):
@retry_openai_api()
def f():
return 1
result = f()
assert result == 1
output = capsys.readouterr()
assert output.out == ""
assert output.err == ""
@pytest.mark.parametrize(
"error_count, retry_count, failure",
[(2, 10, False), (2, 2, False), (10, 2, True), (3, 2, True), (1, 0, True)],
ids=["passing", "passing_edge", "failing", "failing_edge", "failing_no_retries"],
)
def test_retry_open_api_passing(capsys, error, error_count, retry_count, failure):
call_count = min(error_count, retry_count) + 1
raises = error_factory(error, error_count, retry_count)
if failure:
with pytest.raises(type(error)):
raises()
else:
result = raises()
assert result == call_count
assert raises.count == call_count
output = capsys.readouterr()
if error_count and retry_count:
if type(error) == RateLimitError:
assert "Reached rate limit, passing..." in output.out
assert "Please double check" in output.out
if type(error) == APIError:
assert "API Bad gateway" in output.out
else:
assert output.out == ""
def test_retry_open_api_rate_limit_no_warn(capsys):
error_count = 2
retry_count = 10
raises = error_factory(RateLimitError, error_count, retry_count, warn_user=False)
result = raises()
call_count = min(error_count, retry_count) + 1
assert result == call_count
assert raises.count == call_count
output = capsys.readouterr()
assert "Reached rate limit, passing..." in output.out
assert "Please double check" not in output.out
def test_retry_openapi_other_api_error(capsys):
error_count = 2
retry_count = 10
raises = error_factory(APIError("Error", http_status=500), error_count, retry_count)
with pytest.raises(APIError):
raises()
call_count = 1
assert raises.count == call_count
output = capsys.readouterr()
assert output.out == ""
def test_get_ada_embedding(mock_create_embedding, api_manager):
model = "text-embedding-ada-002"
embedding = get_ada_embedding("test")
mock_create_embedding.assert_called_once_with(
"test", model="text-embedding-ada-002"
)
assert embedding == [0.1, 0.2, 0.3]
cost = COSTS[model]["prompt"]
assert api_manager.get_total_prompt_tokens() == 5
assert api_manager.get_total_completion_tokens() == 0
assert api_manager.get_total_cost() == (5 * cost) / 1000

View File

@@ -21,7 +21,7 @@ def LocalCache():
@pytest.fixture
def mock_embed_with_ada(mocker):
mocker.patch(
"autogpt.memory.local.create_embedding_with_ada",
"autogpt.memory.local.get_ada_embedding",
return_value=[0.1] * EMBED_DIM,
)