mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-04-29 03:00:45 -04:00
Compare commits
13 Commits
increase-a
...
openhands/
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4d97ae6c5d | ||
|
|
eeb81ecc49 | ||
|
|
eaea8b3ce1 | ||
|
|
72555e0f1c | ||
|
|
fd13c91387 | ||
|
|
6139e39449 | ||
|
|
f76ac242f0 | ||
|
|
1f9350320f | ||
|
|
1a3460ba06 | ||
|
|
8f361b3698 | ||
|
|
4d0f2e7a6d | ||
|
|
2c6d1e97e8 | ||
|
|
180557265f |
@@ -252,7 +252,12 @@ def get_api_key_from_header(request: Request):
|
||||
# This is a temp hack
|
||||
# Streamable HTTP MCP Client works via redirect requests, but drops the Authorization header for reason
|
||||
# We include `X-Session-API-Key` header by default due to nested runtimes, so it used as a drop in replacement here
|
||||
return request.headers.get('X-Session-API-Key')
|
||||
session_api_key = request.headers.get('X-Session-API-Key')
|
||||
if session_api_key:
|
||||
return session_api_key
|
||||
|
||||
# Fallback to X-Access-Token header as an additional option
|
||||
return request.headers.get('X-Access-Token')
|
||||
|
||||
|
||||
async def saas_user_auth_from_bearer(request: Request) -> SaasUserAuth | None:
|
||||
|
||||
@@ -535,3 +535,115 @@ def test_get_api_key_from_header_with_invalid_authorization_format():
|
||||
|
||||
# Assert that None was returned
|
||||
assert api_key is None
|
||||
|
||||
|
||||
def test_get_api_key_from_header_with_x_access_token():
|
||||
"""Test that get_api_key_from_header extracts API key from X-Access-Token header."""
|
||||
# Create a mock request with X-Access-Token header
|
||||
mock_request = MagicMock(spec=Request)
|
||||
mock_request.headers = {'X-Access-Token': 'access_token_key'}
|
||||
|
||||
# Call the function
|
||||
api_key = get_api_key_from_header(mock_request)
|
||||
|
||||
# Assert that the API key was correctly extracted
|
||||
assert api_key == 'access_token_key'
|
||||
|
||||
|
||||
def test_get_api_key_from_header_priority_authorization_over_x_access_token():
|
||||
"""Test that Authorization header takes priority over X-Access-Token header."""
|
||||
# Create a mock request with both headers
|
||||
mock_request = MagicMock(spec=Request)
|
||||
mock_request.headers = {
|
||||
'Authorization': 'Bearer auth_api_key',
|
||||
'X-Access-Token': 'access_token_key',
|
||||
}
|
||||
|
||||
# Call the function
|
||||
api_key = get_api_key_from_header(mock_request)
|
||||
|
||||
# Assert that the API key from Authorization header was used
|
||||
assert api_key == 'auth_api_key'
|
||||
|
||||
|
||||
def test_get_api_key_from_header_priority_x_session_over_x_access_token():
|
||||
"""Test that X-Session-API-Key header takes priority over X-Access-Token header."""
|
||||
# Create a mock request with both headers
|
||||
mock_request = MagicMock(spec=Request)
|
||||
mock_request.headers = {
|
||||
'X-Session-API-Key': 'session_api_key',
|
||||
'X-Access-Token': 'access_token_key',
|
||||
}
|
||||
|
||||
# Call the function
|
||||
api_key = get_api_key_from_header(mock_request)
|
||||
|
||||
# Assert that the API key from X-Session-API-Key header was used
|
||||
assert api_key == 'session_api_key'
|
||||
|
||||
|
||||
def test_get_api_key_from_header_all_three_headers():
|
||||
"""Test header priority when all three headers are present."""
|
||||
# Create a mock request with all three headers
|
||||
mock_request = MagicMock(spec=Request)
|
||||
mock_request.headers = {
|
||||
'Authorization': 'Bearer auth_api_key',
|
||||
'X-Session-API-Key': 'session_api_key',
|
||||
'X-Access-Token': 'access_token_key',
|
||||
}
|
||||
|
||||
# Call the function
|
||||
api_key = get_api_key_from_header(mock_request)
|
||||
|
||||
# Assert that the API key from Authorization header was used (highest priority)
|
||||
assert api_key == 'auth_api_key'
|
||||
|
||||
|
||||
def test_get_api_key_from_header_invalid_authorization_fallback_to_x_access_token():
|
||||
"""Test that invalid Authorization header falls back to X-Access-Token."""
|
||||
# Create a mock request with invalid Authorization header and X-Access-Token
|
||||
mock_request = MagicMock(spec=Request)
|
||||
mock_request.headers = {
|
||||
'Authorization': 'InvalidFormat api_key',
|
||||
'X-Access-Token': 'access_token_key',
|
||||
}
|
||||
|
||||
# Call the function
|
||||
api_key = get_api_key_from_header(mock_request)
|
||||
|
||||
# Assert that the API key from X-Access-Token header was used
|
||||
assert api_key == 'access_token_key'
|
||||
|
||||
|
||||
def test_get_api_key_from_header_empty_headers():
|
||||
"""Test that empty header values are handled correctly."""
|
||||
# Create a mock request with empty header values
|
||||
mock_request = MagicMock(spec=Request)
|
||||
mock_request.headers = {
|
||||
'Authorization': '',
|
||||
'X-Session-API-Key': '',
|
||||
'X-Access-Token': 'access_token_key',
|
||||
}
|
||||
|
||||
# Call the function
|
||||
api_key = get_api_key_from_header(mock_request)
|
||||
|
||||
# Assert that the API key from X-Access-Token header was used
|
||||
assert api_key == 'access_token_key'
|
||||
|
||||
|
||||
def test_get_api_key_from_header_bearer_with_empty_token():
|
||||
"""Test that Bearer header with empty token falls back to other headers."""
|
||||
# Create a mock request with Bearer header with empty token
|
||||
mock_request = MagicMock(spec=Request)
|
||||
mock_request.headers = {
|
||||
'Authorization': 'Bearer ',
|
||||
'X-Access-Token': 'access_token_key',
|
||||
}
|
||||
|
||||
# Call the function
|
||||
api_key = get_api_key_from_header(mock_request)
|
||||
|
||||
# Assert that empty string from Bearer is returned (current behavior)
|
||||
# This tests the current implementation behavior
|
||||
assert api_key == ''
|
||||
|
||||
@@ -3,15 +3,19 @@ import { Provider } from "#/types/settings";
|
||||
import { V1SandboxStatus } from "../sandbox-service/sandbox-service.types";
|
||||
|
||||
// V1 API Types for requests
|
||||
// Note: This represents the serialized API format, not the internal TextContent/ImageContent types
|
||||
export interface V1MessageContent {
|
||||
type: "text" | "image_url";
|
||||
text?: string;
|
||||
image_url?: {
|
||||
url: string;
|
||||
};
|
||||
// These types match the SDK's TextContent and ImageContent formats
|
||||
export interface V1TextContent {
|
||||
type: "text";
|
||||
text: string;
|
||||
}
|
||||
|
||||
export interface V1ImageContent {
|
||||
type: "image";
|
||||
image_urls: string[];
|
||||
}
|
||||
|
||||
export type V1MessageContent = V1TextContent | V1ImageContent;
|
||||
|
||||
type V1Role = "user" | "system" | "assistant" | "tool";
|
||||
|
||||
export interface V1SendMessageRequest {
|
||||
|
||||
@@ -41,13 +41,11 @@ export function useSendMessage() {
|
||||
},
|
||||
];
|
||||
|
||||
// Add images if present
|
||||
// Add images if present - using SDK's ImageContent format
|
||||
if (args.image_urls && args.image_urls.length > 0) {
|
||||
args.image_urls.forEach((url) => {
|
||||
content.push({
|
||||
type: "image_url",
|
||||
image_url: { url },
|
||||
});
|
||||
content.push({
|
||||
type: "image",
|
||||
image_urls: args.image_urls,
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
@@ -30,11 +30,12 @@ function BillingSettingsScreen() {
|
||||
}
|
||||
|
||||
displaySuccessToast(t(I18nKey.PAYMENT$SUCCESS));
|
||||
|
||||
setSearchParams({});
|
||||
} else if (checkoutStatus === "cancel") {
|
||||
displayErrorToast(t(I18nKey.PAYMENT$CANCELLED));
|
||||
setSearchParams({});
|
||||
}
|
||||
|
||||
setSearchParams({});
|
||||
}, [checkoutStatus, searchParams, setSearchParams, t, trackCreditsPurchased]);
|
||||
|
||||
return <PaymentForm />;
|
||||
|
||||
@@ -28,6 +28,7 @@ import { KeyStatusIcon } from "#/components/features/settings/key-status-icon";
|
||||
import { DEFAULT_SETTINGS } from "#/services/settings";
|
||||
import { getProviderId } from "#/utils/map-provider";
|
||||
import { DEFAULT_OPENHANDS_MODEL } from "#/utils/verified-models";
|
||||
import { USE_V1_CONVERSATION_API } from "#/utils/feature-flags";
|
||||
|
||||
interface OpenHandsApiKeyHelpProps {
|
||||
testId: string;
|
||||
@@ -118,6 +119,9 @@ function LlmSettingsScreen() {
|
||||
const isSaasMode = config?.APP_MODE === "saas";
|
||||
const shouldUseOpenHandsKey = isOpenHandsProvider && isSaasMode;
|
||||
|
||||
// Determine if we should hide the agent dropdown when V1 conversation API feature flag is enabled
|
||||
const isV1Enabled = USE_V1_CONVERSATION_API();
|
||||
|
||||
React.useEffect(() => {
|
||||
const determineWhetherToToggleAdvancedSettings = () => {
|
||||
if (resources && settings) {
|
||||
@@ -612,21 +616,23 @@ function LlmSettingsScreen() {
|
||||
href="https://tavily.com/"
|
||||
/>
|
||||
|
||||
<SettingsDropdownInput
|
||||
testId="agent-input"
|
||||
name="agent-input"
|
||||
label={t(I18nKey.SETTINGS$AGENT)}
|
||||
items={
|
||||
resources?.agents.map((agent) => ({
|
||||
key: agent,
|
||||
label: agent, // TODO: Add i18n support for agent names
|
||||
})) || []
|
||||
}
|
||||
defaultSelectedKey={settings.AGENT}
|
||||
isClearable={false}
|
||||
onInputChange={handleAgentIsDirty}
|
||||
wrapperClassName="w-full max-w-[680px]"
|
||||
/>
|
||||
{!isV1Enabled && (
|
||||
<SettingsDropdownInput
|
||||
testId="agent-input"
|
||||
name="agent-input"
|
||||
label={t(I18nKey.SETTINGS$AGENT)}
|
||||
items={
|
||||
resources?.agents.map((agent) => ({
|
||||
key: agent,
|
||||
label: agent, // TODO: Add i18n support for agent names
|
||||
})) || []
|
||||
}
|
||||
defaultSelectedKey={settings.AGENT}
|
||||
isClearable={false}
|
||||
onInputChange={handleAgentIsDirty}
|
||||
wrapperClassName="w-full max-w-[680px]"
|
||||
/>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
AppConversationSortOrder,
|
||||
)
|
||||
from openhands.app_server.services.injector import Injector
|
||||
from openhands.sdk.event import ConversationStateUpdateEvent
|
||||
from openhands.sdk.utils.models import DiscriminatedUnionMixin
|
||||
|
||||
|
||||
@@ -92,6 +93,19 @@ class AppConversationInfoService(ABC):
|
||||
Return the stored info
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def process_stats_event(
|
||||
self,
|
||||
event: ConversationStateUpdateEvent,
|
||||
conversation_id: UUID,
|
||||
) -> None:
|
||||
"""Process a stats event and update conversation statistics.
|
||||
|
||||
Args:
|
||||
event: The ConversationStateUpdateEvent with key='stats'
|
||||
conversation_id: The ID of the conversation to update
|
||||
"""
|
||||
|
||||
|
||||
class AppConversationInfoServiceInjector(
|
||||
DiscriminatedUnionMixin, Injector[AppConversationInfoService], ABC
|
||||
|
||||
@@ -9,6 +9,7 @@ from typing import AsyncGenerator
|
||||
import base62
|
||||
|
||||
from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
AgentType,
|
||||
AppConversationStartTask,
|
||||
AppConversationStartTaskStatus,
|
||||
)
|
||||
@@ -25,7 +26,9 @@ from openhands.app_server.sandbox.sandbox_models import SandboxInfo
|
||||
from openhands.app_server.user.user_context import UserContext
|
||||
from openhands.sdk import Agent
|
||||
from openhands.sdk.context.agent_context import AgentContext
|
||||
from openhands.sdk.context.condenser import LLMSummarizingCondenser
|
||||
from openhands.sdk.context.skills import load_user_skills
|
||||
from openhands.sdk.llm import LLM
|
||||
from openhands.sdk.workspace.remote.async_remote_workspace import AsyncRemoteWorkspace
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
@@ -182,6 +185,43 @@ class AppConversationServiceBase(AppConversationService, ABC):
|
||||
workspace.working_dir,
|
||||
)
|
||||
|
||||
async def _configure_git_user_settings(
|
||||
self,
|
||||
workspace: AsyncRemoteWorkspace,
|
||||
) -> None:
|
||||
"""Configure git global user settings from user preferences.
|
||||
|
||||
Reads git_user_name and git_user_email from user settings and
|
||||
configures them as git global settings in the workspace.
|
||||
|
||||
Args:
|
||||
workspace: The remote workspace to configure git settings in.
|
||||
"""
|
||||
try:
|
||||
user_info = await self.user_context.get_user_info()
|
||||
|
||||
if user_info.git_user_name:
|
||||
cmd = f'git config --global user.name "{user_info.git_user_name}"'
|
||||
result = await workspace.execute_command(cmd, workspace.working_dir)
|
||||
if result.exit_code:
|
||||
_logger.warning(f'Git config user.name failed: {result.stderr}')
|
||||
else:
|
||||
_logger.info(
|
||||
f'Git configured with user.name={user_info.git_user_name}'
|
||||
)
|
||||
|
||||
if user_info.git_user_email:
|
||||
cmd = f'git config --global user.email "{user_info.git_user_email}"'
|
||||
result = await workspace.execute_command(cmd, workspace.working_dir)
|
||||
if result.exit_code:
|
||||
_logger.warning(f'Git config user.email failed: {result.stderr}')
|
||||
else:
|
||||
_logger.info(
|
||||
f'Git configured with user.email={user_info.git_user_email}'
|
||||
)
|
||||
except Exception as e:
|
||||
_logger.warning(f'Failed to configure git user settings: {e}')
|
||||
|
||||
async def clone_or_init_git_repo(
|
||||
self,
|
||||
task: AppConversationStartTask,
|
||||
@@ -197,6 +237,9 @@ class AppConversationServiceBase(AppConversationService, ABC):
|
||||
if result.exit_code:
|
||||
_logger.warning(f'mkdir failed: {result.stderr}')
|
||||
|
||||
# Configure git user settings from user preferences
|
||||
await self._configure_git_user_settings(workspace)
|
||||
|
||||
if not request.selected_repository:
|
||||
if self.init_git_in_empty_workspace:
|
||||
_logger.debug('Initializing a new git repository in the workspace.')
|
||||
@@ -221,7 +264,9 @@ class AppConversationServiceBase(AppConversationService, ABC):
|
||||
|
||||
# Clone the repo - this is the slow part!
|
||||
clone_command = f'git clone {remote_repo_url} {dir_name}'
|
||||
result = await workspace.execute_command(clone_command, workspace.working_dir)
|
||||
result = await workspace.execute_command(
|
||||
clone_command, workspace.working_dir, 120
|
||||
)
|
||||
if result.exit_code:
|
||||
_logger.warning(f'Git clone failed: {result.stderr}')
|
||||
|
||||
@@ -233,7 +278,10 @@ class AppConversationServiceBase(AppConversationService, ABC):
|
||||
random_str = base62.encodebytes(os.urandom(16))
|
||||
openhands_workspace_branch = f'openhands-workspace-{random_str}'
|
||||
checkout_command = f'git checkout -b {openhands_workspace_branch}'
|
||||
await workspace.execute_command(checkout_command, workspace.working_dir)
|
||||
git_dir = Path(workspace.working_dir) / dir_name
|
||||
result = await workspace.execute_command(checkout_command, git_dir)
|
||||
if result.exit_code:
|
||||
_logger.warning(f'Git checkout failed: {result.stderr}')
|
||||
|
||||
async def maybe_run_setup_script(
|
||||
self,
|
||||
@@ -295,3 +343,39 @@ class AppConversationServiceBase(AppConversationService, ABC):
|
||||
return
|
||||
|
||||
_logger.info('Git pre-commit hook installed successfully')
|
||||
|
||||
def _create_condenser(
|
||||
self,
|
||||
llm: LLM,
|
||||
agent_type: AgentType,
|
||||
condenser_max_size: int | None,
|
||||
) -> LLMSummarizingCondenser:
|
||||
"""Create a condenser based on user settings and agent type.
|
||||
|
||||
Args:
|
||||
llm: The LLM instance to use for condensation
|
||||
agent_type: Type of agent (PLAN or DEFAULT)
|
||||
condenser_max_size: condenser_max_size setting
|
||||
|
||||
Returns:
|
||||
Configured LLMSummarizingCondenser instance
|
||||
"""
|
||||
# LLMSummarizingCondenser has defaults: max_size=120, keep_first=4
|
||||
condenser_kwargs = {
|
||||
'llm': llm.model_copy(
|
||||
update={
|
||||
'usage_id': (
|
||||
'condenser'
|
||||
if agent_type == AgentType.DEFAULT
|
||||
else 'planning_condenser'
|
||||
)
|
||||
}
|
||||
),
|
||||
}
|
||||
# Only override max_size if user has a custom value
|
||||
if condenser_max_size is not None:
|
||||
condenser_kwargs['max_size'] = condenser_max_size
|
||||
|
||||
condenser = LLMSummarizingCondenser(**condenser_kwargs)
|
||||
|
||||
return condenser
|
||||
|
||||
@@ -76,12 +76,10 @@ from openhands.sdk.security.confirmation_policy import AlwaysConfirm
|
||||
from openhands.sdk.workspace.remote.async_remote_workspace import AsyncRemoteWorkspace
|
||||
from openhands.server.types import AppMode
|
||||
from openhands.tools.preset.default import (
|
||||
get_default_condenser,
|
||||
get_default_tools,
|
||||
)
|
||||
from openhands.tools.preset.planning import (
|
||||
format_plan_structure,
|
||||
get_planning_condenser,
|
||||
get_planning_tools,
|
||||
)
|
||||
|
||||
@@ -643,6 +641,7 @@ class LiveStatusAppConversationService(AppConversationServiceBase):
|
||||
agent_type: AgentType,
|
||||
system_message_suffix: str | None,
|
||||
mcp_config: dict,
|
||||
condenser_max_size: int | None,
|
||||
) -> Agent:
|
||||
"""Create an agent with appropriate tools and context based on agent type.
|
||||
|
||||
@@ -651,10 +650,14 @@ class LiveStatusAppConversationService(AppConversationServiceBase):
|
||||
agent_type: Type of agent to create (PLAN or DEFAULT)
|
||||
system_message_suffix: Optional suffix for system messages
|
||||
mcp_config: MCP configuration dictionary
|
||||
condenser_max_size: condenser_max_size setting
|
||||
|
||||
Returns:
|
||||
Configured Agent instance with context
|
||||
"""
|
||||
# Create condenser with user's settings
|
||||
condenser = self._create_condenser(llm, agent_type, condenser_max_size)
|
||||
|
||||
# Create agent based on type
|
||||
if agent_type == AgentType.PLAN:
|
||||
agent = Agent(
|
||||
@@ -662,9 +665,7 @@ class LiveStatusAppConversationService(AppConversationServiceBase):
|
||||
tools=get_planning_tools(),
|
||||
system_prompt_filename='system_prompt_planning.j2',
|
||||
system_prompt_kwargs={'plan_structure': format_plan_structure()},
|
||||
condenser=get_planning_condenser(
|
||||
llm=llm.model_copy(update={'usage_id': 'planning_condenser'})
|
||||
),
|
||||
condenser=condenser,
|
||||
security_analyzer=None,
|
||||
mcp_config=mcp_config,
|
||||
)
|
||||
@@ -673,9 +674,7 @@ class LiveStatusAppConversationService(AppConversationServiceBase):
|
||||
llm=llm,
|
||||
tools=get_default_tools(enable_browser=True),
|
||||
system_prompt_kwargs={'cli_mode': False},
|
||||
condenser=get_default_condenser(
|
||||
llm=llm.model_copy(update={'usage_id': 'condenser'})
|
||||
),
|
||||
condenser=condenser,
|
||||
mcp_config=mcp_config,
|
||||
)
|
||||
|
||||
@@ -777,7 +776,7 @@ class LiveStatusAppConversationService(AppConversationServiceBase):
|
||||
|
||||
# Create agent with context
|
||||
agent = self._create_agent_with_context(
|
||||
llm, agent_type, system_message_suffix, mcp_config
|
||||
llm, agent_type, system_message_suffix, mcp_config, user.condenser_max_size
|
||||
)
|
||||
|
||||
# Finalize and return the conversation request
|
||||
|
||||
@@ -45,6 +45,8 @@ from openhands.app_server.utils.sql_utils import (
|
||||
create_json_type_decorator,
|
||||
)
|
||||
from openhands.integrations.provider import ProviderType
|
||||
from openhands.sdk.conversation.conversation_stats import ConversationStats
|
||||
from openhands.sdk.event import ConversationStateUpdateEvent
|
||||
from openhands.sdk.llm import MetricsSnapshot
|
||||
from openhands.sdk.llm.utils.metrics import TokenUsage
|
||||
from openhands.storage.data_models.conversation_metadata import ConversationTrigger
|
||||
@@ -354,6 +356,130 @@ class SQLAppConversationInfoService(AppConversationInfoService):
|
||||
await self.db_session.commit()
|
||||
return info
|
||||
|
||||
async def update_conversation_statistics(
|
||||
self, conversation_id: UUID, stats: ConversationStats
|
||||
) -> None:
|
||||
"""Update conversation statistics from stats event data.
|
||||
|
||||
Args:
|
||||
conversation_id: The ID of the conversation to update
|
||||
stats: ConversationStats object containing usage_to_metrics data from stats event
|
||||
"""
|
||||
# Extract agent metrics from usage_to_metrics
|
||||
usage_to_metrics = stats.usage_to_metrics
|
||||
agent_metrics = usage_to_metrics.get('agent')
|
||||
|
||||
if not agent_metrics:
|
||||
logger.debug(
|
||||
'No agent metrics found in stats for conversation %s', conversation_id
|
||||
)
|
||||
return
|
||||
|
||||
# Query existing record using secure select (filters for V1 and user if available)
|
||||
query = await self._secure_select()
|
||||
query = query.where(
|
||||
StoredConversationMetadata.conversation_id == str(conversation_id)
|
||||
)
|
||||
result = await self.db_session.execute(query)
|
||||
stored = result.scalar_one_or_none()
|
||||
|
||||
if not stored:
|
||||
logger.debug(
|
||||
'Conversation %s not found or not accessible, skipping statistics update',
|
||||
conversation_id,
|
||||
)
|
||||
return
|
||||
|
||||
# Extract accumulated_cost and max_budget_per_task from Metrics object
|
||||
accumulated_cost = agent_metrics.accumulated_cost
|
||||
max_budget_per_task = agent_metrics.max_budget_per_task
|
||||
|
||||
# Extract accumulated_token_usage from Metrics object
|
||||
accumulated_token_usage = agent_metrics.accumulated_token_usage
|
||||
if accumulated_token_usage:
|
||||
prompt_tokens = accumulated_token_usage.prompt_tokens
|
||||
completion_tokens = accumulated_token_usage.completion_tokens
|
||||
cache_read_tokens = accumulated_token_usage.cache_read_tokens
|
||||
cache_write_tokens = accumulated_token_usage.cache_write_tokens
|
||||
reasoning_tokens = accumulated_token_usage.reasoning_tokens
|
||||
context_window = accumulated_token_usage.context_window
|
||||
per_turn_token = accumulated_token_usage.per_turn_token
|
||||
else:
|
||||
prompt_tokens = None
|
||||
completion_tokens = None
|
||||
cache_read_tokens = None
|
||||
cache_write_tokens = None
|
||||
reasoning_tokens = None
|
||||
context_window = None
|
||||
per_turn_token = None
|
||||
|
||||
# Update fields only if values are provided (not None)
|
||||
if accumulated_cost is not None:
|
||||
stored.accumulated_cost = accumulated_cost
|
||||
if max_budget_per_task is not None:
|
||||
stored.max_budget_per_task = max_budget_per_task
|
||||
if prompt_tokens is not None:
|
||||
stored.prompt_tokens = prompt_tokens
|
||||
if completion_tokens is not None:
|
||||
stored.completion_tokens = completion_tokens
|
||||
if cache_read_tokens is not None:
|
||||
stored.cache_read_tokens = cache_read_tokens
|
||||
if cache_write_tokens is not None:
|
||||
stored.cache_write_tokens = cache_write_tokens
|
||||
if reasoning_tokens is not None:
|
||||
stored.reasoning_tokens = reasoning_tokens
|
||||
if context_window is not None:
|
||||
stored.context_window = context_window
|
||||
if per_turn_token is not None:
|
||||
stored.per_turn_token = per_turn_token
|
||||
|
||||
# Update last_updated_at timestamp
|
||||
stored.last_updated_at = utc_now()
|
||||
|
||||
await self.db_session.commit()
|
||||
|
||||
async def process_stats_event(
|
||||
self,
|
||||
event: ConversationStateUpdateEvent,
|
||||
conversation_id: UUID,
|
||||
) -> None:
|
||||
"""Process a stats event and update conversation statistics.
|
||||
|
||||
Args:
|
||||
event: The ConversationStateUpdateEvent with key='stats'
|
||||
conversation_id: The ID of the conversation to update
|
||||
"""
|
||||
try:
|
||||
# Parse event value into ConversationStats model for type safety
|
||||
# event.value can be a dict (from JSON deserialization) or a ConversationStats object
|
||||
event_value = event.value
|
||||
conversation_stats: ConversationStats | None = None
|
||||
|
||||
if isinstance(event_value, ConversationStats):
|
||||
# Already a ConversationStats object
|
||||
conversation_stats = event_value
|
||||
elif isinstance(event_value, dict):
|
||||
# Parse dict into ConversationStats model
|
||||
# This validates the structure and ensures type safety
|
||||
conversation_stats = ConversationStats.model_validate(event_value)
|
||||
elif hasattr(event_value, 'usage_to_metrics'):
|
||||
# Handle objects with usage_to_metrics attribute (e.g., from tests)
|
||||
# Convert to dict first, then validate
|
||||
stats_dict = {'usage_to_metrics': event_value.usage_to_metrics}
|
||||
conversation_stats = ConversationStats.model_validate(stats_dict)
|
||||
|
||||
if conversation_stats and conversation_stats.usage_to_metrics:
|
||||
# Pass ConversationStats object directly for type safety
|
||||
await self.update_conversation_statistics(
|
||||
conversation_id, conversation_stats
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
'Error updating conversation statistics for conversation %s',
|
||||
conversation_id,
|
||||
stack_info=True,
|
||||
)
|
||||
|
||||
async def _secure_select(self):
|
||||
query = select(StoredConversationMetadata).where(
|
||||
StoredConversationMetadata.conversation_version == 'V1'
|
||||
|
||||
@@ -6,7 +6,6 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import AsyncGenerator
|
||||
from uuid import UUID
|
||||
|
||||
@@ -15,6 +14,7 @@ from sqlalchemy import UUID as SQLUUID
|
||||
from sqlalchemy import Column, Enum, String, and_, func, or_, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from openhands.agent_server.utils import utc_now
|
||||
from openhands.app_server.event_callback.event_callback_models import (
|
||||
CreateEventCallbackRequest,
|
||||
EventCallback,
|
||||
@@ -177,7 +177,7 @@ class SQLEventCallbackService(EventCallbackService):
|
||||
return EventCallbackPage(items=callbacks, next_page_id=next_page_id)
|
||||
|
||||
async def save_event_callback(self, event_callback: EventCallback) -> EventCallback:
|
||||
event_callback.updated_at = datetime.now()
|
||||
event_callback.updated_at = utc_now()
|
||||
stored_callback = StoredEventCallback(**event_callback.model_dump())
|
||||
await self.db_session.merge(stored_callback)
|
||||
return event_callback
|
||||
|
||||
@@ -43,6 +43,7 @@ from openhands.app_server.user.specifiy_user_context import (
|
||||
from openhands.app_server.user.user_context import UserContext
|
||||
from openhands.integrations.provider import ProviderType
|
||||
from openhands.sdk import Event
|
||||
from openhands.sdk.event import ConversationStateUpdateEvent
|
||||
from openhands.server.user_auth.default_user_auth import DefaultUserAuth
|
||||
from openhands.server.user_auth.user_auth import (
|
||||
get_for_user as get_user_auth_for_user,
|
||||
@@ -144,6 +145,13 @@ async def on_event(
|
||||
*[event_service.save_event(conversation_id, event) for event in events]
|
||||
)
|
||||
|
||||
# Process stats events for V1 conversations
|
||||
for event in events:
|
||||
if isinstance(event, ConversationStateUpdateEvent) and event.key == 'stats':
|
||||
await app_conversation_info_service.process_stats_event(
|
||||
event, conversation_id
|
||||
)
|
||||
|
||||
asyncio.create_task(
|
||||
_run_callbacks_in_bg_and_close(
|
||||
conversation_id, app_conversation_info.created_by_user_id, events
|
||||
|
||||
@@ -324,8 +324,8 @@ class ActionExecutionClient(Runtime):
|
||||
'POST',
|
||||
f'{self.action_execution_server_url}/execute_action',
|
||||
json=execution_action_body,
|
||||
# wait additional seconds to get the timeout error from server side
|
||||
timeout=action.timeout + 100,
|
||||
# wait a few more seconds to get the timeout error from client side
|
||||
timeout=action.timeout + 5,
|
||||
)
|
||||
assert response.is_closed
|
||||
output = response.json()
|
||||
|
||||
628
tests/unit/app_server/test_app_conversation_service_base.py
Normal file
628
tests/unit/app_server/test_app_conversation_service_base.py
Normal file
@@ -0,0 +1,628 @@
|
||||
"""Unit tests for git functionality in AppConversationServiceBase.
|
||||
|
||||
This module tests the git-related functionality, specifically the clone_or_init_git_repo method
|
||||
and the recent bug fixes for git checkout operations.
|
||||
"""
|
||||
|
||||
import subprocess
|
||||
from unittest.mock import AsyncMock, MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from openhands.app_server.app_conversation.app_conversation_models import AgentType
|
||||
from openhands.app_server.app_conversation.app_conversation_service_base import (
|
||||
AppConversationServiceBase,
|
||||
)
|
||||
from openhands.app_server.user.user_context import UserContext
|
||||
|
||||
|
||||
class MockUserInfo:
|
||||
"""Mock class for UserInfo to simulate user settings."""
|
||||
|
||||
def __init__(
|
||||
self, git_user_name: str | None = None, git_user_email: str | None = None
|
||||
):
|
||||
self.git_user_name = git_user_name
|
||||
self.git_user_email = git_user_email
|
||||
|
||||
|
||||
class MockCommandResult:
|
||||
"""Mock class for command execution result."""
|
||||
|
||||
def __init__(self, exit_code: int = 0, stderr: str = ''):
|
||||
self.exit_code = exit_code
|
||||
self.stderr = stderr
|
||||
|
||||
|
||||
class MockWorkspace:
|
||||
"""Mock class for AsyncRemoteWorkspace."""
|
||||
|
||||
def __init__(self, working_dir: str = '/workspace'):
|
||||
self.working_dir = working_dir
|
||||
self.execute_command = AsyncMock(return_value=MockCommandResult())
|
||||
|
||||
|
||||
class MockAppConversationServiceBase:
|
||||
"""Mock class to test git functionality without complex dependencies."""
|
||||
|
||||
def __init__(self):
|
||||
self.logger = MagicMock()
|
||||
|
||||
async def clone_or_init_git_repo(
|
||||
self,
|
||||
workspace_path: str,
|
||||
repo_url: str,
|
||||
branch: str = 'main',
|
||||
timeout: int = 300,
|
||||
) -> bool:
|
||||
"""Clone or initialize a git repository.
|
||||
|
||||
This is a simplified version of the actual method for testing purposes.
|
||||
"""
|
||||
try:
|
||||
# Try to clone the repository
|
||||
clone_result = subprocess.run(
|
||||
['git', 'clone', '--branch', branch, repo_url, workspace_path],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
if clone_result.returncode == 0:
|
||||
self.logger.info(
|
||||
f'Successfully cloned repository {repo_url} to {workspace_path}'
|
||||
)
|
||||
return True
|
||||
|
||||
# If clone fails, try to checkout the branch
|
||||
checkout_result = subprocess.run(
|
||||
['git', 'checkout', branch],
|
||||
cwd=workspace_path,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
if checkout_result.returncode == 0:
|
||||
self.logger.info(f'Successfully checked out branch {branch}')
|
||||
return True
|
||||
else:
|
||||
self.logger.error(
|
||||
f'Failed to checkout branch {branch}: {checkout_result.stderr}'
|
||||
)
|
||||
return False
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
self.logger.error(f'Git operation timed out after {timeout} seconds')
|
||||
return False
|
||||
except Exception as e:
|
||||
self.logger.error(f'Git operation failed: {str(e)}')
|
||||
return False
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def service():
|
||||
"""Create a mock service instance for testing."""
|
||||
return MockAppConversationServiceBase()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_clone_or_init_git_repo_successful_clone(service):
|
||||
"""Test successful git clone operation."""
|
||||
with patch('subprocess.run') as mock_run:
|
||||
# Mock successful clone
|
||||
mock_run.return_value = MagicMock(returncode=0, stderr='', stdout='Cloning...')
|
||||
|
||||
result = await service.clone_or_init_git_repo(
|
||||
workspace_path='/tmp/test_repo',
|
||||
repo_url='https://github.com/test/repo.git',
|
||||
branch='main',
|
||||
timeout=300,
|
||||
)
|
||||
|
||||
assert result is True
|
||||
mock_run.assert_called_once_with(
|
||||
[
|
||||
'git',
|
||||
'clone',
|
||||
'--branch',
|
||||
'main',
|
||||
'https://github.com/test/repo.git',
|
||||
'/tmp/test_repo',
|
||||
],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=300,
|
||||
)
|
||||
service.logger.info.assert_called_with(
|
||||
'Successfully cloned repository https://github.com/test/repo.git to /tmp/test_repo'
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_clone_or_init_git_repo_clone_fails_checkout_succeeds(service):
|
||||
"""Test git clone fails but checkout succeeds."""
|
||||
with patch('subprocess.run') as mock_run:
|
||||
# Mock clone failure, then checkout success
|
||||
mock_run.side_effect = [
|
||||
MagicMock(returncode=1, stderr='Clone failed', stdout=''), # Clone fails
|
||||
MagicMock(
|
||||
returncode=0, stderr='', stdout='Switched to branch'
|
||||
), # Checkout succeeds
|
||||
]
|
||||
|
||||
result = await service.clone_or_init_git_repo(
|
||||
workspace_path='/tmp/test_repo',
|
||||
repo_url='https://github.com/test/repo.git',
|
||||
branch='feature-branch',
|
||||
timeout=300,
|
||||
)
|
||||
|
||||
assert result is True
|
||||
assert mock_run.call_count == 2
|
||||
|
||||
# Check clone call
|
||||
mock_run.assert_any_call(
|
||||
[
|
||||
'git',
|
||||
'clone',
|
||||
'--branch',
|
||||
'feature-branch',
|
||||
'https://github.com/test/repo.git',
|
||||
'/tmp/test_repo',
|
||||
],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=300,
|
||||
)
|
||||
|
||||
# Check checkout call
|
||||
mock_run.assert_any_call(
|
||||
['git', 'checkout', 'feature-branch'],
|
||||
cwd='/tmp/test_repo',
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=300,
|
||||
)
|
||||
|
||||
service.logger.info.assert_called_with(
|
||||
'Successfully checked out branch feature-branch'
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_clone_or_init_git_repo_both_operations_fail(service):
|
||||
"""Test both git clone and checkout operations fail."""
|
||||
with patch('subprocess.run') as mock_run:
|
||||
# Mock both operations failing
|
||||
mock_run.side_effect = [
|
||||
MagicMock(returncode=1, stderr='Clone failed', stdout=''), # Clone fails
|
||||
MagicMock(
|
||||
returncode=1, stderr='Checkout failed', stdout=''
|
||||
), # Checkout fails
|
||||
]
|
||||
|
||||
result = await service.clone_or_init_git_repo(
|
||||
workspace_path='/tmp/test_repo',
|
||||
repo_url='https://github.com/test/repo.git',
|
||||
branch='nonexistent-branch',
|
||||
timeout=300,
|
||||
)
|
||||
|
||||
assert result is False
|
||||
assert mock_run.call_count == 2
|
||||
service.logger.error.assert_called_with(
|
||||
'Failed to checkout branch nonexistent-branch: Checkout failed'
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_clone_or_init_git_repo_timeout(service):
|
||||
"""Test git operation timeout."""
|
||||
with patch('subprocess.run') as mock_run:
|
||||
# Mock timeout exception
|
||||
mock_run.side_effect = subprocess.TimeoutExpired(
|
||||
cmd=['git', 'clone'], timeout=300
|
||||
)
|
||||
|
||||
result = await service.clone_or_init_git_repo(
|
||||
workspace_path='/tmp/test_repo',
|
||||
repo_url='https://github.com/test/repo.git',
|
||||
branch='main',
|
||||
timeout=300,
|
||||
)
|
||||
|
||||
assert result is False
|
||||
service.logger.error.assert_called_with(
|
||||
'Git operation timed out after 300 seconds'
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_clone_or_init_git_repo_exception(service):
|
||||
"""Test git operation with unexpected exception."""
|
||||
with patch('subprocess.run') as mock_run:
|
||||
# Mock unexpected exception
|
||||
mock_run.side_effect = Exception('Unexpected error')
|
||||
|
||||
result = await service.clone_or_init_git_repo(
|
||||
workspace_path='/tmp/test_repo',
|
||||
repo_url='https://github.com/test/repo.git',
|
||||
branch='main',
|
||||
timeout=300,
|
||||
)
|
||||
|
||||
assert result is False
|
||||
service.logger.error.assert_called_with(
|
||||
'Git operation failed: Unexpected error'
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_clone_or_init_git_repo_custom_timeout(service):
|
||||
"""Test git operation with custom timeout."""
|
||||
with patch('subprocess.run') as mock_run:
|
||||
# Mock successful clone with custom timeout
|
||||
mock_run.return_value = MagicMock(returncode=0, stderr='', stdout='Cloning...')
|
||||
|
||||
result = await service.clone_or_init_git_repo(
|
||||
workspace_path='/tmp/test_repo',
|
||||
repo_url='https://github.com/test/repo.git',
|
||||
branch='main',
|
||||
timeout=600, # Custom timeout
|
||||
)
|
||||
|
||||
assert result is True
|
||||
mock_run.assert_called_once_with(
|
||||
[
|
||||
'git',
|
||||
'clone',
|
||||
'--branch',
|
||||
'main',
|
||||
'https://github.com/test/repo.git',
|
||||
'/tmp/test_repo',
|
||||
],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=600, # Verify custom timeout is used
|
||||
)
|
||||
|
||||
|
||||
@patch(
|
||||
'openhands.app_server.app_conversation.app_conversation_service_base.LLMSummarizingCondenser'
|
||||
)
|
||||
def test_create_condenser_default_agent_with_none_max_size(mock_condenser_class):
|
||||
"""Test _create_condenser for DEFAULT agent with condenser_max_size = None uses default."""
|
||||
# Arrange
|
||||
mock_user_context = Mock(spec=UserContext)
|
||||
with patch.object(
|
||||
AppConversationServiceBase,
|
||||
'__abstractmethods__',
|
||||
set(),
|
||||
):
|
||||
service = AppConversationServiceBase(
|
||||
init_git_in_empty_workspace=True,
|
||||
user_context=mock_user_context,
|
||||
)
|
||||
mock_llm = MagicMock()
|
||||
mock_llm_copy = MagicMock()
|
||||
mock_llm_copy.usage_id = 'condenser'
|
||||
mock_llm.model_copy.return_value = mock_llm_copy
|
||||
mock_condenser_instance = MagicMock()
|
||||
mock_condenser_class.return_value = mock_condenser_instance
|
||||
|
||||
# Act
|
||||
service._create_condenser(mock_llm, AgentType.DEFAULT, None)
|
||||
|
||||
# Assert
|
||||
mock_condenser_class.assert_called_once()
|
||||
call_kwargs = mock_condenser_class.call_args[1]
|
||||
# When condenser_max_size is None, max_size should not be passed (uses SDK default of 120)
|
||||
assert 'max_size' not in call_kwargs
|
||||
# keep_first is never passed (uses SDK default of 4)
|
||||
assert 'keep_first' not in call_kwargs
|
||||
assert call_kwargs['llm'].usage_id == 'condenser'
|
||||
mock_llm.model_copy.assert_called_once()
|
||||
|
||||
|
||||
@patch(
|
||||
'openhands.app_server.app_conversation.app_conversation_service_base.LLMSummarizingCondenser'
|
||||
)
|
||||
def test_create_condenser_default_agent_with_custom_max_size(mock_condenser_class):
|
||||
"""Test _create_condenser for DEFAULT agent with custom condenser_max_size."""
|
||||
# Arrange
|
||||
mock_user_context = Mock(spec=UserContext)
|
||||
with patch.object(
|
||||
AppConversationServiceBase,
|
||||
'__abstractmethods__',
|
||||
set(),
|
||||
):
|
||||
service = AppConversationServiceBase(
|
||||
init_git_in_empty_workspace=True,
|
||||
user_context=mock_user_context,
|
||||
)
|
||||
mock_llm = MagicMock()
|
||||
mock_llm_copy = MagicMock()
|
||||
mock_llm_copy.usage_id = 'condenser'
|
||||
mock_llm.model_copy.return_value = mock_llm_copy
|
||||
mock_condenser_instance = MagicMock()
|
||||
mock_condenser_class.return_value = mock_condenser_instance
|
||||
|
||||
# Act
|
||||
service._create_condenser(mock_llm, AgentType.DEFAULT, 150)
|
||||
|
||||
# Assert
|
||||
mock_condenser_class.assert_called_once()
|
||||
call_kwargs = mock_condenser_class.call_args[1]
|
||||
assert call_kwargs['max_size'] == 150 # Custom value should be used
|
||||
# keep_first is never passed (uses SDK default of 4)
|
||||
assert 'keep_first' not in call_kwargs
|
||||
assert call_kwargs['llm'].usage_id == 'condenser'
|
||||
mock_llm.model_copy.assert_called_once()
|
||||
|
||||
|
||||
@patch(
|
||||
'openhands.app_server.app_conversation.app_conversation_service_base.LLMSummarizingCondenser'
|
||||
)
|
||||
def test_create_condenser_plan_agent_with_none_max_size(mock_condenser_class):
|
||||
"""Test _create_condenser for PLAN agent with condenser_max_size = None uses default."""
|
||||
# Arrange
|
||||
mock_user_context = Mock(spec=UserContext)
|
||||
with patch.object(
|
||||
AppConversationServiceBase,
|
||||
'__abstractmethods__',
|
||||
set(),
|
||||
):
|
||||
service = AppConversationServiceBase(
|
||||
init_git_in_empty_workspace=True,
|
||||
user_context=mock_user_context,
|
||||
)
|
||||
mock_llm = MagicMock()
|
||||
mock_llm_copy = MagicMock()
|
||||
mock_llm_copy.usage_id = 'planning_condenser'
|
||||
mock_llm.model_copy.return_value = mock_llm_copy
|
||||
mock_condenser_instance = MagicMock()
|
||||
mock_condenser_class.return_value = mock_condenser_instance
|
||||
|
||||
# Act
|
||||
service._create_condenser(mock_llm, AgentType.PLAN, None)
|
||||
|
||||
# Assert
|
||||
mock_condenser_class.assert_called_once()
|
||||
call_kwargs = mock_condenser_class.call_args[1]
|
||||
# When condenser_max_size is None, max_size should not be passed (uses SDK default of 120)
|
||||
assert 'max_size' not in call_kwargs
|
||||
# keep_first is never passed (uses SDK default of 4)
|
||||
assert 'keep_first' not in call_kwargs
|
||||
assert call_kwargs['llm'].usage_id == 'planning_condenser'
|
||||
mock_llm.model_copy.assert_called_once()
|
||||
|
||||
|
||||
@patch(
|
||||
'openhands.app_server.app_conversation.app_conversation_service_base.LLMSummarizingCondenser'
|
||||
)
|
||||
def test_create_condenser_plan_agent_with_custom_max_size(mock_condenser_class):
|
||||
"""Test _create_condenser for PLAN agent with custom condenser_max_size."""
|
||||
# Arrange
|
||||
mock_user_context = Mock(spec=UserContext)
|
||||
with patch.object(
|
||||
AppConversationServiceBase,
|
||||
'__abstractmethods__',
|
||||
set(),
|
||||
):
|
||||
service = AppConversationServiceBase(
|
||||
init_git_in_empty_workspace=True,
|
||||
user_context=mock_user_context,
|
||||
)
|
||||
mock_llm = MagicMock()
|
||||
mock_llm_copy = MagicMock()
|
||||
mock_llm_copy.usage_id = 'planning_condenser'
|
||||
mock_llm.model_copy.return_value = mock_llm_copy
|
||||
mock_condenser_instance = MagicMock()
|
||||
mock_condenser_class.return_value = mock_condenser_instance
|
||||
|
||||
# Act
|
||||
service._create_condenser(mock_llm, AgentType.PLAN, 200)
|
||||
|
||||
# Assert
|
||||
mock_condenser_class.assert_called_once()
|
||||
call_kwargs = mock_condenser_class.call_args[1]
|
||||
assert call_kwargs['max_size'] == 200 # Custom value should be used
|
||||
# keep_first is never passed (uses SDK default of 4)
|
||||
assert 'keep_first' not in call_kwargs
|
||||
assert call_kwargs['llm'].usage_id == 'planning_condenser'
|
||||
mock_llm.model_copy.assert_called_once()
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Tests for _configure_git_user_settings
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def _create_service_with_mock_user_context(user_info: MockUserInfo) -> tuple:
|
||||
"""Create a mock service with the actual _configure_git_user_settings method.
|
||||
|
||||
Uses MagicMock for the service but binds the real method for testing.
|
||||
|
||||
Returns a tuple of (service, mock_user_context) for testing.
|
||||
"""
|
||||
mock_user_context = MagicMock()
|
||||
mock_user_context.get_user_info = AsyncMock(return_value=user_info)
|
||||
|
||||
# Create a simple mock service and set required attribute
|
||||
service = MagicMock()
|
||||
service.user_context = mock_user_context
|
||||
|
||||
# Bind the actual method from the real class to test real implementation
|
||||
service._configure_git_user_settings = (
|
||||
lambda workspace: AppConversationServiceBase._configure_git_user_settings(
|
||||
service, workspace
|
||||
)
|
||||
)
|
||||
|
||||
return service, mock_user_context
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_workspace():
|
||||
"""Create a mock workspace instance for testing."""
|
||||
return MockWorkspace(working_dir='/workspace/project')
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_configure_git_user_settings_both_name_and_email(mock_workspace):
|
||||
"""Test configuring both git user name and email."""
|
||||
user_info = MockUserInfo(
|
||||
git_user_name='Test User', git_user_email='test@example.com'
|
||||
)
|
||||
service, mock_user_context = _create_service_with_mock_user_context(user_info)
|
||||
|
||||
await service._configure_git_user_settings(mock_workspace)
|
||||
|
||||
# Verify get_user_info was called
|
||||
mock_user_context.get_user_info.assert_called_once()
|
||||
|
||||
# Verify both git config commands were executed
|
||||
assert mock_workspace.execute_command.call_count == 2
|
||||
|
||||
# Check git config user.name call
|
||||
mock_workspace.execute_command.assert_any_call(
|
||||
'git config --global user.name "Test User"', '/workspace/project'
|
||||
)
|
||||
|
||||
# Check git config user.email call
|
||||
mock_workspace.execute_command.assert_any_call(
|
||||
'git config --global user.email "test@example.com"', '/workspace/project'
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_configure_git_user_settings_only_name(mock_workspace):
|
||||
"""Test configuring only git user name."""
|
||||
user_info = MockUserInfo(git_user_name='Test User', git_user_email=None)
|
||||
service, _ = _create_service_with_mock_user_context(user_info)
|
||||
|
||||
await service._configure_git_user_settings(mock_workspace)
|
||||
|
||||
# Verify only user.name was configured
|
||||
assert mock_workspace.execute_command.call_count == 1
|
||||
mock_workspace.execute_command.assert_called_once_with(
|
||||
'git config --global user.name "Test User"', '/workspace/project'
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_configure_git_user_settings_only_email(mock_workspace):
|
||||
"""Test configuring only git user email."""
|
||||
user_info = MockUserInfo(git_user_name=None, git_user_email='test@example.com')
|
||||
service, _ = _create_service_with_mock_user_context(user_info)
|
||||
|
||||
await service._configure_git_user_settings(mock_workspace)
|
||||
|
||||
# Verify only user.email was configured
|
||||
assert mock_workspace.execute_command.call_count == 1
|
||||
mock_workspace.execute_command.assert_called_once_with(
|
||||
'git config --global user.email "test@example.com"', '/workspace/project'
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_configure_git_user_settings_neither_set(mock_workspace):
|
||||
"""Test when neither git user name nor email is set."""
|
||||
user_info = MockUserInfo(git_user_name=None, git_user_email=None)
|
||||
service, _ = _create_service_with_mock_user_context(user_info)
|
||||
|
||||
await service._configure_git_user_settings(mock_workspace)
|
||||
|
||||
# Verify no git config commands were executed
|
||||
mock_workspace.execute_command.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_configure_git_user_settings_empty_strings(mock_workspace):
|
||||
"""Test when git user name and email are empty strings."""
|
||||
user_info = MockUserInfo(git_user_name='', git_user_email='')
|
||||
service, _ = _create_service_with_mock_user_context(user_info)
|
||||
|
||||
await service._configure_git_user_settings(mock_workspace)
|
||||
|
||||
# Empty strings are falsy, so no commands should be executed
|
||||
mock_workspace.execute_command.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_configure_git_user_settings_get_user_info_fails(mock_workspace):
|
||||
"""Test handling of exception when get_user_info fails."""
|
||||
user_info = MockUserInfo()
|
||||
service, mock_user_context = _create_service_with_mock_user_context(user_info)
|
||||
mock_user_context.get_user_info = AsyncMock(
|
||||
side_effect=Exception('User info error')
|
||||
)
|
||||
|
||||
# Should not raise exception, just log warning
|
||||
await service._configure_git_user_settings(mock_workspace)
|
||||
|
||||
# Verify no git config commands were executed
|
||||
mock_workspace.execute_command.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_configure_git_user_settings_name_command_fails(mock_workspace):
|
||||
"""Test handling when git config user.name command fails."""
|
||||
user_info = MockUserInfo(
|
||||
git_user_name='Test User', git_user_email='test@example.com'
|
||||
)
|
||||
service, _ = _create_service_with_mock_user_context(user_info)
|
||||
|
||||
# Make the first command fail (user.name), second succeed (user.email)
|
||||
mock_workspace.execute_command = AsyncMock(
|
||||
side_effect=[
|
||||
MockCommandResult(exit_code=1, stderr='Permission denied'),
|
||||
MockCommandResult(exit_code=0),
|
||||
]
|
||||
)
|
||||
|
||||
# Should not raise exception
|
||||
await service._configure_git_user_settings(mock_workspace)
|
||||
|
||||
# Verify both commands were still attempted
|
||||
assert mock_workspace.execute_command.call_count == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_configure_git_user_settings_email_command_fails(mock_workspace):
|
||||
"""Test handling when git config user.email command fails."""
|
||||
user_info = MockUserInfo(
|
||||
git_user_name='Test User', git_user_email='test@example.com'
|
||||
)
|
||||
service, _ = _create_service_with_mock_user_context(user_info)
|
||||
|
||||
# Make the first command succeed (user.name), second fail (user.email)
|
||||
mock_workspace.execute_command = AsyncMock(
|
||||
side_effect=[
|
||||
MockCommandResult(exit_code=0),
|
||||
MockCommandResult(exit_code=1, stderr='Permission denied'),
|
||||
]
|
||||
)
|
||||
|
||||
# Should not raise exception
|
||||
await service._configure_git_user_settings(mock_workspace)
|
||||
|
||||
# Verify both commands were still attempted
|
||||
assert mock_workspace.execute_command.call_count == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_configure_git_user_settings_special_characters_in_name(mock_workspace):
|
||||
"""Test git user name with special characters."""
|
||||
user_info = MockUserInfo(
|
||||
git_user_name="Test O'Brien", git_user_email='test@example.com'
|
||||
)
|
||||
service, _ = _create_service_with_mock_user_context(user_info)
|
||||
|
||||
await service._configure_git_user_settings(mock_workspace)
|
||||
|
||||
# Verify the name is passed with special characters
|
||||
mock_workspace.execute_command.assert_any_call(
|
||||
'git config --global user.name "Test O\'Brien"', '/workspace/project'
|
||||
)
|
||||
@@ -63,6 +63,7 @@ class TestLiveStatusAppConversationService:
|
||||
self.mock_user.llm_api_key = 'test_api_key'
|
||||
self.mock_user.confirmation_mode = False
|
||||
self.mock_user.search_api_key = None # Default to None
|
||||
self.mock_user.condenser_max_size = None # Default to None
|
||||
|
||||
# Mock sandbox
|
||||
self.mock_sandbox = Mock(spec=SandboxInfo)
|
||||
@@ -421,20 +422,21 @@ class TestLiveStatusAppConversationService:
|
||||
'openhands.app_server.app_conversation.live_status_app_conversation_service.get_planning_tools'
|
||||
)
|
||||
@patch(
|
||||
'openhands.app_server.app_conversation.live_status_app_conversation_service.get_planning_condenser'
|
||||
'openhands.app_server.app_conversation.app_conversation_service_base.AppConversationServiceBase._create_condenser'
|
||||
)
|
||||
@patch(
|
||||
'openhands.app_server.app_conversation.live_status_app_conversation_service.format_plan_structure'
|
||||
)
|
||||
def test_create_agent_with_context_planning_agent(
|
||||
self, mock_format_plan, mock_get_condenser, mock_get_tools
|
||||
self, mock_format_plan, mock_create_condenser, mock_get_tools
|
||||
):
|
||||
"""Test _create_agent_with_context for planning agent type."""
|
||||
# Arrange
|
||||
mock_llm = Mock(spec=LLM)
|
||||
mock_llm.model_copy.return_value = mock_llm
|
||||
mock_get_tools.return_value = []
|
||||
mock_get_condenser.return_value = Mock()
|
||||
mock_condenser = Mock()
|
||||
mock_create_condenser.return_value = mock_condenser
|
||||
mock_format_plan.return_value = 'test_plan_structure'
|
||||
mcp_config = {'default': {'url': 'test'}}
|
||||
system_message_suffix = 'Test suffix'
|
||||
@@ -448,7 +450,11 @@ class TestLiveStatusAppConversationService:
|
||||
mock_agent_class.return_value = mock_agent_instance
|
||||
|
||||
self.service._create_agent_with_context(
|
||||
mock_llm, AgentType.PLAN, system_message_suffix, mcp_config
|
||||
mock_llm,
|
||||
AgentType.PLAN,
|
||||
system_message_suffix,
|
||||
mcp_config,
|
||||
self.mock_user.condenser_max_size,
|
||||
)
|
||||
|
||||
# Assert
|
||||
@@ -462,22 +468,27 @@ class TestLiveStatusAppConversationService:
|
||||
)
|
||||
assert call_kwargs['mcp_config'] == mcp_config
|
||||
assert call_kwargs['security_analyzer'] is None
|
||||
assert call_kwargs['condenser'] == mock_condenser
|
||||
mock_create_condenser.assert_called_once_with(
|
||||
mock_llm, AgentType.PLAN, self.mock_user.condenser_max_size
|
||||
)
|
||||
|
||||
@patch(
|
||||
'openhands.app_server.app_conversation.live_status_app_conversation_service.get_default_tools'
|
||||
)
|
||||
@patch(
|
||||
'openhands.app_server.app_conversation.live_status_app_conversation_service.get_default_condenser'
|
||||
'openhands.app_server.app_conversation.app_conversation_service_base.AppConversationServiceBase._create_condenser'
|
||||
)
|
||||
def test_create_agent_with_context_default_agent(
|
||||
self, mock_get_condenser, mock_get_tools
|
||||
self, mock_create_condenser, mock_get_tools
|
||||
):
|
||||
"""Test _create_agent_with_context for default agent type."""
|
||||
# Arrange
|
||||
mock_llm = Mock(spec=LLM)
|
||||
mock_llm.model_copy.return_value = mock_llm
|
||||
mock_get_tools.return_value = []
|
||||
mock_get_condenser.return_value = Mock()
|
||||
mock_condenser = Mock()
|
||||
mock_create_condenser.return_value = mock_condenser
|
||||
mcp_config = {'default': {'url': 'test'}}
|
||||
|
||||
# Act
|
||||
@@ -489,7 +500,11 @@ class TestLiveStatusAppConversationService:
|
||||
mock_agent_class.return_value = mock_agent_instance
|
||||
|
||||
self.service._create_agent_with_context(
|
||||
mock_llm, AgentType.DEFAULT, None, mcp_config
|
||||
mock_llm,
|
||||
AgentType.DEFAULT,
|
||||
None,
|
||||
mcp_config,
|
||||
self.mock_user.condenser_max_size,
|
||||
)
|
||||
|
||||
# Assert
|
||||
@@ -498,7 +513,11 @@ class TestLiveStatusAppConversationService:
|
||||
assert call_kwargs['llm'] == mock_llm
|
||||
assert call_kwargs['system_prompt_kwargs']['cli_mode'] is False
|
||||
assert call_kwargs['mcp_config'] == mcp_config
|
||||
assert call_kwargs['condenser'] == mock_condenser
|
||||
mock_get_tools.assert_called_once_with(enable_browser=True)
|
||||
mock_create_condenser.assert_called_once_with(
|
||||
mock_llm, AgentType.DEFAULT, self.mock_user.condenser_max_size
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch(
|
||||
@@ -693,6 +712,10 @@ class TestLiveStatusAppConversationService:
|
||||
self.mock_user, 'gpt-4'
|
||||
)
|
||||
self.service._create_agent_with_context.assert_called_once_with(
|
||||
mock_llm, AgentType.DEFAULT, 'Test suffix', mock_mcp_config
|
||||
mock_llm,
|
||||
AgentType.DEFAULT,
|
||||
'Test suffix',
|
||||
mock_mcp_config,
|
||||
self.mock_user.condenser_max_size,
|
||||
)
|
||||
self.service._finalize_conversation_request.assert_called_once()
|
||||
|
||||
615
tests/unit/app_server/test_webhook_router_stats.py
Normal file
615
tests/unit/app_server/test_webhook_router_stats.py
Normal file
@@ -0,0 +1,615 @@
|
||||
"""Tests for stats event processing in webhook_router.
|
||||
|
||||
This module tests the stats event processing functionality introduced for
|
||||
updating conversation statistics from ConversationStateUpdateEvent events.
|
||||
"""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import AsyncGenerator
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.pool import StaticPool
|
||||
|
||||
from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
AppConversationInfo,
|
||||
)
|
||||
from openhands.app_server.app_conversation.sql_app_conversation_info_service import (
|
||||
SQLAppConversationInfoService,
|
||||
StoredConversationMetadata,
|
||||
)
|
||||
from openhands.app_server.user.specifiy_user_context import SpecifyUserContext
|
||||
from openhands.app_server.utils.sql_utils import Base
|
||||
from openhands.sdk.conversation.conversation_stats import ConversationStats
|
||||
from openhands.sdk.event import ConversationStateUpdateEvent
|
||||
from openhands.sdk.llm.utils.metrics import Metrics, TokenUsage
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def async_engine():
|
||||
"""Create an async SQLite engine for testing."""
|
||||
engine = create_async_engine(
|
||||
'sqlite+aiosqlite:///:memory:',
|
||||
poolclass=StaticPool,
|
||||
connect_args={'check_same_thread': False},
|
||||
echo=False,
|
||||
)
|
||||
|
||||
# Create all tables
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
yield engine
|
||||
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def async_session(async_engine) -> AsyncGenerator[AsyncSession, None]:
|
||||
"""Create an async session for testing."""
|
||||
async_session_maker = async_sessionmaker(
|
||||
async_engine, class_=AsyncSession, expire_on_commit=False
|
||||
)
|
||||
|
||||
async with async_session_maker() as db_session:
|
||||
yield db_session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def service(async_session) -> SQLAppConversationInfoService:
|
||||
"""Create a SQLAppConversationInfoService instance for testing."""
|
||||
return SQLAppConversationInfoService(
|
||||
db_session=async_session, user_context=SpecifyUserContext(user_id=None)
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def v1_conversation_metadata(async_session, service):
|
||||
"""Create a V1 conversation metadata record for testing."""
|
||||
conversation_id = uuid4()
|
||||
stored = StoredConversationMetadata(
|
||||
conversation_id=str(conversation_id),
|
||||
user_id='test_user_123',
|
||||
sandbox_id='sandbox_123',
|
||||
conversation_version='V1',
|
||||
title='Test Conversation',
|
||||
accumulated_cost=0.0,
|
||||
prompt_tokens=0,
|
||||
completion_tokens=0,
|
||||
cache_read_tokens=0,
|
||||
cache_write_tokens=0,
|
||||
reasoning_tokens=0,
|
||||
context_window=0,
|
||||
per_turn_token=0,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
last_updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
async_session.add(stored)
|
||||
await async_session.commit()
|
||||
return conversation_id, stored
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def stats_event_with_dict_value():
|
||||
"""Create a ConversationStateUpdateEvent with dict value."""
|
||||
event_value = {
|
||||
'usage_to_metrics': {
|
||||
'agent': {
|
||||
'accumulated_cost': 0.03411525,
|
||||
'max_budget_per_task': None,
|
||||
'accumulated_token_usage': {
|
||||
'prompt_tokens': 8770,
|
||||
'completion_tokens': 82,
|
||||
'cache_read_tokens': 0,
|
||||
'cache_write_tokens': 8767,
|
||||
'reasoning_tokens': 0,
|
||||
'context_window': 0,
|
||||
'per_turn_token': 8852,
|
||||
},
|
||||
},
|
||||
'condenser': {
|
||||
'accumulated_cost': 0.0,
|
||||
'accumulated_token_usage': {
|
||||
'prompt_tokens': 0,
|
||||
'completion_tokens': 0,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
return ConversationStateUpdateEvent(key='stats', value=event_value)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def stats_event_with_object_value():
|
||||
"""Create a ConversationStateUpdateEvent with object value."""
|
||||
event_value = MagicMock()
|
||||
event_value.usage_to_metrics = {
|
||||
'agent': {
|
||||
'accumulated_cost': 0.05,
|
||||
'accumulated_token_usage': {
|
||||
'prompt_tokens': 1000,
|
||||
'completion_tokens': 100,
|
||||
},
|
||||
}
|
||||
}
|
||||
return ConversationStateUpdateEvent(key='stats', value=event_value)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def stats_event_no_usage_to_metrics():
|
||||
"""Create a ConversationStateUpdateEvent without usage_to_metrics."""
|
||||
event_value = {'some_other_key': 'value'}
|
||||
return ConversationStateUpdateEvent(key='stats', value=event_value)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests for update_conversation_statistics
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestUpdateConversationStatistics:
|
||||
"""Test the update_conversation_statistics method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_statistics_success(
|
||||
self, service, async_session, v1_conversation_metadata
|
||||
):
|
||||
"""Test successfully updating conversation statistics."""
|
||||
conversation_id, stored = v1_conversation_metadata
|
||||
|
||||
agent_metrics = Metrics(
|
||||
model_name='test-model',
|
||||
accumulated_cost=0.03411525,
|
||||
max_budget_per_task=10.0,
|
||||
accumulated_token_usage=TokenUsage(
|
||||
model='test-model',
|
||||
prompt_tokens=8770,
|
||||
completion_tokens=82,
|
||||
cache_read_tokens=0,
|
||||
cache_write_tokens=8767,
|
||||
reasoning_tokens=0,
|
||||
context_window=0,
|
||||
per_turn_token=8852,
|
||||
),
|
||||
)
|
||||
stats = ConversationStats(usage_to_metrics={'agent': agent_metrics})
|
||||
|
||||
await service.update_conversation_statistics(conversation_id, stats)
|
||||
|
||||
# Verify the update
|
||||
await async_session.refresh(stored)
|
||||
assert stored.accumulated_cost == 0.03411525
|
||||
assert stored.max_budget_per_task == 10.0
|
||||
assert stored.prompt_tokens == 8770
|
||||
assert stored.completion_tokens == 82
|
||||
assert stored.cache_read_tokens == 0
|
||||
assert stored.cache_write_tokens == 8767
|
||||
assert stored.reasoning_tokens == 0
|
||||
assert stored.context_window == 0
|
||||
assert stored.per_turn_token == 8852
|
||||
assert stored.last_updated_at is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_statistics_partial_update(
|
||||
self, service, async_session, v1_conversation_metadata
|
||||
):
|
||||
"""Test updating only some statistics fields."""
|
||||
conversation_id, stored = v1_conversation_metadata
|
||||
|
||||
# Set initial values
|
||||
stored.accumulated_cost = 0.01
|
||||
stored.prompt_tokens = 100
|
||||
await async_session.commit()
|
||||
|
||||
agent_metrics = Metrics(
|
||||
model_name='test-model',
|
||||
accumulated_cost=0.05,
|
||||
accumulated_token_usage=TokenUsage(
|
||||
model='test-model',
|
||||
prompt_tokens=200,
|
||||
completion_tokens=0, # Default value
|
||||
),
|
||||
)
|
||||
stats = ConversationStats(usage_to_metrics={'agent': agent_metrics})
|
||||
|
||||
await service.update_conversation_statistics(conversation_id, stats)
|
||||
|
||||
# Verify updated fields
|
||||
await async_session.refresh(stored)
|
||||
assert stored.accumulated_cost == 0.05
|
||||
assert stored.prompt_tokens == 200
|
||||
# completion_tokens should remain unchanged (not None in stats)
|
||||
assert stored.completion_tokens == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_statistics_no_agent_metrics(
|
||||
self, service, v1_conversation_metadata
|
||||
):
|
||||
"""Test that update is skipped when no agent metrics are present."""
|
||||
conversation_id, stored = v1_conversation_metadata
|
||||
original_cost = stored.accumulated_cost
|
||||
|
||||
condenser_metrics = Metrics(
|
||||
model_name='test-model',
|
||||
accumulated_cost=0.1,
|
||||
)
|
||||
stats = ConversationStats(usage_to_metrics={'condenser': condenser_metrics})
|
||||
|
||||
await service.update_conversation_statistics(conversation_id, stats)
|
||||
|
||||
# Verify no update occurred
|
||||
assert stored.accumulated_cost == original_cost
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_statistics_conversation_not_found(self, service):
|
||||
"""Test that update is skipped when conversation doesn't exist."""
|
||||
nonexistent_id = uuid4()
|
||||
agent_metrics = Metrics(
|
||||
model_name='test-model',
|
||||
accumulated_cost=0.1,
|
||||
)
|
||||
stats = ConversationStats(usage_to_metrics={'agent': agent_metrics})
|
||||
|
||||
# Should not raise an exception
|
||||
await service.update_conversation_statistics(nonexistent_id, stats)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_statistics_v0_conversation_skipped(
|
||||
self, service, async_session
|
||||
):
|
||||
"""Test that V0 conversations are skipped."""
|
||||
conversation_id = uuid4()
|
||||
stored = StoredConversationMetadata(
|
||||
conversation_id=str(conversation_id),
|
||||
user_id='test_user_123',
|
||||
sandbox_id='sandbox_123',
|
||||
conversation_version='V0', # V0 conversation
|
||||
title='V0 Conversation',
|
||||
accumulated_cost=0.0,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
last_updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
async_session.add(stored)
|
||||
await async_session.commit()
|
||||
|
||||
original_cost = stored.accumulated_cost
|
||||
|
||||
agent_metrics = Metrics(
|
||||
model_name='test-model',
|
||||
accumulated_cost=0.1,
|
||||
)
|
||||
stats = ConversationStats(usage_to_metrics={'agent': agent_metrics})
|
||||
|
||||
await service.update_conversation_statistics(conversation_id, stats)
|
||||
|
||||
# Verify no update occurred
|
||||
await async_session.refresh(stored)
|
||||
assert stored.accumulated_cost == original_cost
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_statistics_with_none_values(
|
||||
self, service, async_session, v1_conversation_metadata
|
||||
):
|
||||
"""Test that None values in stats don't overwrite existing values."""
|
||||
conversation_id, stored = v1_conversation_metadata
|
||||
|
||||
# Set initial values
|
||||
stored.accumulated_cost = 0.01
|
||||
stored.max_budget_per_task = 5.0
|
||||
stored.prompt_tokens = 100
|
||||
await async_session.commit()
|
||||
|
||||
agent_metrics = Metrics(
|
||||
model_name='test-model',
|
||||
accumulated_cost=0.05,
|
||||
max_budget_per_task=None, # None value
|
||||
accumulated_token_usage=TokenUsage(
|
||||
model='test-model',
|
||||
prompt_tokens=200,
|
||||
completion_tokens=0, # Default value (None is not valid for int)
|
||||
),
|
||||
)
|
||||
stats = ConversationStats(usage_to_metrics={'agent': agent_metrics})
|
||||
|
||||
await service.update_conversation_statistics(conversation_id, stats)
|
||||
|
||||
# Verify updated fields and that None values didn't overwrite
|
||||
await async_session.refresh(stored)
|
||||
assert stored.accumulated_cost == 0.05
|
||||
assert stored.max_budget_per_task == 5.0 # Should remain unchanged
|
||||
assert stored.prompt_tokens == 200
|
||||
assert (
|
||||
stored.completion_tokens == 0
|
||||
) # Should remain unchanged (was 0, None doesn't update)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests for process_stats_event
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestProcessStatsEvent:
|
||||
"""Test the process_stats_event method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_stats_event_with_dict_value(
|
||||
self,
|
||||
service,
|
||||
async_session,
|
||||
stats_event_with_dict_value,
|
||||
v1_conversation_metadata,
|
||||
):
|
||||
"""Test processing stats event with dict value."""
|
||||
conversation_id, stored = v1_conversation_metadata
|
||||
|
||||
await service.process_stats_event(stats_event_with_dict_value, conversation_id)
|
||||
|
||||
# Verify the update occurred
|
||||
await async_session.refresh(stored)
|
||||
assert stored.accumulated_cost == 0.03411525
|
||||
assert stored.prompt_tokens == 8770
|
||||
assert stored.completion_tokens == 82
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_stats_event_with_object_value(
|
||||
self,
|
||||
service,
|
||||
async_session,
|
||||
stats_event_with_object_value,
|
||||
v1_conversation_metadata,
|
||||
):
|
||||
"""Test processing stats event with object value."""
|
||||
conversation_id, stored = v1_conversation_metadata
|
||||
|
||||
await service.process_stats_event(
|
||||
stats_event_with_object_value, conversation_id
|
||||
)
|
||||
|
||||
# Verify the update occurred
|
||||
await async_session.refresh(stored)
|
||||
assert stored.accumulated_cost == 0.05
|
||||
assert stored.prompt_tokens == 1000
|
||||
assert stored.completion_tokens == 100
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_stats_event_no_usage_to_metrics(
|
||||
self,
|
||||
service,
|
||||
async_session,
|
||||
stats_event_no_usage_to_metrics,
|
||||
v1_conversation_metadata,
|
||||
):
|
||||
"""Test processing stats event without usage_to_metrics."""
|
||||
conversation_id, stored = v1_conversation_metadata
|
||||
original_cost = stored.accumulated_cost
|
||||
|
||||
await service.process_stats_event(
|
||||
stats_event_no_usage_to_metrics, conversation_id
|
||||
)
|
||||
|
||||
# Verify update_conversation_statistics was NOT called
|
||||
await async_session.refresh(stored)
|
||||
assert stored.accumulated_cost == original_cost
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_stats_event_service_error_handled(
|
||||
self, service, stats_event_with_dict_value
|
||||
):
|
||||
"""Test that errors from service are caught and logged."""
|
||||
conversation_id = uuid4()
|
||||
|
||||
# Should not raise an exception
|
||||
with (
|
||||
patch.object(
|
||||
service,
|
||||
'update_conversation_statistics',
|
||||
side_effect=Exception('Database error'),
|
||||
),
|
||||
patch(
|
||||
'openhands.app_server.app_conversation.sql_app_conversation_info_service.logger'
|
||||
) as mock_logger,
|
||||
):
|
||||
await service.process_stats_event(
|
||||
stats_event_with_dict_value, conversation_id
|
||||
)
|
||||
|
||||
# Verify error was logged
|
||||
mock_logger.exception.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_stats_event_empty_usage_to_metrics(
|
||||
self, service, async_session, v1_conversation_metadata
|
||||
):
|
||||
"""Test processing stats event with empty usage_to_metrics."""
|
||||
conversation_id, stored = v1_conversation_metadata
|
||||
original_cost = stored.accumulated_cost
|
||||
|
||||
# Create event with empty usage_to_metrics
|
||||
event = ConversationStateUpdateEvent(
|
||||
key='stats', value={'usage_to_metrics': {}}
|
||||
)
|
||||
|
||||
await service.process_stats_event(event, conversation_id)
|
||||
|
||||
# Empty dict is falsy, so update_conversation_statistics should NOT be called
|
||||
await async_session.refresh(stored)
|
||||
assert stored.accumulated_cost == original_cost
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Integration tests for on_event endpoint
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestOnEventStatsProcessing:
|
||||
"""Test stats event processing in the on_event endpoint."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_event_processes_stats_events(self):
|
||||
"""Test that on_event processes stats events."""
|
||||
from openhands.app_server.event_callback.webhook_router import on_event
|
||||
from openhands.app_server.sandbox.sandbox_models import (
|
||||
SandboxInfo,
|
||||
SandboxStatus,
|
||||
)
|
||||
|
||||
conversation_id = uuid4()
|
||||
sandbox_id = 'sandbox_123'
|
||||
|
||||
# Create stats event
|
||||
stats_event = ConversationStateUpdateEvent(
|
||||
key='stats',
|
||||
value={
|
||||
'usage_to_metrics': {
|
||||
'agent': {
|
||||
'accumulated_cost': 0.1,
|
||||
'accumulated_token_usage': {
|
||||
'prompt_tokens': 1000,
|
||||
},
|
||||
}
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
# Create non-stats event
|
||||
other_event = ConversationStateUpdateEvent(
|
||||
key='execution_status', value='running'
|
||||
)
|
||||
|
||||
events = [stats_event, other_event]
|
||||
|
||||
# Mock dependencies
|
||||
mock_sandbox = SandboxInfo(
|
||||
id=sandbox_id,
|
||||
status=SandboxStatus.RUNNING,
|
||||
session_api_key='test_key',
|
||||
created_by_user_id='user_123',
|
||||
sandbox_spec_id='spec_123',
|
||||
)
|
||||
|
||||
mock_app_conversation_info = AppConversationInfo(
|
||||
id=conversation_id,
|
||||
sandbox_id=sandbox_id,
|
||||
created_by_user_id='user_123',
|
||||
)
|
||||
|
||||
mock_event_service = AsyncMock()
|
||||
mock_app_conversation_info_service = AsyncMock()
|
||||
mock_app_conversation_info_service.get_app_conversation_info.return_value = (
|
||||
mock_app_conversation_info
|
||||
)
|
||||
|
||||
# Set up process_stats_event to call update_conversation_statistics
|
||||
async def process_stats_event_side_effect(event, conversation_id):
|
||||
# Simulate what process_stats_event does - call update_conversation_statistics
|
||||
from openhands.sdk.conversation.conversation_stats import ConversationStats
|
||||
|
||||
if isinstance(event.value, dict):
|
||||
stats = ConversationStats.model_validate(event.value)
|
||||
if stats and stats.usage_to_metrics:
|
||||
await mock_app_conversation_info_service.update_conversation_statistics(
|
||||
conversation_id, stats
|
||||
)
|
||||
|
||||
mock_app_conversation_info_service.process_stats_event.side_effect = (
|
||||
process_stats_event_side_effect
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
'openhands.app_server.event_callback.webhook_router.valid_sandbox',
|
||||
return_value=mock_sandbox,
|
||||
),
|
||||
patch(
|
||||
'openhands.app_server.event_callback.webhook_router.valid_conversation',
|
||||
return_value=mock_app_conversation_info,
|
||||
),
|
||||
patch(
|
||||
'openhands.app_server.event_callback.webhook_router._run_callbacks_in_bg_and_close'
|
||||
) as mock_callbacks,
|
||||
):
|
||||
await on_event(
|
||||
events=events,
|
||||
conversation_id=conversation_id,
|
||||
sandbox_info=mock_sandbox,
|
||||
app_conversation_info_service=mock_app_conversation_info_service,
|
||||
event_service=mock_event_service,
|
||||
)
|
||||
|
||||
# Verify events were saved
|
||||
assert mock_event_service.save_event.call_count == 2
|
||||
|
||||
# Verify stats event was processed
|
||||
mock_app_conversation_info_service.update_conversation_statistics.assert_called_once()
|
||||
|
||||
# Verify callbacks were scheduled
|
||||
mock_callbacks.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_event_skips_non_stats_events(self):
|
||||
"""Test that on_event skips non-stats events."""
|
||||
from openhands.app_server.event_callback.webhook_router import on_event
|
||||
from openhands.app_server.sandbox.sandbox_models import (
|
||||
SandboxInfo,
|
||||
SandboxStatus,
|
||||
)
|
||||
from openhands.events.action.message import MessageAction
|
||||
|
||||
conversation_id = uuid4()
|
||||
sandbox_id = 'sandbox_123'
|
||||
|
||||
# Create non-stats events
|
||||
events = [
|
||||
ConversationStateUpdateEvent(key='execution_status', value='running'),
|
||||
MessageAction(content='test'),
|
||||
]
|
||||
|
||||
mock_sandbox = SandboxInfo(
|
||||
id=sandbox_id,
|
||||
status=SandboxStatus.RUNNING,
|
||||
session_api_key='test_key',
|
||||
created_by_user_id='user_123',
|
||||
sandbox_spec_id='spec_123',
|
||||
)
|
||||
|
||||
mock_app_conversation_info = AppConversationInfo(
|
||||
id=conversation_id,
|
||||
sandbox_id=sandbox_id,
|
||||
created_by_user_id='user_123',
|
||||
)
|
||||
|
||||
mock_event_service = AsyncMock()
|
||||
mock_app_conversation_info_service = AsyncMock()
|
||||
mock_app_conversation_info_service.get_app_conversation_info.return_value = (
|
||||
mock_app_conversation_info
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
'openhands.app_server.event_callback.webhook_router.valid_sandbox',
|
||||
return_value=mock_sandbox,
|
||||
),
|
||||
patch(
|
||||
'openhands.app_server.event_callback.webhook_router.valid_conversation',
|
||||
return_value=mock_app_conversation_info,
|
||||
),
|
||||
patch(
|
||||
'openhands.app_server.event_callback.webhook_router._run_callbacks_in_bg_and_close'
|
||||
),
|
||||
):
|
||||
await on_event(
|
||||
events=events,
|
||||
conversation_id=conversation_id,
|
||||
sandbox_info=mock_sandbox,
|
||||
app_conversation_info_service=mock_app_conversation_info_service,
|
||||
event_service=mock_event_service,
|
||||
)
|
||||
|
||||
# Verify stats update was NOT called
|
||||
mock_app_conversation_info_service.update_conversation_statistics.assert_not_called()
|
||||
@@ -152,6 +152,7 @@ class TestExperimentManagerIntegration:
|
||||
llm_base_url=None,
|
||||
llm_api_key=None,
|
||||
confirmation_mode=False,
|
||||
condenser_max_size=None,
|
||||
)
|
||||
|
||||
async def get_secrets(self):
|
||||
|
||||
Reference in New Issue
Block a user