mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-01-10 07:18:10 -05:00
[Refactor]: Add LLMRegistry for llm services (#9589)
Co-authored-by: openhands <openhands@all-hands.dev> Co-authored-by: Graham Neubig <neubig@gmail.com> Co-authored-by: Engel Nyst <enyst@users.noreply.github.com>
This commit is contained in:
0
tests/__init__.py
Normal file
0
tests/__init__.py
Normal file
@@ -10,6 +10,7 @@ from pytest import TempPathFactory
|
||||
from openhands.core.config import MCPConfig, OpenHandsConfig, load_openhands_config
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.events import EventStream
|
||||
from openhands.llm.llm_registry import LLMRegistry
|
||||
from openhands.runtime.base import Runtime
|
||||
from openhands.runtime.impl.cli.cli_runtime import CLIRuntime
|
||||
from openhands.runtime.impl.docker.docker_runtime import DockerRuntime
|
||||
@@ -268,9 +269,13 @@ def _load_runtime(
|
||||
)
|
||||
event_stream = EventStream(sid, file_store)
|
||||
|
||||
# Create a LLMRegistry instance for the runtime
|
||||
llm_registry = LLMRegistry(config=OpenHandsConfig())
|
||||
|
||||
runtime = runtime_cls(
|
||||
config=config,
|
||||
event_stream=event_stream,
|
||||
llm_registry=llm_registry,
|
||||
sid=sid,
|
||||
plugins=plugins,
|
||||
)
|
||||
|
||||
0
tests/unit/__init__.py
Normal file
0
tests/unit/__init__.py
Normal file
@@ -20,7 +20,7 @@ def test_llm():
|
||||
|
||||
def _get_llm(type_: type[LLM]):
|
||||
with _patch_http():
|
||||
return type_(config=config.get_llm_config())
|
||||
return type_(config=config.get_llm_config(), service_id='test_service')
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
@@ -38,7 +38,7 @@ def default_config():
|
||||
|
||||
|
||||
def test_llm_init_with_default_config(default_config):
|
||||
llm = LLM(default_config)
|
||||
llm = LLM(default_config, service_id='test-service')
|
||||
assert llm.config.model == 'gpt-4o'
|
||||
assert llm.config.api_key.get_secret_value() == 'test_key'
|
||||
assert isinstance(llm.metrics, Metrics)
|
||||
@@ -129,7 +129,7 @@ def test_llm_init_with_model_info(mock_get_model_info, default_config):
|
||||
'max_input_tokens': 8000,
|
||||
'max_output_tokens': 2000,
|
||||
}
|
||||
llm = LLM(default_config)
|
||||
llm = LLM(default_config, service_id='test-service')
|
||||
llm.init_model_info()
|
||||
assert llm.config.max_input_tokens == 8000
|
||||
assert llm.config.max_output_tokens == 2000
|
||||
@@ -138,7 +138,7 @@ def test_llm_init_with_model_info(mock_get_model_info, default_config):
|
||||
@patch('openhands.llm.llm.litellm.get_model_info')
|
||||
def test_llm_init_without_model_info(mock_get_model_info, default_config):
|
||||
mock_get_model_info.side_effect = Exception('Model info not available')
|
||||
llm = LLM(default_config)
|
||||
llm = LLM(default_config, service_id='test-service')
|
||||
llm.init_model_info()
|
||||
assert llm.config.max_input_tokens is None
|
||||
assert llm.config.max_output_tokens is None
|
||||
@@ -154,7 +154,7 @@ def test_llm_init_with_custom_config():
|
||||
top_p=0.9,
|
||||
top_k=None,
|
||||
)
|
||||
llm = LLM(custom_config)
|
||||
llm = LLM(custom_config, service_id='test-service')
|
||||
assert llm.config.model == 'custom-model'
|
||||
assert llm.config.api_key.get_secret_value() == 'custom_key'
|
||||
assert llm.config.max_input_tokens == 5000
|
||||
@@ -168,7 +168,7 @@ def test_llm_init_with_custom_config():
|
||||
def test_llm_top_k_in_completion_when_set(mock_litellm_completion):
|
||||
# Create a config with top_k set
|
||||
config_with_top_k = LLMConfig(top_k=50)
|
||||
llm = LLM(config_with_top_k)
|
||||
llm = LLM(config_with_top_k, service_id='test-service')
|
||||
|
||||
# Define a side effect function to check top_k
|
||||
def side_effect(*args, **kwargs):
|
||||
@@ -186,7 +186,7 @@ def test_llm_top_k_in_completion_when_set(mock_litellm_completion):
|
||||
def test_llm_top_k_not_in_completion_when_none(mock_litellm_completion):
|
||||
# Create a config with top_k set to None
|
||||
config_without_top_k = LLMConfig(top_k=None)
|
||||
llm = LLM(config_without_top_k)
|
||||
llm = LLM(config_without_top_k, service_id='test-service')
|
||||
|
||||
# Define a side effect function to check top_k
|
||||
def side_effect(*args, **kwargs):
|
||||
@@ -202,7 +202,7 @@ def test_llm_top_k_not_in_completion_when_none(mock_litellm_completion):
|
||||
def test_llm_init_with_metrics():
|
||||
config = LLMConfig(model='gpt-4o', api_key='test_key')
|
||||
metrics = Metrics()
|
||||
llm = LLM(config, metrics=metrics)
|
||||
llm = LLM(config, metrics=metrics, service_id='test-service')
|
||||
assert llm.metrics is metrics
|
||||
assert (
|
||||
llm.metrics.model_name == 'default'
|
||||
@@ -224,7 +224,7 @@ def test_response_latency_tracking(mock_time, mock_litellm_completion):
|
||||
|
||||
# Create LLM instance and make a completion call
|
||||
config = LLMConfig(model='gpt-4o', api_key='test_key')
|
||||
llm = LLM(config)
|
||||
llm = LLM(config, service_id='test-service')
|
||||
response = llm.completion(messages=[{'role': 'user', 'content': 'Hello!'}])
|
||||
|
||||
# Verify the response latency was tracked correctly
|
||||
@@ -257,7 +257,7 @@ def test_llm_init_with_openrouter_model(mock_get_model_info, default_config):
|
||||
'max_input_tokens': 7000,
|
||||
'max_output_tokens': 1500,
|
||||
}
|
||||
llm = LLM(default_config)
|
||||
llm = LLM(default_config, service_id='test-service')
|
||||
llm.init_model_info()
|
||||
assert llm.config.max_input_tokens == 7000
|
||||
assert llm.config.max_output_tokens == 1500
|
||||
@@ -280,7 +280,7 @@ def test_stop_parameter_handling(mock_litellm_completion, default_config):
|
||||
default_config.model = (
|
||||
'custom-model' # Use a model not in FUNCTION_CALLING_SUPPORTED_MODELS
|
||||
)
|
||||
llm = LLM(default_config)
|
||||
llm = LLM(default_config, service_id='test-service')
|
||||
llm.completion(
|
||||
messages=[{'role': 'user', 'content': 'Hello!'}],
|
||||
tools=[
|
||||
@@ -292,7 +292,7 @@ def test_stop_parameter_handling(mock_litellm_completion, default_config):
|
||||
|
||||
# Test with Grok-4 model that doesn't support stop parameter
|
||||
default_config.model = 'xai/grok-4-0709'
|
||||
llm = LLM(default_config)
|
||||
llm = LLM(default_config, service_id='test-service')
|
||||
llm.completion(
|
||||
messages=[{'role': 'user', 'content': 'Hello!'}],
|
||||
tools=[
|
||||
@@ -314,7 +314,7 @@ def test_completion_with_mocked_logger(
|
||||
'choices': [{'message': {'content': 'Test response'}}]
|
||||
}
|
||||
|
||||
llm = LLM(config=default_config)
|
||||
llm = LLM(config=default_config, service_id='test-service')
|
||||
response = llm.completion(
|
||||
messages=[{'role': 'user', 'content': 'Hello!'}],
|
||||
stream=False,
|
||||
@@ -345,7 +345,7 @@ def test_completion_retries(
|
||||
{'choices': [{'message': {'content': 'Retry successful'}}]},
|
||||
]
|
||||
|
||||
llm = LLM(config=default_config)
|
||||
llm = LLM(config=default_config, service_id='test-service')
|
||||
response = llm.completion(
|
||||
messages=[{'role': 'user', 'content': 'Hello!'}],
|
||||
stream=False,
|
||||
@@ -365,7 +365,7 @@ def test_completion_rate_limit_wait_time(mock_litellm_completion, default_config
|
||||
{'choices': [{'message': {'content': 'Retry successful'}}]},
|
||||
]
|
||||
|
||||
llm = LLM(config=default_config)
|
||||
llm = LLM(config=default_config, service_id='test-service')
|
||||
response = llm.completion(
|
||||
messages=[{'role': 'user', 'content': 'Hello!'}],
|
||||
stream=False,
|
||||
@@ -387,7 +387,7 @@ def test_completion_rate_limit_wait_time(mock_litellm_completion, default_config
|
||||
def test_completion_operation_cancelled(mock_litellm_completion, default_config):
|
||||
mock_litellm_completion.side_effect = OperationCancelled('Operation cancelled')
|
||||
|
||||
llm = LLM(config=default_config)
|
||||
llm = LLM(config=default_config, service_id='test-service')
|
||||
with pytest.raises(OperationCancelled):
|
||||
llm.completion(
|
||||
messages=[{'role': 'user', 'content': 'Hello!'}],
|
||||
@@ -404,7 +404,7 @@ def test_completion_keyboard_interrupt(mock_litellm_completion, default_config):
|
||||
|
||||
mock_litellm_completion.side_effect = side_effect
|
||||
|
||||
llm = LLM(config=default_config)
|
||||
llm = LLM(config=default_config, service_id='test-service')
|
||||
with pytest.raises(OperationCancelled):
|
||||
try:
|
||||
llm.completion(
|
||||
@@ -428,7 +428,7 @@ def test_completion_keyboard_interrupt_handler(mock_litellm_completion, default_
|
||||
|
||||
mock_litellm_completion.side_effect = side_effect
|
||||
|
||||
llm = LLM(config=default_config)
|
||||
llm = LLM(config=default_config, service_id='test-service')
|
||||
result = llm.completion(
|
||||
messages=[{'role': 'user', 'content': 'Hello!'}],
|
||||
stream=False,
|
||||
@@ -469,7 +469,7 @@ def test_completion_retry_with_llm_no_response_error_zero_temp(
|
||||
mock_litellm_completion.side_effect = side_effect
|
||||
|
||||
# Create LLM instance and make a completion call
|
||||
llm = LLM(config=default_config)
|
||||
llm = LLM(config=default_config, service_id='test-service')
|
||||
response = llm.completion(
|
||||
messages=[{'role': 'user', 'content': 'Hello!'}],
|
||||
stream=False,
|
||||
@@ -509,7 +509,7 @@ def test_completion_retry_with_llm_no_response_error_nonzero_temp(
|
||||
'LLM did not return a response'
|
||||
)
|
||||
|
||||
llm = LLM(config=default_config)
|
||||
llm = LLM(config=default_config, service_id='test-service')
|
||||
with pytest.raises(LLMNoResponseError):
|
||||
llm.completion(
|
||||
messages=[{'role': 'user', 'content': 'Hello!'}],
|
||||
@@ -575,7 +575,7 @@ def test_gemini_25_pro_function_calling(mock_httpx_get, mock_get_model_info):
|
||||
|
||||
for model_name, expected_support in test_cases:
|
||||
config = LLMConfig(model=model_name, api_key='test_key')
|
||||
llm = LLM(config)
|
||||
llm = LLM(config, service_id='test-service')
|
||||
|
||||
assert llm.is_function_calling_active() == expected_support, (
|
||||
f'Expected function calling support to be {expected_support} for model {model_name}'
|
||||
@@ -617,7 +617,7 @@ def test_completion_retry_with_llm_no_response_error_nonzero_temp_successful_ret
|
||||
mock_litellm_completion.side_effect = side_effect
|
||||
|
||||
# Create LLM instance and make a completion call with non-zero temperature
|
||||
llm = LLM(config=default_config)
|
||||
llm = LLM(config=default_config, service_id='test-service')
|
||||
response = llm.completion(
|
||||
messages=[{'role': 'user', 'content': 'Hello!'}],
|
||||
stream=False,
|
||||
@@ -677,7 +677,7 @@ def test_completion_retry_with_llm_no_response_error_successful_retry(
|
||||
mock_litellm_completion.side_effect = side_effect
|
||||
|
||||
# Create LLM instance and make a completion call with explicit temperature=0
|
||||
llm = LLM(config=default_config)
|
||||
llm = LLM(config=default_config, service_id='test-service')
|
||||
response = llm.completion(
|
||||
messages=[{'role': 'user', 'content': 'Hello!'}],
|
||||
stream=False,
|
||||
@@ -709,7 +709,7 @@ def test_completion_with_litellm_mock(mock_litellm_completion, default_config):
|
||||
}
|
||||
mock_litellm_completion.return_value = mock_response
|
||||
|
||||
test_llm = LLM(config=default_config)
|
||||
test_llm = LLM(config=default_config, service_id='test-service')
|
||||
response = test_llm.completion(
|
||||
messages=[{'role': 'user', 'content': 'Hello!'}],
|
||||
stream=False,
|
||||
@@ -743,7 +743,7 @@ def test_llm_gemini_thinking_parameter(mock_litellm_completion, default_config):
|
||||
}
|
||||
|
||||
# Initialize LLM and call completion
|
||||
llm = LLM(config=gemini_config)
|
||||
llm = LLM(config=gemini_config, service_id='test-service')
|
||||
llm.completion(messages=[{'role': 'user', 'content': 'Hello!'}])
|
||||
|
||||
# Verify that litellm_completion was called with the 'thinking' parameter
|
||||
@@ -762,7 +762,7 @@ def test_llm_gemini_thinking_parameter(mock_litellm_completion, default_config):
|
||||
@patch('openhands.llm.llm.litellm.token_counter')
|
||||
def test_get_token_count_with_dict_messages(mock_token_counter, default_config):
|
||||
mock_token_counter.return_value = 42
|
||||
llm = LLM(default_config)
|
||||
llm = LLM(default_config, service_id='test-service')
|
||||
messages = [{'role': 'user', 'content': 'Hello!'}]
|
||||
|
||||
token_count = llm.get_token_count(messages)
|
||||
@@ -777,7 +777,7 @@ def test_get_token_count_with_dict_messages(mock_token_counter, default_config):
|
||||
def test_get_token_count_with_message_objects(
|
||||
mock_token_counter, default_config, mock_logger
|
||||
):
|
||||
llm = LLM(default_config)
|
||||
llm = LLM(default_config, service_id='test-service')
|
||||
|
||||
# Create a Message object and its equivalent dict
|
||||
message_obj = Message(role='user', content=[TextContent(text='Hello!')])
|
||||
@@ -806,7 +806,7 @@ def test_get_token_count_with_custom_tokenizer(
|
||||
|
||||
config = copy.deepcopy(default_config)
|
||||
config.custom_tokenizer = 'custom/tokenizer'
|
||||
llm = LLM(config)
|
||||
llm = LLM(config, service_id='test-service')
|
||||
messages = [{'role': 'user', 'content': 'Hello!'}]
|
||||
|
||||
token_count = llm.get_token_count(messages)
|
||||
@@ -823,7 +823,7 @@ def test_get_token_count_error_handling(
|
||||
mock_token_counter, default_config, mock_logger
|
||||
):
|
||||
mock_token_counter.side_effect = Exception('Token counting failed')
|
||||
llm = LLM(default_config)
|
||||
llm = LLM(default_config, service_id='test-service')
|
||||
messages = [{'role': 'user', 'content': 'Hello!'}]
|
||||
|
||||
token_count = llm.get_token_count(messages)
|
||||
@@ -865,7 +865,7 @@ def test_llm_token_usage(mock_litellm_completion, default_config):
|
||||
# We'll make mock_litellm_completion return these responses in sequence
|
||||
mock_litellm_completion.side_effect = [mock_response_1, mock_response_2]
|
||||
|
||||
llm = LLM(config=default_config)
|
||||
llm = LLM(config=default_config, service_id='test-service')
|
||||
|
||||
# First call
|
||||
llm.completion(messages=[{'role': 'user', 'content': 'Hello usage!'}])
|
||||
@@ -924,7 +924,7 @@ def test_accumulated_token_usage(mock_litellm_completion, default_config):
|
||||
mock_litellm_completion.side_effect = [mock_response_1, mock_response_2]
|
||||
|
||||
# Create LLM instance
|
||||
llm = LLM(config=default_config)
|
||||
llm = LLM(config=default_config, service_id='test-service')
|
||||
|
||||
# First call
|
||||
llm.completion(messages=[{'role': 'user', 'content': 'First message'}])
|
||||
@@ -980,7 +980,7 @@ def test_completion_with_log_completions(mock_litellm_completion, default_config
|
||||
}
|
||||
mock_litellm_completion.return_value = mock_response
|
||||
|
||||
test_llm = LLM(config=default_config)
|
||||
test_llm = LLM(config=default_config, service_id='test-service')
|
||||
response = test_llm.completion(
|
||||
messages=[{'role': 'user', 'content': 'Hello!'}],
|
||||
stream=False,
|
||||
@@ -1006,7 +1006,7 @@ def test_llm_base_url_auto_protocol_patch(mock_get):
|
||||
mock_get.return_value.status_code = 200
|
||||
mock_get.return_value.json.return_value = {'model': 'fake'}
|
||||
|
||||
llm = LLM(config=config)
|
||||
llm = LLM(config=config, service_id='test-service')
|
||||
llm.init_model_info()
|
||||
|
||||
called_url = mock_get.call_args[0][0]
|
||||
@@ -1020,7 +1020,7 @@ def test_unknown_model_token_limits():
|
||||
"""Test that models without known token limits get None for both max_output_tokens and max_input_tokens."""
|
||||
# Create LLM instance with a non-existent model to avoid litellm having model info for it
|
||||
config = LLMConfig(model='non-existent-model', api_key='test_key')
|
||||
llm = LLM(config)
|
||||
llm = LLM(config, service_id='test-service')
|
||||
|
||||
# Verify max_output_tokens and max_input_tokens are initialized to None (default value)
|
||||
assert llm.config.max_output_tokens is None
|
||||
@@ -1031,7 +1031,7 @@ def test_max_tokens_from_model_info():
|
||||
"""Test that max_output_tokens and max_input_tokens are correctly initialized from model info."""
|
||||
# Create LLM instance with GPT-4 model which has known token limits
|
||||
config = LLMConfig(model='gpt-4', api_key='test_key')
|
||||
llm = LLM(config)
|
||||
llm = LLM(config, service_id='test-service')
|
||||
|
||||
# GPT-4 has specific token limits
|
||||
# These are the expected values from litellm
|
||||
@@ -1043,7 +1043,7 @@ def test_claude_3_7_sonnet_max_output_tokens():
|
||||
"""Test that Claude 3.7 Sonnet models get the special 64000 max_output_tokens value and default max_input_tokens."""
|
||||
# Create LLM instance with Claude 3.7 Sonnet model
|
||||
config = LLMConfig(model='claude-3-7-sonnet', api_key='test_key')
|
||||
llm = LLM(config)
|
||||
llm = LLM(config, service_id='test-service')
|
||||
|
||||
# Verify max_output_tokens is set to 64000 for Claude 3.7 Sonnet
|
||||
assert llm.config.max_output_tokens == 64000
|
||||
@@ -1055,7 +1055,7 @@ def test_claude_sonnet_4_max_output_tokens():
|
||||
"""Test that Claude Sonnet 4 models get the correct max_output_tokens and max_input_tokens values."""
|
||||
# Create LLM instance with a Claude Sonnet 4 model
|
||||
config = LLMConfig(model='claude-sonnet-4-20250514', api_key='test_key')
|
||||
llm = LLM(config)
|
||||
llm = LLM(config, service_id='test-service')
|
||||
|
||||
# Verify max_output_tokens is set to the expected value
|
||||
assert llm.config.max_output_tokens == 64000
|
||||
@@ -1068,7 +1068,7 @@ def test_sambanova_deepseek_model_max_output_tokens():
|
||||
"""Test that SambaNova DeepSeek-V3-0324 model gets the correct max_output_tokens value."""
|
||||
# Create LLM instance with SambaNova DeepSeek model
|
||||
config = LLMConfig(model='sambanova/DeepSeek-V3-0324', api_key='test_key')
|
||||
llm = LLM(config)
|
||||
llm = LLM(config, service_id='test-service')
|
||||
|
||||
# SambaNova DeepSeek model has specific token limits
|
||||
# This is the expected value from litellm
|
||||
@@ -1081,7 +1081,7 @@ def test_max_output_tokens_override_in_config():
|
||||
config = LLMConfig(
|
||||
model='claude-sonnet-4-20250514', api_key='test_key', max_output_tokens=2048
|
||||
)
|
||||
llm = LLM(config)
|
||||
llm = LLM(config, service_id='test-service')
|
||||
|
||||
# Verify the config has the overridden max_output_tokens value
|
||||
assert llm.config.max_output_tokens == 2048
|
||||
@@ -1098,7 +1098,7 @@ def test_azure_model_default_max_tokens():
|
||||
)
|
||||
|
||||
# Create LLM instance with Azure model
|
||||
llm = LLM(azure_config)
|
||||
llm = LLM(azure_config, service_id='test-service')
|
||||
|
||||
# Verify the config has the default max_output_tokens value
|
||||
assert llm.config.max_output_tokens is None # Default value
|
||||
@@ -1143,7 +1143,7 @@ def test_gemini_none_reasoning_effort_uses_thinking_budget(mock_completion):
|
||||
'usage': {'prompt_tokens': 10, 'completion_tokens': 5},
|
||||
}
|
||||
|
||||
llm = LLM(config)
|
||||
llm = LLM(config, service_id='test-service')
|
||||
sample_messages = [{'role': 'user', 'content': 'Hello, how are you?'}]
|
||||
llm.completion(messages=sample_messages)
|
||||
|
||||
@@ -1167,7 +1167,7 @@ def test_gemini_low_reasoning_effort_uses_thinking_budget(mock_completion):
|
||||
'usage': {'prompt_tokens': 10, 'completion_tokens': 5},
|
||||
}
|
||||
|
||||
llm = LLM(config)
|
||||
llm = LLM(config, service_id='test-service')
|
||||
sample_messages = [{'role': 'user', 'content': 'Hello, how are you?'}]
|
||||
llm.completion(messages=sample_messages)
|
||||
|
||||
@@ -1191,7 +1191,7 @@ def test_gemini_medium_reasoning_effort_passes_through(mock_completion):
|
||||
'usage': {'prompt_tokens': 10, 'completion_tokens': 5},
|
||||
}
|
||||
|
||||
llm = LLM(config)
|
||||
llm = LLM(config, service_id='test-service')
|
||||
sample_messages = [{'role': 'user', 'content': 'Hello, how are you?'}]
|
||||
llm.completion(messages=sample_messages)
|
||||
|
||||
@@ -1214,7 +1214,7 @@ def test_gemini_high_reasoning_effort_passes_through(mock_completion):
|
||||
'usage': {'prompt_tokens': 10, 'completion_tokens': 5},
|
||||
}
|
||||
|
||||
llm = LLM(config)
|
||||
llm = LLM(config, service_id='test-service')
|
||||
sample_messages = [{'role': 'user', 'content': 'Hello, how are you?'}]
|
||||
llm.completion(messages=sample_messages)
|
||||
|
||||
@@ -1235,7 +1235,7 @@ def test_non_gemini_uses_reasoning_effort(mock_completion):
|
||||
'usage': {'prompt_tokens': 10, 'completion_tokens': 5},
|
||||
}
|
||||
|
||||
llm = LLM(config)
|
||||
llm = LLM(config, service_id='test-service')
|
||||
sample_messages = [{'role': 'user', 'content': 'Hello, how are you?'}]
|
||||
llm.completion(messages=sample_messages)
|
||||
|
||||
@@ -1259,7 +1259,7 @@ def test_non_reasoning_model_no_optimization(mock_completion):
|
||||
'usage': {'prompt_tokens': 10, 'completion_tokens': 5},
|
||||
}
|
||||
|
||||
llm = LLM(config)
|
||||
llm = LLM(config, service_id='test-service')
|
||||
sample_messages = [{'role': 'user', 'content': 'Hello, how are you?'}]
|
||||
llm.completion(messages=sample_messages)
|
||||
|
||||
@@ -1285,7 +1285,7 @@ def test_gemini_performance_optimization_end_to_end(mock_completion):
|
||||
assert config.reasoning_effort is None
|
||||
|
||||
# Create LLM and make completion
|
||||
llm = LLM(config)
|
||||
llm = LLM(config, service_id='test-service')
|
||||
messages = [{'role': 'user', 'content': 'Solve this complex problem'}]
|
||||
|
||||
response = llm.completion(messages=messages)
|
||||
|
||||
@@ -207,7 +207,7 @@ def test_guess_success_rate_limit_wait_time(mock_litellm_completion, default_con
|
||||
),
|
||||
]
|
||||
|
||||
llm = LLM(config=default_config)
|
||||
llm = LLM(config=default_config, service_id='test-service')
|
||||
handler = ServiceContextIssue(
|
||||
GithubIssueHandler('test-owner', 'test-repo', 'test-token'), default_config
|
||||
)
|
||||
@@ -251,7 +251,7 @@ def test_guess_success_exhausts_retries(mock_completion, default_config):
|
||||
)
|
||||
|
||||
# Initialize LLM and handler
|
||||
llm = LLM(config=default_config)
|
||||
llm = LLM(config=default_config, service_id='test-service')
|
||||
handler = ServiceContextPR(
|
||||
GithubPRHandler('test-owner', 'test-repo', 'test-token'), default_config
|
||||
)
|
||||
|
||||
@@ -463,7 +463,7 @@ async def test_process_issue(
|
||||
[],
|
||||
)
|
||||
handler_instance.issue_type = 'pr' if test_case.get('is_pr', False) else 'issue'
|
||||
handler_instance.llm = LLM(llm_config)
|
||||
handler_instance.llm = LLM(llm_config, service_id='test-service')
|
||||
|
||||
# Mock the runtime and its methods
|
||||
mock_runtime = MagicMock()
|
||||
|
||||
@@ -209,7 +209,7 @@ def test_guess_success_rate_limit_wait_time(mock_litellm_completion, default_con
|
||||
),
|
||||
]
|
||||
|
||||
llm = LLM(config=default_config)
|
||||
llm = LLM(config=default_config, service_id='test-service')
|
||||
handler = ServiceContextIssue(
|
||||
GitlabIssueHandler('test-owner', 'test-repo', 'test-token'), default_config
|
||||
)
|
||||
@@ -253,7 +253,7 @@ def test_guess_success_exhausts_retries(mock_completion, default_config):
|
||||
)
|
||||
|
||||
# Initialize LLM and handler
|
||||
llm = LLM(config=default_config)
|
||||
llm = LLM(config=default_config, service_id='test-service')
|
||||
handler = ServiceContextPR(
|
||||
GitlabPRHandler('test-owner', 'test-repo', 'test-token'), default_config
|
||||
)
|
||||
|
||||
@@ -500,7 +500,7 @@ async def test_process_issue(
|
||||
[],
|
||||
)
|
||||
handler_instance.issue_type = 'pr' if test_case.get('is_pr', False) else 'issue'
|
||||
handler_instance.llm = LLM(llm_config)
|
||||
handler_instance.llm = LLM(llm_config, service_id='test-service')
|
||||
|
||||
# Create mock runtime and mock run_controller
|
||||
mock_runtime = MagicMock()
|
||||
|
||||
@@ -18,6 +18,7 @@ from openhands.controller.state.control_flags import (
|
||||
from openhands.controller.state.state import State
|
||||
from openhands.core.config import OpenHandsConfig
|
||||
from openhands.core.config.agent_config import AgentConfig
|
||||
from openhands.core.config.llm_config import LLMConfig
|
||||
from openhands.core.main import run_controller
|
||||
from openhands.core.schema import AgentState
|
||||
from openhands.events import Event, EventSource, EventStream, EventStreamSubscriber
|
||||
@@ -33,6 +34,7 @@ from openhands.events.observation.agent import RecallObservation
|
||||
from openhands.events.observation.empty import NullObservation
|
||||
from openhands.events.serialization import event_to_dict
|
||||
from openhands.llm import LLM
|
||||
from openhands.llm.llm_registry import LLMRegistry, RegistryEvent
|
||||
from openhands.llm.metrics import Metrics, TokenUsage
|
||||
from openhands.memory.condenser.condenser import Condensation
|
||||
from openhands.memory.condenser.impl.conversation_window_condenser import (
|
||||
@@ -45,6 +47,7 @@ from openhands.runtime.impl.action_execution.action_execution_client import (
|
||||
ActionExecutionClient,
|
||||
)
|
||||
from openhands.runtime.runtime_status import RuntimeStatus
|
||||
from openhands.server.services.conversation_stats import ConversationStats
|
||||
from openhands.storage.memory import InMemoryFileStore
|
||||
|
||||
|
||||
@@ -61,15 +64,43 @@ def event_loop():
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_agent():
|
||||
agent = MagicMock(spec=Agent)
|
||||
agent.llm = MagicMock(spec=LLM)
|
||||
agent.llm.metrics = Metrics()
|
||||
agent.llm.config = OpenHandsConfig().get_llm_config()
|
||||
def mock_agent_with_stats():
|
||||
"""Create a mock agent with properly connected LLM registry and conversation stats."""
|
||||
import uuid
|
||||
|
||||
# Add config with enable_mcp attribute
|
||||
agent.config = MagicMock(spec=AgentConfig)
|
||||
agent.config.enable_mcp = True
|
||||
# Create LLM registry
|
||||
config = OpenHandsConfig()
|
||||
llm_registry = LLMRegistry(config=config)
|
||||
|
||||
# Create conversation stats
|
||||
file_store = InMemoryFileStore({})
|
||||
conversation_id = f'test-conversation-{uuid.uuid4()}'
|
||||
conversation_stats = ConversationStats(
|
||||
file_store=file_store, conversation_id=conversation_id, user_id='test-user'
|
||||
)
|
||||
|
||||
# Connect registry to stats (this is the key requirement)
|
||||
llm_registry.subscribe(conversation_stats.register_llm)
|
||||
|
||||
# Create mock agent
|
||||
agent = MagicMock(spec=Agent)
|
||||
agent_config = MagicMock(spec=AgentConfig)
|
||||
llm_config = LLMConfig(
|
||||
model='gpt-4o',
|
||||
api_key='test_key',
|
||||
num_retries=2,
|
||||
retry_min_wait=1,
|
||||
retry_max_wait=2,
|
||||
)
|
||||
agent_config.disabled_microagents = []
|
||||
agent_config.enable_mcp = True
|
||||
llm_registry.service_to_llm.clear()
|
||||
mock_llm = llm_registry.get_llm('agent_llm', llm_config)
|
||||
agent.llm = mock_llm
|
||||
agent.name = 'test-agent'
|
||||
agent.sandbox_plugins = []
|
||||
agent.config = agent_config
|
||||
agent.prompt_manager = MagicMock()
|
||||
|
||||
# Add a proper system message mock
|
||||
system_message = SystemMessageAction(
|
||||
@@ -79,7 +110,7 @@ def mock_agent():
|
||||
system_message._id = -1 # Set invalid ID to avoid the ID check
|
||||
agent.get_system_message.return_value = system_message
|
||||
|
||||
return agent
|
||||
return agent, conversation_stats, llm_registry
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -134,10 +165,13 @@ async def send_event_to_controller(controller, event):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_agent_state(mock_agent, mock_event_stream):
|
||||
async def test_set_agent_state(mock_agent_with_stats, mock_event_stream):
|
||||
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
|
||||
|
||||
controller = AgentController(
|
||||
agent=mock_agent,
|
||||
event_stream=mock_event_stream,
|
||||
convo_stats=conversation_stats,
|
||||
iteration_delta=10,
|
||||
sid='test',
|
||||
confirmation_mode=False,
|
||||
@@ -152,10 +186,13 @@ async def test_set_agent_state(mock_agent, mock_event_stream):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_event_message_action(mock_agent, mock_event_stream):
|
||||
async def test_on_event_message_action(mock_agent_with_stats, mock_event_stream):
|
||||
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
|
||||
|
||||
controller = AgentController(
|
||||
agent=mock_agent,
|
||||
event_stream=mock_event_stream,
|
||||
convo_stats=conversation_stats,
|
||||
iteration_delta=10,
|
||||
sid='test',
|
||||
confirmation_mode=False,
|
||||
@@ -169,10 +206,15 @@ async def test_on_event_message_action(mock_agent, mock_event_stream):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_event_change_agent_state_action(mock_agent, mock_event_stream):
|
||||
async def test_on_event_change_agent_state_action(
|
||||
mock_agent_with_stats, mock_event_stream
|
||||
):
|
||||
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
|
||||
|
||||
controller = AgentController(
|
||||
agent=mock_agent,
|
||||
event_stream=mock_event_stream,
|
||||
convo_stats=conversation_stats,
|
||||
iteration_delta=10,
|
||||
sid='test',
|
||||
confirmation_mode=False,
|
||||
@@ -186,10 +228,17 @@ async def test_on_event_change_agent_state_action(mock_agent, mock_event_stream)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_react_to_exception(mock_agent, mock_event_stream, mock_status_callback):
|
||||
async def test_react_to_exception(
|
||||
mock_agent_with_stats,
|
||||
mock_event_stream,
|
||||
mock_status_callback,
|
||||
):
|
||||
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
|
||||
|
||||
controller = AgentController(
|
||||
agent=mock_agent,
|
||||
event_stream=mock_event_stream,
|
||||
convo_stats=conversation_stats,
|
||||
status_callback=mock_status_callback,
|
||||
iteration_delta=10,
|
||||
sid='test',
|
||||
@@ -204,12 +253,17 @@ async def test_react_to_exception(mock_agent, mock_event_stream, mock_status_cal
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_react_to_content_policy_violation(
|
||||
mock_agent, mock_event_stream, mock_status_callback
|
||||
mock_agent_with_stats,
|
||||
mock_event_stream,
|
||||
mock_status_callback,
|
||||
):
|
||||
"""Test that the controller properly handles content policy violations from the LLM."""
|
||||
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
|
||||
|
||||
controller = AgentController(
|
||||
agent=mock_agent,
|
||||
event_stream=mock_event_stream,
|
||||
convo_stats=conversation_stats,
|
||||
status_callback=mock_status_callback,
|
||||
iteration_delta=10,
|
||||
sid='test',
|
||||
@@ -246,18 +300,16 @@ async def test_react_to_content_policy_violation(
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_controller_with_fatal_error(
|
||||
test_event_stream, mock_memory, mock_agent
|
||||
test_event_stream, mock_memory, mock_agent_with_stats
|
||||
):
|
||||
config = OpenHandsConfig()
|
||||
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
|
||||
|
||||
def agent_step_fn(state):
|
||||
print(f'agent_step_fn received state: {state}')
|
||||
return CmdRunAction(command='ls')
|
||||
|
||||
mock_agent.step = agent_step_fn
|
||||
mock_agent.llm = MagicMock(spec=LLM)
|
||||
mock_agent.llm.metrics = Metrics()
|
||||
mock_agent.llm.config = config.get_llm_config()
|
||||
|
||||
runtime = MagicMock(spec=ActionExecutionClient)
|
||||
|
||||
@@ -284,15 +336,17 @@ async def test_run_controller_with_fatal_error(
|
||||
EventStreamSubscriber.MEMORY, on_event_memory, str(uuid4())
|
||||
)
|
||||
|
||||
state = await run_controller(
|
||||
config=config,
|
||||
initial_user_action=MessageAction(content='Test message'),
|
||||
runtime=runtime,
|
||||
sid='test',
|
||||
agent=mock_agent,
|
||||
fake_user_response_fn=lambda _: 'repeat',
|
||||
memory=mock_memory,
|
||||
)
|
||||
# Mock the create_agent function to return our mock agent
|
||||
with patch('openhands.core.main.create_agent', return_value=mock_agent):
|
||||
state = await run_controller(
|
||||
config=config,
|
||||
initial_user_action=MessageAction(content='Test message'),
|
||||
runtime=runtime,
|
||||
sid='test',
|
||||
fake_user_response_fn=lambda _: 'repeat',
|
||||
memory=mock_memory,
|
||||
llm_registry=llm_registry,
|
||||
)
|
||||
print(f'state: {state}')
|
||||
events = list(test_event_stream.get_events())
|
||||
print(f'event_stream: {events}')
|
||||
@@ -312,18 +366,16 @@ async def test_run_controller_with_fatal_error(
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_controller_stop_with_stuck(
|
||||
test_event_stream, mock_memory, mock_agent
|
||||
test_event_stream, mock_memory, mock_agent_with_stats
|
||||
):
|
||||
config = OpenHandsConfig()
|
||||
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
|
||||
|
||||
def agent_step_fn(state):
|
||||
print(f'agent_step_fn received state: {state}')
|
||||
return CmdRunAction(command='ls')
|
||||
|
||||
mock_agent.step = agent_step_fn
|
||||
mock_agent.llm = MagicMock(spec=LLM)
|
||||
mock_agent.llm.metrics = Metrics()
|
||||
mock_agent.llm.config = config.get_llm_config()
|
||||
|
||||
runtime = MagicMock(spec=ActionExecutionClient)
|
||||
|
||||
@@ -352,15 +404,17 @@ async def test_run_controller_stop_with_stuck(
|
||||
EventStreamSubscriber.MEMORY, on_event_memory, str(uuid4())
|
||||
)
|
||||
|
||||
state = await run_controller(
|
||||
config=config,
|
||||
initial_user_action=MessageAction(content='Test message'),
|
||||
runtime=runtime,
|
||||
sid='test',
|
||||
agent=mock_agent,
|
||||
fake_user_response_fn=lambda _: 'repeat',
|
||||
memory=mock_memory,
|
||||
)
|
||||
# Mock the create_agent function to return our mock agent
|
||||
with patch('openhands.core.main.create_agent', return_value=mock_agent):
|
||||
state = await run_controller(
|
||||
config=config,
|
||||
initial_user_action=MessageAction(content='Test message'),
|
||||
runtime=runtime,
|
||||
sid='test',
|
||||
fake_user_response_fn=lambda _: 'repeat',
|
||||
memory=mock_memory,
|
||||
llm_registry=llm_registry,
|
||||
)
|
||||
events = list(test_event_stream.get_events())
|
||||
print(f'state: {state}')
|
||||
for i, event in enumerate(events):
|
||||
@@ -391,11 +445,14 @@ async def test_run_controller_stop_with_stuck(
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_max_iterations_extension(mock_agent, mock_event_stream):
|
||||
async def test_max_iterations_extension(mock_agent_with_stats, mock_event_stream):
|
||||
# Test with headless_mode=False - should extend max_iterations
|
||||
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
|
||||
|
||||
controller = AgentController(
|
||||
agent=mock_agent,
|
||||
event_stream=mock_event_stream,
|
||||
convo_stats=conversation_stats,
|
||||
iteration_delta=10,
|
||||
sid='test',
|
||||
confirmation_mode=False,
|
||||
@@ -426,6 +483,7 @@ async def test_max_iterations_extension(mock_agent, mock_event_stream):
|
||||
controller = AgentController(
|
||||
agent=mock_agent,
|
||||
event_stream=mock_event_stream,
|
||||
convo_stats=conversation_stats,
|
||||
iteration_delta=10,
|
||||
sid='test',
|
||||
confirmation_mode=False,
|
||||
@@ -450,7 +508,9 @@ async def test_max_iterations_extension(mock_agent, mock_event_stream):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_step_max_budget(mock_agent, mock_event_stream):
|
||||
async def test_step_max_budget(mock_agent_with_stats, mock_event_stream):
|
||||
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
|
||||
|
||||
# Metrics are always synced with budget flag before
|
||||
metrics = Metrics()
|
||||
metrics.accumulated_cost = 10.1
|
||||
@@ -458,9 +518,13 @@ async def test_step_max_budget(mock_agent, mock_event_stream):
|
||||
limit_increase_amount=10, current_value=10.1, max_value=10
|
||||
)
|
||||
|
||||
# Update agent's LLM metrics in place
|
||||
mock_agent.llm.metrics.accumulated_cost = metrics.accumulated_cost
|
||||
|
||||
controller = AgentController(
|
||||
agent=mock_agent,
|
||||
event_stream=mock_event_stream,
|
||||
convo_stats=conversation_stats,
|
||||
iteration_delta=10,
|
||||
budget_per_task_delta=10,
|
||||
sid='test',
|
||||
@@ -475,7 +539,9 @@ async def test_step_max_budget(mock_agent, mock_event_stream):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_step_max_budget_headless(mock_agent, mock_event_stream):
|
||||
async def test_step_max_budget_headless(mock_agent_with_stats, mock_event_stream):
|
||||
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
|
||||
|
||||
# Metrics are always synced with budget flag before
|
||||
metrics = Metrics()
|
||||
metrics.accumulated_cost = 10.1
|
||||
@@ -483,9 +549,13 @@ async def test_step_max_budget_headless(mock_agent, mock_event_stream):
|
||||
limit_increase_amount=10, current_value=10.1, max_value=10
|
||||
)
|
||||
|
||||
# Update agent's LLM metrics in place
|
||||
mock_agent.llm.metrics.accumulated_cost = metrics.accumulated_cost
|
||||
|
||||
controller = AgentController(
|
||||
agent=mock_agent,
|
||||
event_stream=mock_event_stream,
|
||||
convo_stats=conversation_stats,
|
||||
iteration_delta=10,
|
||||
budget_per_task_delta=10,
|
||||
sid='test',
|
||||
@@ -500,12 +570,14 @@ async def test_step_max_budget_headless(mock_agent, mock_event_stream):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_budget_reset_on_continue(mock_agent, mock_event_stream):
|
||||
async def test_budget_reset_on_continue(mock_agent_with_stats, mock_event_stream):
|
||||
"""Test that when a user continues after hitting the budget limit:
|
||||
1. Error is thrown when budget cap is exceeded
|
||||
2. LLM budget does not reset when user continues
|
||||
3. Budget is extended by adding the initial budget cap to the current accumulated cost
|
||||
"""
|
||||
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
|
||||
|
||||
# Create a real Metrics instance shared between controller state and llm
|
||||
metrics = Metrics()
|
||||
metrics.accumulated_cost = 6.0
|
||||
@@ -521,10 +593,14 @@ async def test_budget_reset_on_continue(mock_agent, mock_event_stream):
|
||||
),
|
||||
)
|
||||
|
||||
# Update agent's LLM metrics in place
|
||||
mock_agent.llm.metrics.accumulated_cost = metrics.accumulated_cost
|
||||
|
||||
# Create controller with budget cap
|
||||
controller = AgentController(
|
||||
agent=mock_agent,
|
||||
event_stream=mock_event_stream,
|
||||
convo_stats=conversation_stats,
|
||||
iteration_delta=10,
|
||||
budget_per_task_delta=initial_budget,
|
||||
sid='test',
|
||||
@@ -570,11 +646,17 @@ async def test_budget_reset_on_continue(mock_agent, mock_event_stream):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reset_with_pending_action_no_observation(mock_agent, mock_event_stream):
|
||||
async def test_reset_with_pending_action_no_observation(
|
||||
mock_agent_with_stats, mock_event_stream
|
||||
):
|
||||
"""Test reset() when there's a pending action with tool call metadata but no observation."""
|
||||
# Connect LLM registry to conversation stats
|
||||
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
|
||||
|
||||
controller = AgentController(
|
||||
agent=mock_agent,
|
||||
event_stream=mock_event_stream,
|
||||
convo_stats=conversation_stats,
|
||||
iteration_delta=10,
|
||||
sid='test',
|
||||
confirmation_mode=False,
|
||||
@@ -617,11 +699,17 @@ async def test_reset_with_pending_action_no_observation(mock_agent, mock_event_s
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reset_with_pending_action_stopped_state(mock_agent, mock_event_stream):
|
||||
async def test_reset_with_pending_action_stopped_state(
|
||||
mock_agent_with_stats, mock_event_stream
|
||||
):
|
||||
"""Test reset() when there's a pending action and agent state is STOPPED."""
|
||||
# Connect LLM registry to conversation stats
|
||||
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
|
||||
|
||||
controller = AgentController(
|
||||
agent=mock_agent,
|
||||
event_stream=mock_event_stream,
|
||||
convo_stats=conversation_stats,
|
||||
iteration_delta=10,
|
||||
sid='test',
|
||||
confirmation_mode=False,
|
||||
@@ -665,12 +753,16 @@ async def test_reset_with_pending_action_stopped_state(mock_agent, mock_event_st
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reset_with_pending_action_existing_observation(
|
||||
mock_agent, mock_event_stream
|
||||
mock_agent_with_stats, mock_event_stream
|
||||
):
|
||||
"""Test reset() when there's a pending action with tool call metadata and an existing observation."""
|
||||
# Connect LLM registry to conversation stats
|
||||
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
|
||||
|
||||
controller = AgentController(
|
||||
agent=mock_agent,
|
||||
event_stream=mock_event_stream,
|
||||
convo_stats=conversation_stats,
|
||||
iteration_delta=10,
|
||||
sid='test',
|
||||
confirmation_mode=False,
|
||||
@@ -708,11 +800,15 @@ async def test_reset_with_pending_action_existing_observation(
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reset_without_pending_action(mock_agent, mock_event_stream):
|
||||
async def test_reset_without_pending_action(mock_agent_with_stats, mock_event_stream):
|
||||
"""Test reset() when there's no pending action."""
|
||||
# Connect LLM registry to conversation stats
|
||||
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
|
||||
|
||||
controller = AgentController(
|
||||
agent=mock_agent,
|
||||
event_stream=mock_event_stream,
|
||||
convo_stats=conversation_stats,
|
||||
iteration_delta=10,
|
||||
sid='test',
|
||||
confirmation_mode=False,
|
||||
@@ -738,12 +834,15 @@ async def test_reset_without_pending_action(mock_agent, mock_event_stream):
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reset_with_pending_action_no_metadata(
|
||||
mock_agent, mock_event_stream, monkeypatch
|
||||
mock_agent_with_stats, mock_event_stream, monkeypatch
|
||||
):
|
||||
"""Test reset() when there's a pending action without tool call metadata."""
|
||||
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
|
||||
|
||||
controller = AgentController(
|
||||
agent=mock_agent,
|
||||
event_stream=mock_event_stream,
|
||||
convo_stats=conversation_stats,
|
||||
iteration_delta=10,
|
||||
sid='test',
|
||||
confirmation_mode=False,
|
||||
@@ -782,16 +881,13 @@ async def test_reset_with_pending_action_no_metadata(
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_controller_max_iterations_has_metrics(
|
||||
test_event_stream, mock_memory, mock_agent
|
||||
test_event_stream, mock_memory, mock_agent_with_stats
|
||||
):
|
||||
config = OpenHandsConfig(
|
||||
max_iterations=3,
|
||||
)
|
||||
event_stream = test_event_stream
|
||||
|
||||
mock_agent.llm = MagicMock(spec=LLM)
|
||||
mock_agent.llm.metrics = Metrics()
|
||||
mock_agent.llm.config = config.get_llm_config()
|
||||
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
|
||||
|
||||
step_count = 0
|
||||
|
||||
@@ -833,15 +929,17 @@ async def test_run_controller_max_iterations_has_metrics(
|
||||
|
||||
event_stream.subscribe(EventStreamSubscriber.MEMORY, on_event_memory, str(uuid4()))
|
||||
|
||||
state = await run_controller(
|
||||
config=config,
|
||||
initial_user_action=MessageAction(content='Test message'),
|
||||
runtime=runtime,
|
||||
sid='test',
|
||||
agent=mock_agent,
|
||||
fake_user_response_fn=lambda _: 'repeat',
|
||||
memory=mock_memory,
|
||||
)
|
||||
# Mock the create_agent function to return our mock agent
|
||||
with patch('openhands.core.main.create_agent', return_value=mock_agent):
|
||||
state = await run_controller(
|
||||
config=config,
|
||||
initial_user_action=MessageAction(content='Test message'),
|
||||
runtime=runtime,
|
||||
sid='test',
|
||||
fake_user_response_fn=lambda _: 'repeat',
|
||||
memory=mock_memory,
|
||||
llm_registry=llm_registry,
|
||||
)
|
||||
|
||||
state.metrics = mock_agent.llm.metrics
|
||||
assert state.iteration_flag.current_value == 3
|
||||
@@ -867,10 +965,17 @@ async def test_run_controller_max_iterations_has_metrics(
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_notify_on_llm_retry(mock_agent, mock_event_stream, mock_status_callback):
|
||||
async def test_notify_on_llm_retry(
|
||||
mock_agent_with_stats,
|
||||
mock_event_stream,
|
||||
mock_status_callback,
|
||||
):
|
||||
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
|
||||
|
||||
controller = AgentController(
|
||||
agent=mock_agent,
|
||||
event_stream=mock_event_stream,
|
||||
convo_stats=conversation_stats,
|
||||
status_callback=mock_status_callback,
|
||||
iteration_delta=10,
|
||||
sid='test',
|
||||
@@ -908,9 +1013,15 @@ async def test_notify_on_llm_retry(mock_agent, mock_event_stream, mock_status_ca
|
||||
],
|
||||
)
|
||||
async def test_context_window_exceeded_error_handling(
|
||||
context_window_error, mock_agent, mock_runtime, test_event_stream, mock_memory
|
||||
context_window_error,
|
||||
mock_agent_with_stats,
|
||||
mock_runtime,
|
||||
test_event_stream,
|
||||
mock_memory,
|
||||
):
|
||||
"""Test that context window exceeded errors are handled correctly by the controller, providing a smaller view but keeping the history intact."""
|
||||
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
|
||||
|
||||
max_iterations = 5
|
||||
error_after = 2
|
||||
|
||||
@@ -973,18 +1084,20 @@ async def test_context_window_exceeded_error_handling(
|
||||
# state is set to error out before then, if this terminates and we have a
|
||||
# record of the error being thrown we can be confident that the controller
|
||||
# handles the truncation correctly.
|
||||
final_state = await asyncio.wait_for(
|
||||
run_controller(
|
||||
config=config,
|
||||
initial_user_action=MessageAction(content='INITIAL'),
|
||||
runtime=mock_runtime,
|
||||
sid='test',
|
||||
agent=mock_agent,
|
||||
fake_user_response_fn=lambda _: 'repeat',
|
||||
memory=mock_memory,
|
||||
),
|
||||
timeout=10,
|
||||
)
|
||||
# Mock the create_agent function to return our mock agent
|
||||
with patch('openhands.core.main.create_agent', return_value=mock_agent):
|
||||
final_state = await asyncio.wait_for(
|
||||
run_controller(
|
||||
config=config,
|
||||
initial_user_action=MessageAction(content='INITIAL'),
|
||||
runtime=mock_runtime,
|
||||
sid='test',
|
||||
fake_user_response_fn=lambda _: 'repeat',
|
||||
memory=mock_memory,
|
||||
llm_registry=llm_registry,
|
||||
),
|
||||
timeout=10,
|
||||
)
|
||||
|
||||
# Check that the context window exception was thrown and the controller
|
||||
# called the agent's `step` function the right number of times.
|
||||
@@ -1072,9 +1185,13 @@ async def test_context_window_exceeded_error_handling(
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_controller_with_context_window_exceeded_with_truncation(
|
||||
mock_agent, mock_runtime, mock_memory, test_event_stream
|
||||
mock_agent_with_stats,
|
||||
mock_runtime,
|
||||
mock_memory,
|
||||
test_event_stream,
|
||||
):
|
||||
"""Tests that the controller can make progress after handling context window exceeded errors, as long as enable_history_truncation is ON."""
|
||||
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
|
||||
|
||||
class StepState:
|
||||
def __init__(self):
|
||||
@@ -1121,18 +1238,20 @@ async def test_run_controller_with_context_window_exceeded_with_truncation(
|
||||
mock_runtime.config = copy.deepcopy(config)
|
||||
|
||||
try:
|
||||
state = await asyncio.wait_for(
|
||||
run_controller(
|
||||
config=config,
|
||||
initial_user_action=MessageAction(content='INITIAL'),
|
||||
runtime=mock_runtime,
|
||||
sid='test',
|
||||
agent=mock_agent,
|
||||
fake_user_response_fn=lambda _: 'repeat',
|
||||
memory=mock_memory,
|
||||
),
|
||||
timeout=10,
|
||||
)
|
||||
# Mock the create_agent function to return our mock agent
|
||||
with patch('openhands.core.main.create_agent', return_value=mock_agent):
|
||||
state = await asyncio.wait_for(
|
||||
run_controller(
|
||||
config=config,
|
||||
initial_user_action=MessageAction(content='INITIAL'),
|
||||
runtime=mock_runtime,
|
||||
sid='test',
|
||||
fake_user_response_fn=lambda _: 'repeat',
|
||||
memory=mock_memory,
|
||||
llm_registry=llm_registry,
|
||||
),
|
||||
timeout=10,
|
||||
)
|
||||
|
||||
# A timeout error indicates the run_controller entrypoint is not making
|
||||
# progress
|
||||
@@ -1156,9 +1275,13 @@ async def test_run_controller_with_context_window_exceeded_with_truncation(
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_controller_with_context_window_exceeded_without_truncation(
|
||||
mock_agent, mock_runtime, mock_memory, test_event_stream
|
||||
mock_agent_with_stats,
|
||||
mock_runtime,
|
||||
mock_memory,
|
||||
test_event_stream,
|
||||
):
|
||||
"""Tests that the controller would quit upon context window exceeded errors without enable_history_truncation ON."""
|
||||
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
|
||||
|
||||
class StepState:
|
||||
def __init__(self):
|
||||
@@ -1199,18 +1322,20 @@ async def test_run_controller_with_context_window_exceeded_without_truncation(
|
||||
config = OpenHandsConfig(max_iterations=3)
|
||||
mock_runtime.config = copy.deepcopy(config)
|
||||
try:
|
||||
state = await asyncio.wait_for(
|
||||
run_controller(
|
||||
config=config,
|
||||
initial_user_action=MessageAction(content='INITIAL'),
|
||||
runtime=mock_runtime,
|
||||
sid='test',
|
||||
agent=mock_agent,
|
||||
fake_user_response_fn=lambda _: 'repeat',
|
||||
memory=mock_memory,
|
||||
),
|
||||
timeout=10,
|
||||
)
|
||||
# Mock the create_agent function to return our mock agent
|
||||
with patch('openhands.core.main.create_agent', return_value=mock_agent):
|
||||
state = await asyncio.wait_for(
|
||||
run_controller(
|
||||
config=config,
|
||||
initial_user_action=MessageAction(content='INITIAL'),
|
||||
runtime=mock_runtime,
|
||||
sid='test',
|
||||
fake_user_response_fn=lambda _: 'repeat',
|
||||
memory=mock_memory,
|
||||
llm_registry=llm_registry,
|
||||
),
|
||||
timeout=10,
|
||||
)
|
||||
|
||||
# A timeout error indicates the run_controller entrypoint is not making
|
||||
# progress
|
||||
@@ -1244,7 +1369,11 @@ async def test_run_controller_with_context_window_exceeded_without_truncation(
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_controller_with_memory_error(test_event_stream, mock_agent):
|
||||
async def test_run_controller_with_memory_error(
|
||||
test_event_stream, mock_agent_with_stats
|
||||
):
|
||||
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
|
||||
|
||||
config = OpenHandsConfig()
|
||||
event_stream = test_event_stream
|
||||
|
||||
@@ -1273,15 +1402,17 @@ async def test_run_controller_with_memory_error(test_event_stream, mock_agent):
|
||||
with patch.object(
|
||||
memory, '_find_microagent_knowledge', side_effect=mock_find_microagent_knowledge
|
||||
):
|
||||
state = await run_controller(
|
||||
config=config,
|
||||
initial_user_action=MessageAction(content='Test message'),
|
||||
runtime=runtime,
|
||||
sid='test',
|
||||
agent=mock_agent,
|
||||
fake_user_response_fn=lambda _: 'repeat',
|
||||
memory=memory,
|
||||
)
|
||||
# Mock the create_agent function to return our mock agent
|
||||
with patch('openhands.core.main.create_agent', return_value=mock_agent):
|
||||
state = await run_controller(
|
||||
config=config,
|
||||
initial_user_action=MessageAction(content='Test message'),
|
||||
runtime=runtime,
|
||||
sid='test',
|
||||
fake_user_response_fn=lambda _: 'repeat',
|
||||
memory=memory,
|
||||
llm_registry=llm_registry,
|
||||
)
|
||||
|
||||
assert state.iteration_flag.current_value == 0
|
||||
assert state.agent_state == AgentState.ERROR
|
||||
@@ -1289,7 +1420,9 @@ async def test_run_controller_with_memory_error(test_event_stream, mock_agent):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_action_metrics_copy(mock_agent):
|
||||
async def test_action_metrics_copy(mock_agent_with_stats):
|
||||
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
|
||||
|
||||
# Setup
|
||||
file_store = InMemoryFileStore({})
|
||||
event_stream = EventStream(sid='test', file_store=file_store)
|
||||
@@ -1299,8 +1432,7 @@ async def test_action_metrics_copy(mock_agent):
|
||||
|
||||
initial_state = State(metrics=metrics, budget_flag=None)
|
||||
|
||||
# Create agent with metrics
|
||||
mock_agent.llm = MagicMock(spec=LLM)
|
||||
# Update agent's LLM metrics
|
||||
|
||||
# Add multiple token usages - we should get the last one in the action
|
||||
usage1 = TokenUsage(
|
||||
@@ -1342,6 +1474,11 @@ async def test_action_metrics_copy(mock_agent):
|
||||
|
||||
mock_agent.llm.metrics = metrics
|
||||
|
||||
# Register the metrics with the LLM registry
|
||||
llm_registry.service_to_llm['agent'] = mock_agent.llm
|
||||
# Manually notify the conversation stats about the LLM registration
|
||||
llm_registry.notify(RegistryEvent(llm=mock_agent.llm, service_id='agent'))
|
||||
|
||||
# Mock agent step to return an action
|
||||
action = MessageAction(content='Test message')
|
||||
|
||||
@@ -1354,6 +1491,7 @@ async def test_action_metrics_copy(mock_agent):
|
||||
controller = AgentController(
|
||||
agent=mock_agent,
|
||||
event_stream=event_stream,
|
||||
convo_stats=conversation_stats,
|
||||
iteration_delta=10,
|
||||
sid='test',
|
||||
confirmation_mode=False,
|
||||
@@ -1411,12 +1549,13 @@ async def test_action_metrics_copy(mock_agent):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_condenser_metrics_included(mock_agent, test_event_stream):
|
||||
async def test_condenser_metrics_included(mock_agent_with_stats, test_event_stream):
|
||||
"""Test that metrics from the condenser's LLM are included in the action metrics."""
|
||||
# Set up agent metrics
|
||||
agent_metrics = Metrics(model_name='agent-model')
|
||||
agent_metrics.accumulated_cost = 0.05
|
||||
agent_metrics._accumulated_token_usage = TokenUsage(
|
||||
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
|
||||
|
||||
# Set up agent metrics in place
|
||||
mock_agent.llm.metrics.accumulated_cost = 0.05
|
||||
mock_agent.llm.metrics._accumulated_token_usage = TokenUsage(
|
||||
model='agent-model',
|
||||
prompt_tokens=100,
|
||||
completion_tokens=50,
|
||||
@@ -1424,7 +1563,6 @@ async def test_condenser_metrics_included(mock_agent, test_event_stream):
|
||||
cache_write_tokens=10,
|
||||
response_id='agent-accumulated',
|
||||
)
|
||||
# mock_agent.llm.metrics = agent_metrics
|
||||
mock_agent.name = 'TestAgent'
|
||||
|
||||
# Create condenser with its own metrics
|
||||
@@ -1442,6 +1580,11 @@ async def test_condenser_metrics_included(mock_agent, test_event_stream):
|
||||
)
|
||||
condenser.llm.metrics = condenser_metrics
|
||||
|
||||
# Register the condenser metrics with the LLM registry
|
||||
llm_registry.service_to_llm['condenser'] = condenser.llm
|
||||
# Manually notify the conversation stats about the condenser LLM registration
|
||||
llm_registry.notify(RegistryEvent(llm=condenser.llm, service_id='condenser'))
|
||||
|
||||
# Attach the condenser to the mock_agent
|
||||
mock_agent.condenser = condenser
|
||||
|
||||
@@ -1463,11 +1606,12 @@ async def test_condenser_metrics_included(mock_agent, test_event_stream):
|
||||
controller = AgentController(
|
||||
agent=mock_agent,
|
||||
event_stream=test_event_stream,
|
||||
convo_stats=conversation_stats,
|
||||
iteration_delta=10,
|
||||
sid='test',
|
||||
confirmation_mode=False,
|
||||
headless_mode=True,
|
||||
initial_state=State(metrics=agent_metrics, budget_flag=None),
|
||||
initial_state=State(metrics=mock_agent.llm.metrics, budget_flag=None),
|
||||
)
|
||||
|
||||
# Execute one step
|
||||
@@ -1505,7 +1649,9 @@ async def test_condenser_metrics_included(mock_agent, test_event_stream):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_first_user_message_with_identical_content(test_event_stream, mock_agent):
|
||||
async def test_first_user_message_with_identical_content(
|
||||
test_event_stream, mock_agent_with_stats
|
||||
):
|
||||
"""Test that _first_user_message correctly identifies the first user message.
|
||||
|
||||
This test verifies that messages with identical content but different IDs are properly
|
||||
@@ -1514,14 +1660,12 @@ async def test_first_user_message_with_identical_content(test_event_stream, mock
|
||||
The issue we're checking is that the comparison (action == self._first_user_message())
|
||||
should correctly differentiate between messages with the same content but different IDs.
|
||||
"""
|
||||
# Create an agent controller
|
||||
mock_agent.llm = MagicMock(spec=LLM)
|
||||
mock_agent.llm.metrics = Metrics()
|
||||
mock_agent.llm.config = OpenHandsConfig().get_llm_config()
|
||||
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
|
||||
|
||||
controller = AgentController(
|
||||
agent=mock_agent,
|
||||
event_stream=test_event_stream,
|
||||
convo_stats=conversation_stats,
|
||||
iteration_delta=10,
|
||||
sid='test',
|
||||
confirmation_mode=False,
|
||||
@@ -1569,11 +1713,15 @@ async def test_first_user_message_with_identical_content(test_event_stream, mock
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_controller_processes_null_observation_with_cause():
|
||||
async def test_agent_controller_processes_null_observation_with_cause(
|
||||
mock_agent_with_stats,
|
||||
):
|
||||
"""Test that AgentController processes NullObservation events with a cause value.
|
||||
|
||||
And that the agent's step method is called as a result.
|
||||
"""
|
||||
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
|
||||
|
||||
# Create an in-memory file store and real event stream
|
||||
file_store = InMemoryFileStore()
|
||||
event_stream = EventStream(sid='test-session', file_store=file_store)
|
||||
@@ -1581,19 +1729,11 @@ async def test_agent_controller_processes_null_observation_with_cause():
|
||||
# Create a Memory instance - not used directly in this test but needed for setup
|
||||
Memory(event_stream=event_stream, sid='test-session')
|
||||
|
||||
# Create a mock agent with necessary attributes
|
||||
mock_agent = MagicMock(spec=Agent)
|
||||
mock_agent.get_system_message = MagicMock(
|
||||
return_value=None,
|
||||
)
|
||||
mock_agent.llm = MagicMock(spec=LLM)
|
||||
mock_agent.llm.metrics = Metrics()
|
||||
mock_agent.llm.config = OpenHandsConfig().get_llm_config()
|
||||
|
||||
# Create a controller with the mock agent
|
||||
controller = AgentController(
|
||||
agent=mock_agent,
|
||||
event_stream=event_stream,
|
||||
convo_stats=conversation_stats,
|
||||
iteration_delta=10,
|
||||
sid='test-session',
|
||||
)
|
||||
@@ -1655,8 +1795,12 @@ async def test_agent_controller_processes_null_observation_with_cause():
|
||||
)
|
||||
|
||||
|
||||
def test_agent_controller_should_step_with_null_observation_cause_zero(mock_agent):
|
||||
def test_agent_controller_should_step_with_null_observation_cause_zero(
|
||||
mock_agent_with_stats,
|
||||
):
|
||||
"""Test that AgentController's should_step method returns False for NullObservation with cause = 0."""
|
||||
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
|
||||
|
||||
# Create a mock event stream
|
||||
file_store = InMemoryFileStore()
|
||||
event_stream = EventStream(sid='test-session', file_store=file_store)
|
||||
@@ -1665,6 +1809,7 @@ def test_agent_controller_should_step_with_null_observation_cause_zero(mock_agen
|
||||
controller = AgentController(
|
||||
agent=mock_agent,
|
||||
event_stream=event_stream,
|
||||
convo_stats=conversation_stats,
|
||||
iteration_delta=10,
|
||||
sid='test-session',
|
||||
)
|
||||
@@ -1683,10 +1828,15 @@ def test_agent_controller_should_step_with_null_observation_cause_zero(mock_agen
|
||||
)
|
||||
|
||||
|
||||
def test_system_message_in_event_stream(mock_agent, test_event_stream):
|
||||
def test_system_message_in_event_stream(mock_agent_with_stats, test_event_stream):
|
||||
"""Test that SystemMessageAction is added to event stream in AgentController."""
|
||||
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
|
||||
|
||||
_ = AgentController(
|
||||
agent=mock_agent, event_stream=test_event_stream, iteration_delta=10
|
||||
agent=mock_agent,
|
||||
event_stream=test_event_stream,
|
||||
convo_stats=conversation_stats,
|
||||
iteration_delta=10,
|
||||
)
|
||||
|
||||
# Get events from the event stream
|
||||
|
||||
@@ -12,8 +12,9 @@ from openhands.controller.state.control_flags import (
|
||||
IterationControlFlag,
|
||||
)
|
||||
from openhands.controller.state.state import State
|
||||
from openhands.core.config import LLMConfig
|
||||
from openhands.core.config import OpenHandsConfig
|
||||
from openhands.core.config.agent_config import AgentConfig
|
||||
from openhands.core.config.llm_config import LLMConfig
|
||||
from openhands.core.schema import AgentState
|
||||
from openhands.events import EventSource, EventStream
|
||||
from openhands.events.action import (
|
||||
@@ -28,11 +29,39 @@ from openhands.events.event import Event, RecallType
|
||||
from openhands.events.observation.agent import RecallObservation
|
||||
from openhands.events.stream import EventStreamSubscriber
|
||||
from openhands.llm.llm import LLM
|
||||
from openhands.llm.llm_registry import LLMRegistry
|
||||
from openhands.llm.metrics import Metrics
|
||||
from openhands.memory.memory import Memory
|
||||
from openhands.server.services.conversation_stats import ConversationStats
|
||||
from openhands.storage.memory import InMemoryFileStore
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def llm_registry():
|
||||
config = OpenHandsConfig()
|
||||
return LLMRegistry(config=config)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def conversation_stats():
|
||||
import uuid
|
||||
|
||||
file_store = InMemoryFileStore({})
|
||||
# Use a unique conversation ID for each test to avoid conflicts
|
||||
conversation_id = f'test-conversation-{uuid.uuid4()}'
|
||||
return ConversationStats(
|
||||
file_store=file_store, conversation_id=conversation_id, user_id='test-user'
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def connected_registry_and_stats(llm_registry, conversation_stats):
|
||||
"""Connect the LLMRegistry and ConversationStats properly"""
|
||||
# Subscribe to LLM registry events to track metrics
|
||||
llm_registry.subscribe(conversation_stats.register_llm)
|
||||
return llm_registry, conversation_stats
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_event_stream():
|
||||
"""Creates an event stream in memory."""
|
||||
@@ -42,15 +71,17 @@ def mock_event_stream():
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_parent_agent():
|
||||
def mock_parent_agent(llm_registry):
|
||||
"""Creates a mock parent agent for testing delegation."""
|
||||
agent = MagicMock(spec=Agent)
|
||||
agent.name = 'ParentAgent'
|
||||
agent.llm = MagicMock(spec=LLM)
|
||||
agent.llm.service_id = 'main_agent'
|
||||
agent.llm.metrics = Metrics()
|
||||
agent.llm.config = LLMConfig()
|
||||
agent.llm.retry_listener = None # Add retry_listener attribute
|
||||
agent.config = AgentConfig()
|
||||
agent.llm_registry = llm_registry # Add the missing llm_registry attribute
|
||||
|
||||
# Add a proper system message mock
|
||||
system_message = SystemMessageAction(content='Test system message')
|
||||
@@ -61,15 +92,17 @@ def mock_parent_agent():
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_child_agent():
|
||||
def mock_child_agent(llm_registry):
|
||||
"""Creates a mock child agent for testing delegation."""
|
||||
agent = MagicMock(spec=Agent)
|
||||
agent.name = 'ChildAgent'
|
||||
agent.llm = MagicMock(spec=LLM)
|
||||
agent.llm.service_id = 'main_agent'
|
||||
agent.llm.metrics = Metrics()
|
||||
agent.llm.config = LLMConfig()
|
||||
agent.llm.retry_listener = None # Add retry_listener attribute
|
||||
agent.config = AgentConfig()
|
||||
agent.llm_registry = llm_registry # Add the missing llm_registry attribute
|
||||
|
||||
system_message = SystemMessageAction(content='Test system message')
|
||||
system_message._source = EventSource.AGENT
|
||||
@@ -78,15 +111,37 @@ def mock_child_agent():
|
||||
return agent
|
||||
|
||||
|
||||
def create_mock_agent_factory(mock_child_agent, llm_registry):
|
||||
"""Helper function to create a mock agent factory with proper LLM registration."""
|
||||
|
||||
def create_mock_agent(config, llm_registry=None):
|
||||
# Register the mock agent's LLM in the registry so get_combined_metrics() can find it
|
||||
if llm_registry:
|
||||
mock_child_agent.llm = llm_registry.get_llm('agent_llm', LLMConfig())
|
||||
mock_child_agent.llm_registry = (
|
||||
llm_registry # Set the llm_registry attribute
|
||||
)
|
||||
return mock_child_agent
|
||||
|
||||
return create_mock_agent
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delegation_flow(mock_parent_agent, mock_child_agent, mock_event_stream):
|
||||
"""Test that when the parent agent delegates to a child
|
||||
1. the parent's delegate is set, and once the child finishes, the parent is cleaned up properly.
|
||||
2. metrics are accumulated globally (delegate is adding to the parents metrics)
|
||||
3. local metrics for the delegate are still accessible
|
||||
async def test_delegation_flow(
|
||||
mock_parent_agent, mock_child_agent, mock_event_stream, connected_registry_and_stats
|
||||
):
|
||||
"""
|
||||
Test that when the parent agent delegates to a child
|
||||
1. the parent's delegate is set, and once the child finishes, the parent is cleaned up properly.
|
||||
2. metrics are accumulated globally via LLM registry (delegate adds to the global metrics)
|
||||
3. global metrics tracking works correctly through the LLM registry
|
||||
"""
|
||||
llm_registry, conversation_stats = connected_registry_and_stats
|
||||
|
||||
# Mock the agent class resolution so that AgentController can instantiate mock_child_agent
|
||||
Agent.get_cls = Mock(return_value=lambda llm, config: mock_child_agent)
|
||||
Agent.get_cls = Mock(
|
||||
return_value=create_mock_agent_factory(mock_child_agent, llm_registry)
|
||||
)
|
||||
|
||||
step_count = 0
|
||||
|
||||
@@ -97,6 +152,12 @@ async def test_delegation_flow(mock_parent_agent, mock_child_agent, mock_event_s
|
||||
|
||||
mock_child_agent.step = agent_step_fn
|
||||
|
||||
# Set up the parent agent's LLM with initial cost and register it in the registry
|
||||
# The parent agent's LLM should use the existing registered LLM to ensure proper tracking
|
||||
parent_llm = llm_registry.service_to_llm['agent']
|
||||
parent_llm.metrics.accumulated_cost = 2
|
||||
mock_parent_agent.llm = parent_llm
|
||||
|
||||
parent_metrics = Metrics()
|
||||
parent_metrics.accumulated_cost = 2
|
||||
# Create parent controller
|
||||
@@ -114,6 +175,7 @@ async def test_delegation_flow(mock_parent_agent, mock_child_agent, mock_event_s
|
||||
parent_controller = AgentController(
|
||||
agent=mock_parent_agent,
|
||||
event_stream=mock_event_stream,
|
||||
convo_stats=conversation_stats,
|
||||
iteration_delta=1, # Add the required iteration_delta parameter
|
||||
sid='parent',
|
||||
confirmation_mode=False,
|
||||
@@ -180,21 +242,23 @@ async def test_delegation_flow(mock_parent_agent, mock_child_agent, mock_event_s
|
||||
for i in range(4):
|
||||
delegate_controller.state.iteration_flag.step()
|
||||
delegate_controller.agent.step(delegate_controller.state)
|
||||
# Update the agent's LLM metrics (not the deprecated state metrics)
|
||||
delegate_controller.agent.llm.metrics.add_cost(1.0)
|
||||
|
||||
assert (
|
||||
delegate_controller.state.get_local_step() == 4
|
||||
) # verify local metrics are accessible via snapshot
|
||||
|
||||
# Check that the conversation stats has the combined metrics (parent + delegate)
|
||||
combined_metrics = delegate_controller.state.convo_stats.get_combined_metrics()
|
||||
assert (
|
||||
delegate_controller.state.metrics.accumulated_cost
|
||||
== 6 # Make sure delegate tracks global cost
|
||||
combined_metrics.accumulated_cost
|
||||
== 6 # Make sure delegate tracks global cost (2 from parent + 4 from delegate)
|
||||
)
|
||||
|
||||
assert (
|
||||
delegate_controller.state.get_local_metrics().accumulated_cost
|
||||
== 4 # Delegate spent one dollar per step
|
||||
)
|
||||
# Since metrics are now global via LLM registry, local metrics tracking
|
||||
# is handled differently. The delegate's LLM shares the same metrics object
|
||||
# as the parent for global tracking, so we verify the global total is correct.
|
||||
|
||||
delegate_controller.state.outputs = {'delegate_result': 'done'}
|
||||
|
||||
@@ -228,15 +292,18 @@ async def test_delegation_flow(mock_parent_agent, mock_child_agent, mock_event_s
|
||||
],
|
||||
)
|
||||
async def test_delegate_step_different_states(
|
||||
mock_parent_agent, mock_event_stream, delegate_state
|
||||
mock_parent_agent, mock_event_stream, delegate_state, connected_registry_and_stats
|
||||
):
|
||||
"""Ensure that delegate is closed or remains open based on the delegate's state."""
|
||||
llm_registry, conversation_stats = connected_registry_and_stats
|
||||
|
||||
# Create a state with iteration_flag.max_value set to 10
|
||||
state = State(inputs={})
|
||||
state.iteration_flag.max_value = 10
|
||||
controller = AgentController(
|
||||
agent=mock_parent_agent,
|
||||
event_stream=mock_event_stream,
|
||||
convo_stats=conversation_stats,
|
||||
iteration_delta=1, # Add the required iteration_delta parameter
|
||||
sid='test',
|
||||
confirmation_mode=False,
|
||||
@@ -292,11 +359,23 @@ async def test_delegate_step_different_states(
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delegate_hits_global_limits(
|
||||
mock_child_agent, mock_event_stream, mock_parent_agent
|
||||
mock_child_agent, mock_event_stream, mock_parent_agent, connected_registry_and_stats
|
||||
):
|
||||
"""Global limits from control flags should apply to delegates"""
|
||||
"""
|
||||
Global limits from control flags should apply to delegates
|
||||
"""
|
||||
llm_registry, conversation_stats = connected_registry_and_stats
|
||||
|
||||
# Mock the agent class resolution so that AgentController can instantiate mock_child_agent
|
||||
Agent.get_cls = Mock(return_value=lambda llm, config: mock_child_agent)
|
||||
Agent.get_cls = Mock(
|
||||
return_value=create_mock_agent_factory(mock_child_agent, llm_registry)
|
||||
)
|
||||
|
||||
# Set up the parent agent's LLM with initial cost and register it in the registry
|
||||
mock_parent_agent.llm.metrics.accumulated_cost = 2
|
||||
mock_parent_agent.llm.service_id = 'main_agent'
|
||||
# Register the parent agent's LLM in the registry
|
||||
llm_registry.service_to_llm['main_agent'] = mock_parent_agent.llm
|
||||
|
||||
parent_metrics = Metrics()
|
||||
parent_metrics.accumulated_cost = 2
|
||||
@@ -315,6 +394,7 @@ async def test_delegate_hits_global_limits(
|
||||
parent_controller = AgentController(
|
||||
agent=mock_parent_agent,
|
||||
event_stream=mock_event_stream,
|
||||
convo_stats=conversation_stats,
|
||||
iteration_delta=1, # Add the required iteration_delta parameter
|
||||
sid='parent',
|
||||
confirmation_mode=False,
|
||||
|
||||
@@ -9,12 +9,13 @@ from openhands.core.config import LLMConfig, OpenHandsConfig
|
||||
from openhands.core.config.agent_config import AgentConfig
|
||||
from openhands.events import EventStream, EventStreamSubscriber
|
||||
from openhands.integrations.service_types import ProviderType
|
||||
from openhands.llm import LLM
|
||||
from openhands.llm.llm_registry import LLMRegistry
|
||||
from openhands.llm.metrics import Metrics
|
||||
from openhands.memory.memory import Memory
|
||||
from openhands.runtime.impl.action_execution.action_execution_client import (
|
||||
ActionExecutionClient,
|
||||
)
|
||||
from openhands.server.services.conversation_stats import ConversationStats
|
||||
from openhands.server.session.agent_session import AgentSession
|
||||
from openhands.storage.memory import InMemoryFileStore
|
||||
|
||||
@@ -22,44 +23,70 @@ from openhands.storage.memory import InMemoryFileStore
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_agent():
|
||||
"""Create a properly configured mock agent with all required nested attributes"""
|
||||
# Create the base mocks
|
||||
agent = MagicMock(spec=Agent)
|
||||
llm = MagicMock(spec=LLM)
|
||||
metrics = MagicMock(spec=Metrics)
|
||||
llm_config = MagicMock(spec=LLMConfig)
|
||||
agent_config = MagicMock(spec=AgentConfig)
|
||||
def mock_llm_registry():
|
||||
"""Create a mock LLM registry that properly simulates LLM registration"""
|
||||
config = OpenHandsConfig()
|
||||
registry = LLMRegistry(config=config, agent_cls=None, retry_listener=None)
|
||||
return registry
|
||||
|
||||
# Configure the LLM config
|
||||
llm_config.model = 'test-model'
|
||||
llm_config.base_url = 'http://test'
|
||||
llm_config.max_message_chars = 1000
|
||||
|
||||
# Configure the agent config
|
||||
agent_config.disabled_microagents = []
|
||||
agent_config.enable_mcp = True
|
||||
@pytest.fixture
|
||||
def mock_conversation_stats():
|
||||
"""Create a mock ConversationStats that properly simulates metrics tracking"""
|
||||
file_store = InMemoryFileStore({})
|
||||
stats = ConversationStats(
|
||||
file_store=file_store, conversation_id='test-conversation', user_id='test-user'
|
||||
)
|
||||
return stats
|
||||
|
||||
# Set up the chain of mocks
|
||||
llm.metrics = metrics
|
||||
llm.config = llm_config
|
||||
agent.llm = llm
|
||||
agent.name = 'test-agent'
|
||||
agent.sandbox_plugins = []
|
||||
agent.config = agent_config
|
||||
agent.prompt_manager = MagicMock()
|
||||
|
||||
return agent
|
||||
@pytest.fixture
|
||||
def connected_registry_and_stats(mock_llm_registry, mock_conversation_stats):
|
||||
"""Connect the LLMRegistry and ConversationStats properly"""
|
||||
# Subscribe to LLM registry events to track metrics
|
||||
mock_llm_registry.subscribe(mock_conversation_stats.register_llm)
|
||||
return mock_llm_registry, mock_conversation_stats
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def make_mock_agent():
|
||||
def _make_mock_agent(llm_registry):
|
||||
agent = MagicMock(spec=Agent)
|
||||
agent_config = MagicMock(spec=AgentConfig)
|
||||
llm_config = LLMConfig(
|
||||
model='gpt-4o',
|
||||
api_key='test_key',
|
||||
num_retries=2,
|
||||
retry_min_wait=1,
|
||||
retry_max_wait=2,
|
||||
)
|
||||
agent_config.disabled_microagents = []
|
||||
agent_config.enable_mcp = True
|
||||
llm_registry.service_to_llm.clear()
|
||||
mock_llm = llm_registry.get_llm('agent_llm', llm_config)
|
||||
agent.llm = mock_llm
|
||||
agent.name = 'test-agent'
|
||||
agent.sandbox_plugins = []
|
||||
agent.config = agent_config
|
||||
agent.prompt_manager = MagicMock()
|
||||
return agent
|
||||
|
||||
return _make_mock_agent
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_session_start_with_no_state(mock_agent):
|
||||
async def test_agent_session_start_with_no_state(
|
||||
make_mock_agent, mock_llm_registry, mock_conversation_stats
|
||||
):
|
||||
"""Test that AgentSession.start() works correctly when there's no state to restore"""
|
||||
mock_agent = make_mock_agent(mock_llm_registry)
|
||||
# Setup
|
||||
file_store = InMemoryFileStore({})
|
||||
session = AgentSession(
|
||||
sid='test-session',
|
||||
file_store=file_store,
|
||||
llm_registry=mock_llm_registry,
|
||||
convo_stats=mock_conversation_stats,
|
||||
)
|
||||
|
||||
# Create a mock runtime and set it up
|
||||
@@ -140,13 +167,18 @@ async def test_agent_session_start_with_no_state(mock_agent):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_session_start_with_restored_state(mock_agent):
|
||||
async def test_agent_session_start_with_restored_state(
|
||||
make_mock_agent, mock_llm_registry, mock_conversation_stats
|
||||
):
|
||||
"""Test that AgentSession.start() works correctly when there's a state to restore"""
|
||||
mock_agent = make_mock_agent(mock_llm_registry)
|
||||
# Setup
|
||||
file_store = InMemoryFileStore({})
|
||||
session = AgentSession(
|
||||
sid='test-session',
|
||||
file_store=file_store,
|
||||
llm_registry=mock_llm_registry,
|
||||
convo_stats=mock_conversation_stats,
|
||||
)
|
||||
|
||||
# Create a mock runtime and set it up
|
||||
@@ -230,13 +262,21 @@ async def test_agent_session_start_with_restored_state(mock_agent):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_metrics_centralization_and_sharing(mock_agent):
|
||||
"""Test that metrics are centralized and shared between controller and agent."""
|
||||
async def test_metrics_centralization_via_conversation_stats(
|
||||
make_mock_agent, connected_registry_and_stats
|
||||
):
|
||||
"""Test that metrics are centralized through the ConversationStats service."""
|
||||
|
||||
mock_llm_registry, mock_conversation_stats = connected_registry_and_stats
|
||||
mock_agent = make_mock_agent(mock_llm_registry)
|
||||
|
||||
# Setup
|
||||
file_store = InMemoryFileStore({})
|
||||
session = AgentSession(
|
||||
sid='test-session',
|
||||
file_store=file_store,
|
||||
llm_registry=mock_llm_registry,
|
||||
convo_stats=mock_conversation_stats,
|
||||
)
|
||||
|
||||
# Create a mock runtime and set it up
|
||||
@@ -262,6 +302,8 @@ async def test_metrics_centralization_and_sharing(mock_agent):
|
||||
memory = Memory(event_stream=mock_event_stream, sid='test-session')
|
||||
memory.microagents_dir = 'test-dir'
|
||||
|
||||
# The registry already has a real metrics object set up in the fixture
|
||||
|
||||
# Patch necessary components
|
||||
with (
|
||||
patch(
|
||||
@@ -281,49 +323,50 @@ async def test_metrics_centralization_and_sharing(mock_agent):
|
||||
max_iterations=10,
|
||||
)
|
||||
|
||||
# Verify that the agent's LLM metrics and controller's state metrics are the same object
|
||||
assert session.controller.agent.llm.metrics is session.controller.state.metrics
|
||||
# Verify that the ConversationStats is properly set up
|
||||
assert session.controller.state.convo_stats is mock_conversation_stats
|
||||
|
||||
# Add some metrics to the agent's LLM
|
||||
# Add some metrics to the agent's LLM (simulating LLM usage)
|
||||
test_cost = 0.05
|
||||
session.controller.agent.llm.metrics.add_cost(test_cost)
|
||||
|
||||
# Verify that the cost is reflected in the controller's state metrics
|
||||
assert session.controller.state.metrics.accumulated_cost == test_cost
|
||||
# Verify that the cost is reflected in the combined metrics from the conversation stats
|
||||
combined_metrics = session.controller.state.convo_stats.get_combined_metrics()
|
||||
assert combined_metrics.accumulated_cost == test_cost
|
||||
|
||||
# Create a test metrics object to simulate an observation with metrics
|
||||
test_observation_metrics = Metrics()
|
||||
test_observation_metrics.add_cost(0.1)
|
||||
# Add more cost to simulate additional LLM usage
|
||||
additional_cost = 0.1
|
||||
session.controller.agent.llm.metrics.add_cost(additional_cost)
|
||||
|
||||
# Get the current accumulated cost before merging
|
||||
current_cost = session.controller.state.metrics.accumulated_cost
|
||||
# Verify the combined metrics reflect the total cost
|
||||
combined_metrics = session.controller.state.convo_stats.get_combined_metrics()
|
||||
assert combined_metrics.accumulated_cost == test_cost + additional_cost
|
||||
|
||||
# Simulate merging metrics from an observation
|
||||
session.controller.state_tracker.merge_metrics(test_observation_metrics)
|
||||
|
||||
# Verify that the merged metrics are reflected in both agent and controller
|
||||
assert session.controller.state.metrics.accumulated_cost == current_cost + 0.1
|
||||
assert (
|
||||
session.controller.agent.llm.metrics.accumulated_cost == current_cost + 0.1
|
||||
)
|
||||
|
||||
# Reset the agent and verify that metrics are not reset
|
||||
# Reset the agent and verify that combined metrics are preserved
|
||||
session.controller.agent.reset()
|
||||
|
||||
# Metrics should still be the same after reset
|
||||
assert session.controller.state.metrics.accumulated_cost == test_cost + 0.1
|
||||
assert session.controller.agent.llm.metrics.accumulated_cost == test_cost + 0.1
|
||||
assert session.controller.agent.llm.metrics is session.controller.state.metrics
|
||||
# Combined metrics should still be preserved after agent reset
|
||||
assert (
|
||||
session.controller.state.convo_stats.get_combined_metrics().accumulated_cost
|
||||
== test_cost + additional_cost
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_budget_control_flag_syncs_with_metrics(mock_agent):
|
||||
async def test_budget_control_flag_syncs_with_metrics(
|
||||
make_mock_agent, connected_registry_and_stats
|
||||
):
|
||||
"""Test that BudgetControlFlag's current value matches the accumulated costs."""
|
||||
|
||||
mock_llm_registry, mock_conversation_stats = connected_registry_and_stats
|
||||
mock_agent = make_mock_agent(mock_llm_registry)
|
||||
# Setup
|
||||
file_store = InMemoryFileStore({})
|
||||
session = AgentSession(
|
||||
sid='test-session',
|
||||
file_store=file_store,
|
||||
llm_registry=mock_llm_registry,
|
||||
convo_stats=mock_conversation_stats,
|
||||
)
|
||||
|
||||
# Create a mock runtime and set it up
|
||||
@@ -349,6 +392,8 @@ async def test_budget_control_flag_syncs_with_metrics(mock_agent):
|
||||
memory = Memory(event_stream=mock_event_stream, sid='test-session')
|
||||
memory.microagents_dir = 'test-dir'
|
||||
|
||||
# The registry already has a real metrics object set up in the fixture
|
||||
|
||||
# Patch necessary components
|
||||
with (
|
||||
patch(
|
||||
@@ -375,7 +420,7 @@ async def test_budget_control_flag_syncs_with_metrics(mock_agent):
|
||||
assert session.controller.state.budget_flag.max_value == 1.0
|
||||
assert session.controller.state.budget_flag.current_value == 0.0
|
||||
|
||||
# Add some metrics to the agent's LLM
|
||||
# Add some metrics to the agent's LLM (simulating LLM usage)
|
||||
test_cost = 0.05
|
||||
session.controller.agent.llm.metrics.add_cost(test_cost)
|
||||
|
||||
@@ -384,24 +429,31 @@ async def test_budget_control_flag_syncs_with_metrics(mock_agent):
|
||||
session.controller.state_tracker.sync_budget_flag_with_metrics()
|
||||
assert session.controller.state.budget_flag.current_value == test_cost
|
||||
|
||||
# Create a test metrics object to simulate an observation with metrics
|
||||
test_observation_metrics = Metrics()
|
||||
test_observation_metrics.add_cost(0.1)
|
||||
# Add more cost to simulate additional LLM usage
|
||||
additional_cost = 0.1
|
||||
session.controller.agent.llm.metrics.add_cost(additional_cost)
|
||||
|
||||
# Simulate merging metrics from an observation
|
||||
session.controller.state_tracker.merge_metrics(test_observation_metrics)
|
||||
# Sync again and verify the budget flag is updated
|
||||
session.controller.state_tracker.sync_budget_flag_with_metrics()
|
||||
assert (
|
||||
session.controller.state.budget_flag.current_value
|
||||
== test_cost + additional_cost
|
||||
)
|
||||
|
||||
# Verify that the budget control flag's current value is updated to match the new accumulated cost
|
||||
assert session.controller.state.budget_flag.current_value == test_cost + 0.1
|
||||
|
||||
# Reset the agent and verify that metrics and budget flag are not reset
|
||||
# Reset the agent and verify that budget flag still reflects the accumulated cost
|
||||
session.controller.agent.reset()
|
||||
|
||||
# Budget control flag should still reflect the accumulated cost after reset
|
||||
assert session.controller.state.budget_flag.current_value == test_cost + 0.1
|
||||
session.controller.state_tracker.sync_budget_flag_with_metrics()
|
||||
assert (
|
||||
session.controller.state.budget_flag.current_value
|
||||
== test_cost + additional_cost
|
||||
)
|
||||
|
||||
|
||||
def test_override_provider_tokens_with_custom_secret():
|
||||
def test_override_provider_tokens_with_custom_secret(
|
||||
mock_llm_registry, mock_conversation_stats
|
||||
):
|
||||
"""Test that override_provider_tokens_with_custom_secret works correctly.
|
||||
|
||||
This test verifies that the method properly removes provider tokens when
|
||||
@@ -413,6 +465,8 @@ def test_override_provider_tokens_with_custom_secret():
|
||||
session = AgentSession(
|
||||
sid='test-session',
|
||||
file_store=file_store,
|
||||
llm_registry=mock_llm_registry,
|
||||
convo_stats=mock_conversation_stats,
|
||||
)
|
||||
|
||||
# Create test data
|
||||
|
||||
@@ -30,6 +30,7 @@ from openhands.agenthub.readonly_agent.tools import (
|
||||
)
|
||||
from openhands.controller.state.state import State
|
||||
from openhands.core.config import AgentConfig, LLMConfig
|
||||
from openhands.core.config.openhands_config import OpenHandsConfig
|
||||
from openhands.core.exceptions import FunctionCallNotExistsError
|
||||
from openhands.core.message import ImageContent, Message, TextContent
|
||||
from openhands.events.action import (
|
||||
@@ -42,10 +43,20 @@ from openhands.events.observation.commands import (
|
||||
CmdOutputObservation,
|
||||
)
|
||||
from openhands.events.tool import ToolCallMetadata
|
||||
from openhands.llm.llm import LLM
|
||||
from openhands.llm.llm_registry import LLMRegistry
|
||||
from openhands.memory.condenser import View
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def create_llm_registry():
|
||||
def _get_registry(llm_config):
|
||||
config = OpenHandsConfig()
|
||||
config.set_llm_config(llm_config)
|
||||
return LLMRegistry(config=config)
|
||||
|
||||
return _get_registry
|
||||
|
||||
|
||||
@pytest.fixture(params=['CodeActAgent', 'ReadOnlyAgent'])
|
||||
def agent_class(request):
|
||||
if request.param == 'CodeActAgent':
|
||||
@@ -57,18 +68,22 @@ def agent_class(request):
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def agent(agent_class) -> Union[CodeActAgent, ReadOnlyAgent]:
|
||||
def agent(agent_class, create_llm_registry) -> Union[CodeActAgent, ReadOnlyAgent]:
|
||||
llm_config = LLMConfig(model='gpt-4o', api_key='test_key')
|
||||
config = AgentConfig()
|
||||
agent = agent_class(llm=LLM(LLMConfig()), config=config)
|
||||
agent = agent_class(config=config, llm_registry=create_llm_registry(llm_config))
|
||||
agent.llm = Mock()
|
||||
agent.llm.config = Mock()
|
||||
agent.llm.config.max_message_chars = 1000
|
||||
return agent
|
||||
|
||||
|
||||
def test_agent_with_default_config_has_default_tools():
|
||||
def test_agent_with_default_config_has_default_tools(create_llm_registry):
|
||||
llm_config = LLMConfig(model='gpt-4o', api_key='test_key')
|
||||
config = AgentConfig()
|
||||
codeact_agent = CodeActAgent(llm=LLM(LLMConfig()), config=config)
|
||||
codeact_agent = CodeActAgent(
|
||||
config=config, llm_registry=create_llm_registry(llm_config)
|
||||
)
|
||||
assert len(codeact_agent.tools) > 0
|
||||
default_tool_names = [tool['function']['name'] for tool in codeact_agent.tools]
|
||||
assert {
|
||||
@@ -231,7 +246,7 @@ def test_response_to_actions_invalid_tool():
|
||||
readonly_response_to_actions(mock_response)
|
||||
|
||||
|
||||
def test_step_with_no_pending_actions(mock_state: State):
|
||||
def test_step_with_no_pending_actions(mock_state: State, create_llm_registry):
|
||||
# Mock the LLM response
|
||||
mock_response = Mock()
|
||||
mock_response.id = 'mock_id'
|
||||
@@ -252,9 +267,12 @@ def test_step_with_no_pending_actions(mock_state: State):
|
||||
llm.format_messages_for_llm = Mock(return_value=[]) # Mock message formatting
|
||||
|
||||
# Create agent with mocked LLM
|
||||
llm_config = LLMConfig(model='gpt-4o', api_key='test_key')
|
||||
config = AgentConfig()
|
||||
config.enable_prompt_extensions = False
|
||||
agent = CodeActAgent(llm=llm, config=config)
|
||||
agent = CodeActAgent(config=config, llm_registry=create_llm_registry(llm_config))
|
||||
# Replace the LLM with our mock after creation
|
||||
agent.llm = llm
|
||||
|
||||
# Test step with no pending actions
|
||||
mock_state.latest_user_message = None
|
||||
@@ -281,15 +299,10 @@ def test_step_with_no_pending_actions(mock_state: State):
|
||||
|
||||
@pytest.mark.parametrize('agent_type', ['CodeActAgent', 'ReadOnlyAgent'])
|
||||
def test_correct_tool_description_loaded_based_on_model_name(
|
||||
agent_type, mock_state: State
|
||||
agent_type, create_llm_registry
|
||||
):
|
||||
"""Tests that the simplified tool descriptions are loaded for specific models."""
|
||||
o3_mock_config = Mock()
|
||||
o3_mock_config.model = 'mock_o3_model'
|
||||
|
||||
llm = Mock()
|
||||
llm.config = o3_mock_config
|
||||
|
||||
o3_mock_config = LLMConfig(model='mock_o3_model', api_key='test_key')
|
||||
if agent_type == 'CodeActAgent':
|
||||
from openhands.agenthub.codeact_agent.codeact_agent import CodeActAgent
|
||||
|
||||
@@ -299,16 +312,19 @@ def test_correct_tool_description_loaded_based_on_model_name(
|
||||
|
||||
agent_class = ReadOnlyAgent
|
||||
|
||||
agent = agent_class(llm=llm, config=AgentConfig())
|
||||
agent = agent_class(
|
||||
config=AgentConfig(),
|
||||
llm_registry=create_llm_registry(o3_mock_config),
|
||||
)
|
||||
for tool in agent.tools:
|
||||
# Assert all descriptions have less than 1024 characters
|
||||
assert len(tool['function']['description']) < 1024
|
||||
|
||||
sonnet_mock_config = Mock()
|
||||
sonnet_mock_config.model = 'mock_sonnet_model'
|
||||
|
||||
llm.config = sonnet_mock_config
|
||||
agent = agent_class(llm=llm, config=AgentConfig())
|
||||
sonnect_mock_config = LLMConfig(model='mock_sonnet_model', api_key='test_key')
|
||||
agent = agent_class(
|
||||
config=AgentConfig(),
|
||||
llm_registry=create_llm_registry(sonnect_mock_config),
|
||||
)
|
||||
# Assert existence of the detailed tool descriptions that are longer than 1024 characters
|
||||
if agent_type == 'CodeActAgent':
|
||||
# This only holds for CodeActAgent
|
||||
@@ -481,10 +497,12 @@ def test_enhance_messages_adds_newlines_between_consecutive_user_messages(
|
||||
assert isinstance(enhanced_messages[5].content[0], ImageContent)
|
||||
|
||||
|
||||
def test_get_system_message():
|
||||
def test_get_system_message(create_llm_registry):
|
||||
"""Test that the Agent.get_system_message method returns a SystemMessageAction."""
|
||||
# Create a mock agent
|
||||
agent = CodeActAgent(llm=LLM(LLMConfig()), config=AgentConfig())
|
||||
config = AgentConfig()
|
||||
llm_config = LLMConfig(model='gpt-4o', api_key='test_key')
|
||||
agent = CodeActAgent(config=config, llm_registry=create_llm_registry(llm_config))
|
||||
|
||||
result = agent.get_system_message()
|
||||
|
||||
|
||||
@@ -34,7 +34,7 @@ def test_completion_retries_api_connection_error(
|
||||
]
|
||||
|
||||
# Create an LLM instance and call completion
|
||||
llm = LLM(config=default_config)
|
||||
llm = LLM(config=default_config, service_id='test-service')
|
||||
response = llm.completion(
|
||||
messages=[{'role': 'user', 'content': 'Hello!'}],
|
||||
stream=False,
|
||||
@@ -70,7 +70,7 @@ def test_completion_max_retries_api_connection_error(
|
||||
]
|
||||
|
||||
# Create an LLM instance and call completion
|
||||
llm = LLM(config=default_config)
|
||||
llm = LLM(config=default_config, service_id='test-service')
|
||||
|
||||
# The completion should raise an APIConnectionError after exhausting all retries
|
||||
with pytest.raises(APIConnectionError) as excinfo:
|
||||
|
||||
@@ -5,11 +5,11 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from openhands.core.config.llm_config import LLMConfig
|
||||
from openhands.core.config.openhands_config import OpenHandsConfig
|
||||
from openhands.events.action import MessageAction
|
||||
from openhands.events.event import EventSource
|
||||
from openhands.events.event_store import EventStore
|
||||
from openhands.llm.llm_registry import LLMRegistry
|
||||
from openhands.server.conversation_manager.standalone_conversation_manager import (
|
||||
StandaloneConversationManager,
|
||||
)
|
||||
@@ -24,6 +24,7 @@ async def test_auto_generate_title_with_llm():
|
||||
"""Test auto-generating a title using LLM."""
|
||||
# Mock dependencies
|
||||
file_store = InMemoryFileStore()
|
||||
llm_registry = MagicMock(spec=LLMRegistry)
|
||||
|
||||
# Create test conversation with a user message
|
||||
conversation_id = 'test-conversation'
|
||||
@@ -46,43 +47,33 @@ async def test_auto_generate_title_with_llm():
|
||||
mock_event_store.search_events.return_value = [user_message]
|
||||
mock_event_store_cls.return_value = mock_event_store
|
||||
|
||||
# Mock the LLM response
|
||||
with patch('openhands.utils.conversation_summary.LLM') as mock_llm_cls:
|
||||
mock_llm = mock_llm_cls.return_value
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = 'Python Data Analysis Script'
|
||||
mock_llm.completion.return_value = mock_response
|
||||
# Mock the LLM registry response
|
||||
llm_registry.request_extraneous_completion.return_value = (
|
||||
'Python Data Analysis Script'
|
||||
)
|
||||
|
||||
# Create test settings with LLM config
|
||||
settings = Settings(
|
||||
llm_model='test-model',
|
||||
llm_api_key='test-key',
|
||||
llm_base_url='test-url',
|
||||
)
|
||||
# Create test settings with LLM config
|
||||
settings = Settings(
|
||||
llm_model='test-model',
|
||||
llm_api_key='test-key',
|
||||
llm_base_url='test-url',
|
||||
)
|
||||
|
||||
# Call the auto_generate_title function directly
|
||||
title = await auto_generate_title(
|
||||
conversation_id, user_id, file_store, settings
|
||||
)
|
||||
# Call the auto_generate_title function directly
|
||||
title = await auto_generate_title(
|
||||
conversation_id, user_id, file_store, settings, llm_registry
|
||||
)
|
||||
|
||||
# Verify the result
|
||||
assert title == 'Python Data Analysis Script'
|
||||
# Verify the result
|
||||
assert title == 'Python Data Analysis Script'
|
||||
|
||||
# Verify EventStore was created with the correct parameters
|
||||
mock_event_store_cls.assert_called_once_with(
|
||||
conversation_id, file_store, user_id
|
||||
)
|
||||
# Verify EventStore was created with the correct parameters
|
||||
mock_event_store_cls.assert_called_once_with(
|
||||
conversation_id, file_store, user_id
|
||||
)
|
||||
|
||||
# Verify LLM was called with appropriate parameters
|
||||
mock_llm_cls.assert_called_once_with(
|
||||
LLMConfig(
|
||||
model='test-model',
|
||||
api_key='test-key',
|
||||
base_url='test-url',
|
||||
)
|
||||
)
|
||||
mock_llm.completion.assert_called_once()
|
||||
# Verify LLM registry was called with appropriate parameters
|
||||
llm_registry.request_extraneous_completion.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -90,6 +81,7 @@ async def test_auto_generate_title_fallback():
|
||||
"""Test auto-generating a title with fallback to truncation when LLM fails."""
|
||||
# Mock dependencies
|
||||
file_store = InMemoryFileStore()
|
||||
llm_registry = MagicMock(spec=LLMRegistry)
|
||||
|
||||
# Create test conversation with a user message
|
||||
conversation_id = 'test-conversation'
|
||||
@@ -111,31 +103,29 @@ async def test_auto_generate_title_fallback():
|
||||
mock_event_store.search_events.return_value = [user_message]
|
||||
mock_event_store_cls.return_value = mock_event_store
|
||||
|
||||
# Mock the LLM to raise an exception
|
||||
with patch('openhands.utils.conversation_summary.LLM') as mock_llm_cls:
|
||||
mock_llm = mock_llm_cls.return_value
|
||||
mock_llm.completion.side_effect = Exception('Test error')
|
||||
# Mock the LLM registry to raise an exception
|
||||
llm_registry.request_extraneous_completion.side_effect = Exception('Test error')
|
||||
|
||||
# Create test settings with LLM config
|
||||
settings = Settings(
|
||||
llm_model='test-model',
|
||||
llm_api_key='test-key',
|
||||
llm_base_url='test-url',
|
||||
)
|
||||
# Create test settings with LLM config
|
||||
settings = Settings(
|
||||
llm_model='test-model',
|
||||
llm_api_key='test-key',
|
||||
llm_base_url='test-url',
|
||||
)
|
||||
|
||||
# Call the auto_generate_title function directly
|
||||
title = await auto_generate_title(
|
||||
conversation_id, user_id, file_store, settings
|
||||
)
|
||||
# Call the auto_generate_title function directly
|
||||
title = await auto_generate_title(
|
||||
conversation_id, user_id, file_store, settings, llm_registry
|
||||
)
|
||||
|
||||
# Verify the result is a truncated version of the message
|
||||
assert title == 'This is a very long message th...'
|
||||
assert len(title) <= 35
|
||||
# Verify the result is a truncated version of the message
|
||||
assert title == 'This is a very long message th...'
|
||||
assert len(title) <= 35
|
||||
|
||||
# Verify EventStore was created with the correct parameters
|
||||
mock_event_store_cls.assert_called_once_with(
|
||||
conversation_id, file_store, user_id
|
||||
)
|
||||
# Verify EventStore was created with the correct parameters
|
||||
mock_event_store_cls.assert_called_once_with(
|
||||
conversation_id, file_store, user_id
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -143,6 +133,7 @@ async def test_auto_generate_title_no_messages():
|
||||
"""Test auto-generating a title when there are no user messages."""
|
||||
# Mock dependencies
|
||||
file_store = InMemoryFileStore()
|
||||
llm_registry = MagicMock(spec=LLMRegistry)
|
||||
|
||||
# Create test conversation with no messages
|
||||
conversation_id = 'test-conversation'
|
||||
@@ -166,7 +157,7 @@ async def test_auto_generate_title_no_messages():
|
||||
|
||||
# Call the auto_generate_title function directly
|
||||
title = await auto_generate_title(
|
||||
conversation_id, user_id, file_store, settings
|
||||
conversation_id, user_id, file_store, settings, llm_registry
|
||||
)
|
||||
|
||||
# Verify the result is empty
|
||||
@@ -186,6 +177,7 @@ async def test_update_conversation_with_title():
|
||||
sio.emit = AsyncMock()
|
||||
file_store = InMemoryFileStore()
|
||||
server_config = MagicMock()
|
||||
llm_registry = MagicMock(spec=LLMRegistry)
|
||||
|
||||
# Create test conversation
|
||||
conversation_id = 'test-conversation'
|
||||
@@ -222,7 +214,9 @@ async def test_update_conversation_with_title():
|
||||
AsyncMock(return_value='Generated Title'),
|
||||
):
|
||||
# Call the method
|
||||
await manager._update_conversation_for_event(user_id, conversation_id, settings)
|
||||
await manager._update_conversation_for_event(
|
||||
user_id, conversation_id, settings, llm_registry
|
||||
)
|
||||
|
||||
# Verify the title was updated
|
||||
assert mock_metadata.title == 'Generated Title'
|
||||
|
||||
@@ -6,6 +6,7 @@ import pytest_asyncio
|
||||
|
||||
from openhands.cli import main as cli
|
||||
from openhands.controller.state.state import State
|
||||
from openhands.core.config.llm_config import LLMConfig
|
||||
from openhands.events import EventSource
|
||||
from openhands.events.action import MessageAction
|
||||
|
||||
@@ -124,12 +125,14 @@ def mock_config():
|
||||
'' # Empty string, not starting with 'tvly-'
|
||||
)
|
||||
config.search_api_key = search_api_key_mock
|
||||
config.get_llm_config_from_agent.return_value = LLMConfig(model='model')
|
||||
|
||||
# Mock sandbox with volumes attribute to prevent finalize_config issues
|
||||
config.sandbox = MagicMock()
|
||||
config.sandbox.volumes = (
|
||||
None # This prevents finalize_config from overriding workspace_base
|
||||
)
|
||||
config.model_name = 'model'
|
||||
|
||||
return config
|
||||
|
||||
@@ -213,7 +216,11 @@ async def test_run_session_without_initial_action(
|
||||
# Assertions for initialization flow
|
||||
mock_display_runtime_init.assert_called_once_with('local')
|
||||
mock_display_animation.assert_called_once()
|
||||
mock_create_agent.assert_called_once_with(mock_config)
|
||||
# Check that mock_config is the first parameter to create_agent
|
||||
mock_create_agent.assert_called_once()
|
||||
assert mock_create_agent.call_args[0][0] == mock_config, (
|
||||
'First parameter to create_agent should be mock_config'
|
||||
)
|
||||
mock_add_mcp_tools.assert_called_once_with(mock_agent, mock_runtime, mock_memory)
|
||||
mock_create_runtime.assert_called_once()
|
||||
mock_create_controller.assert_called_once()
|
||||
|
||||
@@ -4,8 +4,10 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from litellm.exceptions import AuthenticationError
|
||||
from pydantic import SecretStr
|
||||
|
||||
from openhands.cli import main as cli
|
||||
from openhands.core.config.llm_config import LLMConfig
|
||||
from openhands.events import EventSource
|
||||
from openhands.events.action import MessageAction
|
||||
|
||||
@@ -45,11 +47,10 @@ def mock_config():
|
||||
config.workspace_base = '/test/dir'
|
||||
|
||||
# Set up LLM config to use OpenHands provider
|
||||
llm_config = MagicMock()
|
||||
llm_config = LLMConfig(model='openhands/o3', api_key=SecretStr('invalid-api-key'))
|
||||
llm_config.model = 'openhands/o3' # Use OpenHands provider with o3 model
|
||||
llm_config.api_key = MagicMock()
|
||||
llm_config.api_key.get_secret_value.return_value = 'invalid-api-key'
|
||||
config.llm = llm_config
|
||||
config.get_llm_config.return_value = llm_config
|
||||
config.get_llm_config_from_agent.return_value = llm_config
|
||||
|
||||
# Mock search_api_key with get_secret_value method
|
||||
search_api_key_mock = MagicMock()
|
||||
|
||||
@@ -13,6 +13,7 @@ from openhands.core.config.mcp_config import (
|
||||
from openhands.events.action.mcp import MCPAction
|
||||
from openhands.events.observation import ErrorObservation
|
||||
from openhands.events.observation.mcp import MCPObservation
|
||||
from openhands.llm.llm_registry import LLMRegistry
|
||||
from openhands.runtime.impl.cli.cli_runtime import CLIRuntime
|
||||
|
||||
|
||||
@@ -23,8 +24,12 @@ class TestCLIRuntimeMCP:
|
||||
"""Set up test fixtures."""
|
||||
self.config = OpenHandsConfig()
|
||||
self.event_stream = MagicMock()
|
||||
llm_registry = LLMRegistry(config=OpenHandsConfig())
|
||||
self.runtime = CLIRuntime(
|
||||
config=self.config, event_stream=self.event_stream, sid='test-session'
|
||||
config=self.config,
|
||||
event_stream=self.event_stream,
|
||||
sid='test-session',
|
||||
llm_registry=llm_registry,
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
@@ -7,10 +7,18 @@ import pytest
|
||||
|
||||
from openhands.core.config import OpenHandsConfig
|
||||
from openhands.events import EventStream
|
||||
|
||||
# Mock LLMRegistry
|
||||
from openhands.runtime.impl.cli.cli_runtime import CLIRuntime
|
||||
from openhands.storage import get_file_store
|
||||
|
||||
|
||||
# Create a mock LLMRegistry class
|
||||
class MockLLMRegistry:
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_dir():
|
||||
"""Create a temporary directory for testing."""
|
||||
@@ -25,7 +33,8 @@ def cli_runtime(temp_dir):
|
||||
event_stream = EventStream('test', file_store)
|
||||
config = OpenHandsConfig()
|
||||
config.workspace_base = temp_dir
|
||||
runtime = CLIRuntime(config, event_stream)
|
||||
llm_registry = MockLLMRegistry(config)
|
||||
runtime = CLIRuntime(config, event_stream, llm_registry)
|
||||
runtime._runtime_initialized = True # Skip initialization
|
||||
return runtime
|
||||
|
||||
|
||||
@@ -17,6 +17,7 @@ from openhands.core.config.condenser_config import (
|
||||
StructuredSummaryCondenserConfig,
|
||||
)
|
||||
from openhands.core.config.llm_config import LLMConfig
|
||||
from openhands.core.config.openhands_config import OpenHandsConfig
|
||||
from openhands.core.message import Message, TextContent
|
||||
from openhands.core.schema.action import ActionType
|
||||
from openhands.events.event import Event, EventSource
|
||||
@@ -24,6 +25,7 @@ from openhands.events.observation import BrowserOutputObservation
|
||||
from openhands.events.observation.agent import AgentCondensationObservation
|
||||
from openhands.events.observation.observation import Observation
|
||||
from openhands.llm import LLM
|
||||
from openhands.llm.llm_registry import LLMRegistry
|
||||
from openhands.memory.condenser import Condenser
|
||||
from openhands.memory.condenser.condenser import Condensation, RollingCondenser, View
|
||||
from openhands.memory.condenser.impl import (
|
||||
@@ -38,6 +40,7 @@ from openhands.memory.condenser.impl import (
|
||||
StructuredSummaryCondenser,
|
||||
)
|
||||
from openhands.memory.condenser.impl.pipeline import CondenserPipeline
|
||||
from openhands.server.services.conversation_stats import ConversationStats
|
||||
|
||||
|
||||
def create_test_event(
|
||||
@@ -56,12 +59,15 @@ def create_test_event(
|
||||
@pytest.fixture
|
||||
def mock_llm() -> LLM:
|
||||
"""Mocks an LLM object with a utility function for setting and resetting response contents in unit tests."""
|
||||
# Create a real LLMConfig instead of a mock to properly handle SecretStr api_key
|
||||
real_config = LLMConfig(
|
||||
model='gpt-4o', api_key='test_key', custom_llm_provider=None
|
||||
)
|
||||
|
||||
# Create a MagicMock for the LLM object
|
||||
mock_llm = MagicMock(
|
||||
spec=LLM,
|
||||
config=MagicMock(
|
||||
spec=LLMConfig, model='gpt-4o', api_key='test_key', custom_llm_provider=None
|
||||
),
|
||||
config=real_config,
|
||||
metrics=MagicMock(),
|
||||
)
|
||||
_mock_content = None
|
||||
@@ -95,6 +101,23 @@ def mock_llm() -> LLM:
|
||||
return mock_llm
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_conversation_stats() -> ConversationStats:
|
||||
"""Creates a mock ConversationStats service."""
|
||||
mock_stats = MagicMock(spec=ConversationStats)
|
||||
return mock_stats
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_llm_registry(mock_llm, mock_conversation_stats) -> LLMRegistry:
|
||||
"""Creates an actual LLMRegistry that returns real LLMs."""
|
||||
# Create an actual LLMRegistry with a basic OpenHandsConfig
|
||||
config = OpenHandsConfig()
|
||||
registry = LLMRegistry(config=config, agent_cls=None, retry_listener=None)
|
||||
|
||||
return registry
|
||||
|
||||
|
||||
class RollingCondenserTestHarness:
|
||||
"""Test harness for rolling condensers.
|
||||
|
||||
@@ -165,10 +188,10 @@ class RollingCondenserTestHarness:
|
||||
return ((index - max_size) // target_size) + 1
|
||||
|
||||
|
||||
def test_noop_condenser_from_config():
|
||||
def test_noop_condenser_from_config(mock_llm_registry):
|
||||
"""Test that the NoOpCondenser objects can be made from config."""
|
||||
config = NoOpCondenserConfig()
|
||||
condenser = Condenser.from_config(config)
|
||||
condenser = Condenser.from_config(config, mock_llm_registry)
|
||||
|
||||
assert isinstance(condenser, NoOpCondenser)
|
||||
|
||||
@@ -189,11 +212,11 @@ def test_noop_condenser():
|
||||
assert result == View(events=events)
|
||||
|
||||
|
||||
def test_observation_masking_condenser_from_config():
|
||||
def test_observation_masking_condenser_from_config(mock_llm_registry):
|
||||
"""Test that ObservationMaskingCondenser objects can be made from config."""
|
||||
attention_window = 5
|
||||
config = ObservationMaskingCondenserConfig(attention_window=attention_window)
|
||||
condenser = Condenser.from_config(config)
|
||||
condenser = Condenser.from_config(config, mock_llm_registry)
|
||||
|
||||
assert isinstance(condenser, ObservationMaskingCondenser)
|
||||
assert condenser.attention_window == attention_window
|
||||
@@ -229,11 +252,11 @@ def test_observation_masking_condenser_respects_attention_window():
|
||||
assert event == condensed_event
|
||||
|
||||
|
||||
def test_browser_output_condenser_from_config():
|
||||
def test_browser_output_condenser_from_config(mock_llm_registry):
|
||||
"""Test that BrowserOutputCondenser objects can be made from config."""
|
||||
attention_window = 5
|
||||
config = BrowserOutputCondenserConfig(attention_window=attention_window)
|
||||
condenser = Condenser.from_config(config)
|
||||
condenser = Condenser.from_config(config, mock_llm_registry)
|
||||
|
||||
assert isinstance(condenser, BrowserOutputCondenser)
|
||||
assert condenser.attention_window == attention_window
|
||||
@@ -271,12 +294,12 @@ def test_browser_output_condenser_respects_attention_window():
|
||||
assert event == condensed_event
|
||||
|
||||
|
||||
def test_recent_events_condenser_from_config():
|
||||
def test_recent_events_condenser_from_config(mock_llm_registry):
|
||||
"""Test that RecentEventsCondenser objects can be made from config."""
|
||||
max_events = 5
|
||||
keep_first = True
|
||||
config = RecentEventsCondenserConfig(keep_first=keep_first, max_events=max_events)
|
||||
condenser = Condenser.from_config(config)
|
||||
condenser = Condenser.from_config(config, mock_llm_registry)
|
||||
|
||||
assert isinstance(condenser, RecentEventsCondenser)
|
||||
assert condenser.max_events == max_events
|
||||
@@ -334,14 +357,14 @@ def test_recent_events_condenser():
|
||||
assert result[2]._message == 'Event 5' # kept from max_events
|
||||
|
||||
|
||||
def test_llm_summarizing_condenser_from_config():
|
||||
def test_llm_summarizing_condenser_from_config(mock_llm_registry):
|
||||
"""Test that LLMSummarizingCondenser objects can be made from config."""
|
||||
config = LLMSummarizingCondenserConfig(
|
||||
max_size=50,
|
||||
keep_first=10,
|
||||
llm_config=LLMConfig(model='gpt-4o', api_key='test_key', caching_prompt=True),
|
||||
)
|
||||
condenser = Condenser.from_config(config)
|
||||
condenser = Condenser.from_config(config, mock_llm_registry)
|
||||
|
||||
assert isinstance(condenser, LLMSummarizingCondenser)
|
||||
assert condenser.llm.config.model == 'gpt-4o'
|
||||
@@ -349,25 +372,33 @@ def test_llm_summarizing_condenser_from_config():
|
||||
assert condenser.max_size == 50
|
||||
assert condenser.keep_first == 10
|
||||
|
||||
# Since this condenser can't take advantage of caching, we intercept the
|
||||
# passed config and manually flip the caching prompt to False.
|
||||
assert not condenser.llm.config.caching_prompt
|
||||
|
||||
|
||||
def test_llm_summarizing_condenser_invalid_config():
|
||||
def test_llm_summarizing_condenser_invalid_config(mock_llm, mock_llm_registry):
|
||||
"""Test that LLMSummarizingCondenser raises error when keep_first > max_size."""
|
||||
pytest.raises(
|
||||
ValueError,
|
||||
LLMSummarizingCondenser,
|
||||
llm=MagicMock(),
|
||||
llm=mock_llm,
|
||||
max_size=4,
|
||||
keep_first=2,
|
||||
)
|
||||
pytest.raises(ValueError, LLMSummarizingCondenser, llm=MagicMock(), max_size=0)
|
||||
pytest.raises(ValueError, LLMSummarizingCondenser, llm=MagicMock(), keep_first=-1)
|
||||
pytest.raises(
|
||||
ValueError,
|
||||
LLMSummarizingCondenser,
|
||||
llm=mock_llm,
|
||||
max_size=0,
|
||||
)
|
||||
pytest.raises(
|
||||
ValueError,
|
||||
LLMSummarizingCondenser,
|
||||
llm=mock_llm,
|
||||
keep_first=-1,
|
||||
)
|
||||
|
||||
|
||||
def test_llm_summarizing_condenser_gives_expected_view_size(mock_llm):
|
||||
def test_llm_summarizing_condenser_gives_expected_view_size(
|
||||
mock_llm, mock_llm_registry
|
||||
):
|
||||
"""Test that LLMSummarizingCondenser maintains the correct view size."""
|
||||
max_size = 10
|
||||
condenser = LLMSummarizingCondenser(max_size=max_size, llm=mock_llm)
|
||||
@@ -383,12 +414,16 @@ def test_llm_summarizing_condenser_gives_expected_view_size(mock_llm):
|
||||
assert len(view) == harness.expected_size(i, max_size)
|
||||
|
||||
|
||||
def test_llm_summarizing_condenser_keeps_first_and_summary_events(mock_llm):
|
||||
def test_llm_summarizing_condenser_keeps_first_and_summary_events(
|
||||
mock_llm, mock_llm_registry
|
||||
):
|
||||
"""Test that the LLM summarizing condenser appropriately maintains the event prefix and any summary events."""
|
||||
max_size = 10
|
||||
keep_first = 3
|
||||
condenser = LLMSummarizingCondenser(
|
||||
max_size=max_size, keep_first=keep_first, llm=mock_llm
|
||||
max_size=max_size,
|
||||
keep_first=keep_first,
|
||||
llm=mock_llm,
|
||||
)
|
||||
|
||||
mock_llm.set_mock_response_content('Summary of forgotten events')
|
||||
@@ -412,14 +447,14 @@ def test_llm_summarizing_condenser_keeps_first_and_summary_events(mock_llm):
|
||||
assert isinstance(view[keep_first], AgentCondensationObservation)
|
||||
|
||||
|
||||
def test_amortized_forgetting_condenser_from_config():
|
||||
def test_amortized_forgetting_condenser_from_config(mock_llm_registry):
|
||||
"""Test that AmortizedForgettingCondenser objects can be made from config."""
|
||||
max_size = 50
|
||||
keep_first = 10
|
||||
config = AmortizedForgettingCondenserConfig(
|
||||
max_size=max_size, keep_first=keep_first
|
||||
)
|
||||
condenser = Condenser.from_config(config)
|
||||
condenser = Condenser.from_config(config, mock_llm_registry)
|
||||
|
||||
assert isinstance(condenser, AmortizedForgettingCondenser)
|
||||
assert condenser.max_size == max_size
|
||||
@@ -475,7 +510,7 @@ def test_amortized_forgetting_condenser_keeps_first_and_last_events():
|
||||
assert view[:keep_first] == events[: min(keep_first, i + 1)]
|
||||
|
||||
|
||||
def test_llm_attention_condenser_from_config():
|
||||
def test_llm_attention_condenser_from_config(mock_llm_registry):
|
||||
"""Test that LLMAttentionCondenser objects can be made from config."""
|
||||
config = LLMAttentionCondenserConfig(
|
||||
max_size=50,
|
||||
@@ -486,37 +521,32 @@ def test_llm_attention_condenser_from_config():
|
||||
caching_prompt=True,
|
||||
),
|
||||
)
|
||||
condenser = Condenser.from_config(config)
|
||||
condenser = Condenser.from_config(config, mock_llm_registry)
|
||||
|
||||
assert isinstance(condenser, LLMAttentionCondenser)
|
||||
assert condenser.llm.config.model == 'gpt-4o'
|
||||
assert condenser.llm.config.api_key.get_secret_value() == 'test_key'
|
||||
assert condenser.max_size == 50
|
||||
assert condenser.keep_first == 10
|
||||
|
||||
# Since this condenser can't take advantage of caching, we intercept the
|
||||
# passed config and manually flip the caching prompt to False.
|
||||
assert not condenser.llm.config.caching_prompt
|
||||
# Create a mock LLM that doesn't support function calling
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.is_function_calling_active.return_value = False
|
||||
|
||||
# Create a new registry that returns our mock LLM that doesn't support function calling
|
||||
mock_registry = MagicMock(spec=LLMRegistry)
|
||||
mock_registry.get_llm.return_value = mock_llm
|
||||
|
||||
pytest.raises(ValueError, LLMAttentionCondenser.from_config, config, mock_registry)
|
||||
|
||||
|
||||
def test_llm_attention_condenser_invalid_config():
|
||||
"""Test that LLMAttentionCondenser raises an error if the configured LLM doesn't support response schema."""
|
||||
config = LLMAttentionCondenserConfig(
|
||||
max_size=50,
|
||||
keep_first=10,
|
||||
llm_config=LLMConfig(
|
||||
model='claude-2', # Older model that doesn't support response schema
|
||||
api_key='test_key',
|
||||
),
|
||||
)
|
||||
|
||||
pytest.raises(ValueError, LLMAttentionCondenser.from_config, config)
|
||||
|
||||
|
||||
def test_llm_attention_condenser_gives_expected_view_size(mock_llm):
|
||||
def test_llm_attention_condenser_gives_expected_view_size(mock_llm, mock_llm_registry):
|
||||
"""Test that the LLMAttentionCondenser gives views of the expected size."""
|
||||
max_size = 10
|
||||
condenser = LLMAttentionCondenser(max_size=max_size, keep_first=0, llm=mock_llm)
|
||||
condenser = LLMAttentionCondenser(
|
||||
max_size=max_size,
|
||||
keep_first=0,
|
||||
llm=mock_llm,
|
||||
)
|
||||
|
||||
events = [create_test_event(f'Event {i}', id=i) for i in range(max_size * 10)]
|
||||
|
||||
@@ -534,10 +564,16 @@ def test_llm_attention_condenser_gives_expected_view_size(mock_llm):
|
||||
assert len(view) == harness.expected_size(i, max_size)
|
||||
|
||||
|
||||
def test_llm_attention_condenser_handles_events_outside_history(mock_llm):
|
||||
def test_llm_attention_condenser_handles_events_outside_history(
|
||||
mock_llm, mock_llm_registry
|
||||
):
|
||||
"""Test that the LLMAttentionCondenser handles event IDs that aren't from the event history."""
|
||||
max_size = 2
|
||||
condenser = LLMAttentionCondenser(max_size=max_size, keep_first=0, llm=mock_llm)
|
||||
condenser = LLMAttentionCondenser(
|
||||
max_size=max_size,
|
||||
keep_first=0,
|
||||
llm=mock_llm,
|
||||
)
|
||||
|
||||
events = [create_test_event(f'Event {i}', id=i) for i in range(max_size * 10)]
|
||||
|
||||
@@ -555,10 +591,14 @@ def test_llm_attention_condenser_handles_events_outside_history(mock_llm):
|
||||
assert len(view) == harness.expected_size(i, max_size)
|
||||
|
||||
|
||||
def test_llm_attention_condenser_handles_too_many_events(mock_llm):
|
||||
def test_llm_attention_condenser_handles_too_many_events(mock_llm, mock_llm_registry):
|
||||
"""Test that the LLMAttentionCondenser handles when the response contains too many event IDs."""
|
||||
max_size = 2
|
||||
condenser = LLMAttentionCondenser(max_size=max_size, keep_first=0, llm=mock_llm)
|
||||
condenser = LLMAttentionCondenser(
|
||||
max_size=max_size,
|
||||
keep_first=0,
|
||||
llm=mock_llm,
|
||||
)
|
||||
|
||||
events = [create_test_event(f'Event {i}', id=i) for i in range(max_size * 10)]
|
||||
|
||||
@@ -576,12 +616,16 @@ def test_llm_attention_condenser_handles_too_many_events(mock_llm):
|
||||
assert len(view) == harness.expected_size(i, max_size)
|
||||
|
||||
|
||||
def test_llm_attention_condenser_handles_too_few_events(mock_llm):
|
||||
def test_llm_attention_condenser_handles_too_few_events(mock_llm, mock_llm_registry):
|
||||
"""Test that the LLMAttentionCondenser handles when the response contains too few event IDs."""
|
||||
max_size = 2
|
||||
# Developer note: We must specify keep_first=0 because
|
||||
# keep_first (1) >= max_size//2 (1) is invalid.
|
||||
condenser = LLMAttentionCondenser(max_size=max_size, keep_first=0, llm=mock_llm)
|
||||
condenser = LLMAttentionCondenser(
|
||||
max_size=max_size,
|
||||
keep_first=0,
|
||||
llm=mock_llm,
|
||||
)
|
||||
|
||||
events = [create_test_event(f'Event {i}', id=i) for i in range(max_size * 10)]
|
||||
|
||||
@@ -597,12 +641,14 @@ def test_llm_attention_condenser_handles_too_few_events(mock_llm):
|
||||
assert len(view) == harness.expected_size(i, max_size)
|
||||
|
||||
|
||||
def test_llm_attention_condenser_handles_keep_first_events(mock_llm):
|
||||
def test_llm_attention_condenser_handles_keep_first_events(mock_llm, mock_llm_registry):
|
||||
"""Test that LLMAttentionCondenser works when keep_first=1 is allowed (must be less than half of max_size)."""
|
||||
max_size = 12
|
||||
keep_first = 4
|
||||
condenser = LLMAttentionCondenser(
|
||||
max_size=max_size, keep_first=keep_first, llm=mock_llm
|
||||
max_size=max_size,
|
||||
keep_first=keep_first,
|
||||
llm=mock_llm,
|
||||
)
|
||||
|
||||
events = [create_test_event(f'Event {i}', id=i) for i in range(max_size * 10)]
|
||||
@@ -620,7 +666,7 @@ def test_llm_attention_condenser_handles_keep_first_events(mock_llm):
|
||||
assert view[:keep_first] == events[: min(keep_first, i + 1)]
|
||||
|
||||
|
||||
def test_structured_summary_condenser_from_config():
|
||||
def test_structured_summary_condenser_from_config(mock_llm_registry):
|
||||
"""Test that StructuredSummaryCondenser objects can be made from config."""
|
||||
config = StructuredSummaryCondenserConfig(
|
||||
max_size=50,
|
||||
@@ -631,7 +677,7 @@ def test_structured_summary_condenser_from_config():
|
||||
caching_prompt=True,
|
||||
),
|
||||
)
|
||||
condenser = Condenser.from_config(config)
|
||||
condenser = Condenser.from_config(config, mock_llm_registry)
|
||||
|
||||
assert isinstance(condenser, StructuredSummaryCondenser)
|
||||
assert condenser.llm.config.model == 'gpt-4o'
|
||||
@@ -639,40 +685,55 @@ def test_structured_summary_condenser_from_config():
|
||||
assert condenser.max_size == 50
|
||||
assert condenser.keep_first == 10
|
||||
|
||||
# Since this condenser can't take advantage of caching, we intercept the
|
||||
# passed config and manually flip the caching prompt to False.
|
||||
assert not condenser.llm.config.caching_prompt
|
||||
|
||||
|
||||
def test_structured_summary_condenser_invalid_config():
|
||||
def test_structured_summary_condenser_invalid_config(mock_llm):
|
||||
"""Test that StructuredSummaryCondenser raises error when keep_first > max_size."""
|
||||
# Since the condenser only works when function calling is on, we need to
|
||||
# mock up the check for that.
|
||||
llm = MagicMock()
|
||||
llm.is_function_calling_active.return_value = True
|
||||
mock_llm.is_function_calling_active.return_value = True
|
||||
|
||||
pytest.raises(
|
||||
ValueError,
|
||||
StructuredSummaryCondenser,
|
||||
llm=llm,
|
||||
llm=mock_llm,
|
||||
max_size=4,
|
||||
keep_first=2,
|
||||
)
|
||||
|
||||
pytest.raises(ValueError, StructuredSummaryCondenser, llm=llm, max_size=0)
|
||||
pytest.raises(ValueError, StructuredSummaryCondenser, llm=llm, keep_first=-1)
|
||||
pytest.raises(
|
||||
ValueError,
|
||||
StructuredSummaryCondenser,
|
||||
llm=mock_llm,
|
||||
max_size=0,
|
||||
)
|
||||
pytest.raises(
|
||||
ValueError,
|
||||
StructuredSummaryCondenser,
|
||||
llm=mock_llm,
|
||||
keep_first=-1,
|
||||
)
|
||||
|
||||
# If all other parameters are good but there's no function calling the
|
||||
# condenser still counts as improperly configured.
|
||||
llm.is_function_calling_active.return_value = False
|
||||
# Create a mock LLM that doesn't support function calling
|
||||
mock_llm_no_func = MagicMock()
|
||||
mock_llm_no_func.is_function_calling_active.return_value = False
|
||||
|
||||
pytest.raises(
|
||||
ValueError, StructuredSummaryCondenser, llm=llm, max_size=40, keep_first=2
|
||||
ValueError,
|
||||
StructuredSummaryCondenser,
|
||||
llm=mock_llm_no_func,
|
||||
max_size=40,
|
||||
keep_first=2,
|
||||
)
|
||||
|
||||
|
||||
def test_structured_summary_condenser_gives_expected_view_size(mock_llm):
|
||||
def test_structured_summary_condenser_gives_expected_view_size(
|
||||
mock_llm, mock_llm_registry
|
||||
):
|
||||
"""Test that StructuredSummaryCondenser maintains the correct view size."""
|
||||
max_size = 10
|
||||
mock_llm.is_function_calling_active.return_value = True
|
||||
condenser = StructuredSummaryCondenser(max_size=max_size, llm=mock_llm)
|
||||
|
||||
events = [create_test_event(f'Event {i}', id=i) for i in range(max_size * 10)]
|
||||
@@ -686,12 +747,17 @@ def test_structured_summary_condenser_gives_expected_view_size(mock_llm):
|
||||
assert len(view) == harness.expected_size(i, max_size)
|
||||
|
||||
|
||||
def test_structured_summary_condenser_keeps_first_and_summary_events(mock_llm):
|
||||
def test_structured_summary_condenser_keeps_first_and_summary_events(
|
||||
mock_llm, mock_llm_registry
|
||||
):
|
||||
"""Test that the StructuredSummaryCondenser appropriately maintains the event prefix and any summary events."""
|
||||
max_size = 10
|
||||
keep_first = 3
|
||||
mock_llm.is_function_calling_active.return_value = True
|
||||
condenser = StructuredSummaryCondenser(
|
||||
max_size=max_size, keep_first=keep_first, llm=mock_llm
|
||||
max_size=max_size,
|
||||
keep_first=keep_first,
|
||||
llm=mock_llm,
|
||||
)
|
||||
|
||||
mock_llm.set_mock_response_content('Summary of forgotten events')
|
||||
@@ -715,7 +781,7 @@ def test_structured_summary_condenser_keeps_first_and_summary_events(mock_llm):
|
||||
assert isinstance(view[keep_first], AgentCondensationObservation)
|
||||
|
||||
|
||||
def test_condenser_pipeline_from_config():
|
||||
def test_condenser_pipeline_from_config(mock_llm_registry):
|
||||
"""Test that CondenserPipeline condensers can be created from configuration objects."""
|
||||
config = CondenserPipelineConfig(
|
||||
condensers=[
|
||||
@@ -728,7 +794,7 @@ def test_condenser_pipeline_from_config():
|
||||
),
|
||||
]
|
||||
)
|
||||
condenser = Condenser.from_config(config)
|
||||
condenser = Condenser.from_config(config, mock_llm_registry)
|
||||
|
||||
assert isinstance(condenser, CondenserPipeline)
|
||||
assert len(condenser.condensers) == 3
|
||||
|
||||
490
tests/unit/test_conversation_stats.py
Normal file
490
tests/unit/test_conversation_stats.py
Normal file
@@ -0,0 +1,490 @@
|
||||
import base64
|
||||
import pickle
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from openhands.core.config import LLMConfig, OpenHandsConfig
|
||||
from openhands.llm.llm import LLM
|
||||
from openhands.llm.llm_registry import LLMRegistry, RegistryEvent
|
||||
from openhands.llm.metrics import Metrics
|
||||
from openhands.server.services.conversation_stats import ConversationStats
|
||||
from openhands.storage.memory import InMemoryFileStore
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_file_store():
|
||||
"""Create a mock file store for testing."""
|
||||
return InMemoryFileStore({})
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def conversation_stats(mock_file_store):
|
||||
"""Create a ConversationStats instance for testing."""
|
||||
return ConversationStats(
|
||||
file_store=mock_file_store,
|
||||
conversation_id='test-conversation-id',
|
||||
user_id='test-user-id',
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_llm_registry():
|
||||
"""Create a mock LLM registry that properly simulates LLM registration."""
|
||||
config = OpenHandsConfig()
|
||||
registry = LLMRegistry(config=config, agent_cls=None, retry_listener=None)
|
||||
return registry
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def connected_registry_and_stats(mock_llm_registry, conversation_stats):
|
||||
"""Connect the LLMRegistry and ConversationStats properly."""
|
||||
# Subscribe to LLM registry events to track metrics
|
||||
mock_llm_registry.subscribe(conversation_stats.register_llm)
|
||||
return mock_llm_registry, conversation_stats
|
||||
|
||||
|
||||
def test_conversation_stats_initialization(conversation_stats):
|
||||
"""Test that ConversationStats initializes correctly."""
|
||||
assert conversation_stats.conversation_id == 'test-conversation-id'
|
||||
assert conversation_stats.user_id == 'test-user-id'
|
||||
assert conversation_stats.service_to_metrics == {}
|
||||
assert isinstance(conversation_stats.restored_metrics, dict)
|
||||
|
||||
|
||||
def test_save_metrics(conversation_stats, mock_file_store):
|
||||
"""Test that metrics are saved correctly."""
|
||||
# Add a service with metrics
|
||||
service_id = 'test-service'
|
||||
metrics = Metrics(model_name='gpt-4')
|
||||
metrics.add_cost(0.05)
|
||||
conversation_stats.service_to_metrics[service_id] = metrics
|
||||
|
||||
# Save metrics
|
||||
conversation_stats.save_metrics()
|
||||
|
||||
# Verify that metrics were saved to the file store
|
||||
try:
|
||||
# Verify the saved content can be decoded and unpickled
|
||||
encoded = mock_file_store.read(conversation_stats.metrics_path)
|
||||
pickled = base64.b64decode(encoded)
|
||||
restored = pickle.loads(pickled)
|
||||
|
||||
assert service_id in restored
|
||||
assert restored[service_id].accumulated_cost == 0.05
|
||||
except FileNotFoundError:
|
||||
pytest.fail(f'File not found: {conversation_stats.metrics_path}')
|
||||
|
||||
|
||||
def test_maybe_restore_metrics(mock_file_store):
|
||||
"""Test that metrics are restored correctly."""
|
||||
# Create metrics to save
|
||||
service_id = 'test-service'
|
||||
metrics = Metrics(model_name='gpt-4')
|
||||
metrics.add_cost(0.1)
|
||||
service_to_metrics = {service_id: metrics}
|
||||
|
||||
# Serialize and save metrics
|
||||
pickled = pickle.dumps(service_to_metrics)
|
||||
serialized_metrics = base64.b64encode(pickled).decode('utf-8')
|
||||
|
||||
# Create a new ConversationStats with pre-populated file store
|
||||
conversation_id = 'test-conversation-id'
|
||||
user_id = 'test-user-id'
|
||||
|
||||
# Get the correct path using the same function as ConversationStats
|
||||
from openhands.storage.locations import get_conversation_stats_filename
|
||||
|
||||
metrics_path = get_conversation_stats_filename(conversation_id, user_id)
|
||||
|
||||
# Write to the correct path
|
||||
mock_file_store.write(metrics_path, serialized_metrics)
|
||||
|
||||
# Create ConversationStats which should restore metrics
|
||||
stats = ConversationStats(
|
||||
file_store=mock_file_store, conversation_id=conversation_id, user_id=user_id
|
||||
)
|
||||
|
||||
# Verify metrics were restored
|
||||
assert service_id in stats.restored_metrics
|
||||
assert stats.restored_metrics[service_id].accumulated_cost == 0.1
|
||||
|
||||
|
||||
def test_get_combined_metrics(conversation_stats):
|
||||
"""Test that combined metrics are calculated correctly."""
|
||||
# Add multiple services with metrics
|
||||
service1 = 'service1'
|
||||
metrics1 = Metrics(model_name='gpt-4')
|
||||
metrics1.add_cost(0.05)
|
||||
metrics1.add_token_usage(
|
||||
prompt_tokens=100,
|
||||
completion_tokens=50,
|
||||
cache_read_tokens=0,
|
||||
cache_write_tokens=0,
|
||||
context_window=8000,
|
||||
response_id='resp1',
|
||||
)
|
||||
|
||||
service2 = 'service2'
|
||||
metrics2 = Metrics(model_name='gpt-3.5')
|
||||
metrics2.add_cost(0.02)
|
||||
metrics2.add_token_usage(
|
||||
prompt_tokens=200,
|
||||
completion_tokens=100,
|
||||
cache_read_tokens=0,
|
||||
cache_write_tokens=0,
|
||||
context_window=4000,
|
||||
response_id='resp2',
|
||||
)
|
||||
|
||||
conversation_stats.service_to_metrics[service1] = metrics1
|
||||
conversation_stats.service_to_metrics[service2] = metrics2
|
||||
|
||||
# Get combined metrics
|
||||
combined = conversation_stats.get_combined_metrics()
|
||||
|
||||
# Verify combined metrics
|
||||
assert combined.accumulated_cost == 0.07 # 0.05 + 0.02
|
||||
assert combined.accumulated_token_usage.prompt_tokens == 300 # 100 + 200
|
||||
assert combined.accumulated_token_usage.completion_tokens == 150 # 50 + 100
|
||||
assert (
|
||||
combined.accumulated_token_usage.context_window == 8000
|
||||
) # max of 8000 and 4000
|
||||
|
||||
|
||||
def test_get_metrics_for_service(conversation_stats):
|
||||
"""Test that metrics for a specific service are retrieved correctly."""
|
||||
# Add a service with metrics
|
||||
service_id = 'test-service'
|
||||
metrics = Metrics(model_name='gpt-4')
|
||||
metrics.add_cost(0.05)
|
||||
conversation_stats.service_to_metrics[service_id] = metrics
|
||||
|
||||
# Get metrics for the service
|
||||
retrieved_metrics = conversation_stats.get_metrics_for_service(service_id)
|
||||
|
||||
# Verify metrics
|
||||
assert retrieved_metrics.accumulated_cost == 0.05
|
||||
assert retrieved_metrics is metrics # Should be the same object
|
||||
|
||||
# Test getting metrics for non-existent service
|
||||
# Use a specific exception message pattern instead of a blind Exception
|
||||
with pytest.raises(Exception, match='LLM service does not exist'):
|
||||
conversation_stats.get_metrics_for_service('non-existent-service')
|
||||
|
||||
|
||||
def test_register_llm_with_new_service(conversation_stats):
|
||||
"""Test registering a new LLM service."""
|
||||
# Create a real LLM instance with a mock config
|
||||
llm_config = LLMConfig(
|
||||
model='gpt-4o',
|
||||
api_key='test_key',
|
||||
num_retries=2,
|
||||
retry_min_wait=1,
|
||||
retry_max_wait=2,
|
||||
)
|
||||
|
||||
# Patch the LLM class to avoid actual API calls
|
||||
with patch('openhands.llm.llm.litellm_completion'):
|
||||
llm = LLM(service_id='new-service', config=llm_config)
|
||||
|
||||
# Create a registry event
|
||||
service_id = 'new-service'
|
||||
event = RegistryEvent(llm=llm, service_id=service_id)
|
||||
|
||||
# Register the LLM
|
||||
conversation_stats.register_llm(event)
|
||||
|
||||
# Verify the service was registered
|
||||
assert service_id in conversation_stats.service_to_metrics
|
||||
assert conversation_stats.service_to_metrics[service_id] is llm.metrics
|
||||
|
||||
|
||||
def test_register_llm_with_restored_metrics(conversation_stats):
|
||||
"""Test registering an LLM service with restored metrics."""
|
||||
# Create restored metrics
|
||||
service_id = 'restored-service'
|
||||
restored_metrics = Metrics(model_name='gpt-4')
|
||||
restored_metrics.add_cost(0.1)
|
||||
conversation_stats.restored_metrics = {service_id: restored_metrics}
|
||||
|
||||
# Create a real LLM instance with a mock config
|
||||
llm_config = LLMConfig(
|
||||
model='gpt-4o',
|
||||
api_key='test_key',
|
||||
num_retries=2,
|
||||
retry_min_wait=1,
|
||||
retry_max_wait=2,
|
||||
)
|
||||
|
||||
# Patch the LLM class to avoid actual API calls
|
||||
with patch('openhands.llm.llm.litellm_completion'):
|
||||
llm = LLM(service_id=service_id, config=llm_config)
|
||||
|
||||
# Create a registry event
|
||||
event = RegistryEvent(llm=llm, service_id=service_id)
|
||||
|
||||
# Register the LLM
|
||||
conversation_stats.register_llm(event)
|
||||
|
||||
# Verify the service was registered with restored metrics
|
||||
assert service_id in conversation_stats.service_to_metrics
|
||||
assert conversation_stats.service_to_metrics[service_id] is llm.metrics
|
||||
assert llm.metrics.accumulated_cost == 0.1 # Restored cost
|
||||
|
||||
# Verify the specific service was removed from restored_metrics
|
||||
assert service_id not in conversation_stats.restored_metrics
|
||||
assert hasattr(
|
||||
conversation_stats, 'restored_metrics'
|
||||
) # The dict should still exist
|
||||
|
||||
|
||||
def test_llm_registry_notifications(connected_registry_and_stats):
|
||||
"""Test that LLM registry notifications update conversation stats."""
|
||||
mock_llm_registry, conversation_stats = connected_registry_and_stats
|
||||
|
||||
# Create a new LLM through the registry
|
||||
service_id = 'test-service'
|
||||
llm_config = LLMConfig(
|
||||
model='gpt-4o',
|
||||
api_key='test_key',
|
||||
num_retries=2,
|
||||
retry_min_wait=1,
|
||||
retry_max_wait=2,
|
||||
)
|
||||
|
||||
# Get LLM from registry (this should trigger the notification)
|
||||
llm = mock_llm_registry.get_llm(service_id, llm_config)
|
||||
|
||||
# Verify the service was registered in conversation stats
|
||||
assert service_id in conversation_stats.service_to_metrics
|
||||
assert conversation_stats.service_to_metrics[service_id] is llm.metrics
|
||||
|
||||
# Add some metrics to the LLM
|
||||
llm.metrics.add_cost(0.05)
|
||||
llm.metrics.add_token_usage(
|
||||
prompt_tokens=100,
|
||||
completion_tokens=50,
|
||||
cache_read_tokens=0,
|
||||
cache_write_tokens=0,
|
||||
context_window=8000,
|
||||
response_id='resp1',
|
||||
)
|
||||
|
||||
# Verify the metrics are reflected in conversation stats
|
||||
assert conversation_stats.service_to_metrics[service_id].accumulated_cost == 0.05
|
||||
assert (
|
||||
conversation_stats.service_to_metrics[
|
||||
service_id
|
||||
].accumulated_token_usage.prompt_tokens
|
||||
== 100
|
||||
)
|
||||
assert (
|
||||
conversation_stats.service_to_metrics[
|
||||
service_id
|
||||
].accumulated_token_usage.completion_tokens
|
||||
== 50
|
||||
)
|
||||
|
||||
# Get combined metrics and verify
|
||||
combined = conversation_stats.get_combined_metrics()
|
||||
assert combined.accumulated_cost == 0.05
|
||||
assert combined.accumulated_token_usage.prompt_tokens == 100
|
||||
assert combined.accumulated_token_usage.completion_tokens == 50
|
||||
|
||||
|
||||
def test_multiple_llm_services(connected_registry_and_stats):
|
||||
"""Test tracking metrics for multiple LLM services."""
|
||||
mock_llm_registry, conversation_stats = connected_registry_and_stats
|
||||
|
||||
# Create multiple LLMs through the registry
|
||||
service1 = 'service1'
|
||||
service2 = 'service2'
|
||||
|
||||
llm_config1 = LLMConfig(
|
||||
model='gpt-4o',
|
||||
api_key='test_key',
|
||||
num_retries=2,
|
||||
retry_min_wait=1,
|
||||
retry_max_wait=2,
|
||||
)
|
||||
|
||||
llm_config2 = LLMConfig(
|
||||
model='gpt-3.5-turbo',
|
||||
api_key='test_key',
|
||||
num_retries=2,
|
||||
retry_min_wait=1,
|
||||
retry_max_wait=2,
|
||||
)
|
||||
|
||||
# Get LLMs from registry (this should trigger notifications)
|
||||
llm1 = mock_llm_registry.get_llm(service1, llm_config1)
|
||||
llm2 = mock_llm_registry.get_llm(service2, llm_config2)
|
||||
|
||||
# Add different metrics to each LLM
|
||||
llm1.metrics.add_cost(0.05)
|
||||
llm1.metrics.add_token_usage(
|
||||
prompt_tokens=100,
|
||||
completion_tokens=50,
|
||||
cache_read_tokens=0,
|
||||
cache_write_tokens=0,
|
||||
context_window=8000,
|
||||
response_id='resp1',
|
||||
)
|
||||
|
||||
llm2.metrics.add_cost(0.02)
|
||||
llm2.metrics.add_token_usage(
|
||||
prompt_tokens=200,
|
||||
completion_tokens=100,
|
||||
cache_read_tokens=0,
|
||||
cache_write_tokens=0,
|
||||
context_window=4000,
|
||||
response_id='resp2',
|
||||
)
|
||||
|
||||
# Verify services were registered in conversation stats
|
||||
assert service1 in conversation_stats.service_to_metrics
|
||||
assert service2 in conversation_stats.service_to_metrics
|
||||
|
||||
# Verify individual metrics
|
||||
assert conversation_stats.service_to_metrics[service1].accumulated_cost == 0.05
|
||||
assert conversation_stats.service_to_metrics[service2].accumulated_cost == 0.02
|
||||
|
||||
# Get combined metrics and verify
|
||||
combined = conversation_stats.get_combined_metrics()
|
||||
assert combined.accumulated_cost == 0.07 # 0.05 + 0.02
|
||||
assert combined.accumulated_token_usage.prompt_tokens == 300 # 100 + 200
|
||||
assert combined.accumulated_token_usage.completion_tokens == 150 # 50 + 100
|
||||
assert (
|
||||
combined.accumulated_token_usage.context_window == 8000
|
||||
) # max of 8000 and 4000
|
||||
|
||||
|
||||
def test_register_llm_with_multiple_restored_services_bug(conversation_stats):
|
||||
"""Test that reproduces the bug where del self.restored_metrics deletes entire dict instead of specific service."""
|
||||
# Create restored metrics for multiple services
|
||||
service_id_1 = 'service-1'
|
||||
service_id_2 = 'service-2'
|
||||
|
||||
restored_metrics_1 = Metrics(model_name='gpt-4')
|
||||
restored_metrics_1.add_cost(0.1)
|
||||
|
||||
restored_metrics_2 = Metrics(model_name='gpt-3.5')
|
||||
restored_metrics_2.add_cost(0.05)
|
||||
|
||||
# Set up restored metrics for both services
|
||||
conversation_stats.restored_metrics = {
|
||||
service_id_1: restored_metrics_1,
|
||||
service_id_2: restored_metrics_2,
|
||||
}
|
||||
|
||||
# Create LLM configs
|
||||
llm_config_1 = LLMConfig(
|
||||
model='gpt-4o',
|
||||
api_key='test_key',
|
||||
num_retries=2,
|
||||
retry_min_wait=1,
|
||||
retry_max_wait=2,
|
||||
)
|
||||
|
||||
llm_config_2 = LLMConfig(
|
||||
model='gpt-3.5-turbo',
|
||||
api_key='test_key',
|
||||
num_retries=2,
|
||||
retry_min_wait=1,
|
||||
retry_max_wait=2,
|
||||
)
|
||||
|
||||
# Patch the LLM class to avoid actual API calls
|
||||
with patch('openhands.llm.llm.litellm_completion'):
|
||||
# Register first LLM
|
||||
llm_1 = LLM(service_id=service_id_1, config=llm_config_1)
|
||||
event_1 = RegistryEvent(llm=llm_1, service_id=service_id_1)
|
||||
conversation_stats.register_llm(event_1)
|
||||
|
||||
# Verify first service was registered with restored metrics
|
||||
assert service_id_1 in conversation_stats.service_to_metrics
|
||||
assert llm_1.metrics.accumulated_cost == 0.1
|
||||
|
||||
# After registering first service, restored_metrics should still contain service_id_2
|
||||
assert service_id_2 in conversation_stats.restored_metrics
|
||||
|
||||
# Register second LLM - this should also work with restored metrics
|
||||
llm_2 = LLM(service_id=service_id_2, config=llm_config_2)
|
||||
event_2 = RegistryEvent(llm=llm_2, service_id=service_id_2)
|
||||
conversation_stats.register_llm(event_2)
|
||||
|
||||
# Verify second service was registered with restored metrics
|
||||
assert service_id_2 in conversation_stats.service_to_metrics
|
||||
assert llm_2.metrics.accumulated_cost == 0.05
|
||||
|
||||
# After both services are registered, restored_metrics should be empty
|
||||
assert len(conversation_stats.restored_metrics) == 0
|
||||
|
||||
|
||||
def test_save_and_restore_workflow(mock_file_store):
|
||||
"""Test the full workflow of saving and restoring metrics."""
|
||||
# Create initial conversation stats
|
||||
conversation_id = 'test-conversation-id'
|
||||
user_id = 'test-user-id'
|
||||
|
||||
stats1 = ConversationStats(
|
||||
file_store=mock_file_store, conversation_id=conversation_id, user_id=user_id
|
||||
)
|
||||
|
||||
# Add a service with metrics
|
||||
service_id = 'test-service'
|
||||
metrics = Metrics(model_name='gpt-4')
|
||||
metrics.add_cost(0.05)
|
||||
metrics.add_token_usage(
|
||||
prompt_tokens=100,
|
||||
completion_tokens=50,
|
||||
cache_read_tokens=0,
|
||||
cache_write_tokens=0,
|
||||
context_window=8000,
|
||||
response_id='resp1',
|
||||
)
|
||||
stats1.service_to_metrics[service_id] = metrics
|
||||
|
||||
# Save metrics
|
||||
stats1.save_metrics()
|
||||
|
||||
# Create a new conversation stats instance that should restore the metrics
|
||||
stats2 = ConversationStats(
|
||||
file_store=mock_file_store, conversation_id=conversation_id, user_id=user_id
|
||||
)
|
||||
|
||||
# Verify metrics were restored
|
||||
assert service_id in stats2.restored_metrics
|
||||
assert stats2.restored_metrics[service_id].accumulated_cost == 0.05
|
||||
assert (
|
||||
stats2.restored_metrics[service_id].accumulated_token_usage.prompt_tokens == 100
|
||||
)
|
||||
assert (
|
||||
stats2.restored_metrics[service_id].accumulated_token_usage.completion_tokens
|
||||
== 50
|
||||
)
|
||||
|
||||
# Create a real LLM instance with a mock config
|
||||
llm_config = LLMConfig(
|
||||
model='gpt-4o',
|
||||
api_key='test_key',
|
||||
num_retries=2,
|
||||
retry_min_wait=1,
|
||||
retry_max_wait=2,
|
||||
)
|
||||
|
||||
# Patch the LLM class to avoid actual API calls
|
||||
with patch('openhands.llm.llm.litellm_completion'):
|
||||
llm = LLM(service_id=service_id, config=llm_config)
|
||||
|
||||
# Create a registry event
|
||||
event = RegistryEvent(llm=llm, service_id=service_id)
|
||||
|
||||
# Register the LLM to trigger restoration
|
||||
stats2.register_llm(event)
|
||||
|
||||
# Verify metrics were applied to the LLM
|
||||
assert llm.metrics.accumulated_cost == 0.05
|
||||
assert llm.metrics.accumulated_token_usage.prompt_tokens == 100
|
||||
assert llm.metrics.accumulated_token_usage.completion_tokens == 50
|
||||
@@ -1,6 +1,6 @@
|
||||
"""Tests for the conversation summary generator."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -11,55 +11,51 @@ from openhands.utils.conversation_summary import generate_conversation_title
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_conversation_title_empty_message():
|
||||
"""Test that an empty message returns None."""
|
||||
result = await generate_conversation_title('', MagicMock())
|
||||
mock_llm_registry = MagicMock()
|
||||
mock_llm_config = LLMConfig(model='test-model')
|
||||
|
||||
result = await generate_conversation_title('', mock_llm_config, mock_llm_registry)
|
||||
assert result is None
|
||||
|
||||
result = await generate_conversation_title(' ', MagicMock())
|
||||
result = await generate_conversation_title(
|
||||
' ', mock_llm_config, mock_llm_registry
|
||||
)
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_conversation_title_success():
|
||||
"""Test successful title generation."""
|
||||
# Create a proper mock response
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = 'Generated Title'
|
||||
# Create a mock LLM registry that returns a title
|
||||
mock_llm_registry = MagicMock()
|
||||
mock_llm_registry.request_extraneous_completion.return_value = 'Generated Title'
|
||||
|
||||
# Create a mock LLM instance with a synchronous completion method
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.completion = MagicMock(return_value=mock_response)
|
||||
mock_llm_config = LLMConfig(model='test-model')
|
||||
|
||||
# Patch the LLM class to return our mock
|
||||
with patch('openhands.utils.conversation_summary.LLM', return_value=mock_llm):
|
||||
result = await generate_conversation_title(
|
||||
'Can you help me with Python?', LLMConfig(model='test-model')
|
||||
)
|
||||
result = await generate_conversation_title(
|
||||
'Can you help me with Python?', mock_llm_config, mock_llm_registry
|
||||
)
|
||||
|
||||
assert result == 'Generated Title'
|
||||
# Verify the mock was called with the expected arguments
|
||||
mock_llm.completion.assert_called_once()
|
||||
mock_llm_registry.request_extraneous_completion.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_conversation_title_long_title():
|
||||
"""Test that long titles are truncated."""
|
||||
# Create a proper mock response with a long title
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[
|
||||
0
|
||||
].message.content = 'This is a very long title that should be truncated because it exceeds the maximum length'
|
||||
# Create a mock LLM registry that returns a long title
|
||||
mock_llm_registry = MagicMock()
|
||||
mock_llm_registry.request_extraneous_completion.return_value = 'This is a very long title that should be truncated because it exceeds the maximum length'
|
||||
|
||||
# Create a mock LLM instance with a synchronous completion method
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.completion = MagicMock(return_value=mock_response)
|
||||
mock_llm_config = LLMConfig(model='test-model')
|
||||
|
||||
# Patch the LLM class to return our mock
|
||||
with patch('openhands.utils.conversation_summary.LLM', return_value=mock_llm):
|
||||
result = await generate_conversation_title(
|
||||
'Can you help me with Python?', LLMConfig(model='test-model'), max_length=30
|
||||
)
|
||||
result = await generate_conversation_title(
|
||||
'Can you help me with Python?',
|
||||
mock_llm_config,
|
||||
mock_llm_registry,
|
||||
max_length=30,
|
||||
)
|
||||
|
||||
# Verify the title is truncated correctly
|
||||
assert len(result) <= 30
|
||||
@@ -69,15 +65,17 @@ async def test_generate_conversation_title_long_title():
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_conversation_title_exception():
|
||||
"""Test that exceptions are handled gracefully."""
|
||||
# Create a mock LLM instance with a synchronous completion method that raises an exception
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.completion = MagicMock(side_effect=Exception('Test error'))
|
||||
# Create a mock LLM registry that raises an exception
|
||||
mock_llm_registry = MagicMock()
|
||||
mock_llm_registry.request_extraneous_completion.side_effect = Exception(
|
||||
'Test error'
|
||||
)
|
||||
|
||||
# Patch the LLM class to return our mock
|
||||
with patch('openhands.utils.conversation_summary.LLM', return_value=mock_llm):
|
||||
result = await generate_conversation_title(
|
||||
'Can you help me with Python?', LLMConfig(model='test-model')
|
||||
)
|
||||
mock_llm_config = LLMConfig(model='test-model')
|
||||
|
||||
result = await generate_conversation_title(
|
||||
'Can you help me with Python?', mock_llm_config, mock_llm_registry
|
||||
)
|
||||
|
||||
# Verify that None is returned when an exception occurs
|
||||
assert result is None
|
||||
|
||||
@@ -4,6 +4,7 @@ import pytest
|
||||
|
||||
from openhands.core.config import OpenHandsConfig
|
||||
from openhands.events import EventStream
|
||||
from openhands.llm.llm_registry import LLMRegistry
|
||||
from openhands.runtime.impl.docker.docker_runtime import DockerRuntime
|
||||
|
||||
|
||||
@@ -40,12 +41,17 @@ def event_stream():
|
||||
return MagicMock(spec=EventStream)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def llm_registry():
|
||||
return MagicMock(spec=LLMRegistry)
|
||||
|
||||
|
||||
@patch('openhands.runtime.impl.docker.docker_runtime.stop_all_containers')
|
||||
def test_container_stopped_when_keep_runtime_alive_false(
|
||||
mock_stop_containers, mock_docker_client, config, event_stream
|
||||
mock_stop_containers, mock_docker_client, config, event_stream, llm_registry
|
||||
):
|
||||
# Arrange
|
||||
runtime = DockerRuntime(config, event_stream, sid='test-sid')
|
||||
runtime = DockerRuntime(config, event_stream, llm_registry, sid='test-sid')
|
||||
runtime.container = mock_docker_client.containers.get.return_value
|
||||
|
||||
# Act
|
||||
@@ -57,11 +63,11 @@ def test_container_stopped_when_keep_runtime_alive_false(
|
||||
|
||||
@patch('openhands.runtime.impl.docker.docker_runtime.stop_all_containers')
|
||||
def test_container_not_stopped_when_keep_runtime_alive_true(
|
||||
mock_stop_containers, mock_docker_client, config, event_stream
|
||||
mock_stop_containers, mock_docker_client, config, event_stream, llm_registry
|
||||
):
|
||||
# Arrange
|
||||
config.sandbox.keep_runtime_alive = True
|
||||
runtime = DockerRuntime(config, event_stream, sid='test-sid')
|
||||
runtime = DockerRuntime(config, event_stream, llm_registry, sid='test-sid')
|
||||
runtime.container = mock_docker_client.containers.get.return_value
|
||||
|
||||
# Act
|
||||
|
||||
178
tests/unit/test_llm_registry.py
Normal file
178
tests/unit/test_llm_registry.py
Normal file
@@ -0,0 +1,178 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from openhands.core.config.llm_config import LLMConfig
|
||||
from openhands.core.config.openhands_config import OpenHandsConfig
|
||||
from openhands.llm.llm_registry import LLMRegistry, RegistryEvent
|
||||
|
||||
|
||||
class TestLLMRegistry(unittest.TestCase):
|
||||
def setUp(self):
|
||||
"""Set up test environment before each test."""
|
||||
# Create a basic LLM config for testing
|
||||
self.llm_config = LLMConfig(model='test-model')
|
||||
|
||||
# Create a basic OpenHands config for testing
|
||||
self.config = OpenHandsConfig(
|
||||
llms={'llm': self.llm_config}, default_agent='CodeActAgent'
|
||||
)
|
||||
|
||||
# Create a registry for testing
|
||||
self.registry = LLMRegistry(config=self.config)
|
||||
|
||||
def test_get_llm_creates_new_llm(self):
|
||||
"""Test that get_llm creates a new LLM when service doesn't exist."""
|
||||
service_id = 'test-service'
|
||||
|
||||
# Mock the _create_new_llm method to avoid actual LLM initialization
|
||||
with patch.object(self.registry, '_create_new_llm') as mock_create:
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.config = self.llm_config
|
||||
mock_create.return_value = mock_llm
|
||||
|
||||
# Get LLM for the first time
|
||||
llm = self.registry.get_llm(service_id, self.llm_config)
|
||||
|
||||
# Verify LLM was created and stored
|
||||
self.assertEqual(llm, mock_llm)
|
||||
mock_create.assert_called_once_with(
|
||||
config=self.llm_config, service_id=service_id
|
||||
)
|
||||
|
||||
def test_get_llm_returns_existing_llm(self):
|
||||
"""Test that get_llm returns existing LLM when service already exists."""
|
||||
service_id = 'test-service'
|
||||
|
||||
# Mock the _create_new_llm method to avoid actual LLM initialization
|
||||
with patch.object(self.registry, '_create_new_llm') as mock_create:
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.config = self.llm_config
|
||||
mock_create.return_value = mock_llm
|
||||
|
||||
# Get LLM for the first time
|
||||
llm1 = self.registry.get_llm(service_id, self.llm_config)
|
||||
|
||||
# Manually add to registry to simulate existing LLM
|
||||
self.registry.service_to_llm[service_id] = mock_llm
|
||||
|
||||
# Get LLM for the second time - should return the same instance
|
||||
llm2 = self.registry.get_llm(service_id, self.llm_config)
|
||||
|
||||
# Verify same LLM instance is returned
|
||||
self.assertEqual(llm1, llm2)
|
||||
self.assertEqual(llm1, mock_llm)
|
||||
|
||||
# Verify _create_new_llm was only called once
|
||||
mock_create.assert_called_once()
|
||||
|
||||
def test_get_llm_with_different_config_raises_error(self):
|
||||
"""Test that requesting same service ID with different config raises an error."""
|
||||
service_id = 'test-service'
|
||||
different_config = LLMConfig(model='different-model')
|
||||
|
||||
# Manually add an LLM to the registry to simulate existing service
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.config = self.llm_config
|
||||
self.registry.service_to_llm[service_id] = mock_llm
|
||||
|
||||
# Attempt to get LLM with different config should raise ValueError
|
||||
with self.assertRaises(ValueError) as context:
|
||||
self.registry.get_llm(service_id, different_config)
|
||||
|
||||
self.assertIn('Requesting same service ID', str(context.exception))
|
||||
self.assertIn('with different config', str(context.exception))
|
||||
|
||||
def test_get_llm_without_config_raises_error(self):
|
||||
"""Test that requesting new LLM without config raises an error."""
|
||||
service_id = 'test-service'
|
||||
|
||||
# Attempt to get LLM without providing config should raise ValueError
|
||||
with self.assertRaises(ValueError) as context:
|
||||
self.registry.get_llm(service_id, None)
|
||||
|
||||
self.assertIn(
|
||||
'Requesting new LLM without specifying LLM config', str(context.exception)
|
||||
)
|
||||
|
||||
def test_request_extraneous_completion(self):
|
||||
"""Test that requesting an extraneous completion creates a new LLM if needed."""
|
||||
service_id = 'extraneous-service'
|
||||
messages = [{'role': 'user', 'content': 'Hello, world!'}]
|
||||
|
||||
# Mock the _create_new_llm method to avoid actual LLM initialization
|
||||
with patch.object(self.registry, '_create_new_llm') as mock_create:
|
||||
mock_llm = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = ' Hello from the LLM! '
|
||||
mock_llm.completion.return_value = mock_response
|
||||
mock_create.return_value = mock_llm
|
||||
|
||||
# Mock the side effect to add the LLM to the registry
|
||||
def side_effect(*args, **kwargs):
|
||||
self.registry.service_to_llm[service_id] = mock_llm
|
||||
return mock_llm
|
||||
|
||||
mock_create.side_effect = side_effect
|
||||
|
||||
# Request a completion
|
||||
response = self.registry.request_extraneous_completion(
|
||||
service_id=service_id,
|
||||
llm_config=self.llm_config,
|
||||
messages=messages,
|
||||
)
|
||||
|
||||
# Verify the response (should be stripped)
|
||||
self.assertEqual(response, 'Hello from the LLM!')
|
||||
|
||||
# Verify that _create_new_llm was called with correct parameters
|
||||
mock_create.assert_called_once_with(
|
||||
config=self.llm_config, service_id=service_id, with_listener=False
|
||||
)
|
||||
|
||||
# Verify completion was called with correct messages
|
||||
mock_llm.completion.assert_called_once_with(messages=messages)
|
||||
|
||||
def test_get_active_llm(self):
|
||||
"""Test that get_active_llm returns the active agent LLM."""
|
||||
active_llm = self.registry.get_active_llm()
|
||||
self.assertEqual(active_llm, self.registry.active_agent_llm)
|
||||
|
||||
def test_subscribe_and_notify(self):
|
||||
"""Test the subscription and notification system."""
|
||||
events_received = []
|
||||
|
||||
def callback(event: RegistryEvent):
|
||||
events_received.append(event)
|
||||
|
||||
# Subscribe to events
|
||||
self.registry.subscribe(callback)
|
||||
|
||||
# Should receive notification for the active agent LLM
|
||||
self.assertEqual(len(events_received), 1)
|
||||
self.assertEqual(events_received[0].llm, self.registry.active_agent_llm)
|
||||
self.assertEqual(
|
||||
events_received[0].service_id, self.registry.active_agent_llm.service_id
|
||||
)
|
||||
|
||||
# Test that the subscriber is set correctly
|
||||
self.assertIsNotNone(self.registry.subscriber)
|
||||
|
||||
# Test notify method directly with a mock event
|
||||
with patch.object(self.registry, 'subscriber') as mock_subscriber:
|
||||
mock_event = MagicMock()
|
||||
self.registry.notify(mock_event)
|
||||
mock_subscriber.assert_called_once_with(mock_event)
|
||||
|
||||
def test_registry_has_unique_id(self):
|
||||
"""Test that each registry instance has a unique ID."""
|
||||
registry2 = LLMRegistry(config=self.config)
|
||||
self.assertNotEqual(self.registry.registry_id, registry2.registry_id)
|
||||
self.assertTrue(len(self.registry.registry_id) > 0)
|
||||
self.assertTrue(len(registry2.registry_id) > 0)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -12,6 +12,8 @@ from openhands.core.config.mcp_config import (
|
||||
MCPSSEServerConfig,
|
||||
MCPStdioServerConfig,
|
||||
)
|
||||
from openhands.llm.llm_registry import LLMRegistry
|
||||
from openhands.server.services.conversation_stats import ConversationStats
|
||||
from openhands.server.session.conversation_init_data import ConversationInitData
|
||||
from openhands.server.session.session import Session
|
||||
from openhands.storage.memory import InMemoryFileStore
|
||||
@@ -428,6 +430,8 @@ async def test_session_preserves_env_mcp_config(monkeypatch):
|
||||
file_store=InMemoryFileStore({}),
|
||||
config=config,
|
||||
sio=AsyncMock(),
|
||||
llm_registry=LLMRegistry(config=OpenHandsConfig()),
|
||||
convo_stats=ConversationStats(None, 'test-sid', None),
|
||||
)
|
||||
|
||||
# Create empty settings
|
||||
|
||||
@@ -8,7 +8,8 @@ import pytest
|
||||
from mcp import McpError
|
||||
|
||||
from openhands.controller.agent import Agent
|
||||
from openhands.controller.agent_controller import AgentController, AgentState
|
||||
from openhands.controller.agent_controller import AgentController
|
||||
from openhands.core.schema import AgentState
|
||||
from openhands.events.action.mcp import MCPAction
|
||||
from openhands.events.action.message import SystemMessageAction
|
||||
from openhands.events.event import EventSource
|
||||
@@ -17,6 +18,8 @@ from openhands.events.stream import EventStream
|
||||
from openhands.mcp.client import MCPClient
|
||||
from openhands.mcp.tool import MCPClientTool
|
||||
from openhands.mcp.utils import call_tool_mcp
|
||||
from openhands.server.services.conversation_stats import ConversationStats
|
||||
from openhands.storage.memory import InMemoryFileStore
|
||||
|
||||
|
||||
class MockConfig:
|
||||
@@ -34,6 +37,11 @@ class MockLLM:
|
||||
self.config = MockConfig()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def convo_stats():
|
||||
return ConversationStats(None, 'convo-id', None)
|
||||
|
||||
|
||||
class MockAgent(Agent):
|
||||
"""Mock agent for testing."""
|
||||
|
||||
@@ -53,7 +61,7 @@ class MockAgent(Agent):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mcp_tool_timeout_error_handling():
|
||||
async def test_mcp_tool_timeout_error_handling(convo_stats):
|
||||
"""Test that verifies MCP tool timeout errors are properly handled and returned as observations."""
|
||||
# Create a mock MCPClient
|
||||
mock_client = mock.MagicMock(spec=MCPClient)
|
||||
@@ -80,7 +88,7 @@ async def test_mcp_tool_timeout_error_handling():
|
||||
mock_client.tool_map = {'test_tool': mock_tool}
|
||||
|
||||
# Create a mock file store
|
||||
mock_file_store = mock.MagicMock()
|
||||
mock_file_store = InMemoryFileStore({})
|
||||
|
||||
# Create a mock event stream
|
||||
event_stream = EventStream(sid='test-session', file_store=mock_file_store)
|
||||
@@ -90,13 +98,12 @@ async def test_mcp_tool_timeout_error_handling():
|
||||
|
||||
# Create a mock agent controller
|
||||
controller = AgentController(
|
||||
sid='test-session',
|
||||
file_store=mock_file_store,
|
||||
user_id='test-user',
|
||||
agent=agent,
|
||||
event_stream=event_stream,
|
||||
convo_stats=convo_stats,
|
||||
iteration_delta=10,
|
||||
budget_per_task_delta=None,
|
||||
sid='test-session',
|
||||
)
|
||||
|
||||
# Set up the agent state
|
||||
@@ -143,7 +150,7 @@ async def test_mcp_tool_timeout_error_handling():
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mcp_tool_timeout_agent_continuation():
|
||||
async def test_mcp_tool_timeout_agent_continuation(convo_stats):
|
||||
"""Test that verifies the agent can continue processing after an MCP tool timeout."""
|
||||
# Create a mock MCPClient
|
||||
mock_client = mock.MagicMock(spec=MCPClient)
|
||||
@@ -170,7 +177,7 @@ async def test_mcp_tool_timeout_agent_continuation():
|
||||
mock_client.tool_map = {'test_tool': mock_tool}
|
||||
|
||||
# Create a mock file store
|
||||
mock_file_store = mock.MagicMock()
|
||||
mock_file_store = InMemoryFileStore({})
|
||||
|
||||
# Create a mock event stream
|
||||
event_stream = EventStream(sid='test-session', file_store=mock_file_store)
|
||||
@@ -180,13 +187,12 @@ async def test_mcp_tool_timeout_agent_continuation():
|
||||
|
||||
# Create a mock agent controller
|
||||
controller = AgentController(
|
||||
sid='test-session',
|
||||
file_store=mock_file_store,
|
||||
user_id='test-user',
|
||||
agent=agent,
|
||||
event_stream=event_stream,
|
||||
convo_stats=convo_stats,
|
||||
iteration_delta=10,
|
||||
budget_per_task_delta=None,
|
||||
sid='test-session',
|
||||
)
|
||||
|
||||
# Set up the agent state
|
||||
|
||||
@@ -21,11 +21,13 @@ from openhands.events.observation.agent import (
|
||||
from openhands.events.serialization.observation import observation_from_dict
|
||||
from openhands.events.stream import EventStream
|
||||
from openhands.llm import LLM
|
||||
from openhands.llm.llm_registry import LLMRegistry
|
||||
from openhands.llm.metrics import Metrics
|
||||
from openhands.memory.memory import Memory
|
||||
from openhands.runtime.impl.action_execution.action_execution_client import (
|
||||
ActionExecutionClient,
|
||||
)
|
||||
from openhands.server.services.conversation_stats import ConversationStats
|
||||
from openhands.server.session.agent_session import AgentSession
|
||||
from openhands.storage.memory import InMemoryFileStore
|
||||
from openhands.utils.prompt import (
|
||||
@@ -42,6 +44,12 @@ def file_store():
|
||||
return InMemoryFileStore({})
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_llm_registry(file_store):
|
||||
"""Create a mock LLMRegistry for testing."""
|
||||
return MagicMock(spec=LLMRegistry)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def event_stream(file_store):
|
||||
"""Create a test event stream."""
|
||||
@@ -90,24 +98,29 @@ def mock_agent():
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_memory_on_event_exception_handling(memory, event_stream, mock_agent):
|
||||
async def test_memory_on_event_exception_handling(
|
||||
memory, event_stream, mock_agent, mock_llm_registry
|
||||
):
|
||||
"""Test that exceptions in Memory.on_event are properly handled via status callback."""
|
||||
# Create a mock runtime
|
||||
runtime = MagicMock(spec=ActionExecutionClient)
|
||||
runtime.event_stream = event_stream
|
||||
|
||||
# Mock Memory method to raise an exception
|
||||
with patch.object(
|
||||
memory, '_on_workspace_context_recall', side_effect=Exception('Test error')
|
||||
with (
|
||||
patch.object(
|
||||
memory, '_on_workspace_context_recall', side_effect=Exception('Test error')
|
||||
),
|
||||
patch('openhands.core.main.create_agent', return_value=mock_agent),
|
||||
):
|
||||
state = await run_controller(
|
||||
config=OpenHandsConfig(),
|
||||
initial_user_action=MessageAction(content='Test message'),
|
||||
runtime=runtime,
|
||||
sid='test',
|
||||
agent=mock_agent,
|
||||
fake_user_response_fn=lambda _: 'repeat',
|
||||
memory=memory,
|
||||
llm_registry=mock_llm_registry,
|
||||
)
|
||||
|
||||
# Verify that the controller's last error was set
|
||||
@@ -118,7 +131,7 @@ async def test_memory_on_event_exception_handling(memory, event_stream, mock_age
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_memory_on_workspace_context_recall_exception_handling(
|
||||
memory, event_stream, mock_agent
|
||||
memory, event_stream, mock_agent, mock_llm_registry
|
||||
):
|
||||
"""Test that exceptions in Memory._on_workspace_context_recall are properly handled via status callback."""
|
||||
# Create a mock runtime
|
||||
@@ -126,19 +139,22 @@ async def test_memory_on_workspace_context_recall_exception_handling(
|
||||
runtime.event_stream = event_stream
|
||||
|
||||
# Mock Memory._on_workspace_context_recall to raise an exception
|
||||
with patch.object(
|
||||
memory,
|
||||
'_find_microagent_knowledge',
|
||||
side_effect=Exception('Test error from _find_microagent_knowledge'),
|
||||
with (
|
||||
patch.object(
|
||||
memory,
|
||||
'_find_microagent_knowledge',
|
||||
side_effect=Exception('Test error from _find_microagent_knowledge'),
|
||||
),
|
||||
patch('openhands.core.main.create_agent', return_value=mock_agent),
|
||||
):
|
||||
state = await run_controller(
|
||||
config=OpenHandsConfig(),
|
||||
initial_user_action=MessageAction(content='Test message'),
|
||||
runtime=runtime,
|
||||
sid='test',
|
||||
agent=mock_agent,
|
||||
fake_user_response_fn=lambda _: 'repeat',
|
||||
memory=memory,
|
||||
llm_registry=mock_llm_registry,
|
||||
)
|
||||
|
||||
# Verify that the controller's last error was set
|
||||
@@ -593,12 +609,14 @@ REPOSITORY INSTRUCTIONS: This is the second test repository.
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_conversation_instructions_plumbed_to_memory(
|
||||
mock_agent, event_stream, file_store
|
||||
mock_agent, event_stream, file_store, mock_llm_registry
|
||||
):
|
||||
# Setup
|
||||
session = AgentSession(
|
||||
sid='test-session',
|
||||
file_store=file_store,
|
||||
llm_registry=mock_llm_registry,
|
||||
convo_stats=ConversationStats(file_store, 'test-session', None),
|
||||
)
|
||||
|
||||
# Create a mock runtime and set it up
|
||||
|
||||
@@ -3,26 +3,30 @@ from litellm import ModelResponse
|
||||
|
||||
from openhands.agenthub.codeact_agent.codeact_agent import CodeActAgent
|
||||
from openhands.core.config import AgentConfig, LLMConfig
|
||||
from openhands.core.config.openhands_config import OpenHandsConfig
|
||||
from openhands.events.action import MessageAction
|
||||
from openhands.llm.llm import LLM
|
||||
from openhands.llm.llm_registry import LLMRegistry
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_llm():
|
||||
llm = LLM(
|
||||
LLMConfig(
|
||||
model='claude-3-5-sonnet-20241022',
|
||||
api_key='fake',
|
||||
caching_prompt=True,
|
||||
)
|
||||
def llm_config():
|
||||
return LLMConfig(
|
||||
model='claude-3-5-sonnet-20241022',
|
||||
api_key='fake',
|
||||
caching_prompt=True,
|
||||
)
|
||||
return llm
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def codeact_agent(mock_llm):
|
||||
def llm_registry():
|
||||
registry = LLMRegistry(config=OpenHandsConfig())
|
||||
return registry
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def codeact_agent(llm_registry):
|
||||
config = AgentConfig()
|
||||
agent = CodeActAgent(mock_llm, config)
|
||||
agent = CodeActAgent(config, llm_registry)
|
||||
return agent
|
||||
|
||||
|
||||
|
||||
@@ -12,14 +12,26 @@ from openhands.events.observation import NullObservation, Observation
|
||||
from openhands.events.stream import EventStream
|
||||
from openhands.integrations.provider import ProviderHandler, ProviderToken, ProviderType
|
||||
from openhands.integrations.service_types import AuthenticationError, Repository
|
||||
from openhands.llm.llm_registry import LLMRegistry
|
||||
from openhands.runtime.base import Runtime
|
||||
from openhands.storage import get_file_store
|
||||
|
||||
|
||||
class TestRuntime(Runtime):
|
||||
class MockRuntime(Runtime):
|
||||
"""A concrete implementation of Runtime for testing"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
# Ensure llm_registry is provided if not already in kwargs
|
||||
if 'llm_registry' not in kwargs and len(args) < 3:
|
||||
# Create a mock LLMRegistry if not provided
|
||||
config = (
|
||||
kwargs.get('config')
|
||||
if 'config' in kwargs
|
||||
else args[0]
|
||||
if args
|
||||
else OpenHandsConfig()
|
||||
)
|
||||
kwargs['llm_registry'] = LLMRegistry(config=config)
|
||||
super().__init__(*args, **kwargs)
|
||||
self.run_action_calls = []
|
||||
self._execute_shell_fn_git_handler = MagicMock(
|
||||
@@ -89,9 +101,11 @@ def runtime(temp_dir):
|
||||
)
|
||||
file_store = get_file_store('local', temp_dir)
|
||||
event_stream = EventStream('abc', file_store)
|
||||
runtime = TestRuntime(
|
||||
llm_registry = LLMRegistry(config=config)
|
||||
runtime = MockRuntime(
|
||||
config=config,
|
||||
event_stream=event_stream,
|
||||
llm_registry=llm_registry,
|
||||
sid='test',
|
||||
user_id='test_user',
|
||||
git_provider_tokens=git_provider_tokens,
|
||||
@@ -119,7 +133,7 @@ async def test_export_latest_git_provider_tokens_no_user_id(temp_dir):
|
||||
config = OpenHandsConfig()
|
||||
file_store = get_file_store('local', temp_dir)
|
||||
event_stream = EventStream('abc', file_store)
|
||||
runtime = TestRuntime(config=config, event_stream=event_stream, sid='test')
|
||||
runtime = MockRuntime(config=config, event_stream=event_stream, sid='test')
|
||||
|
||||
# Create a command that would normally trigger token export
|
||||
cmd = CmdRunAction(command='echo $GITHUB_TOKEN')
|
||||
@@ -137,7 +151,7 @@ async def test_export_latest_git_provider_tokens_no_token_ref(temp_dir):
|
||||
config = OpenHandsConfig()
|
||||
file_store = get_file_store('local', temp_dir)
|
||||
event_stream = EventStream('abc', file_store)
|
||||
runtime = TestRuntime(
|
||||
runtime = MockRuntime(
|
||||
config=config, event_stream=event_stream, sid='test', user_id='test_user'
|
||||
)
|
||||
|
||||
@@ -177,7 +191,7 @@ async def test_export_latest_git_provider_tokens_multiple_refs(temp_dir):
|
||||
)
|
||||
file_store = get_file_store('local', temp_dir)
|
||||
event_stream = EventStream('abc', file_store)
|
||||
runtime = TestRuntime(
|
||||
runtime = MockRuntime(
|
||||
config=config,
|
||||
event_stream=event_stream,
|
||||
sid='test',
|
||||
@@ -225,7 +239,7 @@ async def test_clone_or_init_repo_no_repo_init_git_in_empty_workspace(temp_dir):
|
||||
config.init_git_in_empty_workspace = True
|
||||
file_store = get_file_store('local', temp_dir)
|
||||
event_stream = EventStream('abc', file_store)
|
||||
runtime = TestRuntime(
|
||||
runtime = MockRuntime(
|
||||
config=config, event_stream=event_stream, sid='test', user_id=None
|
||||
)
|
||||
|
||||
@@ -249,7 +263,7 @@ async def test_clone_or_init_repo_no_repo_no_user_id_with_workspace_base(temp_di
|
||||
config.workspace_base = '/some/path' # Set workspace_base
|
||||
file_store = get_file_store('local', temp_dir)
|
||||
event_stream = EventStream('abc', file_store)
|
||||
runtime = TestRuntime(
|
||||
runtime = MockRuntime(
|
||||
config=config, event_stream=event_stream, sid='test', user_id=None
|
||||
)
|
||||
|
||||
@@ -267,7 +281,7 @@ async def test_clone_or_init_repo_auth_error(temp_dir):
|
||||
config = OpenHandsConfig()
|
||||
file_store = get_file_store('local', temp_dir)
|
||||
event_stream = EventStream('abc', file_store)
|
||||
runtime = TestRuntime(
|
||||
runtime = MockRuntime(
|
||||
config=config, event_stream=event_stream, sid='test', user_id='test_user'
|
||||
)
|
||||
|
||||
@@ -298,7 +312,7 @@ async def test_clone_or_init_repo_github_with_token(temp_dir, monkeypatch):
|
||||
{ProviderType.GITHUB: ProviderToken(token=SecretStr(github_token))}
|
||||
)
|
||||
|
||||
runtime = TestRuntime(
|
||||
runtime = MockRuntime(
|
||||
config=config,
|
||||
event_stream=event_stream,
|
||||
sid='test',
|
||||
@@ -336,7 +350,7 @@ async def test_clone_or_init_repo_github_no_token(temp_dir, monkeypatch):
|
||||
file_store = get_file_store('local', temp_dir)
|
||||
event_stream = EventStream('abc', file_store)
|
||||
|
||||
runtime = TestRuntime(
|
||||
runtime = MockRuntime(
|
||||
config=config, event_stream=event_stream, sid='test', user_id='test_user'
|
||||
)
|
||||
|
||||
@@ -371,7 +385,7 @@ async def test_clone_or_init_repo_gitlab_with_token(temp_dir, monkeypatch):
|
||||
{ProviderType.GITLAB: ProviderToken(token=SecretStr(gitlab_token))}
|
||||
)
|
||||
|
||||
runtime = TestRuntime(
|
||||
runtime = MockRuntime(
|
||||
config=config,
|
||||
event_stream=event_stream,
|
||||
sid='test',
|
||||
@@ -410,7 +424,7 @@ async def test_clone_or_init_repo_with_branch(temp_dir, monkeypatch):
|
||||
file_store = get_file_store('local', temp_dir)
|
||||
event_stream = EventStream('abc', file_store)
|
||||
|
||||
runtime = TestRuntime(
|
||||
runtime = MockRuntime(
|
||||
config=config, event_stream=event_stream, sid='test', user_id='test_user'
|
||||
)
|
||||
|
||||
|
||||
@@ -9,10 +9,12 @@ import pytest
|
||||
from openhands.core.config import OpenHandsConfig, SandboxConfig
|
||||
from openhands.events import EventStream
|
||||
from openhands.integrations.service_types import ProviderType, Repository
|
||||
from openhands.llm.llm_registry import LLMRegistry
|
||||
from openhands.microagent.microagent import (
|
||||
RepoMicroagent,
|
||||
)
|
||||
from openhands.runtime.base import Runtime
|
||||
from openhands.storage import get_file_store
|
||||
|
||||
|
||||
class MockRuntime(Runtime):
|
||||
@@ -24,12 +26,21 @@ class MockRuntime(Runtime):
|
||||
config.workspace_mount_path_in_sandbox = str(workspace_root)
|
||||
config.sandbox = SandboxConfig()
|
||||
|
||||
# Create a mock event stream
|
||||
# Create a mock event stream and file store
|
||||
file_store = get_file_store('local', str(workspace_root))
|
||||
event_stream = MagicMock(spec=EventStream)
|
||||
event_stream.file_store = file_store
|
||||
|
||||
# Create a mock LLM registry
|
||||
llm_registry = LLMRegistry(config)
|
||||
|
||||
# Initialize the parent class properly
|
||||
super().__init__(
|
||||
config=config, event_stream=event_stream, sid='test', git_provider_tokens={}
|
||||
config=config,
|
||||
event_stream=event_stream,
|
||||
llm_registry=llm_registry,
|
||||
sid='test',
|
||||
git_provider_tokens={},
|
||||
)
|
||||
|
||||
self._workspace_root = workspace_root
|
||||
|
||||
@@ -595,7 +595,7 @@ async def test_check_usertask(
|
||||
analyzer = InvariantAnalyzer(event_stream)
|
||||
mock_response = {'choices': [{'message': {'content': is_appropriate}}]}
|
||||
mock_litellm_completion.return_value = mock_response
|
||||
analyzer.guardrail_llm = LLM(config=default_config)
|
||||
analyzer.guardrail_llm = LLM(config=default_config, service_id='test')
|
||||
analyzer.check_browsing_alignment = True
|
||||
data = [
|
||||
(MessageAction(usertask), EventSource.USER),
|
||||
@@ -657,7 +657,7 @@ async def test_check_fillaction(
|
||||
analyzer = InvariantAnalyzer(event_stream)
|
||||
mock_response = {'choices': [{'message': {'content': is_harmful}}]}
|
||||
mock_litellm_completion.return_value = mock_response
|
||||
analyzer.guardrail_llm = LLM(config=default_config)
|
||||
analyzer.guardrail_llm = LLM(config=default_config, service_id='test')
|
||||
analyzer.check_browsing_alignment = True
|
||||
data = [
|
||||
(BrowseInteractiveAction(browser_actions=fillaction), EventSource.AGENT),
|
||||
|
||||
@@ -7,7 +7,9 @@ from litellm.exceptions import (
|
||||
|
||||
from openhands.core.config.llm_config import LLMConfig
|
||||
from openhands.core.config.openhands_config import OpenHandsConfig
|
||||
from openhands.llm.llm_registry import LLMRegistry
|
||||
from openhands.runtime.runtime_status import RuntimeStatus
|
||||
from openhands.server.services.conversation_stats import ConversationStats
|
||||
from openhands.server.session.session import Session
|
||||
from openhands.storage.memory import InMemoryFileStore
|
||||
|
||||
@@ -33,10 +35,28 @@ def default_llm_config():
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def llm_registry():
|
||||
config = OpenHandsConfig()
|
||||
return LLMRegistry(config=config)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def conversation_stats():
|
||||
file_store = InMemoryFileStore({})
|
||||
return ConversationStats(
|
||||
file_store=file_store, conversation_id='test-conversation', user_id='test-user'
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('openhands.llm.llm.litellm_completion')
|
||||
async def test_notify_on_llm_retry(
|
||||
mock_litellm_completion, mock_sio, default_llm_config
|
||||
mock_litellm_completion,
|
||||
mock_sio,
|
||||
default_llm_config,
|
||||
llm_registry,
|
||||
conversation_stats,
|
||||
):
|
||||
config = OpenHandsConfig()
|
||||
config.set_llm_config(default_llm_config)
|
||||
@@ -44,6 +64,8 @@ async def test_notify_on_llm_retry(
|
||||
sid='..sid..',
|
||||
file_store=InMemoryFileStore({}),
|
||||
config=config,
|
||||
llm_registry=llm_registry,
|
||||
convo_stats=conversation_stats,
|
||||
sio=mock_sio,
|
||||
user_id='..uid..',
|
||||
)
|
||||
@@ -56,12 +78,20 @@ async def test_notify_on_llm_retry(
|
||||
),
|
||||
{'choices': [{'message': {'content': 'Retry successful'}}]},
|
||||
]
|
||||
llm = session._create_llm('..cls..')
|
||||
|
||||
llm.completion(
|
||||
messages=[{'role': 'user', 'content': 'Hello!'}],
|
||||
stream=False,
|
||||
)
|
||||
# Set the retry listener on the registry
|
||||
llm_registry.retry_listner = session._notify_on_llm_retry
|
||||
|
||||
# Create an LLM through the registry
|
||||
llm = llm_registry.get_llm(
|
||||
service_id='test_service',
|
||||
config=default_llm_config,
|
||||
)
|
||||
|
||||
llm.completion(
|
||||
messages=[{'role': 'user', 'content': 'Hello!'}],
|
||||
stream=False,
|
||||
)
|
||||
|
||||
assert mock_litellm_completion.call_count == 2
|
||||
session.queue_status_message.assert_called_once_with(
|
||||
|
||||
Reference in New Issue
Block a user