diff --git a/openhands/llm/llm.py b/openhands/llm/llm.py index fab3de1e40..0569acbf52 100644 --- a/openhands/llm/llm.py +++ b/openhands/llm/llm.py @@ -502,11 +502,13 @@ class LLM(RetryMixin, DebugMixin): # Set max_output_tokens from model info if not explicitly set if self.config.max_output_tokens is None: - # Special case for Claude 3.7 Sonnet models - if any( - model in self.config.model - for model in ['claude-3-7-sonnet', 'claude-3.7-sonnet'] - ): + # Special case for Claude Sonnet models + sonnet_models = [ + 'claude-3-7-sonnet', + 'claude-3.7-sonnet', + 'claude-sonnet-4', + ] + if any(model in self.config.model for model in sonnet_models): self.config.max_output_tokens = 64000 # litellm set max to 128k, but that requires a header to be set # Try to get from model info elif self.model_info is not None: diff --git a/tests/unit/llm/test_llm.py b/tests/unit/llm/test_llm.py index 4a7de5a0f7..e95baedbc0 100644 --- a/tests/unit/llm/test_llm.py +++ b/tests/unit/llm/test_llm.py @@ -1053,17 +1053,22 @@ def test_claude_3_7_sonnet_max_output_tokens(): assert llm.config.max_input_tokens is None -def test_claude_sonnet_4_max_output_tokens(): +@patch('openhands.llm.llm.litellm.get_model_info') +def test_claude_sonnet_4_max_output_tokens(mock_get_model_info): """Test that Claude Sonnet 4 models get the correct max_output_tokens and max_input_tokens values.""" + mock_get_model_info.return_value = { + 'max_input_tokens': 100000, + 'max_output_tokens': 100000, + } # Create LLM instance with a Claude Sonnet 4 model config = LLMConfig(model='claude-sonnet-4-20250514', api_key='test_key') llm = LLM(config, service_id='test-service') + llm.init_model_info() - # Verify max_output_tokens is set to the expected value - assert llm.config.max_output_tokens == 64000 - # Verify max_input_tokens is set to the expected value - # For Claude models, we expect a specific value from litellm - assert llm.config.max_input_tokens == 200000 + assert llm.config.max_output_tokens == 64000, 'output max should be decreased' + assert llm.config.max_input_tokens == 100000, ( + 'input max should be the litellm value' + ) def test_sambanova_deepseek_model_max_output_tokens():