mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-01-08 22:38:05 -05:00
Add top_k (#8480)
This commit is contained in:
@@ -28,6 +28,7 @@ class LLMConfig(BaseModel):
|
||||
max_message_chars: The approximate max number of characters in the content of an event included in the prompt to the LLM. Larger observations are truncated.
|
||||
temperature: The temperature for the API.
|
||||
top_p: The top p for the API.
|
||||
top_k: The top k for the API.
|
||||
custom_llm_provider: The custom LLM provider to use. This is undocumented in openhands, and normally not used. It is documented on the litellm side.
|
||||
max_input_tokens: The maximum number of input tokens. Note that this is currently unused, and the value at runtime is actually the total tokens in OpenAI (e.g. 128,000 tokens for GPT-4).
|
||||
max_output_tokens: The maximum number of output tokens. This is sent to the LLM.
|
||||
@@ -66,6 +67,7 @@ class LLMConfig(BaseModel):
|
||||
) # maximum number of characters in an observation's content when sent to the llm
|
||||
temperature: float = Field(default=0.0)
|
||||
top_p: float = Field(default=1.0)
|
||||
top_k: float | None = Field(default=None)
|
||||
custom_llm_provider: str | None = Field(default=None)
|
||||
max_input_tokens: int | None = Field(default=None)
|
||||
max_output_tokens: int | None = Field(default=None)
|
||||
|
||||
@@ -159,6 +159,11 @@ class LLM(RetryMixin, DebugMixin):
|
||||
'temperature': self.config.temperature,
|
||||
'max_completion_tokens': self.config.max_output_tokens,
|
||||
}
|
||||
if self.config.top_k is not None:
|
||||
# openai doesn't expose top_k
|
||||
# litellm will handle it a bit differently than the openai-compatible params
|
||||
kwargs['top_k'] = self.config.top_k
|
||||
|
||||
if (
|
||||
self.config.model.lower() in REASONING_EFFORT_SUPPORTED_MODELS
|
||||
or self.config.model.split('/')[-1] in REASONING_EFFORT_SUPPORTED_MODELS
|
||||
|
||||
@@ -152,6 +152,7 @@ def test_llm_init_with_custom_config():
|
||||
max_output_tokens=1500,
|
||||
temperature=0.8,
|
||||
top_p=0.9,
|
||||
top_k=None,
|
||||
)
|
||||
llm = LLM(custom_config)
|
||||
assert llm.config.model == 'custom-model'
|
||||
@@ -160,6 +161,42 @@ def test_llm_init_with_custom_config():
|
||||
assert llm.config.max_output_tokens == 1500
|
||||
assert llm.config.temperature == 0.8
|
||||
assert llm.config.top_p == 0.9
|
||||
assert llm.config.top_k is None
|
||||
|
||||
|
||||
@patch('openhands.llm.llm.litellm_completion')
|
||||
def test_llm_top_k_in_completion_when_set(mock_litellm_completion):
|
||||
# Create a config with top_k set
|
||||
config_with_top_k = LLMConfig(top_k=50)
|
||||
llm = LLM(config_with_top_k)
|
||||
|
||||
# Define a side effect function to check top_k
|
||||
def side_effect(*args, **kwargs):
|
||||
assert 'top_k' in kwargs
|
||||
assert kwargs['top_k'] == 50
|
||||
return {'choices': [{'message': {'content': 'Mocked response'}}]}
|
||||
|
||||
mock_litellm_completion.side_effect = side_effect
|
||||
|
||||
# Call completion
|
||||
llm.completion(messages=[{'role': 'system', 'content': 'Test message'}])
|
||||
|
||||
|
||||
@patch('openhands.llm.llm.litellm_completion')
|
||||
def test_llm_top_k_not_in_completion_when_none(mock_litellm_completion):
|
||||
# Create a config with top_k set to None
|
||||
config_without_top_k = LLMConfig(top_k=None)
|
||||
llm = LLM(config_without_top_k)
|
||||
|
||||
# Define a side effect function to check top_k
|
||||
def side_effect(*args, **kwargs):
|
||||
assert 'top_k' not in kwargs
|
||||
return {'choices': [{'message': {'content': 'Mocked response'}}]}
|
||||
|
||||
mock_litellm_completion.side_effect = side_effect
|
||||
|
||||
# Call completion
|
||||
llm.completion(messages=[{'role': 'system', 'content': 'Test message'}])
|
||||
|
||||
|
||||
def test_llm_init_with_metrics():
|
||||
|
||||
Reference in New Issue
Block a user