mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-04-29 03:00:45 -04:00
Compare commits
11 Commits
debug-logg
...
fix-async-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4557d98cf9 | ||
|
|
f97117a1fe | ||
|
|
89228a01c3 | ||
|
|
b7b569993e | ||
|
|
102095affb | ||
|
|
b6ce45b474 | ||
|
|
11c87caba4 | ||
|
|
b8a608c45e | ||
|
|
8a446787be | ||
|
|
353124e171 | ||
|
|
e9298c89bd |
205
enterprise/downgrade_migrated_users.py
Normal file
205
enterprise/downgrade_migrated_users.py
Normal file
@@ -0,0 +1,205 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
Downgrade script for migrated users.
|
||||
|
||||
This script identifies users who have been migrated (already_migrated=True)
|
||||
and reverts them back to the pre-migration state.
|
||||
|
||||
Usage:
|
||||
# Dry run - just list the users that would be downgraded
|
||||
python downgrade_migrated_users.py --dry-run
|
||||
|
||||
# Downgrade a specific user by their keycloak_user_id
|
||||
python downgrade_migrated_users.py --user-id <user_id>
|
||||
|
||||
# Downgrade all migrated users (with confirmation)
|
||||
python downgrade_migrated_users.py --all
|
||||
|
||||
# Downgrade all migrated users without confirmation (dangerous!)
|
||||
python downgrade_migrated_users.py --all --no-confirm
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import sys
|
||||
|
||||
# Add the enterprise directory to the path
|
||||
sys.path.insert(0, '/workspace/project/OpenHands/enterprise')
|
||||
|
||||
from server.logger import logger
|
||||
from sqlalchemy import select, text
|
||||
from storage.database import session_maker
|
||||
from storage.user_settings import UserSettings
|
||||
from storage.user_store import UserStore
|
||||
|
||||
|
||||
def get_migrated_users() -> list[str]:
|
||||
"""Get list of keycloak_user_ids for users who have been migrated.
|
||||
|
||||
This includes:
|
||||
1. Users with already_migrated=True in user_settings (migrated users)
|
||||
2. Users in the 'user' table who don't have a user_settings entry (new sign-ups)
|
||||
"""
|
||||
with session_maker() as session:
|
||||
# Get users from user_settings with already_migrated=True
|
||||
migrated_result = session.execute(
|
||||
select(UserSettings.keycloak_user_id).where(
|
||||
UserSettings.already_migrated.is_(True)
|
||||
)
|
||||
)
|
||||
migrated_users = {row[0] for row in migrated_result.fetchall() if row[0]}
|
||||
|
||||
# Get users from the 'user' table (new sign-ups won't have user_settings)
|
||||
# These are users who signed up after the migration was deployed
|
||||
new_signup_result = session.execute(
|
||||
text("""
|
||||
SELECT CAST(u.id AS VARCHAR)
|
||||
FROM "user" u
|
||||
WHERE NOT EXISTS (
|
||||
SELECT 1 FROM user_settings us
|
||||
WHERE us.keycloak_user_id = CAST(u.id AS VARCHAR)
|
||||
)
|
||||
""")
|
||||
)
|
||||
new_signups = {row[0] for row in new_signup_result.fetchall() if row[0]}
|
||||
|
||||
# Combine both sets
|
||||
all_users = migrated_users | new_signups
|
||||
return list(all_users)
|
||||
|
||||
|
||||
async def downgrade_user(user_id: str) -> bool:
|
||||
"""Downgrade a single user.
|
||||
|
||||
Args:
|
||||
user_id: The keycloak_user_id to downgrade
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
result = await UserStore.downgrade_user(user_id)
|
||||
if result:
|
||||
print(f'✓ Successfully downgraded user: {user_id}')
|
||||
return True
|
||||
else:
|
||||
print(f'✗ Failed to downgrade user: {user_id}')
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f'✗ Error downgrading user {user_id}: {e}')
|
||||
logger.exception(
|
||||
'downgrade_script:error',
|
||||
extra={'user_id': user_id, 'error': str(e)},
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
async def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Downgrade migrated users back to pre-migration state'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--dry-run',
|
||||
action='store_true',
|
||||
help='Just list users that would be downgraded, without making changes',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--user-id',
|
||||
type=str,
|
||||
help='Downgrade a specific user by keycloak_user_id',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--all',
|
||||
action='store_true',
|
||||
help='Downgrade all migrated users',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--no-confirm',
|
||||
action='store_true',
|
||||
help='Skip confirmation prompt (use with caution!)',
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Get list of migrated users
|
||||
migrated_users = get_migrated_users()
|
||||
print(f'\nFound {len(migrated_users)} migrated user(s).')
|
||||
|
||||
if args.dry_run:
|
||||
print('\n--- DRY RUN MODE ---')
|
||||
print('The following users would be downgraded:')
|
||||
for user_id in migrated_users:
|
||||
print(f' - {user_id}')
|
||||
print('\nNo changes were made.')
|
||||
return
|
||||
|
||||
if args.user_id:
|
||||
# Downgrade a specific user
|
||||
if args.user_id not in migrated_users:
|
||||
print(f'\nUser {args.user_id} is not in the migrated users list.')
|
||||
print('Either the user was not migrated, or the user_id is incorrect.')
|
||||
return
|
||||
|
||||
print(f'\nDowngrading user: {args.user_id}')
|
||||
if not args.no_confirm:
|
||||
confirm = input('Are you sure? (yes/no): ')
|
||||
if confirm.lower() != 'yes':
|
||||
print('Cancelled.')
|
||||
return
|
||||
|
||||
success = await downgrade_user(args.user_id)
|
||||
if success:
|
||||
print('\nDowngrade completed successfully.')
|
||||
else:
|
||||
print('\nDowngrade failed. Check logs for details.')
|
||||
sys.exit(1)
|
||||
|
||||
elif args.all:
|
||||
# Downgrade all migrated users
|
||||
if not migrated_users:
|
||||
print('\nNo migrated users to downgrade.')
|
||||
return
|
||||
|
||||
print(f'\n⚠️ About to downgrade {len(migrated_users)} user(s).')
|
||||
if not args.no_confirm:
|
||||
print('\nThis will:')
|
||||
print(' - Revert LiteLLM team/user budget settings')
|
||||
print(' - Delete organization entries')
|
||||
print(' - Delete user entries in the new schema')
|
||||
print(' - Reset the already_migrated flag')
|
||||
print('\nUsers to downgrade:')
|
||||
for user_id in migrated_users[:10]: # Show first 10
|
||||
print(f' - {user_id}')
|
||||
if len(migrated_users) > 10:
|
||||
print(f' ... and {len(migrated_users) - 10} more')
|
||||
|
||||
confirm = input('\nType "yes" to proceed: ')
|
||||
if confirm.lower() != 'yes':
|
||||
print('Cancelled.')
|
||||
return
|
||||
|
||||
print('\nStarting downgrade...\n')
|
||||
success_count = 0
|
||||
fail_count = 0
|
||||
|
||||
for user_id in migrated_users:
|
||||
success = await downgrade_user(user_id)
|
||||
if success:
|
||||
success_count += 1
|
||||
else:
|
||||
fail_count += 1
|
||||
|
||||
print('\n--- Summary ---')
|
||||
print(f'Successful: {success_count}')
|
||||
print(f'Failed: {fail_count}')
|
||||
|
||||
if fail_count > 0:
|
||||
sys.exit(1)
|
||||
|
||||
else:
|
||||
parser.print_help()
|
||||
print('\nPlease specify --dry-run, --user-id, or --all')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
asyncio.run(main())
|
||||
24
enterprise/poetry.lock
generated
24
enterprise/poetry.lock
generated
@@ -6102,14 +6102,14 @@ llama = ["llama-index (>=0.12.29,<0.13.0)", "llama-index-core (>=0.12.29,<0.13.0
|
||||
|
||||
[[package]]
|
||||
name = "openhands-agent-server"
|
||||
version = "1.9.1"
|
||||
version = "1.10.0"
|
||||
description = "OpenHands Agent Server - REST/WebSocket interface for OpenHands AI Agent"
|
||||
optional = false
|
||||
python-versions = ">=3.12"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "openhands_agent_server-1.9.1-py3-none-any.whl", hash = "sha256:ea1457760505b9ebfe6aabea08dedd010ce93aeb93edb450f00e25a0d056a723"},
|
||||
{file = "openhands_agent_server-1.9.1.tar.gz", hash = "sha256:d92a29a9d5aa94207519a5f8daad7c0a3d6641d5cba9f763f25aa4e85713fa0f"},
|
||||
{file = "openhands_agent_server-1.10.0-py3-none-any.whl", hash = "sha256:2e21076fff5e7cf9d03a3b011e2c90a6a3a46d2da3f18db9f7553ac413229c22"},
|
||||
{file = "openhands_agent_server-1.10.0.tar.gz", hash = "sha256:2062da2496a98a6c23201d086f124e02329d6c6d9d1b47be55921c084a29f55a"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -6168,9 +6168,9 @@ memory-profiler = ">=0.61"
|
||||
numpy = "*"
|
||||
openai = "2.8"
|
||||
openhands-aci = "0.3.2"
|
||||
openhands-agent-server = "1.9.1"
|
||||
openhands-sdk = "1.9.1"
|
||||
openhands-tools = "1.9.1"
|
||||
openhands-agent-server = "1.10"
|
||||
openhands-sdk = "1.10"
|
||||
openhands-tools = "1.10"
|
||||
opentelemetry-api = ">=1.33.1"
|
||||
opentelemetry-exporter-otlp-proto-grpc = ">=1.33.1"
|
||||
pathspec = ">=0.12.1"
|
||||
@@ -6225,14 +6225,14 @@ url = ".."
|
||||
|
||||
[[package]]
|
||||
name = "openhands-sdk"
|
||||
version = "1.9.1"
|
||||
version = "1.10.0"
|
||||
description = "OpenHands SDK - Core functionality for building AI agents"
|
||||
optional = false
|
||||
python-versions = ">=3.12"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "openhands_sdk-1.9.1-py3-none-any.whl", hash = "sha256:0e732dfe0d91289536ea0410db9554d5a5b0326f60e547ea7a9d8ddab5fe93e4"},
|
||||
{file = "openhands_sdk-1.9.1.tar.gz", hash = "sha256:c6ba33f85efa4c2ec63eb1040cbe82839662bcbcf323654ed071a9ad38ce7994"},
|
||||
{file = "openhands_sdk-1.10.0-py3-none-any.whl", hash = "sha256:5c8875f2a07d7fabe3449914639572bef9003821207cb06aa237a239e964eed5"},
|
||||
{file = "openhands_sdk-1.10.0.tar.gz", hash = "sha256:93371b1af4532266ad2d225b9d7d3d711c745df31888efe643970673f62bdef9"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -6253,14 +6253,14 @@ boto3 = ["boto3 (>=1.35.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "openhands-tools"
|
||||
version = "1.9.1"
|
||||
version = "1.10.0"
|
||||
description = "OpenHands Tools - Runtime tools for AI agents"
|
||||
optional = false
|
||||
python-versions = ">=3.12"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "openhands_tools-1.9.1-py3-none-any.whl", hash = "sha256:411819657e00ffac5d5b1ba9adc6eb65a0a17cbefb5e3e1a34bb132ff61c59f2"},
|
||||
{file = "openhands_tools-1.9.1.tar.gz", hash = "sha256:331608994cce22b662038a2fed0bf7d2c1bb8dc27b1fc0a12a646e9bd76e0843"},
|
||||
{file = "openhands_tools-1.10.0-py3-none-any.whl", hash = "sha256:1d5d2d1e34cc4ceb02c0ff1f008b06883ad48a8e7236ab8dd61ece64fbf8e2ed"},
|
||||
{file = "openhands_tools-1.10.0.tar.gz", hash = "sha256:7ed38cb13545ec2c4a35c26ece725d5b35788d30597db8b1904619c043ec1194"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
|
||||
@@ -14,7 +14,6 @@ from storage.conversation_callback import (
|
||||
ConversationCallback,
|
||||
ConversationCallbackProcessor,
|
||||
)
|
||||
from storage.database import session_maker
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.core.schema.agent import AgentState
|
||||
@@ -108,13 +107,10 @@ class GithubCallbackProcessor(ConversationCallbackProcessor):
|
||||
f'[GitHub] Sent summary instruction to conversation {conversation_id} {summary_event}'
|
||||
)
|
||||
|
||||
# Update the processor state
|
||||
# Update the processor state - the outer session will commit this
|
||||
self.send_summary_instruction = False
|
||||
callback.set_processor(self)
|
||||
callback.updated_at = datetime.now()
|
||||
with session_maker() as session:
|
||||
session.merge(callback)
|
||||
session.commit()
|
||||
return
|
||||
|
||||
# Extract the summary from the event store
|
||||
@@ -130,14 +126,15 @@ class GithubCallbackProcessor(ConversationCallbackProcessor):
|
||||
|
||||
logger.info(f'[GitHub] Summary sent for conversation {conversation_id}')
|
||||
|
||||
# Mark callback as completed status
|
||||
# Mark callback as completed status - the outer session will commit this
|
||||
callback.status = CallbackStatus.COMPLETED
|
||||
callback.updated_at = datetime.now()
|
||||
with session_maker() as session:
|
||||
session.merge(callback)
|
||||
session.commit()
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f'[GitHub] Error processing conversation callback: {str(e)}'
|
||||
)
|
||||
# Mark callback as error to prevent infinite re-invocation
|
||||
# The outer session will commit this
|
||||
callback.status = CallbackStatus.ERROR
|
||||
callback.updated_at = datetime.now()
|
||||
|
||||
@@ -98,6 +98,29 @@ def decrypt_legacy_value(value: str | SecretStr) -> str:
|
||||
return get_fernet().decrypt(b64decode(value.encode())).decode()
|
||||
|
||||
|
||||
def encrypt_legacy_model(encrypt_keys: list, model_instance) -> dict:
|
||||
return encrypt_legacy_kwargs(encrypt_keys, model_to_kwargs(model_instance))
|
||||
|
||||
|
||||
def encrypt_legacy_kwargs(encrypt_keys: list, kwargs: dict) -> dict:
|
||||
for key, value in kwargs.items():
|
||||
if value is None:
|
||||
continue
|
||||
if key in encrypt_keys:
|
||||
value = encrypt_legacy_value(value)
|
||||
kwargs[key] = value
|
||||
return kwargs
|
||||
|
||||
|
||||
def encrypt_legacy_value(value: str | SecretStr) -> str:
|
||||
if isinstance(value, SecretStr):
|
||||
return b64encode(
|
||||
get_fernet().encrypt(value.get_secret_value().encode())
|
||||
).decode()
|
||||
else:
|
||||
return b64encode(get_fernet().encrypt(value.encode())).decode()
|
||||
|
||||
|
||||
def get_fernet():
|
||||
global _fernet
|
||||
if _fernet is None:
|
||||
|
||||
@@ -141,7 +141,7 @@ class LiteLlmManager:
|
||||
return None
|
||||
credits = max(max_budget - spend, 0.0)
|
||||
|
||||
logger.info(
|
||||
logger.debug(
|
||||
'LiteLlmManager:migrate_lite_llm_entries:create_team',
|
||||
extra={'org_id': org_id, 'user_id': keycloak_user_id},
|
||||
)
|
||||
@@ -149,7 +149,7 @@ class LiteLlmManager:
|
||||
client, keycloak_user_id, org_id, credits
|
||||
)
|
||||
|
||||
logger.info(
|
||||
logger.debug(
|
||||
'LiteLlmManager:migrate_lite_llm_entries:update_user',
|
||||
extra={'org_id': org_id, 'user_id': keycloak_user_id},
|
||||
)
|
||||
@@ -157,7 +157,7 @@ class LiteLlmManager:
|
||||
client, keycloak_user_id, max_budget=UNLIMITED_BUDGET_SETTING
|
||||
)
|
||||
|
||||
logger.info(
|
||||
logger.debug(
|
||||
'LiteLlmManager:migrate_lite_llm_entries:add_user_to_team',
|
||||
extra={'org_id': org_id, 'user_id': keycloak_user_id},
|
||||
)
|
||||
@@ -166,7 +166,7 @@ class LiteLlmManager:
|
||||
)
|
||||
|
||||
if user_settings.llm_api_key:
|
||||
logger.info(
|
||||
logger.debug(
|
||||
'LiteLlmManager:migrate_lite_llm_entries:update_key',
|
||||
extra={'org_id': org_id, 'user_id': keycloak_user_id},
|
||||
)
|
||||
@@ -178,7 +178,7 @@ class LiteLlmManager:
|
||||
)
|
||||
|
||||
if user_settings.llm_api_key_for_byor:
|
||||
logger.info(
|
||||
logger.debug(
|
||||
'LiteLlmManager:migrate_lite_llm_entries:update_byor_key',
|
||||
extra={'org_id': org_id, 'user_id': keycloak_user_id},
|
||||
)
|
||||
@@ -190,7 +190,164 @@ class LiteLlmManager:
|
||||
)
|
||||
|
||||
logger.info(
|
||||
'LiteLlmManager:migrate_lite_llm_entries:end',
|
||||
'LiteLlmManager:migrate_lite_llm_entries:complete',
|
||||
extra={'org_id': org_id, 'user_id': keycloak_user_id},
|
||||
)
|
||||
return user_settings
|
||||
|
||||
@staticmethod
|
||||
async def downgrade_entries(
|
||||
org_id: str,
|
||||
keycloak_user_id: str,
|
||||
user_settings: UserSettings,
|
||||
) -> UserSettings | None:
|
||||
"""Downgrade a migrated user's LiteLLM entries back to the pre-migration state.
|
||||
|
||||
This reverses the migrate_entries operation:
|
||||
1. Get the user max budget from their org team in litellm
|
||||
2. Set the max budget in the user in litellm (restore from team)
|
||||
3. Add the user back to the default team in litellm
|
||||
4. Update keys to remove org team association
|
||||
5. Remove the user from their org team in litellm
|
||||
6. Delete the user org team in litellm
|
||||
|
||||
Note: The database changes (already_migrated flag, org/org_member deletion)
|
||||
should be handled separately by the caller.
|
||||
|
||||
Args:
|
||||
org_id: The organization ID (which is also the team_id in litellm)
|
||||
keycloak_user_id: The user's Keycloak ID
|
||||
user_settings: The user's settings object
|
||||
|
||||
Returns:
|
||||
The user_settings if downgrade was successful, None otherwise
|
||||
"""
|
||||
logger.info(
|
||||
'LiteLlmManager:downgrade_entries:start',
|
||||
extra={'org_id': org_id, 'user_id': keycloak_user_id},
|
||||
)
|
||||
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
|
||||
logger.warning('LiteLLM API configuration not found')
|
||||
return None
|
||||
|
||||
local_deploy = os.environ.get('LOCAL_DEPLOYMENT', None)
|
||||
if not local_deploy:
|
||||
async with httpx.AsyncClient(
|
||||
headers={
|
||||
'x-goog-api-key': LITE_LLM_API_KEY,
|
||||
}
|
||||
) as client:
|
||||
# Step 1: Get the team info to retrieve the budget
|
||||
logger.debug(
|
||||
'LiteLlmManager:downgrade_entries:get_team',
|
||||
extra={'org_id': org_id, 'user_id': keycloak_user_id},
|
||||
)
|
||||
team_info = await LiteLlmManager._get_team(client, org_id)
|
||||
if not team_info:
|
||||
logger.error(
|
||||
'LiteLlmManager:downgrade_entries:team_not_found',
|
||||
extra={'org_id': org_id, 'user_id': keycloak_user_id},
|
||||
)
|
||||
return None
|
||||
|
||||
# Get team budget (max_budget) and spend to calculate current credits
|
||||
team_data = team_info.get('team_info', {})
|
||||
max_budget = team_data.get('max_budget', 0.0)
|
||||
spend = team_data.get('spend', 0.0)
|
||||
|
||||
# Get user membership info for budget in team
|
||||
user_membership = await LiteLlmManager._get_user_team_info(
|
||||
client, keycloak_user_id, org_id
|
||||
)
|
||||
if user_membership:
|
||||
# Use user's budget in team if available
|
||||
user_max_budget_in_team = user_membership.get('max_budget_in_team')
|
||||
user_spend_in_team = user_membership.get('spend', 0.0)
|
||||
if user_max_budget_in_team is not None:
|
||||
max_budget = user_max_budget_in_team
|
||||
spend = user_spend_in_team
|
||||
|
||||
# Calculate total budget to restore (credits + spend = max_budget)
|
||||
# We restore the full max_budget that was on the team/user-in-team
|
||||
restored_budget = max_budget if max_budget else 0.0
|
||||
|
||||
logger.debug(
|
||||
'LiteLlmManager:downgrade_entries:budget_info',
|
||||
extra={
|
||||
'org_id': org_id,
|
||||
'user_id': keycloak_user_id,
|
||||
'max_budget': max_budget,
|
||||
'spend': spend,
|
||||
'restored_budget': restored_budget,
|
||||
},
|
||||
)
|
||||
|
||||
# Step 2: Update user to set their max_budget back from unlimited
|
||||
logger.debug(
|
||||
'LiteLlmManager:downgrade_entries:update_user',
|
||||
extra={'org_id': org_id, 'user_id': keycloak_user_id},
|
||||
)
|
||||
await LiteLlmManager._update_user(
|
||||
client, keycloak_user_id, max_budget=restored_budget, spend=spend
|
||||
)
|
||||
|
||||
# Step 3: Add user back to the default team
|
||||
if LITE_LLM_TEAM_ID:
|
||||
logger.debug(
|
||||
'LiteLlmManager:downgrade_entries:add_to_default_team',
|
||||
extra={
|
||||
'org_id': org_id,
|
||||
'user_id': keycloak_user_id,
|
||||
'default_team_id': LITE_LLM_TEAM_ID,
|
||||
},
|
||||
)
|
||||
await LiteLlmManager._add_user_to_team(
|
||||
client, keycloak_user_id, LITE_LLM_TEAM_ID, restored_budget
|
||||
)
|
||||
|
||||
# Step 4: Update keys to remove org team association (set team_id to default)
|
||||
if user_settings.llm_api_key:
|
||||
logger.debug(
|
||||
'LiteLlmManager:downgrade_entries:update_key',
|
||||
extra={'org_id': org_id, 'user_id': keycloak_user_id},
|
||||
)
|
||||
await LiteLlmManager._update_key(
|
||||
client,
|
||||
keycloak_user_id,
|
||||
user_settings.llm_api_key,
|
||||
team_id=LITE_LLM_TEAM_ID,
|
||||
)
|
||||
|
||||
if user_settings.llm_api_key_for_byor:
|
||||
logger.debug(
|
||||
'LiteLlmManager:downgrade_entries:update_byor_key',
|
||||
extra={'org_id': org_id, 'user_id': keycloak_user_id},
|
||||
)
|
||||
await LiteLlmManager._update_key(
|
||||
client,
|
||||
keycloak_user_id,
|
||||
user_settings.llm_api_key_for_byor,
|
||||
team_id=LITE_LLM_TEAM_ID,
|
||||
)
|
||||
|
||||
# Step 5: Remove user from their org team
|
||||
logger.debug(
|
||||
'LiteLlmManager:downgrade_entries:remove_from_org_team',
|
||||
extra={'org_id': org_id, 'user_id': keycloak_user_id},
|
||||
)
|
||||
await LiteLlmManager._remove_user_from_team(
|
||||
client, keycloak_user_id, org_id
|
||||
)
|
||||
|
||||
# Step 6: Delete the org team
|
||||
logger.debug(
|
||||
'LiteLlmManager:downgrade_entries:delete_team',
|
||||
extra={'org_id': org_id, 'user_id': keycloak_user_id},
|
||||
)
|
||||
await LiteLlmManager._delete_team(client, org_id)
|
||||
|
||||
logger.info(
|
||||
'LiteLlmManager:downgrade_entries:complete',
|
||||
extra={'org_id': org_id, 'user_id': keycloak_user_id},
|
||||
)
|
||||
return user_settings
|
||||
@@ -637,6 +794,45 @@ class LiteLlmManager:
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
async def _remove_user_from_team(
|
||||
client: httpx.AsyncClient,
|
||||
keycloak_user_id: str,
|
||||
team_id: str,
|
||||
):
|
||||
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
|
||||
logger.warning('LiteLLM API configuration not found')
|
||||
return
|
||||
response = await client.post(
|
||||
f'{LITE_LLM_API_URL}/team/member_delete',
|
||||
json={
|
||||
'team_id': team_id,
|
||||
'user_id': keycloak_user_id,
|
||||
},
|
||||
)
|
||||
if not response.is_success:
|
||||
if response.status_code == 404:
|
||||
# User not in team, that's fine for downgrade
|
||||
logger.info(
|
||||
'User not in team during removal',
|
||||
extra={'user_id': keycloak_user_id, 'team_id': team_id},
|
||||
)
|
||||
return
|
||||
logger.error(
|
||||
'error_removing_litellm_user_from_team',
|
||||
extra={
|
||||
'status_code': response.status_code,
|
||||
'text': response.text,
|
||||
'user_id': keycloak_user_id,
|
||||
'team_id': team_id,
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
logger.info(
|
||||
'LiteLlmManager:_remove_user_from_team:user_removed',
|
||||
extra={'user_id': keycloak_user_id, 'team_id': team_id},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def _generate_key(
|
||||
client: httpx.AsyncClient,
|
||||
@@ -880,6 +1076,7 @@ class LiteLlmManager:
|
||||
delete_user = staticmethod(with_http_client(_delete_user))
|
||||
delete_team = staticmethod(with_http_client(_delete_team))
|
||||
add_user_to_team = staticmethod(with_http_client(_add_user_to_team))
|
||||
remove_user_from_team = staticmethod(with_http_client(_remove_user_from_team))
|
||||
get_user_team_info = staticmethod(with_http_client(_get_user_team_info))
|
||||
update_user_in_team = staticmethod(with_http_client(_update_user_in_team))
|
||||
generate_key = staticmethod(with_http_client(_generate_key))
|
||||
|
||||
@@ -17,7 +17,10 @@ from server.logger import logger
|
||||
from sqlalchemy import select, text
|
||||
from sqlalchemy.orm import joinedload
|
||||
from storage.database import a_session_maker, session_maker
|
||||
from storage.encrypt_utils import decrypt_legacy_model
|
||||
from storage.encrypt_utils import (
|
||||
decrypt_legacy_model,
|
||||
encrypt_legacy_value,
|
||||
)
|
||||
from storage.org import Org
|
||||
from storage.org_member import OrgMember
|
||||
from storage.role_store import RoleStore
|
||||
@@ -159,7 +162,7 @@ class UserStore:
|
||||
|
||||
from storage.lite_llm_manager import LiteLlmManager
|
||||
|
||||
logger.info(
|
||||
logger.debug(
|
||||
'user_store:migrate_user:calling_litellm_migrate_entries',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
@@ -169,7 +172,7 @@ class UserStore:
|
||||
decrypted_user_settings,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
logger.debug(
|
||||
'user_store:migrate_user:done_litellm_migrate_entries',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
@@ -180,12 +183,12 @@ class UserStore:
|
||||
# avoids circular reference. This migrate method is temprorary until all users are migrated.
|
||||
from integrations.stripe_service import migrate_customer
|
||||
|
||||
logger.info(
|
||||
logger.debug(
|
||||
'user_store:migrate_user:calling_stripe_migrate_customer',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
await migrate_customer(session, user_id, org)
|
||||
logger.info(
|
||||
logger.debug(
|
||||
'user_store:migrate_user:done_stripe_migrate_customer',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
@@ -217,12 +220,12 @@ class UserStore:
|
||||
)
|
||||
session.add(user)
|
||||
|
||||
logger.info(
|
||||
logger.debug(
|
||||
'user_store:migrate_user:calling_get_role_by_name',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
role = await RoleStore.get_role_by_name_async('owner')
|
||||
logger.info(
|
||||
logger.debug(
|
||||
'user_store:migrate_user:done_get_role_by_name',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
@@ -238,7 +241,6 @@ class UserStore:
|
||||
if not custom_settings:
|
||||
del org_member_kwargs['llm_model']
|
||||
del org_member_kwargs['llm_base_url']
|
||||
del org_member_kwargs['llm_api_key_for_byor']
|
||||
|
||||
org_member = OrgMember(
|
||||
org_id=org.id,
|
||||
@@ -253,7 +255,7 @@ class UserStore:
|
||||
user_settings.already_migrated = True
|
||||
session.merge(user_settings)
|
||||
session.flush()
|
||||
logger.info(
|
||||
logger.debug(
|
||||
'user_store:migrate_user:session_flush_complete',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
@@ -324,12 +326,262 @@ class UserStore:
|
||||
session.commit()
|
||||
session.refresh(user)
|
||||
user.org_members # load org_members
|
||||
logger.info(
|
||||
logger.debug(
|
||||
'user_store:migrate_user:session_committed',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
return user
|
||||
|
||||
@staticmethod
|
||||
async def downgrade_user(user_id: str) -> UserSettings | None:
|
||||
"""Downgrade a migrated user back to the pre-migration state.
|
||||
|
||||
This reverses the migrate_user operation:
|
||||
1. Get the user's settings from user_settings table (migrated users) or
|
||||
create new user_settings from org_members table (new sign-ups)
|
||||
2. Call LiteLlmManager.downgrade_entries to revert LiteLLM state
|
||||
3. Copy user_id from conversation_metadata_saas to conversation_metadata
|
||||
4. Delete conversation_metadata_saas entries
|
||||
5. Reset org_id columns in related tables (stripe_customers, slack_users, etc.)
|
||||
6. Delete the org_member and org entries
|
||||
7. Delete the user entry
|
||||
8. Set already_migrated=False on user_settings
|
||||
|
||||
For new sign-ups (users who registered after migration was deployed),
|
||||
there won't be an existing user_settings entry. In this case, we fall back
|
||||
to the org_members table to get the user's API keys and settings, and create
|
||||
a new user_settings entry for them.
|
||||
|
||||
Args:
|
||||
user_id: The Keycloak user ID to downgrade
|
||||
|
||||
Returns:
|
||||
The user_settings if downgrade was successful, None otherwise.
|
||||
Returns None if the org has multiple members (not a personal org).
|
||||
"""
|
||||
logger.info(
|
||||
'user_store:downgrade_user:start',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
|
||||
with session_maker() as session:
|
||||
# Get the user and their org_member
|
||||
user = (
|
||||
session.query(User)
|
||||
.options(joinedload(User.org_members))
|
||||
.filter(User.id == uuid.UUID(user_id))
|
||||
.first()
|
||||
)
|
||||
if not user:
|
||||
logger.warning(
|
||||
'user_store:downgrade_user:user_not_found',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
return None
|
||||
|
||||
# Get the user's personal org (org_id == user_id)
|
||||
org = session.query(Org).filter(Org.id == uuid.UUID(user_id)).first()
|
||||
if not org:
|
||||
logger.warning(
|
||||
'user_store:downgrade_user:org_not_found',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
return None
|
||||
|
||||
# Get the user_settings (for migrated users)
|
||||
user_settings = (
|
||||
session.query(UserSettings)
|
||||
.filter(
|
||||
UserSettings.keycloak_user_id == user_id,
|
||||
UserSettings.already_migrated.is_(True),
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
# For new sign-ups after migration, user_settings won't exist
|
||||
# Fall back to getting data from org_members
|
||||
is_new_signup = False
|
||||
if not user_settings:
|
||||
logger.info(
|
||||
'user_store:downgrade_user:user_settings_not_found_checking_org_members',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
# Get org_members for this org - should only be one for personal orgs
|
||||
org_members = (
|
||||
session.query(OrgMember).filter(OrgMember.org_id == org.id).all()
|
||||
)
|
||||
|
||||
if len(org_members) != 1:
|
||||
logger.error(
|
||||
'user_store:downgrade_user:unexpected_org_members_count',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'org_id': str(org.id),
|
||||
'org_members_count': len(org_members),
|
||||
},
|
||||
)
|
||||
return None
|
||||
|
||||
org_member = org_members[0]
|
||||
is_new_signup = True
|
||||
|
||||
# Create a new user_settings entry from OrgMember, User, and Org data
|
||||
# This is needed for new sign-ups who don't have user_settings
|
||||
user_settings = UserStore._create_user_settings_from_entities(
|
||||
user_id, org_member, user, org
|
||||
)
|
||||
session.add(user_settings)
|
||||
session.flush()
|
||||
logger.info(
|
||||
'user_store:downgrade_user:created_user_settings_from_org_member',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
|
||||
# Call LiteLLM downgrade
|
||||
from storage.lite_llm_manager import LiteLlmManager
|
||||
|
||||
logger.debug(
|
||||
'user_store:downgrade_user:calling_litellm_downgrade_entries',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
|
||||
# Get the API keys for LiteLLM downgrade
|
||||
if is_new_signup:
|
||||
# For new signups, we already have decrypted values in user_settings
|
||||
decrypted_user_settings = user_settings
|
||||
else:
|
||||
# For migrated users, decrypt the legacy model
|
||||
kwargs = decrypt_legacy_model(
|
||||
[
|
||||
'llm_api_key',
|
||||
'llm_api_key_for_byor',
|
||||
'search_api_key',
|
||||
'sandbox_api_key',
|
||||
],
|
||||
user_settings,
|
||||
)
|
||||
decrypted_user_settings = UserSettings(**kwargs)
|
||||
|
||||
await LiteLlmManager.downgrade_entries(
|
||||
str(org.id),
|
||||
user_id,
|
||||
decrypted_user_settings,
|
||||
)
|
||||
logger.debug(
|
||||
'user_store:downgrade_user:done_litellm_downgrade_entries',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
|
||||
user_uuid = uuid.UUID(user_id)
|
||||
|
||||
# Step 3: Copy user_id from conversation_metadata_saas to conversation_metadata
|
||||
# This ensures any conversations created after migration have their user_id
|
||||
# preserved in the original table before we delete the saas entries
|
||||
session.execute(
|
||||
text("""
|
||||
UPDATE conversation_metadata
|
||||
SET user_id = :user_id
|
||||
WHERE conversation_id IN (
|
||||
SELECT conversation_id
|
||||
FROM conversation_metadata_saas
|
||||
WHERE user_id = :user_uuid
|
||||
)
|
||||
"""),
|
||||
{'user_id': user_id, 'user_uuid': user_uuid},
|
||||
)
|
||||
|
||||
# Step 4: Delete conversation_metadata_saas entries
|
||||
session.execute(
|
||||
text('DELETE FROM conversation_metadata_saas WHERE user_id = :user_id'),
|
||||
{'user_id': user_uuid},
|
||||
)
|
||||
|
||||
# Step 5: Reset org_id columns in related tables
|
||||
# Reset stripe_customers
|
||||
session.execute(
|
||||
text(
|
||||
'UPDATE stripe_customers SET org_id = NULL WHERE org_id = :org_id'
|
||||
),
|
||||
{'org_id': user_uuid},
|
||||
)
|
||||
|
||||
# Reset slack_users
|
||||
session.execute(
|
||||
text('UPDATE slack_users SET org_id = NULL WHERE org_id = :org_id'),
|
||||
{'org_id': user_uuid},
|
||||
)
|
||||
|
||||
# Reset slack_conversation
|
||||
session.execute(
|
||||
text(
|
||||
'UPDATE slack_conversation SET org_id = NULL WHERE org_id = :org_id'
|
||||
),
|
||||
{'org_id': user_uuid},
|
||||
)
|
||||
|
||||
# Reset api_keys
|
||||
session.execute(
|
||||
text('UPDATE api_keys SET org_id = NULL WHERE org_id = :org_id'),
|
||||
{'org_id': user_uuid},
|
||||
)
|
||||
|
||||
# Reset custom_secrets
|
||||
session.execute(
|
||||
text('UPDATE custom_secrets SET org_id = NULL WHERE org_id = :org_id'),
|
||||
{'org_id': user_uuid},
|
||||
)
|
||||
|
||||
# Reset billing_sessions
|
||||
session.execute(
|
||||
text(
|
||||
'UPDATE billing_sessions SET org_id = NULL WHERE org_id = :org_id'
|
||||
),
|
||||
{'org_id': user_uuid},
|
||||
)
|
||||
|
||||
# Step 6: Delete org_member entries for this org
|
||||
session.execute(
|
||||
text('DELETE FROM org_member WHERE org_id = :org_id'),
|
||||
{'org_id': user_uuid},
|
||||
)
|
||||
|
||||
# Step 7: Delete the user entry
|
||||
session.execute(
|
||||
text('DELETE FROM "user" WHERE id = :user_id'),
|
||||
{'user_id': user_uuid},
|
||||
)
|
||||
|
||||
# Delete the org entry
|
||||
session.execute(
|
||||
text('DELETE FROM org WHERE id = :org_id'),
|
||||
{'org_id': user_uuid},
|
||||
)
|
||||
|
||||
# Step 8: Set already_migrated=False on user_settings and encrypt fields
|
||||
user_settings.already_migrated = False
|
||||
|
||||
# Re-encrypt the sensitive fields before storing in the DB
|
||||
encrypt_keys = [
|
||||
'llm_api_key',
|
||||
'llm_api_key_for_byor',
|
||||
'search_api_key',
|
||||
'sandbox_api_key',
|
||||
]
|
||||
for key in encrypt_keys:
|
||||
value = getattr(user_settings, key, None)
|
||||
if value is not None:
|
||||
setattr(user_settings, key, encrypt_legacy_value(value))
|
||||
|
||||
session.merge(user_settings)
|
||||
|
||||
session.commit()
|
||||
|
||||
logger.info(
|
||||
'user_store:downgrade_user:complete',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
return user_settings
|
||||
|
||||
@staticmethod
|
||||
def get_user_by_id(user_id: str) -> Optional[User]:
|
||||
"""Get user by Keycloak user ID (sync version).
|
||||
@@ -520,6 +772,96 @@ class UserStore:
|
||||
}
|
||||
return kwargs
|
||||
|
||||
@staticmethod
|
||||
def _create_user_settings_from_entities(
|
||||
user_id: str, org_member: OrgMember, user: User, org: Org
|
||||
) -> UserSettings:
|
||||
"""Create UserSettings from OrgMember, User, and Org data.
|
||||
|
||||
Uses OrgMember values first. If an OrgMember field is None and there's
|
||||
a corresponding "default_" field in Org, use the Org value.
|
||||
Also pulls relevant fields from User.
|
||||
|
||||
Args:
|
||||
user_id: The Keycloak user ID
|
||||
org_member: The OrgMember entity
|
||||
user: The User entity
|
||||
org: The Org entity
|
||||
|
||||
Returns:
|
||||
A new UserSettings object populated from the entities
|
||||
"""
|
||||
# Mapping from OrgMember fields to corresponding Org "default_" fields
|
||||
org_member_to_org_default = {
|
||||
'llm_model': 'default_llm_model',
|
||||
'llm_base_url': 'default_llm_base_url',
|
||||
'max_iterations': 'default_max_iterations',
|
||||
}
|
||||
|
||||
def get_value_with_org_fallback(field_name: str, org_member_value):
|
||||
"""Get value from OrgMember, falling back to Org default if None."""
|
||||
if org_member_value is not None:
|
||||
return org_member_value
|
||||
org_default_field = org_member_to_org_default.get(field_name)
|
||||
if org_default_field and hasattr(org, org_default_field):
|
||||
return getattr(org, org_default_field)
|
||||
return None
|
||||
|
||||
# Get values from OrgMember with Org fallback for fields with default_ prefix
|
||||
llm_model = get_value_with_org_fallback('llm_model', org_member.llm_model)
|
||||
llm_base_url = get_value_with_org_fallback(
|
||||
'llm_base_url', org_member.llm_base_url
|
||||
)
|
||||
max_iterations = get_value_with_org_fallback(
|
||||
'max_iterations', org_member.max_iterations
|
||||
)
|
||||
|
||||
return UserSettings(
|
||||
keycloak_user_id=user_id,
|
||||
# OrgMember fields
|
||||
llm_api_key=org_member.llm_api_key.get_secret_value()
|
||||
if org_member.llm_api_key
|
||||
else None,
|
||||
llm_api_key_for_byor=org_member.llm_api_key_for_byor.get_secret_value()
|
||||
if org_member.llm_api_key_for_byor
|
||||
else None,
|
||||
llm_model=llm_model,
|
||||
llm_base_url=llm_base_url,
|
||||
max_iterations=max_iterations,
|
||||
# User fields
|
||||
accepted_tos=user.accepted_tos,
|
||||
enable_sound_notifications=user.enable_sound_notifications,
|
||||
language=user.language,
|
||||
user_consents_to_analytics=user.user_consents_to_analytics,
|
||||
email=user.email,
|
||||
email_verified=user.email_verified,
|
||||
git_user_name=user.git_user_name,
|
||||
git_user_email=user.git_user_email,
|
||||
# Org fields
|
||||
agent=org.agent,
|
||||
security_analyzer=org.security_analyzer,
|
||||
confirmation_mode=org.confirmation_mode,
|
||||
remote_runtime_resource_factor=org.remote_runtime_resource_factor,
|
||||
enable_default_condenser=org.enable_default_condenser,
|
||||
billing_margin=org.billing_margin,
|
||||
enable_proactive_conversation_starters=org.enable_proactive_conversation_starters,
|
||||
sandbox_base_container_image=org.sandbox_base_container_image,
|
||||
sandbox_runtime_container_image=org.sandbox_runtime_container_image,
|
||||
user_version=org.org_version,
|
||||
mcp_config=org.mcp_config,
|
||||
search_api_key=org.search_api_key.get_secret_value()
|
||||
if org.search_api_key
|
||||
else None,
|
||||
sandbox_api_key=org.sandbox_api_key.get_secret_value()
|
||||
if org.sandbox_api_key
|
||||
else None,
|
||||
max_budget_per_task=org.max_budget_per_task,
|
||||
enable_solvability_analysis=org.enable_solvability_analysis,
|
||||
v1_enabled=org.v1_enabled,
|
||||
condenser_max_size=org.condenser_max_size,
|
||||
already_migrated=False,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _has_custom_settings(
|
||||
user_settings: UserSettings, old_user_version: int | None
|
||||
|
||||
@@ -1126,3 +1126,174 @@ class TestLiteLlmManager:
|
||||
'http://test.url/team/delete',
|
||||
json={'team_ids': [team_id]},
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_remove_user_from_team_successful(self):
|
||||
"""
|
||||
GIVEN: Valid user_id and team_id
|
||||
WHEN: _remove_user_from_team is called
|
||||
THEN: HTTP POST is made to remove user from team
|
||||
"""
|
||||
mock_response = AsyncMock()
|
||||
mock_response.is_success = True
|
||||
mock_response.status_code = 200
|
||||
|
||||
with (
|
||||
patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'),
|
||||
patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.url'),
|
||||
):
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post.return_value = mock_response
|
||||
|
||||
await LiteLlmManager._remove_user_from_team(
|
||||
mock_client, 'test-user-id', 'test-team-id'
|
||||
)
|
||||
|
||||
mock_client.post.assert_called_once_with(
|
||||
'http://test.url/team/member_delete',
|
||||
json={
|
||||
'team_id': 'test-team-id',
|
||||
'user_id': 'test-user-id',
|
||||
},
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_remove_user_from_team_not_found(self):
|
||||
"""
|
||||
GIVEN: User not in team
|
||||
WHEN: _remove_user_from_team is called
|
||||
THEN: 404 response is handled gracefully without raising
|
||||
"""
|
||||
mock_response = AsyncMock()
|
||||
mock_response.is_success = False
|
||||
mock_response.status_code = 404
|
||||
mock_response.text = 'User not found in team'
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
with (
|
||||
patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'),
|
||||
patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.url'),
|
||||
):
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post.return_value = mock_response
|
||||
|
||||
# Should not raise an exception
|
||||
await LiteLlmManager._remove_user_from_team(
|
||||
mock_client, 'test-user-id', 'test-team-id'
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_downgrade_entries_missing_config(self, mock_user_settings):
|
||||
"""Test downgrade_entries when LiteLLM config is missing."""
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', None):
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_URL', None):
|
||||
result = await LiteLlmManager.downgrade_entries(
|
||||
'test-org-id',
|
||||
'test-user-id',
|
||||
mock_user_settings,
|
||||
)
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_downgrade_entries_team_not_found(self, mock_user_settings):
|
||||
"""Test downgrade_entries when team is not found."""
|
||||
with patch.dict(os.environ, {'LOCAL_DEPLOYMENT': ''}):
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'):
|
||||
with patch(
|
||||
'storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'
|
||||
):
|
||||
with patch.object(
|
||||
LiteLlmManager, '_get_team', new_callable=AsyncMock
|
||||
) as mock_get_team:
|
||||
mock_get_team.return_value = None
|
||||
|
||||
result = await LiteLlmManager.downgrade_entries(
|
||||
'test-org-id',
|
||||
'test-user-id',
|
||||
mock_user_settings,
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_downgrade_entries_successful(self, mock_user_settings):
|
||||
"""Test successful downgrade_entries operation."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.is_success = True
|
||||
mock_response.status_code = 200
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_team_info_response = MagicMock()
|
||||
mock_team_info_response.is_success = True
|
||||
mock_team_info_response.status_code = 200
|
||||
mock_team_info_response.json.return_value = {
|
||||
'team_info': {
|
||||
'max_budget': 100.0,
|
||||
'spend': 20.0,
|
||||
},
|
||||
'team_memberships': [
|
||||
{
|
||||
'user_id': 'test-user-id',
|
||||
'team_id': 'test-org-id',
|
||||
'max_budget_in_team': 100.0,
|
||||
'spend': 20.0,
|
||||
}
|
||||
],
|
||||
}
|
||||
mock_team_info_response.raise_for_status = MagicMock()
|
||||
|
||||
with patch.dict(os.environ, {'LOCAL_DEPLOYMENT': ''}):
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'):
|
||||
with patch(
|
||||
'storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'
|
||||
):
|
||||
with patch(
|
||||
'storage.lite_llm_manager.LITE_LLM_TEAM_ID', 'default-team'
|
||||
):
|
||||
with patch('httpx.AsyncClient') as mock_client_class:
|
||||
mock_client = AsyncMock()
|
||||
mock_client_class.return_value.__aenter__.return_value = (
|
||||
mock_client
|
||||
)
|
||||
mock_client.get.return_value = mock_team_info_response
|
||||
mock_client.post.return_value = mock_response
|
||||
|
||||
result = await LiteLlmManager.downgrade_entries(
|
||||
'test-org-id',
|
||||
'test-user-id',
|
||||
mock_user_settings,
|
||||
)
|
||||
|
||||
# downgrade_entries returns the user_settings
|
||||
assert result is not None
|
||||
assert result.agent == 'TestAgent'
|
||||
|
||||
# Verify downgrade steps were called:
|
||||
# 1. get_team (GET)
|
||||
# 2. get_user_team_info (GET via _get_team)
|
||||
# 3. update_user (POST)
|
||||
# 4. add_user_to_team (POST)
|
||||
# 5. update_key (POST)
|
||||
# 6. remove_user_from_team (POST)
|
||||
# 7. delete_team (POST)
|
||||
assert mock_client.get.call_count >= 1
|
||||
assert mock_client.post.call_count >= 4
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_downgrade_entries_local_deployment(self, mock_user_settings):
|
||||
"""Test downgrade_entries in local deployment mode (skips LiteLLM calls)."""
|
||||
with patch.dict(os.environ, {'LOCAL_DEPLOYMENT': 'true'}):
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'):
|
||||
with patch(
|
||||
'storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'
|
||||
):
|
||||
result = await LiteLlmManager.downgrade_entries(
|
||||
'test-org-id',
|
||||
'test-user-id',
|
||||
mock_user_settings,
|
||||
)
|
||||
|
||||
# In local deployment, should return user_settings without
|
||||
# making any LiteLLM calls
|
||||
assert result is not None
|
||||
assert result.agent == 'TestAgent'
|
||||
|
||||
@@ -1,9 +1,36 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
# Mock the database module before importing RoleStore
|
||||
with patch('storage.database.engine'), patch('storage.database.a_engine'):
|
||||
from storage.role import Role
|
||||
from storage.role_store import RoleStore
|
||||
import pytest
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.pool import StaticPool
|
||||
from storage.base import Base
|
||||
from storage.role import Role
|
||||
from storage.role_store import RoleStore
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def async_engine():
|
||||
"""Create an async SQLite engine for testing."""
|
||||
engine = create_async_engine(
|
||||
'sqlite+aiosqlite:///:memory:',
|
||||
poolclass=StaticPool,
|
||||
connect_args={'check_same_thread': False},
|
||||
echo=False,
|
||||
)
|
||||
|
||||
# Create all tables
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
yield engine
|
||||
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def async_session_maker(async_engine):
|
||||
"""Create an async session maker for testing."""
|
||||
return async_sessionmaker(async_engine, class_=AsyncSession, expire_on_commit=False)
|
||||
|
||||
|
||||
def test_get_role_by_id(session_maker):
|
||||
@@ -81,3 +108,63 @@ def test_create_role(session_maker):
|
||||
assert role.name == 'moderator'
|
||||
assert role.rank == 2
|
||||
assert role.id is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_role_by_name_async_with_session(async_session_maker):
|
||||
"""Test getting role by name asynchronously with an explicit session."""
|
||||
# Create a test role
|
||||
async with async_session_maker() as session:
|
||||
role = Role(name='admin', rank=1)
|
||||
session.add(role)
|
||||
await session.commit()
|
||||
await session.refresh(role)
|
||||
role_id = role.id
|
||||
|
||||
# Test retrieval with explicit session
|
||||
async with async_session_maker() as session:
|
||||
retrieved_role = await RoleStore.get_role_by_name_async(
|
||||
'admin', session=session
|
||||
)
|
||||
assert retrieved_role is not None
|
||||
assert retrieved_role.id == role_id
|
||||
assert retrieved_role.name == 'admin'
|
||||
assert retrieved_role.rank == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_role_by_name_async_without_session(async_session_maker):
|
||||
"""Test getting role by name asynchronously using internal session maker."""
|
||||
# Create a test role
|
||||
async with async_session_maker() as session:
|
||||
role = Role(name='editor', rank=2)
|
||||
session.add(role)
|
||||
await session.commit()
|
||||
await session.refresh(role)
|
||||
role_id = role.id
|
||||
|
||||
# Test retrieval without explicit session (using patched a_session_maker)
|
||||
with patch('storage.role_store.a_session_maker', async_session_maker):
|
||||
retrieved_role = await RoleStore.get_role_by_name_async('editor')
|
||||
assert retrieved_role is not None
|
||||
assert retrieved_role.id == role_id
|
||||
assert retrieved_role.name == 'editor'
|
||||
assert retrieved_role.rank == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_role_by_name_async_not_found_with_session(async_session_maker):
|
||||
"""Test getting role by name when it doesn't exist (with explicit session)."""
|
||||
async with async_session_maker() as session:
|
||||
retrieved_role = await RoleStore.get_role_by_name_async(
|
||||
'nonexistent', session=session
|
||||
)
|
||||
assert retrieved_role is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_role_by_name_async_not_found_without_session(async_session_maker):
|
||||
"""Test getting role by name when it doesn't exist (without explicit session)."""
|
||||
with patch('storage.role_store.a_session_maker', async_session_maker):
|
||||
retrieved_role = await RoleStore.get_role_by_name_async('nonexistent')
|
||||
assert retrieved_role is None
|
||||
|
||||
@@ -0,0 +1,186 @@
|
||||
import { render, screen, waitFor } from "@testing-library/react";
|
||||
import { describe, expect, vi, beforeEach, it } from "vitest";
|
||||
import { QueryClient, QueryClientProvider } from "@tanstack/react-query";
|
||||
import userEvent from "@testing-library/user-event";
|
||||
import { GitBranchDropdown } from "../../../../src/components/features/home/git-branch-dropdown/git-branch-dropdown";
|
||||
import { Branch } from "#/types/git";
|
||||
|
||||
// Mock the branch data hook
|
||||
const mockUseBranchData = vi.fn();
|
||||
vi.mock("#/hooks/query/use-branch-data", () => ({
|
||||
useBranchData: (...args: unknown[]) => mockUseBranchData(...args),
|
||||
}));
|
||||
|
||||
const MOCK_BRANCHES: Branch[] = [
|
||||
{ name: "main", commit_sha: "abc123", protected: true },
|
||||
{ name: "develop", commit_sha: "def456", protected: false },
|
||||
{ name: "feature/test", commit_sha: "ghi789", protected: false },
|
||||
];
|
||||
|
||||
const mockOnBranchSelect = vi.fn();
|
||||
|
||||
const renderDropdown = (
|
||||
props: Partial<Parameters<typeof GitBranchDropdown>[0]> = {},
|
||||
) => {
|
||||
// Default mock return value
|
||||
mockUseBranchData.mockReturnValue({
|
||||
branches: MOCK_BRANCHES,
|
||||
isLoading: false,
|
||||
isError: false,
|
||||
fetchNextPage: vi.fn(),
|
||||
hasNextPage: false,
|
||||
isFetchingNextPage: false,
|
||||
isSearchLoading: false,
|
||||
});
|
||||
|
||||
return render(
|
||||
<GitBranchDropdown
|
||||
repository="user/repo"
|
||||
provider="github"
|
||||
selectedBranch={null}
|
||||
onBranchSelect={mockOnBranchSelect}
|
||||
// eslint-disable-next-line react/jsx-props-no-spreading
|
||||
{...props}
|
||||
/>,
|
||||
{
|
||||
wrapper: ({ children }) => (
|
||||
<QueryClientProvider
|
||||
client={
|
||||
new QueryClient({
|
||||
defaultOptions: {
|
||||
queries: {
|
||||
retry: false,
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
>
|
||||
{children}
|
||||
</QueryClientProvider>
|
||||
),
|
||||
},
|
||||
);
|
||||
};
|
||||
|
||||
describe("GitBranchDropdown", () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
describe("dropdown behavior", () => {
|
||||
it("should open dropdown when input is clicked", async () => {
|
||||
renderDropdown();
|
||||
|
||||
const input = screen.getByTestId("git-branch-dropdown-input");
|
||||
await userEvent.click(input);
|
||||
|
||||
// Dropdown should be open (menu should be visible)
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
screen.getByTestId("git-branch-dropdown-menu"),
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
it("should keep dropdown open when clicking input while already open", async () => {
|
||||
renderDropdown();
|
||||
|
||||
const input = screen.getByTestId("git-branch-dropdown-input");
|
||||
|
||||
// First click - open dropdown
|
||||
await userEvent.click(input);
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
screen.getByTestId("git-branch-dropdown-menu"),
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
|
||||
// Second click on input - should stay open (not close)
|
||||
await userEvent.click(input);
|
||||
|
||||
// Dropdown should still be open
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
screen.getByTestId("git-branch-dropdown-menu"),
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
it("should preserve typed text when clicking input while typing", async () => {
|
||||
renderDropdown();
|
||||
|
||||
const input = screen.getByTestId(
|
||||
"git-branch-dropdown-input",
|
||||
) as HTMLInputElement;
|
||||
|
||||
// Click to open and type
|
||||
await userEvent.click(input);
|
||||
await userEvent.type(input, "feat");
|
||||
|
||||
expect(input.value).toBe("feat");
|
||||
|
||||
// Click on input again (should not reset text)
|
||||
await userEvent.click(input);
|
||||
|
||||
// Text should be preserved
|
||||
expect(input.value).toBe("feat");
|
||||
});
|
||||
});
|
||||
|
||||
describe("cursor position preservation", () => {
|
||||
it("should allow editing in the middle of input text", async () => {
|
||||
renderDropdown();
|
||||
|
||||
const input = screen.getByTestId(
|
||||
"git-branch-dropdown-input",
|
||||
) as HTMLInputElement;
|
||||
|
||||
// Click and type initial text
|
||||
await userEvent.click(input);
|
||||
await userEvent.type(input, "hello");
|
||||
|
||||
expect(input.value).toBe("hello");
|
||||
|
||||
// Move cursor to position 2 and type
|
||||
input.setSelectionRange(2, 2);
|
||||
await userEvent.type(input, "X");
|
||||
|
||||
// The character should be inserted (exact position may vary based on browser behavior)
|
||||
expect(input.value).toContain("X");
|
||||
});
|
||||
});
|
||||
|
||||
describe("input synchronization", () => {
|
||||
it("should show selected branch name in input when provided", async () => {
|
||||
const selectedBranch = MOCK_BRANCHES[0];
|
||||
renderDropdown({ selectedBranch });
|
||||
|
||||
const input = screen.getByTestId(
|
||||
"git-branch-dropdown-input",
|
||||
) as HTMLInputElement;
|
||||
|
||||
await waitFor(() => {
|
||||
expect(input.value).toBe(selectedBranch.name);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe("branch selection", () => {
|
||||
it("should call onBranchSelect when a branch is selected", async () => {
|
||||
renderDropdown();
|
||||
|
||||
const input = screen.getByTestId("git-branch-dropdown-input");
|
||||
await userEvent.click(input);
|
||||
|
||||
// Wait for dropdown to open and show branches
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText("main")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
// Click on a branch
|
||||
await userEvent.click(screen.getByText("develop"));
|
||||
|
||||
expect(mockOnBranchSelect).toHaveBeenCalledWith(MOCK_BRANCHES[1]);
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,234 @@
|
||||
import { render, screen, waitFor } from "@testing-library/react";
|
||||
import { describe, expect, vi, beforeEach, it } from "vitest";
|
||||
import { QueryClient, QueryClientProvider } from "@tanstack/react-query";
|
||||
import userEvent from "@testing-library/user-event";
|
||||
import { GitRepoDropdown } from "../../../../src/components/features/home/git-repo-dropdown/git-repo-dropdown";
|
||||
import { GitRepository } from "#/types/git";
|
||||
|
||||
// Mock the repository data hook
|
||||
const mockUseRepositoryData = vi.fn();
|
||||
vi.mock(
|
||||
"#/components/features/home/git-repo-dropdown/use-repository-data",
|
||||
() => ({
|
||||
useRepositoryData: (...args: unknown[]) => mockUseRepositoryData(...args),
|
||||
}),
|
||||
);
|
||||
|
||||
// Mock the URL search hook
|
||||
const mockUseUrlSearch = vi.fn();
|
||||
vi.mock("#/components/features/home/git-repo-dropdown/use-url-search", () => ({
|
||||
useUrlSearch: (...args: unknown[]) => mockUseUrlSearch(...args),
|
||||
}));
|
||||
|
||||
// Mock useConfig
|
||||
vi.mock("#/hooks/query/use-config", () => ({
|
||||
useConfig: () => ({ data: null }),
|
||||
}));
|
||||
|
||||
// Mock useHomeStore
|
||||
vi.mock("#/stores/home-store", () => ({
|
||||
useHomeStore: () => ({ recentRepositories: [] }),
|
||||
}));
|
||||
|
||||
const MOCK_REPOSITORIES: GitRepository[] = [
|
||||
{
|
||||
id: "1",
|
||||
full_name: "user/repo-one",
|
||||
git_provider: "github",
|
||||
is_public: true,
|
||||
},
|
||||
{
|
||||
id: "2",
|
||||
full_name: "user/repo-two",
|
||||
git_provider: "github",
|
||||
is_public: true,
|
||||
},
|
||||
{
|
||||
id: "3",
|
||||
full_name: "org/feature-repo",
|
||||
git_provider: "github",
|
||||
is_public: false,
|
||||
},
|
||||
];
|
||||
|
||||
const mockOnChange = vi.fn();
|
||||
|
||||
const setupDefaultMocks = (
|
||||
repositoryDataOverrides: Partial<
|
||||
ReturnType<typeof mockUseRepositoryData>
|
||||
> = {},
|
||||
) => {
|
||||
mockUseRepositoryData.mockReturnValue({
|
||||
repositories: MOCK_REPOSITORIES,
|
||||
selectedRepository: null,
|
||||
isLoading: false,
|
||||
isError: false,
|
||||
fetchNextPage: vi.fn(),
|
||||
hasNextPage: false,
|
||||
isFetchingNextPage: false,
|
||||
isSearchLoading: false,
|
||||
...repositoryDataOverrides,
|
||||
});
|
||||
|
||||
mockUseUrlSearch.mockReturnValue({
|
||||
urlSearchResults: [],
|
||||
isUrlSearchLoading: false,
|
||||
});
|
||||
};
|
||||
|
||||
const renderDropdown = (
|
||||
props: Partial<Parameters<typeof GitRepoDropdown>[0]> = {},
|
||||
repositoryDataOverrides: Partial<
|
||||
ReturnType<typeof mockUseRepositoryData>
|
||||
> = {},
|
||||
) => {
|
||||
// Set up mocks with optional overrides
|
||||
setupDefaultMocks(repositoryDataOverrides);
|
||||
|
||||
return render(
|
||||
<GitRepoDropdown
|
||||
provider="github"
|
||||
onChange={mockOnChange}
|
||||
// eslint-disable-next-line react/jsx-props-no-spreading
|
||||
{...props}
|
||||
/>,
|
||||
{
|
||||
wrapper: ({ children }) => (
|
||||
<QueryClientProvider
|
||||
client={
|
||||
new QueryClient({
|
||||
defaultOptions: {
|
||||
queries: {
|
||||
retry: false,
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
>
|
||||
{children}
|
||||
</QueryClientProvider>
|
||||
),
|
||||
},
|
||||
);
|
||||
};
|
||||
|
||||
describe("GitRepoDropdown", () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
describe("dropdown behavior", () => {
|
||||
it("should open dropdown when input is clicked", async () => {
|
||||
renderDropdown();
|
||||
|
||||
const input = screen.getByTestId("git-repo-dropdown");
|
||||
await userEvent.click(input);
|
||||
|
||||
// Dropdown should be open (menu should be visible)
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
screen.getByTestId("git-repo-dropdown-menu"),
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
it("should keep dropdown open when clicking input while already open", async () => {
|
||||
renderDropdown();
|
||||
|
||||
const input = screen.getByTestId("git-repo-dropdown");
|
||||
|
||||
// First click - open dropdown
|
||||
await userEvent.click(input);
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
screen.getByTestId("git-repo-dropdown-menu"),
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
|
||||
// Second click on input - should stay open (not close)
|
||||
await userEvent.click(input);
|
||||
|
||||
// Dropdown should still be open
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
screen.getByTestId("git-repo-dropdown-menu"),
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
it("should preserve typed text when clicking input while typing", async () => {
|
||||
renderDropdown();
|
||||
|
||||
const input = screen.getByTestId("git-repo-dropdown") as HTMLInputElement;
|
||||
|
||||
// Click to open and type
|
||||
await userEvent.click(input);
|
||||
await userEvent.type(input, "repo");
|
||||
|
||||
expect(input.value).toBe("repo");
|
||||
|
||||
// Click on input again (should not reset text)
|
||||
await userEvent.click(input);
|
||||
|
||||
// Text should be preserved
|
||||
expect(input.value).toBe("repo");
|
||||
});
|
||||
});
|
||||
|
||||
describe("cursor position preservation", () => {
|
||||
it("should allow editing in the middle of input text", async () => {
|
||||
renderDropdown();
|
||||
|
||||
const input = screen.getByTestId("git-repo-dropdown") as HTMLInputElement;
|
||||
|
||||
// Click and type initial text
|
||||
await userEvent.click(input);
|
||||
await userEvent.type(input, "hello");
|
||||
|
||||
expect(input.value).toBe("hello");
|
||||
|
||||
// Move cursor to position 2 and type
|
||||
input.setSelectionRange(2, 2);
|
||||
await userEvent.type(input, "X");
|
||||
|
||||
// The character should be inserted (exact position may vary based on browser behavior)
|
||||
expect(input.value).toContain("X");
|
||||
});
|
||||
});
|
||||
|
||||
describe("input synchronization", () => {
|
||||
it("should show selected repository name in input when provided", async () => {
|
||||
const selectedRepository = MOCK_REPOSITORIES[0];
|
||||
|
||||
renderDropdown(
|
||||
{ value: selectedRepository.full_name },
|
||||
{ selectedRepository },
|
||||
);
|
||||
|
||||
const input = screen.getByTestId("git-repo-dropdown") as HTMLInputElement;
|
||||
|
||||
await waitFor(() => {
|
||||
expect(input.value).toBe(selectedRepository.full_name);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe("repository selection", () => {
|
||||
it("should call onChange when a repository is selected", async () => {
|
||||
renderDropdown();
|
||||
|
||||
const input = screen.getByTestId("git-repo-dropdown");
|
||||
await userEvent.click(input);
|
||||
|
||||
// Wait for dropdown to open and show repositories
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText("user/repo-one")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
// Click on a repository
|
||||
await userEvent.click(screen.getByText("user/repo-two"));
|
||||
|
||||
expect(mockOnChange).toHaveBeenCalledWith(MOCK_REPOSITORIES[1]);
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -87,16 +87,6 @@ export function GitBranchDropdown({
|
||||
[onBranchSelect],
|
||||
);
|
||||
|
||||
// Handle input value change
|
||||
const handleInputValueChange = useCallback(
|
||||
({ inputValue: newInputValue }: { inputValue?: string }) => {
|
||||
if (newInputValue !== undefined) {
|
||||
setInputValue(newInputValue);
|
||||
}
|
||||
},
|
||||
[],
|
||||
);
|
||||
|
||||
// Handle menu scroll for infinite loading
|
||||
const handleMenuScroll = useCallback(
|
||||
(event: React.UIEvent<HTMLUListElement>) => {
|
||||
@@ -128,8 +118,14 @@ export function GitBranchDropdown({
|
||||
onSelectedItemChange: ({ selectedItem: newSelectedItem }) => {
|
||||
handleBranchSelect(newSelectedItem || null);
|
||||
},
|
||||
onInputValueChange: handleInputValueChange,
|
||||
inputValue,
|
||||
// Override Downshift's default input-click behavior to avoid closing/reopening
|
||||
// the menu, which would reset scroll position and break search continuity.
|
||||
stateReducer: (state, actionAndChanges) =>
|
||||
actionAndChanges.type === useCombobox.stateChangeTypes.InputClick &&
|
||||
state.isOpen
|
||||
? { ...actionAndChanges.changes, isOpen: true }
|
||||
: actionAndChanges.changes,
|
||||
});
|
||||
|
||||
// Reset branch selection when repository changes
|
||||
@@ -176,12 +172,12 @@ export function GitBranchDropdown({
|
||||
|
||||
// Initialize input value when selectedBranch changes (but not when user is typing)
|
||||
useEffect(() => {
|
||||
if (selectedBranch && !isOpen && inputValue !== selectedBranch.name) {
|
||||
if (selectedBranch && !isOpen) {
|
||||
setInputValue(selectedBranch.name);
|
||||
} else if (!selectedBranch && !isOpen && inputValue) {
|
||||
} else if (!selectedBranch && !isOpen) {
|
||||
setInputValue("");
|
||||
}
|
||||
}, [selectedBranch, isOpen, inputValue]);
|
||||
}, [selectedBranch, isOpen]);
|
||||
|
||||
const isLoadingState = isLoading || isSearchLoading || isFetchingNextPage;
|
||||
|
||||
@@ -207,6 +203,10 @@ export function GitBranchDropdown({
|
||||
"disabled:bg-[#363636] disabled:cursor-not-allowed disabled:opacity-60",
|
||||
"pl-7 pr-16 text-sm font-normal leading-5", // Space for clear and toggle buttons
|
||||
),
|
||||
// Direct onChange for cursor position preservation
|
||||
onChange: (e: React.ChangeEvent<HTMLInputElement>) => {
|
||||
setInputValue(e.target.value);
|
||||
},
|
||||
})}
|
||||
data-testid="git-branch-dropdown-input"
|
||||
/>
|
||||
|
||||
@@ -184,14 +184,6 @@ export function GitRepoDropdown({
|
||||
setInputValue("");
|
||||
}, [handleSelectionChange]);
|
||||
|
||||
// Handle input value change
|
||||
const handleInputValueChange = useCallback(
|
||||
({ inputValue: newInputValue }: { inputValue?: string }) => {
|
||||
setInputValue(newInputValue || "");
|
||||
},
|
||||
[],
|
||||
);
|
||||
|
||||
// Handle scroll to bottom for pagination
|
||||
const handleMenuScroll = useCallback(
|
||||
(event: React.UIEvent<HTMLUListElement>) => {
|
||||
@@ -220,8 +212,14 @@ export function GitRepoDropdown({
|
||||
onSelectedItemChange: ({ selectedItem: newSelectedItem }) => {
|
||||
handleSelectionChange(newSelectedItem);
|
||||
},
|
||||
onInputValueChange: handleInputValueChange,
|
||||
inputValue,
|
||||
// Override Downshift's default input-click behavior to avoid closing/reopening
|
||||
// the menu, which would reset scroll position and break search continuity.
|
||||
stateReducer: (state, actionAndChanges) =>
|
||||
actionAndChanges.type === useCombobox.stateChangeTypes.InputClick &&
|
||||
state.isOpen
|
||||
? { ...actionAndChanges.changes, isOpen: true }
|
||||
: actionAndChanges.changes,
|
||||
});
|
||||
|
||||
// Sync localSelectedItem with external value prop
|
||||
@@ -237,6 +235,8 @@ export function GitRepoDropdown({
|
||||
useEffect(() => {
|
||||
if (selectedRepository && !isOpen) {
|
||||
setInputValue(selectedRepository.full_name);
|
||||
} else if (!selectedRepository && !isOpen) {
|
||||
setInputValue("");
|
||||
}
|
||||
}, [selectedRepository, isOpen]);
|
||||
|
||||
@@ -335,6 +335,10 @@ export function GitRepoDropdown({
|
||||
"disabled:bg-[#363636] disabled:cursor-not-allowed disabled:opacity-60",
|
||||
"pl-7 pr-16 text-sm font-normal leading-5", // Space for clear and toggle buttons
|
||||
),
|
||||
// Direct onChange for cursor position preservation
|
||||
onChange: (e: React.ChangeEvent<HTMLInputElement>) => {
|
||||
setInputValue(e.target.value);
|
||||
},
|
||||
})}
|
||||
data-testid="git-repo-dropdown"
|
||||
/>
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Literal
|
||||
from typing import Any, Literal
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
@@ -14,6 +14,7 @@ from openhands.app_server.sandbox.sandbox_models import SandboxStatus
|
||||
from openhands.integrations.service_types import ProviderType
|
||||
from openhands.sdk.conversation.state import ConversationExecutionStatus
|
||||
from openhands.sdk.llm import MetricsSnapshot
|
||||
from openhands.sdk.plugin import PluginSource
|
||||
from openhands.storage.data_models.conversation_metadata import ConversationTrigger
|
||||
|
||||
|
||||
@@ -24,6 +25,45 @@ class AgentType(Enum):
|
||||
PLAN = 'plan'
|
||||
|
||||
|
||||
class PluginSpec(PluginSource):
|
||||
"""Specification for loading a plugin into a conversation.
|
||||
|
||||
Extends SDK's PluginSource with user-provided plugin configuration parameters.
|
||||
Inherits source, ref, and repo_path fields along with their validation.
|
||||
"""
|
||||
|
||||
parameters: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
description='User-provided values for plugin input parameters',
|
||||
)
|
||||
|
||||
@property
|
||||
def display_name(self) -> str:
|
||||
"""Extract a friendly display name from the plugin source.
|
||||
|
||||
Examples:
|
||||
- 'github:owner/repo' -> 'repo'
|
||||
- 'https://github.com/owner/repo.git' -> 'repo.git'
|
||||
- '/local/path' -> 'path'
|
||||
"""
|
||||
return self.source.split('/')[-1] if '/' in self.source else self.source
|
||||
|
||||
def format_params_as_text(self, indent: str = '') -> str | None:
|
||||
"""Format parameters as a readable text block for display.
|
||||
|
||||
Args:
|
||||
indent: Optional prefix to add before each parameter line.
|
||||
|
||||
Returns:
|
||||
Formatted parameters string, or None if no parameters.
|
||||
"""
|
||||
if not self.parameters:
|
||||
return None
|
||||
return '\n'.join(
|
||||
f'{indent}- {key}: {value}' for key, value in self.parameters.items()
|
||||
)
|
||||
|
||||
|
||||
class AppConversationInfo(BaseModel):
|
||||
"""Conversation info which does not contain status."""
|
||||
|
||||
@@ -118,6 +158,15 @@ class AppConversationStartRequest(OpenHandsModel):
|
||||
|
||||
public: bool | None = None
|
||||
|
||||
# Plugin parameters - for loading remote plugins into the conversation
|
||||
plugins: list[PluginSpec] | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
'List of plugins to load for this conversation. Plugins are loaded '
|
||||
'and their skills/MCP config are merged into the agent.'
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class AppConversationUpdateRequest(BaseModel):
|
||||
public: bool
|
||||
@@ -147,7 +196,8 @@ class AppConversationStartTask(OpenHandsModel):
|
||||
|
||||
Because starting an app conversation can be slow (And can involve starting a sandbox),
|
||||
we kick off a background task for it. Once the conversation is started, the app_conversation_id
|
||||
is populated."""
|
||||
is populated.
|
||||
"""
|
||||
|
||||
id: OpenHandsUUID = Field(default_factory=uuid4)
|
||||
created_by_user_id: str | None
|
||||
|
||||
@@ -621,7 +621,6 @@ async def _stream_app_conversation_start(
|
||||
user_context: UserContext,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Stream a json list, item by item."""
|
||||
|
||||
# Because the original dependencies are closed after the method returns, we need
|
||||
# a new dependency context which will continue intil the stream finishes.
|
||||
state = InjectorState()
|
||||
|
||||
@@ -105,7 +105,8 @@ class AppConversationService(ABC):
|
||||
self, conversation_id: UUID, request: AppConversationUpdateRequest
|
||||
) -> AppConversation | None:
|
||||
"""Update an app conversation and return it. Return None if the conversation
|
||||
did not exist."""
|
||||
did not exist.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def delete_app_conversation(self, conversation_id: UUID) -> bool:
|
||||
|
||||
@@ -51,7 +51,8 @@ PRE_COMMIT_LOCAL = '.git/hooks/pre-commit.local'
|
||||
class AppConversationServiceBase(AppConversationService, ABC):
|
||||
"""App Conversation service which adds git specific functionality.
|
||||
|
||||
Sets up repositories and installs hooks"""
|
||||
Sets up repositories and installs hooks
|
||||
"""
|
||||
|
||||
init_git_in_empty_workspace: bool
|
||||
user_context: UserContext
|
||||
@@ -484,7 +485,6 @@ class AppConversationServiceBase(AppConversationService, ABC):
|
||||
security_analyzer_str: String value from settings
|
||||
httpx_client: HTTP client for making API requests
|
||||
"""
|
||||
|
||||
if session_api_key is None:
|
||||
return
|
||||
|
||||
|
||||
@@ -32,6 +32,7 @@ from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
AppConversationStartTask,
|
||||
AppConversationStartTaskStatus,
|
||||
AppConversationUpdateRequest,
|
||||
PluginSpec,
|
||||
)
|
||||
from openhands.app_server.app_conversation.app_conversation_service import (
|
||||
AppConversationService,
|
||||
@@ -79,6 +80,7 @@ from openhands.experiments.experiment_manager import ExperimentManagerImpl
|
||||
from openhands.integrations.provider import ProviderType
|
||||
from openhands.sdk import Agent, AgentContext, LocalWorkspace
|
||||
from openhands.sdk.llm import LLM
|
||||
from openhands.sdk.plugin import PluginSource
|
||||
from openhands.sdk.secret import LookupSecret, SecretValue, StaticSecret
|
||||
from openhands.sdk.utils.paging import page_iterator
|
||||
from openhands.sdk.workspace.remote.async_remote_workspace import AsyncRemoteWorkspace
|
||||
@@ -254,6 +256,7 @@ class LiveStatusAppConversationService(AppConversationServiceBase):
|
||||
request.conversation_id,
|
||||
remote_workspace=remote_workspace,
|
||||
selected_repository=request.selected_repository,
|
||||
plugins=request.plugins,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -954,6 +957,79 @@ class LiveStatusAppConversationService(AppConversationServiceBase):
|
||||
return agent.model_copy(update=updates)
|
||||
return agent
|
||||
|
||||
def _construct_initial_message_with_plugin_params(
|
||||
self,
|
||||
initial_message: SendMessageRequest | None,
|
||||
plugins: list[PluginSpec] | None,
|
||||
) -> SendMessageRequest | None:
|
||||
"""Incorporate plugin parameters into the initial message if specified.
|
||||
|
||||
Plugin parameters are formatted and appended to the initial message so the
|
||||
agent has context about the user-provided configuration values.
|
||||
|
||||
Args:
|
||||
initial_message: The original initial message, if any
|
||||
plugins: List of plugin specifications with optional parameters
|
||||
|
||||
Returns:
|
||||
The initial message with plugin parameters incorporated, or the
|
||||
original message if no plugin parameters are specified
|
||||
"""
|
||||
from openhands.agent_server.models import TextContent
|
||||
|
||||
if not plugins:
|
||||
return initial_message
|
||||
|
||||
# Collect formatted parameters from plugins that have them
|
||||
plugins_with_params = [p for p in plugins if p.parameters]
|
||||
if not plugins_with_params:
|
||||
return initial_message
|
||||
|
||||
# Format parameters, grouped by plugin if multiple
|
||||
if len(plugins_with_params) == 1:
|
||||
params_text = plugins_with_params[0].format_params_as_text()
|
||||
plugin_params_message = (
|
||||
f'\n\nPlugin Configuration Parameters:\n{params_text}'
|
||||
)
|
||||
else:
|
||||
# Group by plugin name for clarity
|
||||
formatted_plugins = []
|
||||
for plugin in plugins_with_params:
|
||||
params_text = plugin.format_params_as_text(indent=' ')
|
||||
if params_text:
|
||||
formatted_plugins.append(f'{plugin.display_name}:\n{params_text}')
|
||||
|
||||
plugin_params_message = (
|
||||
'\n\nPlugin Configuration Parameters:\n' + '\n'.join(formatted_plugins)
|
||||
)
|
||||
|
||||
if initial_message is None:
|
||||
# Create a new message with just the plugin parameters
|
||||
return SendMessageRequest(
|
||||
content=[TextContent(text=plugin_params_message.strip())],
|
||||
run=True,
|
||||
)
|
||||
|
||||
# Append plugin parameters to existing message content
|
||||
new_content = list(initial_message.content)
|
||||
if new_content and isinstance(new_content[-1], TextContent):
|
||||
# Append to the last text content
|
||||
last_content = new_content[-1]
|
||||
new_content[-1] = TextContent(
|
||||
text=last_content.text + plugin_params_message,
|
||||
cache_prompt=last_content.cache_prompt,
|
||||
enable_truncation=last_content.enable_truncation,
|
||||
)
|
||||
else:
|
||||
# Add as new text content
|
||||
new_content.append(TextContent(text=plugin_params_message.strip()))
|
||||
|
||||
return SendMessageRequest(
|
||||
role=initial_message.role,
|
||||
content=new_content,
|
||||
run=initial_message.run,
|
||||
)
|
||||
|
||||
async def _finalize_conversation_request(
|
||||
self,
|
||||
agent: Agent,
|
||||
@@ -966,6 +1042,7 @@ class LiveStatusAppConversationService(AppConversationServiceBase):
|
||||
remote_workspace: AsyncRemoteWorkspace | None,
|
||||
selected_repository: str | None,
|
||||
working_dir: str,
|
||||
plugins: list[PluginSpec] | None = None,
|
||||
) -> StartConversationRequest:
|
||||
"""Finalize the conversation request with experiment variants and skills.
|
||||
|
||||
@@ -980,6 +1057,7 @@ class LiveStatusAppConversationService(AppConversationServiceBase):
|
||||
remote_workspace: Optional remote workspace for skills loading
|
||||
selected_repository: Optional repository name
|
||||
working_dir: Working directory path
|
||||
plugins: Optional list of plugin specifications to load
|
||||
|
||||
Returns:
|
||||
Complete StartConversationRequest ready for use
|
||||
@@ -1006,6 +1084,23 @@ class LiveStatusAppConversationService(AppConversationServiceBase):
|
||||
_logger.warning(f'Failed to load skills: {e}', exc_info=True)
|
||||
# Continue without skills - don't fail conversation startup
|
||||
|
||||
# Incorporate plugin parameters into initial message if specified
|
||||
final_initial_message = self._construct_initial_message_with_plugin_params(
|
||||
initial_message, plugins
|
||||
)
|
||||
|
||||
# Convert PluginSpec list to SDK PluginSource list for agent server
|
||||
sdk_plugins: list[PluginSource] | None = None
|
||||
if plugins:
|
||||
sdk_plugins = [
|
||||
PluginSource(
|
||||
source=p.source,
|
||||
ref=p.ref,
|
||||
repo_path=p.repo_path,
|
||||
)
|
||||
for p in plugins
|
||||
]
|
||||
|
||||
# Create and return the final request
|
||||
return StartConversationRequest(
|
||||
conversation_id=conversation_id,
|
||||
@@ -1014,8 +1109,9 @@ class LiveStatusAppConversationService(AppConversationServiceBase):
|
||||
confirmation_policy=self._select_confirmation_policy(
|
||||
bool(user.confirmation_mode), user.security_analyzer
|
||||
),
|
||||
initial_message=initial_message,
|
||||
initial_message=final_initial_message,
|
||||
secrets=secrets,
|
||||
plugins=sdk_plugins,
|
||||
)
|
||||
|
||||
async def _build_start_conversation_request_for_user(
|
||||
@@ -1030,6 +1126,7 @@ class LiveStatusAppConversationService(AppConversationServiceBase):
|
||||
conversation_id: UUID | None = None,
|
||||
remote_workspace: AsyncRemoteWorkspace | None = None,
|
||||
selected_repository: str | None = None,
|
||||
plugins: list[PluginSpec] | None = None,
|
||||
) -> StartConversationRequest:
|
||||
"""Build a complete conversation request for a user.
|
||||
|
||||
@@ -1038,6 +1135,7 @@ class LiveStatusAppConversationService(AppConversationServiceBase):
|
||||
2. Configuring LLM and MCP settings
|
||||
3. Creating an agent with appropriate context
|
||||
4. Finalizing the request with skills and experiment variants
|
||||
5. Passing plugins to the agent server for remote plugin loading
|
||||
"""
|
||||
user = await self.user_context.get_user_info()
|
||||
workspace = LocalWorkspace(working_dir=working_dir)
|
||||
@@ -1070,6 +1168,7 @@ class LiveStatusAppConversationService(AppConversationServiceBase):
|
||||
remote_workspace,
|
||||
selected_repository,
|
||||
working_dir,
|
||||
plugins=plugins,
|
||||
)
|
||||
|
||||
async def update_agent_server_conversation_title(
|
||||
@@ -1124,7 +1223,8 @@ class LiveStatusAppConversationService(AppConversationServiceBase):
|
||||
self, conversation_id: UUID, request: AppConversationUpdateRequest
|
||||
) -> AppConversation | None:
|
||||
"""Update an app conversation and return it. Return None if the conversation
|
||||
did not exist."""
|
||||
did not exist.
|
||||
"""
|
||||
info = await self.app_conversation_info_service.get_app_conversation_info(
|
||||
conversation_id
|
||||
)
|
||||
|
||||
@@ -541,7 +541,8 @@ class SQLAppConversationInfoService(AppConversationInfoService):
|
||||
|
||||
def _fix_timezone(self, value: datetime) -> datetime:
|
||||
"""Sqlite does not stpre timezones - and since we can't update the existing models
|
||||
we assume UTC if the timezone is missing."""
|
||||
we assume UTC if the timezone is missing.
|
||||
"""
|
||||
if not value.tzinfo:
|
||||
value = value.replace(tzinfo=UTC)
|
||||
return value
|
||||
|
||||
@@ -68,7 +68,8 @@ class StoredAppConversationStartTask(Base): # type: ignore
|
||||
class SQLAppConversationStartTaskService(AppConversationStartTaskService):
|
||||
"""SQL implementation of AppConversationStartTaskService focused on db operations.
|
||||
|
||||
This allows storing and retrieving conversation start tasks from the database."""
|
||||
This allows storing and retrieving conversation start tasks from the database.
|
||||
"""
|
||||
|
||||
session: AsyncSession
|
||||
user_id: str | None = None
|
||||
|
||||
@@ -13,7 +13,7 @@ from openhands.sdk.utils.models import DiscriminatedUnionMixin
|
||||
|
||||
# The version of the agent server to use for deployments.
|
||||
# Typically this will be the same as the values from the pyproject.toml
|
||||
AGENT_SERVER_IMAGE = 'ghcr.io/openhands/agent-server:31536c8-python'
|
||||
AGENT_SERVER_IMAGE = 'ghcr.io/openhands/agent-server:c775ff6-python'
|
||||
|
||||
|
||||
class SandboxSpecService(ABC):
|
||||
|
||||
20
poetry.lock
generated
20
poetry.lock
generated
@@ -7731,14 +7731,14 @@ llama = ["llama-index (>=0.12.29,<0.13.0)", "llama-index-core (>=0.12.29,<0.13.0
|
||||
|
||||
[[package]]
|
||||
name = "openhands-agent-server"
|
||||
version = "1.9.1"
|
||||
version = "1.10.0"
|
||||
description = "OpenHands Agent Server - REST/WebSocket interface for OpenHands AI Agent"
|
||||
optional = false
|
||||
python-versions = ">=3.12"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "openhands_agent_server-1.9.1-py3-none-any.whl", hash = "sha256:ea1457760505b9ebfe6aabea08dedd010ce93aeb93edb450f00e25a0d056a723"},
|
||||
{file = "openhands_agent_server-1.9.1.tar.gz", hash = "sha256:d92a29a9d5aa94207519a5f8daad7c0a3d6641d5cba9f763f25aa4e85713fa0f"},
|
||||
{file = "openhands_agent_server-1.10.0-py3-none-any.whl", hash = "sha256:2e21076fff5e7cf9d03a3b011e2c90a6a3a46d2da3f18db9f7553ac413229c22"},
|
||||
{file = "openhands_agent_server-1.10.0.tar.gz", hash = "sha256:2062da2496a98a6c23201d086f124e02329d6c6d9d1b47be55921c084a29f55a"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -7755,14 +7755,14 @@ wsproto = ">=1.2.0"
|
||||
|
||||
[[package]]
|
||||
name = "openhands-sdk"
|
||||
version = "1.9.1"
|
||||
version = "1.10.0"
|
||||
description = "OpenHands SDK - Core functionality for building AI agents"
|
||||
optional = false
|
||||
python-versions = ">=3.12"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "openhands_sdk-1.9.1-py3-none-any.whl", hash = "sha256:0e732dfe0d91289536ea0410db9554d5a5b0326f60e547ea7a9d8ddab5fe93e4"},
|
||||
{file = "openhands_sdk-1.9.1.tar.gz", hash = "sha256:c6ba33f85efa4c2ec63eb1040cbe82839662bcbcf323654ed071a9ad38ce7994"},
|
||||
{file = "openhands_sdk-1.10.0-py3-none-any.whl", hash = "sha256:5c8875f2a07d7fabe3449914639572bef9003821207cb06aa237a239e964eed5"},
|
||||
{file = "openhands_sdk-1.10.0.tar.gz", hash = "sha256:93371b1af4532266ad2d225b9d7d3d711c745df31888efe643970673f62bdef9"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -7783,14 +7783,14 @@ boto3 = ["boto3 (>=1.35.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "openhands-tools"
|
||||
version = "1.9.1"
|
||||
version = "1.10.0"
|
||||
description = "OpenHands Tools - Runtime tools for AI agents"
|
||||
optional = false
|
||||
python-versions = ">=3.12"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "openhands_tools-1.9.1-py3-none-any.whl", hash = "sha256:411819657e00ffac5d5b1ba9adc6eb65a0a17cbefb5e3e1a34bb132ff61c59f2"},
|
||||
{file = "openhands_tools-1.9.1.tar.gz", hash = "sha256:331608994cce22b662038a2fed0bf7d2c1bb8dc27b1fc0a12a646e9bd76e0843"},
|
||||
{file = "openhands_tools-1.10.0-py3-none-any.whl", hash = "sha256:1d5d2d1e34cc4ceb02c0ff1f008b06883ad48a8e7236ab8dd61ece64fbf8e2ed"},
|
||||
{file = "openhands_tools-1.10.0.tar.gz", hash = "sha256:7ed38cb13545ec2c4a35c26ece725d5b35788d30597db8b1904619c043ec1194"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -17367,4 +17367,4 @@ third-party-runtimes = ["daytona", "e2b-code-interpreter", "modal", "runloop-api
|
||||
[metadata]
|
||||
lock-version = "2.1"
|
||||
python-versions = "^3.12,<3.14"
|
||||
content-hash = "fecab94e6c18e6da0c67c3a249f20cd938b47a2faff492994311c36ac4e0019a"
|
||||
content-hash = "f67478db2385eb258369313ac831b26582d744294c0996a35e786c3d7ced5db1"
|
||||
|
||||
@@ -54,9 +54,9 @@ dependencies = [
|
||||
"numpy",
|
||||
"openai==2.8",
|
||||
"openhands-aci==0.3.2",
|
||||
"openhands-agent-server==1.9.1",
|
||||
"openhands-sdk==1.9.1",
|
||||
"openhands-tools==1.9.1",
|
||||
"openhands-agent-server==1.10",
|
||||
"openhands-sdk==1.10",
|
||||
"openhands-tools==1.10",
|
||||
"opentelemetry-api>=1.33.1",
|
||||
"opentelemetry-exporter-otlp-proto-grpc>=1.33.1",
|
||||
"pathspec>=0.12.1",
|
||||
@@ -280,12 +280,9 @@ e2b-code-interpreter = { version = "^2.0.0", optional = true }
|
||||
pybase62 = "^1.0.0"
|
||||
|
||||
# V1 dependencies
|
||||
#openhands-agent-server = { git = "https://github.com/OpenHands/agent-sdk.git", subdirectory = "openhands-agent-server", rev = "15f565b8ac38876e40dc05c08e2b04ccaae4a66d" }
|
||||
#openhands-sdk = { git = "https://github.com/OpenHands/agent-sdk.git", subdirectory = "openhands-sdk", rev = "15f565b8ac38876e40dc05c08e2b04ccaae4a66d" }
|
||||
#openhands-tools = { git = "https://github.com/OpenHands/agent-sdk.git", subdirectory = "openhands-tools", rev = "15f565b8ac38876e40dc05c08e2b04ccaae4a66d" }
|
||||
openhands-sdk = "1.9.1"
|
||||
openhands-agent-server = "1.9.1"
|
||||
openhands-tools = "1.9.1"
|
||||
openhands-sdk = "1.10"
|
||||
openhands-agent-server = "1.10"
|
||||
openhands-tools = "1.10"
|
||||
python-jose = { version = ">=3.3", extras = [ "cryptography" ] }
|
||||
sqlalchemy = { extras = [ "asyncio" ], version = "^2.0.40" }
|
||||
pg8000 = "^1.31.5"
|
||||
|
||||
@@ -1786,3 +1786,855 @@ class TestLiveStatusAppConversationService:
|
||||
stdio_server = mcp_servers['stdio-server']
|
||||
assert stdio_server['command'] == 'npx'
|
||||
assert stdio_server['env'] == {'TOKEN': 'value'}
|
||||
|
||||
|
||||
class TestPluginHandling:
|
||||
"""Test cases for plugin-related functionality in LiveStatusAppConversationService."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Set up test fixtures."""
|
||||
# Create mock dependencies
|
||||
self.mock_user_context = Mock(spec=UserContext)
|
||||
self.mock_user_auth = Mock()
|
||||
self.mock_user_context.user_auth = self.mock_user_auth
|
||||
self.mock_jwt_service = Mock()
|
||||
self.mock_sandbox_service = Mock()
|
||||
self.mock_sandbox_spec_service = Mock()
|
||||
self.mock_app_conversation_info_service = Mock()
|
||||
self.mock_app_conversation_start_task_service = Mock()
|
||||
self.mock_event_callback_service = Mock()
|
||||
self.mock_event_service = Mock()
|
||||
self.mock_httpx_client = Mock()
|
||||
|
||||
# Create service instance
|
||||
self.service = LiveStatusAppConversationService(
|
||||
init_git_in_empty_workspace=True,
|
||||
user_context=self.mock_user_context,
|
||||
app_conversation_info_service=self.mock_app_conversation_info_service,
|
||||
app_conversation_start_task_service=self.mock_app_conversation_start_task_service,
|
||||
event_callback_service=self.mock_event_callback_service,
|
||||
event_service=self.mock_event_service,
|
||||
sandbox_service=self.mock_sandbox_service,
|
||||
sandbox_spec_service=self.mock_sandbox_spec_service,
|
||||
jwt_service=self.mock_jwt_service,
|
||||
sandbox_startup_timeout=30,
|
||||
sandbox_startup_poll_frequency=1,
|
||||
httpx_client=self.mock_httpx_client,
|
||||
web_url='https://test.example.com',
|
||||
openhands_provider_base_url='https://provider.example.com',
|
||||
access_token_hard_timeout=None,
|
||||
app_mode='test',
|
||||
)
|
||||
|
||||
# Mock user info
|
||||
self.mock_user = Mock()
|
||||
self.mock_user.id = 'test_user_123'
|
||||
self.mock_user.llm_model = 'gpt-4'
|
||||
self.mock_user.llm_base_url = 'https://api.openai.com/v1'
|
||||
self.mock_user.llm_api_key = 'test_api_key'
|
||||
self.mock_user.confirmation_mode = False
|
||||
self.mock_user.search_api_key = None
|
||||
self.mock_user.condenser_max_size = None
|
||||
self.mock_user.mcp_config = None
|
||||
self.mock_user.security_analyzer = None
|
||||
|
||||
# Mock sandbox
|
||||
self.mock_sandbox = Mock(spec=SandboxInfo)
|
||||
self.mock_sandbox.id = uuid4()
|
||||
self.mock_sandbox.status = SandboxStatus.RUNNING
|
||||
|
||||
def test_construct_initial_message_with_plugin_params_no_plugins(self):
|
||||
"""Test _construct_initial_message_with_plugin_params with no plugins returns original message."""
|
||||
from openhands.agent_server.models import SendMessageRequest, TextContent
|
||||
|
||||
# Test with None initial message and None plugins
|
||||
result = self.service._construct_initial_message_with_plugin_params(None, None)
|
||||
assert result is None
|
||||
|
||||
# Test with None initial message and empty plugins list
|
||||
result = self.service._construct_initial_message_with_plugin_params(None, [])
|
||||
assert result is None
|
||||
|
||||
# Test with initial message but None plugins
|
||||
initial_msg = SendMessageRequest(content=[TextContent(text='Hello world')])
|
||||
result = self.service._construct_initial_message_with_plugin_params(
|
||||
initial_msg, None
|
||||
)
|
||||
assert result is initial_msg
|
||||
|
||||
def test_construct_initial_message_with_plugin_params_no_params(self):
|
||||
"""Test _construct_initial_message_with_plugin_params with plugins but no parameters."""
|
||||
from openhands.agent_server.models import SendMessageRequest, TextContent
|
||||
from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
PluginSpec,
|
||||
)
|
||||
|
||||
# Plugin with no parameters
|
||||
plugins = [PluginSpec(source='github:owner/repo')]
|
||||
|
||||
# Test with None initial message
|
||||
result = self.service._construct_initial_message_with_plugin_params(
|
||||
None, plugins
|
||||
)
|
||||
assert result is None
|
||||
|
||||
# Test with initial message
|
||||
initial_msg = SendMessageRequest(content=[TextContent(text='Hello world')])
|
||||
result = self.service._construct_initial_message_with_plugin_params(
|
||||
initial_msg, plugins
|
||||
)
|
||||
assert result is initial_msg
|
||||
|
||||
def test_construct_initial_message_with_plugin_params_creates_new_message(self):
|
||||
"""Test _construct_initial_message_with_plugin_params creates message when no initial message."""
|
||||
from openhands.agent_server.models import TextContent
|
||||
from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
PluginSpec,
|
||||
)
|
||||
|
||||
plugins = [
|
||||
PluginSpec(
|
||||
source='github:owner/repo',
|
||||
parameters={'api_key': 'test123', 'debug': True},
|
||||
)
|
||||
]
|
||||
|
||||
result = self.service._construct_initial_message_with_plugin_params(
|
||||
None, plugins
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert len(result.content) == 1
|
||||
assert isinstance(result.content[0], TextContent)
|
||||
assert 'Plugin Configuration Parameters:' in result.content[0].text
|
||||
assert '- api_key: test123' in result.content[0].text
|
||||
assert '- debug: True' in result.content[0].text
|
||||
assert result.run is True
|
||||
|
||||
def test_construct_initial_message_with_plugin_params_appends_to_message(self):
|
||||
"""Test _construct_initial_message_with_plugin_params appends to existing message."""
|
||||
from openhands.agent_server.models import SendMessageRequest, TextContent
|
||||
from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
PluginSpec,
|
||||
)
|
||||
|
||||
initial_msg = SendMessageRequest(
|
||||
content=[TextContent(text='Please analyze this codebase')],
|
||||
run=False,
|
||||
)
|
||||
plugins = [
|
||||
PluginSpec(
|
||||
source='github:owner/repo',
|
||||
ref='v1.0.0',
|
||||
parameters={'target_dir': '/src', 'verbose': True},
|
||||
)
|
||||
]
|
||||
|
||||
result = self.service._construct_initial_message_with_plugin_params(
|
||||
initial_msg, plugins
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert len(result.content) == 1
|
||||
text = result.content[0].text
|
||||
assert text.startswith('Please analyze this codebase')
|
||||
assert 'Plugin Configuration Parameters:' in text
|
||||
assert '- target_dir: /src' in text
|
||||
assert '- verbose: True' in text
|
||||
assert result.run is False
|
||||
|
||||
def test_construct_initial_message_with_plugin_params_preserves_role(self):
|
||||
"""Test _construct_initial_message_with_plugin_params preserves message role."""
|
||||
from openhands.agent_server.models import SendMessageRequest, TextContent
|
||||
from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
PluginSpec,
|
||||
)
|
||||
|
||||
initial_msg = SendMessageRequest(
|
||||
role='system',
|
||||
content=[TextContent(text='System message')],
|
||||
)
|
||||
plugins = [PluginSpec(source='github:owner/repo', parameters={'key': 'value'})]
|
||||
|
||||
result = self.service._construct_initial_message_with_plugin_params(
|
||||
initial_msg, plugins
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result.role == 'system'
|
||||
|
||||
def test_construct_initial_message_with_plugin_params_empty_content(self):
|
||||
"""Test _construct_initial_message_with_plugin_params handles empty content list."""
|
||||
from openhands.agent_server.models import SendMessageRequest, TextContent
|
||||
from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
PluginSpec,
|
||||
)
|
||||
|
||||
initial_msg = SendMessageRequest(content=[])
|
||||
plugins = [PluginSpec(source='github:owner/repo', parameters={'key': 'value'})]
|
||||
|
||||
result = self.service._construct_initial_message_with_plugin_params(
|
||||
initial_msg, plugins
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert len(result.content) == 1
|
||||
assert isinstance(result.content[0], TextContent)
|
||||
assert 'Plugin Configuration Parameters:' in result.content[0].text
|
||||
|
||||
def test_construct_initial_message_with_multiple_plugins(self):
|
||||
"""Test _construct_initial_message_with_plugin_params handles multiple plugins."""
|
||||
from openhands.agent_server.models import TextContent
|
||||
from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
PluginSpec,
|
||||
)
|
||||
|
||||
plugins = [
|
||||
PluginSpec(
|
||||
source='github:owner/plugin1',
|
||||
parameters={'key1': 'value1'},
|
||||
),
|
||||
PluginSpec(
|
||||
source='github:owner/plugin2',
|
||||
parameters={'key2': 'value2'},
|
||||
),
|
||||
]
|
||||
|
||||
result = self.service._construct_initial_message_with_plugin_params(
|
||||
None, plugins
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert len(result.content) == 1
|
||||
assert isinstance(result.content[0], TextContent)
|
||||
text = result.content[0].text
|
||||
assert 'Plugin Configuration Parameters:' in text
|
||||
# Multiple plugins should show grouped by plugin name
|
||||
assert 'plugin1' in text
|
||||
assert 'plugin2' in text
|
||||
assert 'key1: value1' in text
|
||||
assert 'key2: value2' in text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch(
|
||||
'openhands.app_server.app_conversation.live_status_app_conversation_service.ExperimentManagerImpl'
|
||||
)
|
||||
async def test_finalize_conversation_request_with_plugins(
|
||||
self, mock_experiment_manager
|
||||
):
|
||||
"""Test _finalize_conversation_request passes plugins list to StartConversationRequest."""
|
||||
from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
PluginSpec,
|
||||
)
|
||||
|
||||
# Arrange
|
||||
mock_agent = Mock(spec=Agent)
|
||||
mock_llm = Mock(spec=LLM)
|
||||
mock_llm.model = 'gpt-4'
|
||||
mock_llm.usage_id = 'agent'
|
||||
|
||||
mock_updated_agent = Mock(spec=Agent)
|
||||
mock_updated_agent.llm = mock_llm
|
||||
mock_updated_agent.condenser = None
|
||||
mock_experiment_manager.run_agent_variant_tests__v1.return_value = (
|
||||
mock_updated_agent
|
||||
)
|
||||
|
||||
workspace = LocalWorkspace(working_dir='/test')
|
||||
secrets = {'test': StaticSecret(value='secret')}
|
||||
|
||||
plugins = [
|
||||
PluginSpec(
|
||||
source='github:owner/my-plugin',
|
||||
ref='v1.0.0',
|
||||
parameters={'api_key': 'test123'},
|
||||
)
|
||||
]
|
||||
|
||||
# Act
|
||||
result = await self.service._finalize_conversation_request(
|
||||
mock_agent,
|
||||
None,
|
||||
self.mock_user,
|
||||
workspace,
|
||||
None,
|
||||
secrets,
|
||||
self.mock_sandbox,
|
||||
None,
|
||||
None,
|
||||
'/test/dir',
|
||||
plugins=plugins,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, StartConversationRequest)
|
||||
assert result.plugins is not None
|
||||
assert len(result.plugins) == 1
|
||||
assert result.plugins[0].source == 'github:owner/my-plugin'
|
||||
assert result.plugins[0].ref == 'v1.0.0'
|
||||
# Also verify initial message contains plugin params
|
||||
assert result.initial_message is not None
|
||||
assert (
|
||||
'Plugin Configuration Parameters:' in result.initial_message.content[0].text
|
||||
)
|
||||
assert '- api_key: test123' in result.initial_message.content[0].text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch(
|
||||
'openhands.app_server.app_conversation.live_status_app_conversation_service.ExperimentManagerImpl'
|
||||
)
|
||||
async def test_finalize_conversation_request_without_plugins(
|
||||
self, mock_experiment_manager
|
||||
):
|
||||
"""Test _finalize_conversation_request without plugins sets plugins to None."""
|
||||
# Arrange
|
||||
mock_agent = Mock(spec=Agent)
|
||||
mock_llm = Mock(spec=LLM)
|
||||
mock_llm.model = 'gpt-4'
|
||||
mock_llm.usage_id = 'agent'
|
||||
|
||||
mock_updated_agent = Mock(spec=Agent)
|
||||
mock_updated_agent.llm = mock_llm
|
||||
mock_updated_agent.condenser = None
|
||||
mock_experiment_manager.run_agent_variant_tests__v1.return_value = (
|
||||
mock_updated_agent
|
||||
)
|
||||
|
||||
workspace = LocalWorkspace(working_dir='/test')
|
||||
secrets = {}
|
||||
|
||||
# Act
|
||||
result = await self.service._finalize_conversation_request(
|
||||
mock_agent,
|
||||
None,
|
||||
self.mock_user,
|
||||
workspace,
|
||||
None,
|
||||
secrets,
|
||||
self.mock_sandbox,
|
||||
None,
|
||||
None,
|
||||
'/test/dir',
|
||||
plugins=None,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, StartConversationRequest)
|
||||
assert result.plugins is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch(
|
||||
'openhands.app_server.app_conversation.live_status_app_conversation_service.ExperimentManagerImpl'
|
||||
)
|
||||
async def test_finalize_conversation_request_plugin_without_ref(
|
||||
self, mock_experiment_manager
|
||||
):
|
||||
"""Test _finalize_conversation_request with plugin that has no ref."""
|
||||
from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
PluginSpec,
|
||||
)
|
||||
|
||||
# Arrange
|
||||
mock_agent = Mock(spec=Agent)
|
||||
mock_llm = Mock(spec=LLM)
|
||||
mock_llm.model = 'gpt-4'
|
||||
mock_llm.usage_id = 'agent'
|
||||
|
||||
mock_updated_agent = Mock(spec=Agent)
|
||||
mock_updated_agent.llm = mock_llm
|
||||
mock_updated_agent.condenser = None
|
||||
mock_experiment_manager.run_agent_variant_tests__v1.return_value = (
|
||||
mock_updated_agent
|
||||
)
|
||||
|
||||
workspace = LocalWorkspace(working_dir='/test')
|
||||
secrets = {}
|
||||
|
||||
# Plugin without ref or parameters
|
||||
plugins = [PluginSpec(source='github:owner/my-plugin')]
|
||||
|
||||
# Act
|
||||
result = await self.service._finalize_conversation_request(
|
||||
mock_agent,
|
||||
None,
|
||||
self.mock_user,
|
||||
workspace,
|
||||
None,
|
||||
secrets,
|
||||
self.mock_sandbox,
|
||||
None,
|
||||
None,
|
||||
'/test/dir',
|
||||
plugins=plugins,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, StartConversationRequest)
|
||||
assert result.plugins is not None
|
||||
assert len(result.plugins) == 1
|
||||
assert result.plugins[0].source == 'github:owner/my-plugin'
|
||||
assert result.plugins[0].ref is None
|
||||
# No parameters, so initial message should be None
|
||||
assert result.initial_message is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch(
|
||||
'openhands.app_server.app_conversation.live_status_app_conversation_service.ExperimentManagerImpl'
|
||||
)
|
||||
async def test_finalize_conversation_request_plugin_with_repo_path(
|
||||
self, mock_experiment_manager
|
||||
):
|
||||
"""Test _finalize_conversation_request passes repo_path to PluginSource."""
|
||||
from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
PluginSpec,
|
||||
)
|
||||
|
||||
# Arrange
|
||||
mock_agent = Mock(spec=Agent)
|
||||
mock_llm = Mock(spec=LLM)
|
||||
mock_llm.model = 'gpt-4'
|
||||
mock_llm.usage_id = 'agent'
|
||||
|
||||
mock_updated_agent = Mock(spec=Agent)
|
||||
mock_updated_agent.llm = mock_llm
|
||||
mock_updated_agent.condenser = None
|
||||
mock_experiment_manager.run_agent_variant_tests__v1.return_value = (
|
||||
mock_updated_agent
|
||||
)
|
||||
|
||||
workspace = LocalWorkspace(working_dir='/test')
|
||||
secrets = {}
|
||||
|
||||
# Plugin with repo_path (for marketplace repos containing multiple plugins)
|
||||
plugins = [
|
||||
PluginSpec(
|
||||
source='github:owner/marketplace-repo',
|
||||
ref='main',
|
||||
repo_path='plugins/city-weather',
|
||||
)
|
||||
]
|
||||
|
||||
# Act
|
||||
result = await self.service._finalize_conversation_request(
|
||||
mock_agent,
|
||||
None,
|
||||
self.mock_user,
|
||||
workspace,
|
||||
None,
|
||||
secrets,
|
||||
self.mock_sandbox,
|
||||
None,
|
||||
None,
|
||||
'/test/dir',
|
||||
plugins=plugins,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, StartConversationRequest)
|
||||
assert result.plugins is not None
|
||||
assert len(result.plugins) == 1
|
||||
assert result.plugins[0].source == 'github:owner/marketplace-repo'
|
||||
assert result.plugins[0].ref == 'main'
|
||||
assert result.plugins[0].repo_path == 'plugins/city-weather'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch(
|
||||
'openhands.app_server.app_conversation.live_status_app_conversation_service.ExperimentManagerImpl'
|
||||
)
|
||||
async def test_finalize_conversation_request_multiple_plugins(
|
||||
self, mock_experiment_manager
|
||||
):
|
||||
"""Test _finalize_conversation_request with multiple plugins."""
|
||||
from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
PluginSpec,
|
||||
)
|
||||
|
||||
# Arrange
|
||||
mock_agent = Mock(spec=Agent)
|
||||
mock_llm = Mock(spec=LLM)
|
||||
mock_llm.model = 'gpt-4'
|
||||
mock_llm.usage_id = 'agent'
|
||||
|
||||
mock_updated_agent = Mock(spec=Agent)
|
||||
mock_updated_agent.llm = mock_llm
|
||||
mock_updated_agent.condenser = None
|
||||
mock_experiment_manager.run_agent_variant_tests__v1.return_value = (
|
||||
mock_updated_agent
|
||||
)
|
||||
|
||||
workspace = LocalWorkspace(working_dir='/test')
|
||||
secrets = {}
|
||||
|
||||
# Multiple plugins
|
||||
plugins = [
|
||||
PluginSpec(source='github:owner/security-plugin', ref='v2.0.0'),
|
||||
PluginSpec(
|
||||
source='github:owner/monorepo',
|
||||
repo_path='plugins/logging',
|
||||
),
|
||||
PluginSpec(source='/local/path/to/plugin'),
|
||||
]
|
||||
|
||||
# Act
|
||||
result = await self.service._finalize_conversation_request(
|
||||
mock_agent,
|
||||
None,
|
||||
self.mock_user,
|
||||
workspace,
|
||||
None,
|
||||
secrets,
|
||||
self.mock_sandbox,
|
||||
None,
|
||||
None,
|
||||
'/test/dir',
|
||||
plugins=plugins,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, StartConversationRequest)
|
||||
assert result.plugins is not None
|
||||
assert len(result.plugins) == 3
|
||||
assert result.plugins[0].source == 'github:owner/security-plugin'
|
||||
assert result.plugins[0].ref == 'v2.0.0'
|
||||
assert result.plugins[1].source == 'github:owner/monorepo'
|
||||
assert result.plugins[1].repo_path == 'plugins/logging'
|
||||
assert result.plugins[2].source == '/local/path/to/plugin'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_start_conversation_request_for_user_with_plugins(self):
|
||||
"""Test _build_start_conversation_request_for_user passes plugins to finalize method."""
|
||||
from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
PluginSpec,
|
||||
)
|
||||
|
||||
# Arrange
|
||||
self.mock_user_context.get_user_info.return_value = self.mock_user
|
||||
self.mock_user_context.get_secrets.return_value = {}
|
||||
self.mock_user_context.get_provider_tokens = AsyncMock(return_value=None)
|
||||
self.mock_user_context.get_mcp_api_key.return_value = None
|
||||
|
||||
plugins = [
|
||||
PluginSpec(
|
||||
source='https://github.com/org/plugin.git',
|
||||
ref='main',
|
||||
parameters={'config_file': 'custom.yaml'},
|
||||
)
|
||||
]
|
||||
|
||||
# Mock _finalize_conversation_request to capture the call
|
||||
mock_finalize = AsyncMock(return_value=Mock(spec=StartConversationRequest))
|
||||
self.service._finalize_conversation_request = mock_finalize
|
||||
|
||||
# Act
|
||||
await self.service._build_start_conversation_request_for_user(
|
||||
self.mock_sandbox,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
'/workspace',
|
||||
plugins=plugins,
|
||||
)
|
||||
|
||||
# Assert
|
||||
mock_finalize.assert_called_once()
|
||||
call_kwargs = mock_finalize.call_args.kwargs
|
||||
assert call_kwargs['plugins'] == plugins
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_start_conversation_request_for_user_without_plugins(self):
|
||||
"""Test _build_start_conversation_request_for_user works without plugins."""
|
||||
# Arrange
|
||||
self.mock_user_context.get_user_info.return_value = self.mock_user
|
||||
self.mock_user_context.get_secrets.return_value = {}
|
||||
self.mock_user_context.get_provider_tokens = AsyncMock(return_value=None)
|
||||
self.mock_user_context.get_mcp_api_key.return_value = None
|
||||
|
||||
# Mock _finalize_conversation_request
|
||||
mock_finalize = AsyncMock(return_value=Mock(spec=StartConversationRequest))
|
||||
self.service._finalize_conversation_request = mock_finalize
|
||||
|
||||
# Act
|
||||
await self.service._build_start_conversation_request_for_user(
|
||||
self.mock_sandbox,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
'/workspace',
|
||||
)
|
||||
|
||||
# Assert
|
||||
mock_finalize.assert_called_once()
|
||||
call_kwargs = mock_finalize.call_args.kwargs
|
||||
assert call_kwargs.get('plugins') is None
|
||||
|
||||
|
||||
class TestPluginSpecModel:
|
||||
"""Test cases for the PluginSpec model."""
|
||||
|
||||
def test_plugin_spec_with_all_fields(self):
|
||||
"""Test PluginSpec with all fields provided."""
|
||||
from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
PluginSpec,
|
||||
)
|
||||
|
||||
plugin = PluginSpec(
|
||||
source='github:owner/repo',
|
||||
ref='v1.0.0',
|
||||
repo_path='plugins/my-plugin',
|
||||
parameters={'key1': 'value1', 'key2': 123, 'key3': True},
|
||||
)
|
||||
|
||||
assert plugin.source == 'github:owner/repo'
|
||||
assert plugin.ref == 'v1.0.0'
|
||||
assert plugin.repo_path == 'plugins/my-plugin'
|
||||
assert plugin.parameters == {'key1': 'value1', 'key2': 123, 'key3': True}
|
||||
|
||||
def test_plugin_spec_with_only_source(self):
|
||||
"""Test PluginSpec with only source provided."""
|
||||
from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
PluginSpec,
|
||||
)
|
||||
|
||||
plugin = PluginSpec(source='https://github.com/owner/repo.git')
|
||||
|
||||
assert plugin.source == 'https://github.com/owner/repo.git'
|
||||
assert plugin.ref is None
|
||||
assert plugin.repo_path is None
|
||||
assert plugin.parameters is None
|
||||
|
||||
def test_plugin_spec_serialization(self):
|
||||
"""Test PluginSpec serialization to JSON."""
|
||||
from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
PluginSpec,
|
||||
)
|
||||
|
||||
plugin = PluginSpec(
|
||||
source='github:owner/repo',
|
||||
ref='main',
|
||||
repo_path='plugins/my-plugin',
|
||||
parameters={'debug': True},
|
||||
)
|
||||
|
||||
json_data = plugin.model_dump()
|
||||
assert json_data == {
|
||||
'source': 'github:owner/repo',
|
||||
'ref': 'main',
|
||||
'repo_path': 'plugins/my-plugin',
|
||||
'parameters': {'debug': True},
|
||||
}
|
||||
|
||||
def test_plugin_spec_deserialization(self):
|
||||
"""Test PluginSpec deserialization from dict."""
|
||||
from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
PluginSpec,
|
||||
)
|
||||
|
||||
data = {
|
||||
'source': 'github:owner/repo',
|
||||
'ref': 'v2.0.0',
|
||||
'repo_path': 'plugins/weather',
|
||||
'parameters': {'timeout': 30},
|
||||
}
|
||||
|
||||
plugin = PluginSpec.model_validate(data)
|
||||
|
||||
assert plugin.source == 'github:owner/repo'
|
||||
assert plugin.ref == 'v2.0.0'
|
||||
assert plugin.repo_path == 'plugins/weather'
|
||||
assert plugin.parameters == {'timeout': 30}
|
||||
|
||||
def test_plugin_spec_display_name_github_format(self):
|
||||
"""Test display_name extracts repo name from github:owner/repo format."""
|
||||
from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
PluginSpec,
|
||||
)
|
||||
|
||||
plugin = PluginSpec(source='github:owner/my-plugin')
|
||||
assert plugin.display_name == 'my-plugin'
|
||||
|
||||
def test_plugin_spec_display_name_git_url(self):
|
||||
"""Test display_name extracts repo name from git URL."""
|
||||
from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
PluginSpec,
|
||||
)
|
||||
|
||||
plugin = PluginSpec(source='https://github.com/owner/repo.git')
|
||||
assert plugin.display_name == 'repo.git'
|
||||
|
||||
def test_plugin_spec_display_name_local_path(self):
|
||||
"""Test display_name extracts directory name from local path."""
|
||||
from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
PluginSpec,
|
||||
)
|
||||
|
||||
plugin = PluginSpec(source='/local/path/to/plugin')
|
||||
assert plugin.display_name == 'plugin'
|
||||
|
||||
def test_plugin_spec_display_name_no_slash(self):
|
||||
"""Test display_name returns source as-is when no slash present."""
|
||||
from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
PluginSpec,
|
||||
)
|
||||
|
||||
plugin = PluginSpec(source='local-plugin')
|
||||
assert plugin.display_name == 'local-plugin'
|
||||
|
||||
def test_plugin_spec_format_params_as_text(self):
|
||||
"""Test format_params_as_text formats parameters as text."""
|
||||
from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
PluginSpec,
|
||||
)
|
||||
|
||||
plugin = PluginSpec(
|
||||
source='github:owner/repo',
|
||||
parameters={'key1': 'value1', 'key2': 123},
|
||||
)
|
||||
|
||||
result = plugin.format_params_as_text()
|
||||
assert result == '- key1: value1\n- key2: 123'
|
||||
|
||||
def test_plugin_spec_format_params_as_text_with_indent(self):
|
||||
"""Test format_params_as_text with custom indent."""
|
||||
from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
PluginSpec,
|
||||
)
|
||||
|
||||
plugin = PluginSpec(
|
||||
source='github:owner/repo',
|
||||
parameters={'debug': True},
|
||||
)
|
||||
|
||||
result = plugin.format_params_as_text(indent=' ')
|
||||
assert result == ' - debug: True'
|
||||
|
||||
def test_plugin_spec_format_params_as_text_no_params(self):
|
||||
"""Test format_params_as_text returns None when no parameters."""
|
||||
from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
PluginSpec,
|
||||
)
|
||||
|
||||
plugin = PluginSpec(source='github:owner/repo')
|
||||
assert plugin.format_params_as_text() is None
|
||||
|
||||
def test_plugin_spec_inherits_repo_path_validation(self):
|
||||
"""Test PluginSpec inherits validation from SDK's PluginSource."""
|
||||
import pytest
|
||||
|
||||
from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
PluginSpec,
|
||||
)
|
||||
|
||||
# Should reject absolute paths
|
||||
with pytest.raises(ValueError, match='must be relative'):
|
||||
PluginSpec(source='github:owner/repo', repo_path='/absolute/path')
|
||||
|
||||
# Should reject parent traversal
|
||||
with pytest.raises(ValueError, match="cannot contain '..'"):
|
||||
PluginSpec(source='github:owner/repo', repo_path='../parent/path')
|
||||
|
||||
|
||||
class TestAppConversationStartRequestWithPlugins:
|
||||
"""Test cases for AppConversationStartRequest with plugins field."""
|
||||
|
||||
def test_start_request_with_plugins(self):
|
||||
"""Test AppConversationStartRequest with plugins field."""
|
||||
from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
AppConversationStartRequest,
|
||||
PluginSpec,
|
||||
)
|
||||
|
||||
plugins = [
|
||||
PluginSpec(
|
||||
source='github:owner/my-plugin',
|
||||
ref='v1.0.0',
|
||||
parameters={'api_key': 'test'},
|
||||
)
|
||||
]
|
||||
|
||||
request = AppConversationStartRequest(
|
||||
title='Test conversation',
|
||||
plugins=plugins,
|
||||
)
|
||||
|
||||
assert request.plugins is not None
|
||||
assert len(request.plugins) == 1
|
||||
assert request.plugins[0].source == 'github:owner/my-plugin'
|
||||
assert request.plugins[0].ref == 'v1.0.0'
|
||||
assert request.plugins[0].parameters == {'api_key': 'test'}
|
||||
|
||||
def test_start_request_without_plugins(self):
|
||||
"""Test AppConversationStartRequest without plugins field (backwards compatible)."""
|
||||
from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
AppConversationStartRequest,
|
||||
)
|
||||
|
||||
request = AppConversationStartRequest(
|
||||
title='Test conversation',
|
||||
)
|
||||
|
||||
assert request.plugins is None
|
||||
|
||||
def test_start_request_serialization_with_plugins(self):
|
||||
"""Test AppConversationStartRequest serialization includes plugins."""
|
||||
from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
AppConversationStartRequest,
|
||||
PluginSpec,
|
||||
)
|
||||
|
||||
plugins = [PluginSpec(source='github:owner/repo')]
|
||||
request = AppConversationStartRequest(plugins=plugins)
|
||||
|
||||
json_data = request.model_dump()
|
||||
|
||||
assert 'plugins' in json_data
|
||||
assert len(json_data['plugins']) == 1
|
||||
assert json_data['plugins'][0]['source'] == 'github:owner/repo'
|
||||
|
||||
def test_start_request_deserialization_with_plugins(self):
|
||||
"""Test AppConversationStartRequest deserialization from JSON with plugins."""
|
||||
from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
AppConversationStartRequest,
|
||||
)
|
||||
|
||||
data = {
|
||||
'title': 'Test',
|
||||
'plugins': [
|
||||
{
|
||||
'source': 'github:owner/plugin',
|
||||
'ref': 'main',
|
||||
'parameters': {'key': 'value'},
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
request = AppConversationStartRequest.model_validate(data)
|
||||
|
||||
assert request.plugins is not None
|
||||
assert len(request.plugins) == 1
|
||||
assert request.plugins[0].source == 'github:owner/plugin'
|
||||
assert request.plugins[0].ref == 'main'
|
||||
assert request.plugins[0].parameters == {'key': 'value'}
|
||||
|
||||
def test_start_request_with_multiple_plugins(self):
|
||||
"""Test AppConversationStartRequest with multiple plugins."""
|
||||
from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
AppConversationStartRequest,
|
||||
PluginSpec,
|
||||
)
|
||||
|
||||
plugins = [
|
||||
PluginSpec(source='github:owner/plugin1', ref='v1.0.0'),
|
||||
PluginSpec(source='github:owner/plugin2', repo_path='plugins/sub'),
|
||||
PluginSpec(source='/local/path'),
|
||||
]
|
||||
|
||||
request = AppConversationStartRequest(
|
||||
title='Test conversation',
|
||||
plugins=plugins,
|
||||
)
|
||||
|
||||
assert request.plugins is not None
|
||||
assert len(request.plugins) == 3
|
||||
assert request.plugins[0].source == 'github:owner/plugin1'
|
||||
assert request.plugins[1].repo_path == 'plugins/sub'
|
||||
assert request.plugins[2].source == '/local/path'
|
||||
|
||||
Reference in New Issue
Block a user