diff --git a/openhands/llm/llm.py b/openhands/llm/llm.py index 150fa54925..c786cfd6a3 100644 --- a/openhands/llm/llm.py +++ b/openhands/llm/llm.py @@ -21,6 +21,7 @@ from litellm import completion as litellm_completion from litellm import completion_cost as litellm_completion_cost from litellm.exceptions import ( APIConnectionError, + BadGatewayError, RateLimitError, ServiceUnavailableError, ) @@ -45,6 +46,7 @@ LLM_RETRY_EXCEPTIONS: tuple[type[Exception], ...] = ( APIConnectionError, RateLimitError, ServiceUnavailableError, + BadGatewayError, litellm.Timeout, litellm.InternalServerError, LLMNoResponseError, diff --git a/tests/unit/llm/test_api_connection_error_retry.py b/tests/unit/llm/test_api_connection_error_retry.py index 8bcf15f986..b88c170079 100644 --- a/tests/unit/llm/test_api_connection_error_retry.py +++ b/tests/unit/llm/test_api_connection_error_retry.py @@ -1,7 +1,7 @@ from unittest.mock import patch import pytest -from litellm.exceptions import APIConnectionError +from litellm.exceptions import APIConnectionError, BadGatewayError from openhands.core.config import LLMConfig from openhands.llm.llm import LLM @@ -86,3 +86,25 @@ def test_completion_max_retries_api_connection_error( # The exception doesn't contain retry information in the current implementation # Just verify that we got an APIConnectionError assert 'API connection error' in str(excinfo.value) + + +@patch('openhands.llm.llm.litellm_completion') +def test_completion_retries_bad_gateway_error(mock_litellm_completion, default_config): + """Test that BadGatewayError is properly retried.""" + mock_litellm_completion.side_effect = [ + BadGatewayError( + message='Bad gateway', + llm_provider='test_provider', + model='test_model', + ), + {'choices': [{'message': {'content': 'Retry successful'}}]}, + ] + + llm = LLM(config=default_config, service_id='test-service') + response = llm.completion( + messages=[{'role': 'user', 'content': 'Hello!'}], + stream=False, + ) + + assert response['choices'][0]['message']['content'] == 'Retry successful' + assert mock_litellm_completion.call_count == 2