mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-04-29 03:00:45 -04:00
Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 1b94276a1e | |||
| bd7b2470bf |
@@ -1,205 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
Downgrade script for migrated users.
|
||||
|
||||
This script identifies users who have been migrated (already_migrated=True)
|
||||
and reverts them back to the pre-migration state.
|
||||
|
||||
Usage:
|
||||
# Dry run - just list the users that would be downgraded
|
||||
python downgrade_migrated_users.py --dry-run
|
||||
|
||||
# Downgrade a specific user by their keycloak_user_id
|
||||
python downgrade_migrated_users.py --user-id <user_id>
|
||||
|
||||
# Downgrade all migrated users (with confirmation)
|
||||
python downgrade_migrated_users.py --all
|
||||
|
||||
# Downgrade all migrated users without confirmation (dangerous!)
|
||||
python downgrade_migrated_users.py --all --no-confirm
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import sys
|
||||
|
||||
# Add the enterprise directory to the path
|
||||
sys.path.insert(0, '/workspace/project/OpenHands/enterprise')
|
||||
|
||||
from server.logger import logger
|
||||
from sqlalchemy import select, text
|
||||
from storage.database import session_maker
|
||||
from storage.user_settings import UserSettings
|
||||
from storage.user_store import UserStore
|
||||
|
||||
|
||||
def get_migrated_users() -> list[str]:
|
||||
"""Get list of keycloak_user_ids for users who have been migrated.
|
||||
|
||||
This includes:
|
||||
1. Users with already_migrated=True in user_settings (migrated users)
|
||||
2. Users in the 'user' table who don't have a user_settings entry (new sign-ups)
|
||||
"""
|
||||
with session_maker() as session:
|
||||
# Get users from user_settings with already_migrated=True
|
||||
migrated_result = session.execute(
|
||||
select(UserSettings.keycloak_user_id).where(
|
||||
UserSettings.already_migrated.is_(True)
|
||||
)
|
||||
)
|
||||
migrated_users = {row[0] for row in migrated_result.fetchall() if row[0]}
|
||||
|
||||
# Get users from the 'user' table (new sign-ups won't have user_settings)
|
||||
# These are users who signed up after the migration was deployed
|
||||
new_signup_result = session.execute(
|
||||
text("""
|
||||
SELECT CAST(u.id AS VARCHAR)
|
||||
FROM "user" u
|
||||
WHERE NOT EXISTS (
|
||||
SELECT 1 FROM user_settings us
|
||||
WHERE us.keycloak_user_id = CAST(u.id AS VARCHAR)
|
||||
)
|
||||
""")
|
||||
)
|
||||
new_signups = {row[0] for row in new_signup_result.fetchall() if row[0]}
|
||||
|
||||
# Combine both sets
|
||||
all_users = migrated_users | new_signups
|
||||
return list(all_users)
|
||||
|
||||
|
||||
async def downgrade_user(user_id: str) -> bool:
|
||||
"""Downgrade a single user.
|
||||
|
||||
Args:
|
||||
user_id: The keycloak_user_id to downgrade
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
result = await UserStore.downgrade_user(user_id)
|
||||
if result:
|
||||
print(f'✓ Successfully downgraded user: {user_id}')
|
||||
return True
|
||||
else:
|
||||
print(f'✗ Failed to downgrade user: {user_id}')
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f'✗ Error downgrading user {user_id}: {e}')
|
||||
logger.exception(
|
||||
'downgrade_script:error',
|
||||
extra={'user_id': user_id, 'error': str(e)},
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
async def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Downgrade migrated users back to pre-migration state'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--dry-run',
|
||||
action='store_true',
|
||||
help='Just list users that would be downgraded, without making changes',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--user-id',
|
||||
type=str,
|
||||
help='Downgrade a specific user by keycloak_user_id',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--all',
|
||||
action='store_true',
|
||||
help='Downgrade all migrated users',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--no-confirm',
|
||||
action='store_true',
|
||||
help='Skip confirmation prompt (use with caution!)',
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Get list of migrated users
|
||||
migrated_users = get_migrated_users()
|
||||
print(f'\nFound {len(migrated_users)} migrated user(s).')
|
||||
|
||||
if args.dry_run:
|
||||
print('\n--- DRY RUN MODE ---')
|
||||
print('The following users would be downgraded:')
|
||||
for user_id in migrated_users:
|
||||
print(f' - {user_id}')
|
||||
print('\nNo changes were made.')
|
||||
return
|
||||
|
||||
if args.user_id:
|
||||
# Downgrade a specific user
|
||||
if args.user_id not in migrated_users:
|
||||
print(f'\nUser {args.user_id} is not in the migrated users list.')
|
||||
print('Either the user was not migrated, or the user_id is incorrect.')
|
||||
return
|
||||
|
||||
print(f'\nDowngrading user: {args.user_id}')
|
||||
if not args.no_confirm:
|
||||
confirm = input('Are you sure? (yes/no): ')
|
||||
if confirm.lower() != 'yes':
|
||||
print('Cancelled.')
|
||||
return
|
||||
|
||||
success = await downgrade_user(args.user_id)
|
||||
if success:
|
||||
print('\nDowngrade completed successfully.')
|
||||
else:
|
||||
print('\nDowngrade failed. Check logs for details.')
|
||||
sys.exit(1)
|
||||
|
||||
elif args.all:
|
||||
# Downgrade all migrated users
|
||||
if not migrated_users:
|
||||
print('\nNo migrated users to downgrade.')
|
||||
return
|
||||
|
||||
print(f'\n⚠️ About to downgrade {len(migrated_users)} user(s).')
|
||||
if not args.no_confirm:
|
||||
print('\nThis will:')
|
||||
print(' - Revert LiteLLM team/user budget settings')
|
||||
print(' - Delete organization entries')
|
||||
print(' - Delete user entries in the new schema')
|
||||
print(' - Reset the already_migrated flag')
|
||||
print('\nUsers to downgrade:')
|
||||
for user_id in migrated_users[:10]: # Show first 10
|
||||
print(f' - {user_id}')
|
||||
if len(migrated_users) > 10:
|
||||
print(f' ... and {len(migrated_users) - 10} more')
|
||||
|
||||
confirm = input('\nType "yes" to proceed: ')
|
||||
if confirm.lower() != 'yes':
|
||||
print('Cancelled.')
|
||||
return
|
||||
|
||||
print('\nStarting downgrade...\n')
|
||||
success_count = 0
|
||||
fail_count = 0
|
||||
|
||||
for user_id in migrated_users:
|
||||
success = await downgrade_user(user_id)
|
||||
if success:
|
||||
success_count += 1
|
||||
else:
|
||||
fail_count += 1
|
||||
|
||||
print('\n--- Summary ---')
|
||||
print(f'Successful: {success_count}')
|
||||
print(f'Failed: {fail_count}')
|
||||
|
||||
if fail_count > 0:
|
||||
sys.exit(1)
|
||||
|
||||
else:
|
||||
parser.print_help()
|
||||
print('\nPlease specify --dry-run, --user-id, or --all')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
asyncio.run(main())
|
||||
@@ -14,6 +14,7 @@ from storage.conversation_callback import (
|
||||
ConversationCallback,
|
||||
ConversationCallbackProcessor,
|
||||
)
|
||||
from storage.database import session_maker
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.core.schema.agent import AgentState
|
||||
@@ -107,10 +108,13 @@ class GithubCallbackProcessor(ConversationCallbackProcessor):
|
||||
f'[GitHub] Sent summary instruction to conversation {conversation_id} {summary_event}'
|
||||
)
|
||||
|
||||
# Update the processor state - the outer session will commit this
|
||||
# Update the processor state
|
||||
self.send_summary_instruction = False
|
||||
callback.set_processor(self)
|
||||
callback.updated_at = datetime.now()
|
||||
with session_maker() as session:
|
||||
session.merge(callback)
|
||||
session.commit()
|
||||
return
|
||||
|
||||
# Extract the summary from the event store
|
||||
@@ -126,15 +130,14 @@ class GithubCallbackProcessor(ConversationCallbackProcessor):
|
||||
|
||||
logger.info(f'[GitHub] Summary sent for conversation {conversation_id}')
|
||||
|
||||
# Mark callback as completed status - the outer session will commit this
|
||||
# Mark callback as completed status
|
||||
callback.status = CallbackStatus.COMPLETED
|
||||
callback.updated_at = datetime.now()
|
||||
with session_maker() as session:
|
||||
session.merge(callback)
|
||||
session.commit()
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f'[GitHub] Error processing conversation callback: {str(e)}'
|
||||
)
|
||||
# Mark callback as error to prevent infinite re-invocation
|
||||
# The outer session will commit this
|
||||
callback.status = CallbackStatus.ERROR
|
||||
callback.updated_at = datetime.now()
|
||||
|
||||
@@ -98,29 +98,6 @@ def decrypt_legacy_value(value: str | SecretStr) -> str:
|
||||
return get_fernet().decrypt(b64decode(value.encode())).decode()
|
||||
|
||||
|
||||
def encrypt_legacy_model(encrypt_keys: list, model_instance) -> dict:
|
||||
return encrypt_legacy_kwargs(encrypt_keys, model_to_kwargs(model_instance))
|
||||
|
||||
|
||||
def encrypt_legacy_kwargs(encrypt_keys: list, kwargs: dict) -> dict:
|
||||
for key, value in kwargs.items():
|
||||
if value is None:
|
||||
continue
|
||||
if key in encrypt_keys:
|
||||
value = encrypt_legacy_value(value)
|
||||
kwargs[key] = value
|
||||
return kwargs
|
||||
|
||||
|
||||
def encrypt_legacy_value(value: str | SecretStr) -> str:
|
||||
if isinstance(value, SecretStr):
|
||||
return b64encode(
|
||||
get_fernet().encrypt(value.get_secret_value().encode())
|
||||
).decode()
|
||||
else:
|
||||
return b64encode(get_fernet().encrypt(value.encode())).decode()
|
||||
|
||||
|
||||
def get_fernet():
|
||||
global _fernet
|
||||
if _fernet is None:
|
||||
|
||||
@@ -195,163 +195,6 @@ class LiteLlmManager:
|
||||
)
|
||||
return user_settings
|
||||
|
||||
@staticmethod
|
||||
async def downgrade_entries(
|
||||
org_id: str,
|
||||
keycloak_user_id: str,
|
||||
user_settings: UserSettings,
|
||||
) -> UserSettings | None:
|
||||
"""Downgrade a migrated user's LiteLLM entries back to the pre-migration state.
|
||||
|
||||
This reverses the migrate_entries operation:
|
||||
1. Get the user max budget from their org team in litellm
|
||||
2. Set the max budget in the user in litellm (restore from team)
|
||||
3. Add the user back to the default team in litellm
|
||||
4. Update keys to remove org team association
|
||||
5. Remove the user from their org team in litellm
|
||||
6. Delete the user org team in litellm
|
||||
|
||||
Note: The database changes (already_migrated flag, org/org_member deletion)
|
||||
should be handled separately by the caller.
|
||||
|
||||
Args:
|
||||
org_id: The organization ID (which is also the team_id in litellm)
|
||||
keycloak_user_id: The user's Keycloak ID
|
||||
user_settings: The user's settings object
|
||||
|
||||
Returns:
|
||||
The user_settings if downgrade was successful, None otherwise
|
||||
"""
|
||||
logger.info(
|
||||
'LiteLlmManager:downgrade_entries:start',
|
||||
extra={'org_id': org_id, 'user_id': keycloak_user_id},
|
||||
)
|
||||
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
|
||||
logger.warning('LiteLLM API configuration not found')
|
||||
return None
|
||||
|
||||
local_deploy = os.environ.get('LOCAL_DEPLOYMENT', None)
|
||||
if not local_deploy:
|
||||
async with httpx.AsyncClient(
|
||||
headers={
|
||||
'x-goog-api-key': LITE_LLM_API_KEY,
|
||||
}
|
||||
) as client:
|
||||
# Step 1: Get the team info to retrieve the budget
|
||||
logger.debug(
|
||||
'LiteLlmManager:downgrade_entries:get_team',
|
||||
extra={'org_id': org_id, 'user_id': keycloak_user_id},
|
||||
)
|
||||
team_info = await LiteLlmManager._get_team(client, org_id)
|
||||
if not team_info:
|
||||
logger.error(
|
||||
'LiteLlmManager:downgrade_entries:team_not_found',
|
||||
extra={'org_id': org_id, 'user_id': keycloak_user_id},
|
||||
)
|
||||
return None
|
||||
|
||||
# Get team budget (max_budget) and spend to calculate current credits
|
||||
team_data = team_info.get('team_info', {})
|
||||
max_budget = team_data.get('max_budget', 0.0)
|
||||
spend = team_data.get('spend', 0.0)
|
||||
|
||||
# Get user membership info for budget in team
|
||||
user_membership = await LiteLlmManager._get_user_team_info(
|
||||
client, keycloak_user_id, org_id
|
||||
)
|
||||
if user_membership:
|
||||
# Use user's budget in team if available
|
||||
user_max_budget_in_team = user_membership.get('max_budget_in_team')
|
||||
user_spend_in_team = user_membership.get('spend', 0.0)
|
||||
if user_max_budget_in_team is not None:
|
||||
max_budget = user_max_budget_in_team
|
||||
spend = user_spend_in_team
|
||||
|
||||
# Calculate total budget to restore (credits + spend = max_budget)
|
||||
# We restore the full max_budget that was on the team/user-in-team
|
||||
restored_budget = max_budget if max_budget else 0.0
|
||||
|
||||
logger.debug(
|
||||
'LiteLlmManager:downgrade_entries:budget_info',
|
||||
extra={
|
||||
'org_id': org_id,
|
||||
'user_id': keycloak_user_id,
|
||||
'max_budget': max_budget,
|
||||
'spend': spend,
|
||||
'restored_budget': restored_budget,
|
||||
},
|
||||
)
|
||||
|
||||
# Step 2: Update user to set their max_budget back from unlimited
|
||||
logger.debug(
|
||||
'LiteLlmManager:downgrade_entries:update_user',
|
||||
extra={'org_id': org_id, 'user_id': keycloak_user_id},
|
||||
)
|
||||
await LiteLlmManager._update_user(
|
||||
client, keycloak_user_id, max_budget=restored_budget, spend=spend
|
||||
)
|
||||
|
||||
# Step 3: Add user back to the default team
|
||||
if LITE_LLM_TEAM_ID:
|
||||
logger.debug(
|
||||
'LiteLlmManager:downgrade_entries:add_to_default_team',
|
||||
extra={
|
||||
'org_id': org_id,
|
||||
'user_id': keycloak_user_id,
|
||||
'default_team_id': LITE_LLM_TEAM_ID,
|
||||
},
|
||||
)
|
||||
await LiteLlmManager._add_user_to_team(
|
||||
client, keycloak_user_id, LITE_LLM_TEAM_ID, restored_budget
|
||||
)
|
||||
|
||||
# Step 4: Update keys to remove org team association (set team_id to default)
|
||||
if user_settings.llm_api_key:
|
||||
logger.debug(
|
||||
'LiteLlmManager:downgrade_entries:update_key',
|
||||
extra={'org_id': org_id, 'user_id': keycloak_user_id},
|
||||
)
|
||||
await LiteLlmManager._update_key(
|
||||
client,
|
||||
keycloak_user_id,
|
||||
user_settings.llm_api_key,
|
||||
team_id=LITE_LLM_TEAM_ID,
|
||||
)
|
||||
|
||||
if user_settings.llm_api_key_for_byor:
|
||||
logger.debug(
|
||||
'LiteLlmManager:downgrade_entries:update_byor_key',
|
||||
extra={'org_id': org_id, 'user_id': keycloak_user_id},
|
||||
)
|
||||
await LiteLlmManager._update_key(
|
||||
client,
|
||||
keycloak_user_id,
|
||||
user_settings.llm_api_key_for_byor,
|
||||
team_id=LITE_LLM_TEAM_ID,
|
||||
)
|
||||
|
||||
# Step 5: Remove user from their org team
|
||||
logger.debug(
|
||||
'LiteLlmManager:downgrade_entries:remove_from_org_team',
|
||||
extra={'org_id': org_id, 'user_id': keycloak_user_id},
|
||||
)
|
||||
await LiteLlmManager._remove_user_from_team(
|
||||
client, keycloak_user_id, org_id
|
||||
)
|
||||
|
||||
# Step 6: Delete the org team
|
||||
logger.debug(
|
||||
'LiteLlmManager:downgrade_entries:delete_team',
|
||||
extra={'org_id': org_id, 'user_id': keycloak_user_id},
|
||||
)
|
||||
await LiteLlmManager._delete_team(client, org_id)
|
||||
|
||||
logger.info(
|
||||
'LiteLlmManager:downgrade_entries:complete',
|
||||
extra={'org_id': org_id, 'user_id': keycloak_user_id},
|
||||
)
|
||||
return user_settings
|
||||
|
||||
@staticmethod
|
||||
async def update_team_and_users_budget(
|
||||
team_id: str,
|
||||
@@ -794,45 +637,6 @@ class LiteLlmManager:
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
async def _remove_user_from_team(
|
||||
client: httpx.AsyncClient,
|
||||
keycloak_user_id: str,
|
||||
team_id: str,
|
||||
):
|
||||
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
|
||||
logger.warning('LiteLLM API configuration not found')
|
||||
return
|
||||
response = await client.post(
|
||||
f'{LITE_LLM_API_URL}/team/member_delete',
|
||||
json={
|
||||
'team_id': team_id,
|
||||
'user_id': keycloak_user_id,
|
||||
},
|
||||
)
|
||||
if not response.is_success:
|
||||
if response.status_code == 404:
|
||||
# User not in team, that's fine for downgrade
|
||||
logger.info(
|
||||
'User not in team during removal',
|
||||
extra={'user_id': keycloak_user_id, 'team_id': team_id},
|
||||
)
|
||||
return
|
||||
logger.error(
|
||||
'error_removing_litellm_user_from_team',
|
||||
extra={
|
||||
'status_code': response.status_code,
|
||||
'text': response.text,
|
||||
'user_id': keycloak_user_id,
|
||||
'team_id': team_id,
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
logger.info(
|
||||
'LiteLlmManager:_remove_user_from_team:user_removed',
|
||||
extra={'user_id': keycloak_user_id, 'team_id': team_id},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def _generate_key(
|
||||
client: httpx.AsyncClient,
|
||||
@@ -1076,7 +880,6 @@ class LiteLlmManager:
|
||||
delete_user = staticmethod(with_http_client(_delete_user))
|
||||
delete_team = staticmethod(with_http_client(_delete_team))
|
||||
add_user_to_team = staticmethod(with_http_client(_add_user_to_team))
|
||||
remove_user_from_team = staticmethod(with_http_client(_remove_user_from_team))
|
||||
get_user_team_info = staticmethod(with_http_client(_get_user_team_info))
|
||||
update_user_in_team = staticmethod(with_http_client(_update_user_in_team))
|
||||
generate_key = staticmethod(with_http_client(_generate_key))
|
||||
|
||||
@@ -17,10 +17,7 @@ from server.logger import logger
|
||||
from sqlalchemy import select, text
|
||||
from sqlalchemy.orm import joinedload
|
||||
from storage.database import a_session_maker, session_maker
|
||||
from storage.encrypt_utils import (
|
||||
decrypt_legacy_model,
|
||||
encrypt_legacy_value,
|
||||
)
|
||||
from storage.encrypt_utils import decrypt_legacy_model
|
||||
from storage.org import Org
|
||||
from storage.org_member import OrgMember
|
||||
from storage.role_store import RoleStore
|
||||
@@ -241,6 +238,7 @@ class UserStore:
|
||||
if not custom_settings:
|
||||
del org_member_kwargs['llm_model']
|
||||
del org_member_kwargs['llm_base_url']
|
||||
del org_member_kwargs['llm_api_key_for_byor']
|
||||
|
||||
org_member = OrgMember(
|
||||
org_id=org.id,
|
||||
@@ -332,256 +330,6 @@ class UserStore:
|
||||
)
|
||||
return user
|
||||
|
||||
@staticmethod
|
||||
async def downgrade_user(user_id: str) -> UserSettings | None:
|
||||
"""Downgrade a migrated user back to the pre-migration state.
|
||||
|
||||
This reverses the migrate_user operation:
|
||||
1. Get the user's settings from user_settings table (migrated users) or
|
||||
create new user_settings from org_members table (new sign-ups)
|
||||
2. Call LiteLlmManager.downgrade_entries to revert LiteLLM state
|
||||
3. Copy user_id from conversation_metadata_saas to conversation_metadata
|
||||
4. Delete conversation_metadata_saas entries
|
||||
5. Reset org_id columns in related tables (stripe_customers, slack_users, etc.)
|
||||
6. Delete the org_member and org entries
|
||||
7. Delete the user entry
|
||||
8. Set already_migrated=False on user_settings
|
||||
|
||||
For new sign-ups (users who registered after migration was deployed),
|
||||
there won't be an existing user_settings entry. In this case, we fall back
|
||||
to the org_members table to get the user's API keys and settings, and create
|
||||
a new user_settings entry for them.
|
||||
|
||||
Args:
|
||||
user_id: The Keycloak user ID to downgrade
|
||||
|
||||
Returns:
|
||||
The user_settings if downgrade was successful, None otherwise.
|
||||
Returns None if the org has multiple members (not a personal org).
|
||||
"""
|
||||
logger.info(
|
||||
'user_store:downgrade_user:start',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
|
||||
with session_maker() as session:
|
||||
# Get the user and their org_member
|
||||
user = (
|
||||
session.query(User)
|
||||
.options(joinedload(User.org_members))
|
||||
.filter(User.id == uuid.UUID(user_id))
|
||||
.first()
|
||||
)
|
||||
if not user:
|
||||
logger.warning(
|
||||
'user_store:downgrade_user:user_not_found',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
return None
|
||||
|
||||
# Get the user's personal org (org_id == user_id)
|
||||
org = session.query(Org).filter(Org.id == uuid.UUID(user_id)).first()
|
||||
if not org:
|
||||
logger.warning(
|
||||
'user_store:downgrade_user:org_not_found',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
return None
|
||||
|
||||
# Get the user_settings (for migrated users)
|
||||
user_settings = (
|
||||
session.query(UserSettings)
|
||||
.filter(
|
||||
UserSettings.keycloak_user_id == user_id,
|
||||
UserSettings.already_migrated.is_(True),
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
# For new sign-ups after migration, user_settings won't exist
|
||||
# Fall back to getting data from org_members
|
||||
is_new_signup = False
|
||||
if not user_settings:
|
||||
logger.info(
|
||||
'user_store:downgrade_user:user_settings_not_found_checking_org_members',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
# Get org_members for this org - should only be one for personal orgs
|
||||
org_members = (
|
||||
session.query(OrgMember).filter(OrgMember.org_id == org.id).all()
|
||||
)
|
||||
|
||||
if len(org_members) != 1:
|
||||
logger.error(
|
||||
'user_store:downgrade_user:unexpected_org_members_count',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'org_id': str(org.id),
|
||||
'org_members_count': len(org_members),
|
||||
},
|
||||
)
|
||||
return None
|
||||
|
||||
org_member = org_members[0]
|
||||
is_new_signup = True
|
||||
|
||||
# Create a new user_settings entry from OrgMember, User, and Org data
|
||||
# This is needed for new sign-ups who don't have user_settings
|
||||
user_settings = UserStore._create_user_settings_from_entities(
|
||||
user_id, org_member, user, org
|
||||
)
|
||||
session.add(user_settings)
|
||||
session.flush()
|
||||
logger.info(
|
||||
'user_store:downgrade_user:created_user_settings_from_org_member',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
|
||||
# Call LiteLLM downgrade
|
||||
from storage.lite_llm_manager import LiteLlmManager
|
||||
|
||||
logger.debug(
|
||||
'user_store:downgrade_user:calling_litellm_downgrade_entries',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
|
||||
# Get the API keys for LiteLLM downgrade
|
||||
if is_new_signup:
|
||||
# For new signups, we already have decrypted values in user_settings
|
||||
decrypted_user_settings = user_settings
|
||||
else:
|
||||
# For migrated users, decrypt the legacy model
|
||||
kwargs = decrypt_legacy_model(
|
||||
[
|
||||
'llm_api_key',
|
||||
'llm_api_key_for_byor',
|
||||
'search_api_key',
|
||||
'sandbox_api_key',
|
||||
],
|
||||
user_settings,
|
||||
)
|
||||
decrypted_user_settings = UserSettings(**kwargs)
|
||||
|
||||
await LiteLlmManager.downgrade_entries(
|
||||
str(org.id),
|
||||
user_id,
|
||||
decrypted_user_settings,
|
||||
)
|
||||
logger.debug(
|
||||
'user_store:downgrade_user:done_litellm_downgrade_entries',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
|
||||
user_uuid = uuid.UUID(user_id)
|
||||
|
||||
# Step 3: Copy user_id from conversation_metadata_saas to conversation_metadata
|
||||
# This ensures any conversations created after migration have their user_id
|
||||
# preserved in the original table before we delete the saas entries
|
||||
session.execute(
|
||||
text("""
|
||||
UPDATE conversation_metadata
|
||||
SET user_id = :user_id
|
||||
WHERE conversation_id IN (
|
||||
SELECT conversation_id
|
||||
FROM conversation_metadata_saas
|
||||
WHERE user_id = :user_uuid
|
||||
)
|
||||
"""),
|
||||
{'user_id': user_id, 'user_uuid': user_uuid},
|
||||
)
|
||||
|
||||
# Step 4: Delete conversation_metadata_saas entries
|
||||
session.execute(
|
||||
text('DELETE FROM conversation_metadata_saas WHERE user_id = :user_id'),
|
||||
{'user_id': user_uuid},
|
||||
)
|
||||
|
||||
# Step 5: Reset org_id columns in related tables
|
||||
# Reset stripe_customers
|
||||
session.execute(
|
||||
text(
|
||||
'UPDATE stripe_customers SET org_id = NULL WHERE org_id = :org_id'
|
||||
),
|
||||
{'org_id': user_uuid},
|
||||
)
|
||||
|
||||
# Reset slack_users
|
||||
session.execute(
|
||||
text('UPDATE slack_users SET org_id = NULL WHERE org_id = :org_id'),
|
||||
{'org_id': user_uuid},
|
||||
)
|
||||
|
||||
# Reset slack_conversation
|
||||
session.execute(
|
||||
text(
|
||||
'UPDATE slack_conversation SET org_id = NULL WHERE org_id = :org_id'
|
||||
),
|
||||
{'org_id': user_uuid},
|
||||
)
|
||||
|
||||
# Reset api_keys
|
||||
session.execute(
|
||||
text('UPDATE api_keys SET org_id = NULL WHERE org_id = :org_id'),
|
||||
{'org_id': user_uuid},
|
||||
)
|
||||
|
||||
# Reset custom_secrets
|
||||
session.execute(
|
||||
text('UPDATE custom_secrets SET org_id = NULL WHERE org_id = :org_id'),
|
||||
{'org_id': user_uuid},
|
||||
)
|
||||
|
||||
# Reset billing_sessions
|
||||
session.execute(
|
||||
text(
|
||||
'UPDATE billing_sessions SET org_id = NULL WHERE org_id = :org_id'
|
||||
),
|
||||
{'org_id': user_uuid},
|
||||
)
|
||||
|
||||
# Step 6: Delete org_member entries for this org
|
||||
session.execute(
|
||||
text('DELETE FROM org_member WHERE org_id = :org_id'),
|
||||
{'org_id': user_uuid},
|
||||
)
|
||||
|
||||
# Step 7: Delete the user entry
|
||||
session.execute(
|
||||
text('DELETE FROM "user" WHERE id = :user_id'),
|
||||
{'user_id': user_uuid},
|
||||
)
|
||||
|
||||
# Delete the org entry
|
||||
session.execute(
|
||||
text('DELETE FROM org WHERE id = :org_id'),
|
||||
{'org_id': user_uuid},
|
||||
)
|
||||
|
||||
# Step 8: Set already_migrated=False on user_settings and encrypt fields
|
||||
user_settings.already_migrated = False
|
||||
|
||||
# Re-encrypt the sensitive fields before storing in the DB
|
||||
encrypt_keys = [
|
||||
'llm_api_key',
|
||||
'llm_api_key_for_byor',
|
||||
'search_api_key',
|
||||
'sandbox_api_key',
|
||||
]
|
||||
for key in encrypt_keys:
|
||||
value = getattr(user_settings, key, None)
|
||||
if value is not None:
|
||||
setattr(user_settings, key, encrypt_legacy_value(value))
|
||||
|
||||
session.merge(user_settings)
|
||||
|
||||
session.commit()
|
||||
|
||||
logger.info(
|
||||
'user_store:downgrade_user:complete',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
return user_settings
|
||||
|
||||
@staticmethod
|
||||
def get_user_by_id(user_id: str) -> Optional[User]:
|
||||
"""Get user by Keycloak user ID (sync version).
|
||||
@@ -772,96 +520,6 @@ class UserStore:
|
||||
}
|
||||
return kwargs
|
||||
|
||||
@staticmethod
|
||||
def _create_user_settings_from_entities(
|
||||
user_id: str, org_member: OrgMember, user: User, org: Org
|
||||
) -> UserSettings:
|
||||
"""Create UserSettings from OrgMember, User, and Org data.
|
||||
|
||||
Uses OrgMember values first. If an OrgMember field is None and there's
|
||||
a corresponding "default_" field in Org, use the Org value.
|
||||
Also pulls relevant fields from User.
|
||||
|
||||
Args:
|
||||
user_id: The Keycloak user ID
|
||||
org_member: The OrgMember entity
|
||||
user: The User entity
|
||||
org: The Org entity
|
||||
|
||||
Returns:
|
||||
A new UserSettings object populated from the entities
|
||||
"""
|
||||
# Mapping from OrgMember fields to corresponding Org "default_" fields
|
||||
org_member_to_org_default = {
|
||||
'llm_model': 'default_llm_model',
|
||||
'llm_base_url': 'default_llm_base_url',
|
||||
'max_iterations': 'default_max_iterations',
|
||||
}
|
||||
|
||||
def get_value_with_org_fallback(field_name: str, org_member_value):
|
||||
"""Get value from OrgMember, falling back to Org default if None."""
|
||||
if org_member_value is not None:
|
||||
return org_member_value
|
||||
org_default_field = org_member_to_org_default.get(field_name)
|
||||
if org_default_field and hasattr(org, org_default_field):
|
||||
return getattr(org, org_default_field)
|
||||
return None
|
||||
|
||||
# Get values from OrgMember with Org fallback for fields with default_ prefix
|
||||
llm_model = get_value_with_org_fallback('llm_model', org_member.llm_model)
|
||||
llm_base_url = get_value_with_org_fallback(
|
||||
'llm_base_url', org_member.llm_base_url
|
||||
)
|
||||
max_iterations = get_value_with_org_fallback(
|
||||
'max_iterations', org_member.max_iterations
|
||||
)
|
||||
|
||||
return UserSettings(
|
||||
keycloak_user_id=user_id,
|
||||
# OrgMember fields
|
||||
llm_api_key=org_member.llm_api_key.get_secret_value()
|
||||
if org_member.llm_api_key
|
||||
else None,
|
||||
llm_api_key_for_byor=org_member.llm_api_key_for_byor.get_secret_value()
|
||||
if org_member.llm_api_key_for_byor
|
||||
else None,
|
||||
llm_model=llm_model,
|
||||
llm_base_url=llm_base_url,
|
||||
max_iterations=max_iterations,
|
||||
# User fields
|
||||
accepted_tos=user.accepted_tos,
|
||||
enable_sound_notifications=user.enable_sound_notifications,
|
||||
language=user.language,
|
||||
user_consents_to_analytics=user.user_consents_to_analytics,
|
||||
email=user.email,
|
||||
email_verified=user.email_verified,
|
||||
git_user_name=user.git_user_name,
|
||||
git_user_email=user.git_user_email,
|
||||
# Org fields
|
||||
agent=org.agent,
|
||||
security_analyzer=org.security_analyzer,
|
||||
confirmation_mode=org.confirmation_mode,
|
||||
remote_runtime_resource_factor=org.remote_runtime_resource_factor,
|
||||
enable_default_condenser=org.enable_default_condenser,
|
||||
billing_margin=org.billing_margin,
|
||||
enable_proactive_conversation_starters=org.enable_proactive_conversation_starters,
|
||||
sandbox_base_container_image=org.sandbox_base_container_image,
|
||||
sandbox_runtime_container_image=org.sandbox_runtime_container_image,
|
||||
user_version=org.org_version,
|
||||
mcp_config=org.mcp_config,
|
||||
search_api_key=org.search_api_key.get_secret_value()
|
||||
if org.search_api_key
|
||||
else None,
|
||||
sandbox_api_key=org.sandbox_api_key.get_secret_value()
|
||||
if org.sandbox_api_key
|
||||
else None,
|
||||
max_budget_per_task=org.max_budget_per_task,
|
||||
enable_solvability_analysis=org.enable_solvability_analysis,
|
||||
v1_enabled=org.v1_enabled,
|
||||
condenser_max_size=org.condenser_max_size,
|
||||
already_migrated=False,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _has_custom_settings(
|
||||
user_settings: UserSettings, old_user_version: int | None
|
||||
|
||||
@@ -1126,174 +1126,3 @@ class TestLiteLlmManager:
|
||||
'http://test.url/team/delete',
|
||||
json={'team_ids': [team_id]},
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_remove_user_from_team_successful(self):
|
||||
"""
|
||||
GIVEN: Valid user_id and team_id
|
||||
WHEN: _remove_user_from_team is called
|
||||
THEN: HTTP POST is made to remove user from team
|
||||
"""
|
||||
mock_response = AsyncMock()
|
||||
mock_response.is_success = True
|
||||
mock_response.status_code = 200
|
||||
|
||||
with (
|
||||
patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'),
|
||||
patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.url'),
|
||||
):
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post.return_value = mock_response
|
||||
|
||||
await LiteLlmManager._remove_user_from_team(
|
||||
mock_client, 'test-user-id', 'test-team-id'
|
||||
)
|
||||
|
||||
mock_client.post.assert_called_once_with(
|
||||
'http://test.url/team/member_delete',
|
||||
json={
|
||||
'team_id': 'test-team-id',
|
||||
'user_id': 'test-user-id',
|
||||
},
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_remove_user_from_team_not_found(self):
|
||||
"""
|
||||
GIVEN: User not in team
|
||||
WHEN: _remove_user_from_team is called
|
||||
THEN: 404 response is handled gracefully without raising
|
||||
"""
|
||||
mock_response = AsyncMock()
|
||||
mock_response.is_success = False
|
||||
mock_response.status_code = 404
|
||||
mock_response.text = 'User not found in team'
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
with (
|
||||
patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'),
|
||||
patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.url'),
|
||||
):
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post.return_value = mock_response
|
||||
|
||||
# Should not raise an exception
|
||||
await LiteLlmManager._remove_user_from_team(
|
||||
mock_client, 'test-user-id', 'test-team-id'
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_downgrade_entries_missing_config(self, mock_user_settings):
|
||||
"""Test downgrade_entries when LiteLLM config is missing."""
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', None):
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_URL', None):
|
||||
result = await LiteLlmManager.downgrade_entries(
|
||||
'test-org-id',
|
||||
'test-user-id',
|
||||
mock_user_settings,
|
||||
)
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_downgrade_entries_team_not_found(self, mock_user_settings):
|
||||
"""Test downgrade_entries when team is not found."""
|
||||
with patch.dict(os.environ, {'LOCAL_DEPLOYMENT': ''}):
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'):
|
||||
with patch(
|
||||
'storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'
|
||||
):
|
||||
with patch.object(
|
||||
LiteLlmManager, '_get_team', new_callable=AsyncMock
|
||||
) as mock_get_team:
|
||||
mock_get_team.return_value = None
|
||||
|
||||
result = await LiteLlmManager.downgrade_entries(
|
||||
'test-org-id',
|
||||
'test-user-id',
|
||||
mock_user_settings,
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_downgrade_entries_successful(self, mock_user_settings):
|
||||
"""Test successful downgrade_entries operation."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.is_success = True
|
||||
mock_response.status_code = 200
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_team_info_response = MagicMock()
|
||||
mock_team_info_response.is_success = True
|
||||
mock_team_info_response.status_code = 200
|
||||
mock_team_info_response.json.return_value = {
|
||||
'team_info': {
|
||||
'max_budget': 100.0,
|
||||
'spend': 20.0,
|
||||
},
|
||||
'team_memberships': [
|
||||
{
|
||||
'user_id': 'test-user-id',
|
||||
'team_id': 'test-org-id',
|
||||
'max_budget_in_team': 100.0,
|
||||
'spend': 20.0,
|
||||
}
|
||||
],
|
||||
}
|
||||
mock_team_info_response.raise_for_status = MagicMock()
|
||||
|
||||
with patch.dict(os.environ, {'LOCAL_DEPLOYMENT': ''}):
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'):
|
||||
with patch(
|
||||
'storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'
|
||||
):
|
||||
with patch(
|
||||
'storage.lite_llm_manager.LITE_LLM_TEAM_ID', 'default-team'
|
||||
):
|
||||
with patch('httpx.AsyncClient') as mock_client_class:
|
||||
mock_client = AsyncMock()
|
||||
mock_client_class.return_value.__aenter__.return_value = (
|
||||
mock_client
|
||||
)
|
||||
mock_client.get.return_value = mock_team_info_response
|
||||
mock_client.post.return_value = mock_response
|
||||
|
||||
result = await LiteLlmManager.downgrade_entries(
|
||||
'test-org-id',
|
||||
'test-user-id',
|
||||
mock_user_settings,
|
||||
)
|
||||
|
||||
# downgrade_entries returns the user_settings
|
||||
assert result is not None
|
||||
assert result.agent == 'TestAgent'
|
||||
|
||||
# Verify downgrade steps were called:
|
||||
# 1. get_team (GET)
|
||||
# 2. get_user_team_info (GET via _get_team)
|
||||
# 3. update_user (POST)
|
||||
# 4. add_user_to_team (POST)
|
||||
# 5. update_key (POST)
|
||||
# 6. remove_user_from_team (POST)
|
||||
# 7. delete_team (POST)
|
||||
assert mock_client.get.call_count >= 1
|
||||
assert mock_client.post.call_count >= 4
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_downgrade_entries_local_deployment(self, mock_user_settings):
|
||||
"""Test downgrade_entries in local deployment mode (skips LiteLLM calls)."""
|
||||
with patch.dict(os.environ, {'LOCAL_DEPLOYMENT': 'true'}):
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'):
|
||||
with patch(
|
||||
'storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'
|
||||
):
|
||||
result = await LiteLlmManager.downgrade_entries(
|
||||
'test-org-id',
|
||||
'test-user-id',
|
||||
mock_user_settings,
|
||||
)
|
||||
|
||||
# In local deployment, should return user_settings without
|
||||
# making any LiteLLM calls
|
||||
assert result is not None
|
||||
assert result.agent == 'TestAgent'
|
||||
|
||||
@@ -8,3 +8,4 @@ node_modules/
|
||||
/blob-report/
|
||||
/playwright/.cache/
|
||||
.react-router/
|
||||
ralph/
|
||||
|
||||
@@ -183,4 +183,170 @@ describe("GitBranchDropdown", () => {
|
||||
expect(mockOnBranchSelect).toHaveBeenCalledWith(MOCK_BRANCHES[1]);
|
||||
});
|
||||
});
|
||||
|
||||
describe("loading states", () => {
|
||||
it("should show spinner when isLoading is true", () => {
|
||||
mockUseBranchData.mockReturnValue({
|
||||
branches: [],
|
||||
isLoading: true,
|
||||
isError: false,
|
||||
fetchNextPage: vi.fn(),
|
||||
hasNextPage: false,
|
||||
isFetchingNextPage: false,
|
||||
isSearchLoading: false,
|
||||
});
|
||||
|
||||
render(
|
||||
<GitBranchDropdown
|
||||
repository="user/repo"
|
||||
provider="github"
|
||||
selectedBranch={null}
|
||||
onBranchSelect={mockOnBranchSelect}
|
||||
/>,
|
||||
{
|
||||
wrapper: ({ children }) => (
|
||||
<QueryClientProvider
|
||||
client={
|
||||
new QueryClient({
|
||||
defaultOptions: {
|
||||
queries: {
|
||||
retry: false,
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
>
|
||||
{children}
|
||||
</QueryClientProvider>
|
||||
),
|
||||
},
|
||||
);
|
||||
|
||||
// Spinner should be visible
|
||||
expect(screen.getByTestId("spinner")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should show spinner when isSearchLoading is true", () => {
|
||||
mockUseBranchData.mockReturnValue({
|
||||
branches: MOCK_BRANCHES,
|
||||
isLoading: false,
|
||||
isError: false,
|
||||
fetchNextPage: vi.fn(),
|
||||
hasNextPage: false,
|
||||
isFetchingNextPage: false,
|
||||
isSearchLoading: true,
|
||||
});
|
||||
|
||||
render(
|
||||
<GitBranchDropdown
|
||||
repository="user/repo"
|
||||
provider="github"
|
||||
selectedBranch={null}
|
||||
onBranchSelect={mockOnBranchSelect}
|
||||
/>,
|
||||
{
|
||||
wrapper: ({ children }) => (
|
||||
<QueryClientProvider
|
||||
client={
|
||||
new QueryClient({
|
||||
defaultOptions: {
|
||||
queries: {
|
||||
retry: false,
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
>
|
||||
{children}
|
||||
</QueryClientProvider>
|
||||
),
|
||||
},
|
||||
);
|
||||
|
||||
// Spinner should be visible during search loading
|
||||
expect(screen.getByTestId("spinner")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should show spinner when isFetchingNextPage is true", () => {
|
||||
mockUseBranchData.mockReturnValue({
|
||||
branches: MOCK_BRANCHES,
|
||||
isLoading: false,
|
||||
isError: false,
|
||||
fetchNextPage: vi.fn(),
|
||||
hasNextPage: true,
|
||||
isFetchingNextPage: true,
|
||||
isSearchLoading: false,
|
||||
});
|
||||
|
||||
render(
|
||||
<GitBranchDropdown
|
||||
repository="user/repo"
|
||||
provider="github"
|
||||
selectedBranch={null}
|
||||
onBranchSelect={mockOnBranchSelect}
|
||||
/>,
|
||||
{
|
||||
wrapper: ({ children }) => (
|
||||
<QueryClientProvider
|
||||
client={
|
||||
new QueryClient({
|
||||
defaultOptions: {
|
||||
queries: {
|
||||
retry: false,
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
>
|
||||
{children}
|
||||
</QueryClientProvider>
|
||||
),
|
||||
},
|
||||
);
|
||||
|
||||
// Spinner should be visible while fetching next page
|
||||
expect(screen.getByTestId("spinner")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should show branch icon when not in any loading state", () => {
|
||||
mockUseBranchData.mockReturnValue({
|
||||
branches: MOCK_BRANCHES,
|
||||
isLoading: false,
|
||||
isError: false,
|
||||
fetchNextPage: vi.fn(),
|
||||
hasNextPage: false,
|
||||
isFetchingNextPage: false,
|
||||
isSearchLoading: false,
|
||||
});
|
||||
|
||||
render(
|
||||
<GitBranchDropdown
|
||||
repository="user/repo"
|
||||
provider="github"
|
||||
selectedBranch={null}
|
||||
onBranchSelect={mockOnBranchSelect}
|
||||
/>,
|
||||
{
|
||||
wrapper: ({ children }) => (
|
||||
<QueryClientProvider
|
||||
client={
|
||||
new QueryClient({
|
||||
defaultOptions: {
|
||||
queries: {
|
||||
retry: false,
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
>
|
||||
{children}
|
||||
</QueryClientProvider>
|
||||
),
|
||||
},
|
||||
);
|
||||
|
||||
// Spinner should NOT be visible when not loading
|
||||
expect(screen.queryByTestId("spinner")).not.toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -57,6 +57,7 @@ const setupDefaultMocks = (
|
||||
repositoryDataOverrides: Partial<
|
||||
ReturnType<typeof mockUseRepositoryData>
|
||||
> = {},
|
||||
urlSearchOverrides: Partial<ReturnType<typeof mockUseUrlSearch>> = {},
|
||||
) => {
|
||||
mockUseRepositoryData.mockReturnValue({
|
||||
repositories: MOCK_REPOSITORIES,
|
||||
@@ -73,6 +74,7 @@ const setupDefaultMocks = (
|
||||
mockUseUrlSearch.mockReturnValue({
|
||||
urlSearchResults: [],
|
||||
isUrlSearchLoading: false,
|
||||
...urlSearchOverrides,
|
||||
});
|
||||
};
|
||||
|
||||
@@ -81,9 +83,10 @@ const renderDropdown = (
|
||||
repositoryDataOverrides: Partial<
|
||||
ReturnType<typeof mockUseRepositoryData>
|
||||
> = {},
|
||||
urlSearchOverrides: Partial<ReturnType<typeof mockUseUrlSearch>> = {},
|
||||
) => {
|
||||
// Set up mocks with optional overrides
|
||||
setupDefaultMocks(repositoryDataOverrides);
|
||||
setupDefaultMocks(repositoryDataOverrides, urlSearchOverrides);
|
||||
|
||||
return render(
|
||||
<GitRepoDropdown
|
||||
@@ -231,4 +234,41 @@ describe("GitRepoDropdown", () => {
|
||||
expect(mockOnChange).toHaveBeenCalledWith(MOCK_REPOSITORIES[1]);
|
||||
});
|
||||
});
|
||||
|
||||
describe("loading states", () => {
|
||||
it("should show spinner when isLoading is true", () => {
|
||||
renderDropdown({}, { isLoading: true, repositories: [] });
|
||||
|
||||
// Spinner should be visible
|
||||
expect(screen.getByTestId("spinner")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should show spinner when isSearchLoading is true", () => {
|
||||
renderDropdown({}, { isSearchLoading: true });
|
||||
|
||||
// Spinner should be visible during search loading
|
||||
expect(screen.getByTestId("spinner")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should show spinner when isFetchingNextPage is true", () => {
|
||||
renderDropdown({}, { isFetchingNextPage: true, hasNextPage: true });
|
||||
|
||||
// Spinner should be visible while fetching next page
|
||||
expect(screen.getByTestId("spinner")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should show spinner when isUrlSearchLoading is true", () => {
|
||||
renderDropdown({}, {}, { isUrlSearchLoading: true });
|
||||
|
||||
// Spinner should be visible during URL search loading
|
||||
expect(screen.getByTestId("spinner")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should show repo icon when not in any loading state", () => {
|
||||
renderDropdown();
|
||||
|
||||
// Spinner should NOT be visible when not loading
|
||||
expect(screen.queryByTestId("spinner")).not.toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
+39
@@ -0,0 +1,39 @@
|
||||
import { render, screen } from "@testing-library/react";
|
||||
import { describe, expect, it } from "vitest";
|
||||
import { BranchLoadingState } from "#/components/features/home/repository-selection/branch-loading-state";
|
||||
|
||||
describe("BranchLoadingState", () => {
|
||||
it("should render spinner with correct testId", () => {
|
||||
render(<BranchLoadingState />);
|
||||
|
||||
expect(screen.getByTestId("branch-dropdown-loading")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should render spinner element inside the component", () => {
|
||||
render(<BranchLoadingState />);
|
||||
|
||||
// The Spinner component renders with testId="spinner" by default
|
||||
expect(screen.getByTestId("spinner")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should display translated loading text", () => {
|
||||
render(<BranchLoadingState />);
|
||||
|
||||
// The mock translates keys to themselves
|
||||
expect(screen.getByText("HOME$LOADING_BRANCHES")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should apply wrapper className when provided", () => {
|
||||
render(<BranchLoadingState wrapperClassName="custom-class" />);
|
||||
|
||||
const wrapper = screen.getByTestId("branch-dropdown-loading");
|
||||
expect(wrapper).toHaveClass("custom-class");
|
||||
});
|
||||
|
||||
it("should have default styling classes", () => {
|
||||
render(<BranchLoadingState />);
|
||||
|
||||
const wrapper = screen.getByTestId("branch-dropdown-loading");
|
||||
expect(wrapper).toHaveClass("flex", "items-center", "gap-2");
|
||||
});
|
||||
});
|
||||
+39
@@ -0,0 +1,39 @@
|
||||
import { render, screen } from "@testing-library/react";
|
||||
import { describe, expect, it } from "vitest";
|
||||
import { RepositoryLoadingState } from "#/components/features/home/repository-selection/repository-loading-state";
|
||||
|
||||
describe("RepositoryLoadingState", () => {
|
||||
it("should render wrapper with correct testId", () => {
|
||||
render(<RepositoryLoadingState />);
|
||||
|
||||
expect(screen.getByTestId("repo-dropdown-loading")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should render spinner element inside the component", () => {
|
||||
render(<RepositoryLoadingState />);
|
||||
|
||||
// The Spinner component renders with testId="spinner" by default
|
||||
expect(screen.getByTestId("spinner")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should display translated loading text", () => {
|
||||
render(<RepositoryLoadingState />);
|
||||
|
||||
// The mock translates keys to themselves
|
||||
expect(screen.getByText("HOME$LOADING_REPOSITORIES")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should apply wrapper className when provided", () => {
|
||||
render(<RepositoryLoadingState wrapperClassName="custom-class" />);
|
||||
|
||||
const wrapper = screen.getByTestId("repo-dropdown-loading");
|
||||
expect(wrapper).toHaveClass("custom-class");
|
||||
});
|
||||
|
||||
it("should have default styling classes", () => {
|
||||
render(<RepositoryLoadingState />);
|
||||
|
||||
const wrapper = screen.getByTestId("repo-dropdown-loading");
|
||||
expect(wrapper).toHaveClass("flex", "items-center", "gap-2");
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,77 @@
|
||||
import { render, screen } from "@testing-library/react";
|
||||
import { describe, expect, it } from "vitest";
|
||||
import { Spinner } from "#/components/shared/spinner";
|
||||
|
||||
describe("Spinner", () => {
|
||||
it("should render with default testId", () => {
|
||||
render(<Spinner />);
|
||||
|
||||
expect(screen.getByTestId("spinner")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should have animate-spin class for rotation animation", () => {
|
||||
render(<Spinner />);
|
||||
|
||||
expect(screen.getByTestId("spinner")).toHaveClass("animate-spin");
|
||||
});
|
||||
|
||||
it("should render with default size (md = w-6 h-6)", () => {
|
||||
render(<Spinner />);
|
||||
|
||||
const spinner = screen.getByTestId("spinner");
|
||||
expect(spinner).toHaveClass("w-6", "h-6");
|
||||
});
|
||||
|
||||
it("should render with sm size (w-4 h-4)", () => {
|
||||
render(<Spinner size="sm" />);
|
||||
|
||||
const spinner = screen.getByTestId("spinner");
|
||||
expect(spinner).toHaveClass("w-4", "h-4");
|
||||
});
|
||||
|
||||
it("should render with lg size (w-10 h-10)", () => {
|
||||
render(<Spinner size="lg" />);
|
||||
|
||||
const spinner = screen.getByTestId("spinner");
|
||||
expect(spinner).toHaveClass("w-10", "h-10");
|
||||
});
|
||||
|
||||
it("should render with xl size (w-16 h-16)", () => {
|
||||
render(<Spinner size="xl" />);
|
||||
|
||||
const spinner = screen.getByTestId("spinner");
|
||||
expect(spinner).toHaveClass("w-16", "h-16");
|
||||
});
|
||||
|
||||
it("should use custom testId when provided", () => {
|
||||
render(<Spinner testId="custom-spinner" />);
|
||||
|
||||
expect(screen.getByTestId("custom-spinner")).toBeInTheDocument();
|
||||
expect(screen.queryByTestId("spinner")).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should apply custom className", () => {
|
||||
render(<Spinner className="text-white" />);
|
||||
|
||||
expect(screen.getByTestId("spinner")).toHaveClass("text-white");
|
||||
});
|
||||
|
||||
it("should render with label text when provided", () => {
|
||||
render(<Spinner label="Loading..." />);
|
||||
|
||||
expect(screen.getByText("Loading...")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should not render label when not provided", () => {
|
||||
render(<Spinner />);
|
||||
|
||||
expect(screen.queryByText(/loading/i)).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should use border-based styling for circular spinner appearance", () => {
|
||||
render(<Spinner />);
|
||||
|
||||
const spinner = screen.getByTestId("spinner");
|
||||
expect(spinner).toHaveClass("border-2", "rounded-full");
|
||||
});
|
||||
});
|
||||
@@ -20,7 +20,7 @@ import { useSendMessage } from "#/hooks/use-send-message";
|
||||
import { useAgentState } from "#/hooks/use-agent-state";
|
||||
|
||||
import { ScrollToBottomButton } from "#/components/shared/buttons/scroll-to-bottom-button";
|
||||
import { LoadingSpinner } from "#/components/shared/loading-spinner";
|
||||
import { Spinner } from "#/components/shared/spinner";
|
||||
import { ChatMessagesSkeleton } from "./chat-messages-skeleton";
|
||||
import { displayErrorToast } from "#/utils/custom-toast-handlers";
|
||||
import { useErrorMessageStore } from "#/stores/error-message-store";
|
||||
@@ -294,7 +294,7 @@ export function ChatInterface() {
|
||||
|
||||
{isChatLoading && !isReturningToConversation && (
|
||||
<div className="flex justify-center" data-testid="loading-spinner">
|
||||
<LoadingSpinner size="small" />
|
||||
<Spinner size="sm" />
|
||||
</div>
|
||||
)}
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import { Spinner } from "@heroui/react";
|
||||
import { useTranslation } from "react-i18next";
|
||||
import { ModalBody } from "#/components/shared/modals/modal-body";
|
||||
import { Spinner } from "#/components/shared/spinner";
|
||||
import { Typography } from "#/ui/typography";
|
||||
|
||||
export function LoadingMicroagentBody() {
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import React from "react";
|
||||
import { useTranslation } from "react-i18next";
|
||||
import { Spinner } from "@heroui/react";
|
||||
import { Spinner } from "#/components/shared/spinner";
|
||||
import { MicroagentStatus } from "#/types/microagent-status";
|
||||
import { SuccessIndicator } from "../success-indicator";
|
||||
import { Typography } from "#/ui/typography";
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import toast from "react-hot-toast";
|
||||
import { Spinner } from "@heroui/react";
|
||||
import { useTranslation } from "react-i18next";
|
||||
import { Spinner } from "#/components/shared/spinner";
|
||||
import { TOAST_OPTIONS } from "#/utils/custom-toast-handlers";
|
||||
import CloseIcon from "#/icons/close.svg?react";
|
||||
import { SuccessIndicator } from "../success-indicator";
|
||||
|
||||
@@ -2,7 +2,7 @@ import { useMemo } from "react";
|
||||
import { useTranslation } from "react-i18next";
|
||||
import CircleIcon from "#/icons/u-circle.svg?react";
|
||||
import CheckCircleIcon from "#/icons/u-check-circle.svg?react";
|
||||
import LoadingIcon from "#/icons/loading.svg?react";
|
||||
import { Spinner } from "#/components/shared/spinner";
|
||||
import { I18nKey } from "#/i18n/declaration";
|
||||
import { cn } from "#/utils/utils";
|
||||
import { Typography } from "#/ui/typography";
|
||||
@@ -24,7 +24,7 @@ export function TaskItem({ task }: TaskItemProps) {
|
||||
case "todo":
|
||||
return <CircleIcon className="w-4 h-4 text-[#ffffff]" />;
|
||||
case "in_progress":
|
||||
return <LoadingIcon className="w-4 h-4 text-[#ffffff] animate-spin" />;
|
||||
return <Spinner size="sm" className="text-white" />;
|
||||
case "done":
|
||||
return <CheckCircleIcon className="w-4 h-4 text-[#A3A3A3]" />;
|
||||
default:
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import { LoaderCircle } from "lucide-react";
|
||||
import FileIcon from "#/icons/file.svg?react";
|
||||
import { RemoveFileButton } from "./remove-file-button";
|
||||
import { Spinner } from "#/components/shared/spinner";
|
||||
import { cn, getFileExtension } from "#/utils/utils";
|
||||
|
||||
interface UploadedFileProps {
|
||||
@@ -39,7 +39,7 @@ export function UploadedFile({
|
||||
</div>
|
||||
{isLoading && (
|
||||
<div className="flex items-center justify-center">
|
||||
<LoaderCircle className="animate-spin w-5 h-5" color="white" />
|
||||
<Spinner size="sm" className="w-5 h-5 text-white" />
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import React from "react";
|
||||
import { LoaderCircle } from "lucide-react";
|
||||
import { RemoveFileButton } from "./remove-file-button";
|
||||
import { Spinner } from "#/components/shared/spinner";
|
||||
|
||||
interface UploadedImageProps {
|
||||
image: File;
|
||||
@@ -30,7 +30,7 @@ export function UploadedImage({
|
||||
<div className="group min-w-[51px] min-h-[49px] w-[51px] h-[49px] rounded-lg bg-[#525252] relative flex items-center justify-center">
|
||||
<RemoveFileButton onClick={onRemove} />
|
||||
{isLoading ? (
|
||||
<LoaderCircle className="animate-spin w-5 h-5" color="white" />
|
||||
<Spinner size="sm" className="w-5 h-5 text-white" />
|
||||
) : (
|
||||
imageUrl && (
|
||||
<img
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
import { LoaderCircle } from "lucide-react";
|
||||
import { Spinner } from "#/components/shared/spinner";
|
||||
|
||||
export function AgentLoading() {
|
||||
return (
|
||||
<div data-testid="agent-loading-spinner">
|
||||
<LoaderCircle className="animate-spin w-4 h-4" color="white" />
|
||||
</div>
|
||||
<Spinner size="sm" testId="agent-loading-spinner" className="text-white" />
|
||||
);
|
||||
}
|
||||
|
||||
@@ -9,7 +9,7 @@ import { useDeleteConversation } from "#/hooks/mutation/use-delete-conversation"
|
||||
import { useUnifiedPauseConversationSandbox } from "#/hooks/mutation/use-unified-stop-conversation";
|
||||
import { ConfirmDeleteModal } from "./confirm-delete-modal";
|
||||
import { ConfirmStopModal } from "./confirm-stop-modal";
|
||||
import { LoadingSpinner } from "#/components/shared/loading-spinner";
|
||||
import { Spinner } from "#/components/shared/spinner";
|
||||
import { ExitConversationModal } from "./exit-conversation-modal";
|
||||
import { useClickOutsideElement } from "#/hooks/use-click-outside-element";
|
||||
import { Provider } from "#/types/settings";
|
||||
@@ -212,7 +212,7 @@ export function ConversationPanel({ onClose }: ConversationPanelProps) {
|
||||
{/* Loading indicator for fetching more conversations */}
|
||||
{isFetchingNextPage && (
|
||||
<div className="flex justify-center py-4">
|
||||
<LoadingSpinner size="small" />
|
||||
<Spinner size="sm" />
|
||||
</div>
|
||||
)}
|
||||
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
import { Spinner } from "#/components/shared/spinner";
|
||||
|
||||
export function SkillsLoadingState() {
|
||||
return (
|
||||
<div className="flex justify-center items-center py-8">
|
||||
<div className="animate-spin rounded-full h-8 w-8 border-t-2 border-b-2 border-primary" />
|
||||
<Spinner size="md" className="text-primary" />
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
import { LoaderCircle } from "lucide-react";
|
||||
import { useTranslation } from "react-i18next";
|
||||
import { I18nKey } from "#/i18n/declaration";
|
||||
import { Spinner } from "#/components/shared/spinner";
|
||||
|
||||
export function ConversationLoading() {
|
||||
const { t } = useTranslation();
|
||||
|
||||
return (
|
||||
<div className="bg-[#25272D] border border-[#525252] rounded-xl flex flex-col items-center justify-center h-full w-full">
|
||||
<LoaderCircle className="animate-spin w-16 h-16" color="white" />
|
||||
<Spinner size="xl" className="text-white" />
|
||||
<span className="text-2xl font-normal leading-5 text-white p-4">
|
||||
{t(I18nKey.HOME$LOADING)}
|
||||
</span>
|
||||
|
||||
@@ -8,26 +8,7 @@ import { getLanguageFromPath } from "#/utils/get-language-from-path";
|
||||
import { cn } from "#/utils/utils";
|
||||
import ChevronUp from "#/icons/chveron-up.svg?react";
|
||||
import { useUnifiedGitDiff } from "#/hooks/query/use-unified-git-diff";
|
||||
|
||||
interface LoadingSpinnerProps {
|
||||
className?: string;
|
||||
}
|
||||
|
||||
// TODO: Move out of this file and replace the current spinner with this one
|
||||
function LoadingSpinner({ className }: LoadingSpinnerProps) {
|
||||
return (
|
||||
<div className="flex items-center justify-center">
|
||||
<div
|
||||
className={cn(
|
||||
"animate-spin rounded-full border-4 border-gray-200 border-t-blue-500",
|
||||
className,
|
||||
)}
|
||||
role="status"
|
||||
aria-label="Loading"
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
import { Spinner } from "#/components/shared/spinner";
|
||||
|
||||
const STATUS_MAP: Record<GitChangeStatus, string | IconType> = {
|
||||
A: LuFilePlus,
|
||||
@@ -144,7 +125,7 @@ export function FileDiffViewer({ path, type }: FileDiffViewerProps) {
|
||||
onClick={() => setIsCollapsed((prev) => !prev)}
|
||||
>
|
||||
<span className="text-sm w-full text-content flex items-center gap-2">
|
||||
{isFetchingData && <LoadingSpinner className="w-5 h-5" />}
|
||||
{isFetchingData && <Spinner className="w-5 h-5 text-blue-500" />}
|
||||
{!isFetchingData && statusIcon}
|
||||
<strong className="w-full truncate">{filePath}</strong>
|
||||
<button data-testid="collapse" type="button">
|
||||
|
||||
@@ -17,6 +17,7 @@ import { ToggleButton } from "../shared/toggle-button";
|
||||
import { ErrorMessage } from "../shared/error-message";
|
||||
import { BranchDropdownMenu } from "./branch-dropdown-menu";
|
||||
import BranchIcon from "#/icons/u-code-branch.svg?react";
|
||||
import { Spinner } from "#/components/shared/spinner";
|
||||
|
||||
export interface GitBranchDropdownProps {
|
||||
repository: string | null;
|
||||
@@ -186,7 +187,7 @@ export function GitBranchDropdown({
|
||||
<div className="relative">
|
||||
<div className="absolute left-2 top-1/2 transform -translate-y-1/2 z-10">
|
||||
{isLoadingState ? (
|
||||
<div className="animate-spin h-4 w-4 border-2 border-blue-500 border-t-transparent rounded-full" />
|
||||
<Spinner size="sm" className="text-blue-500" />
|
||||
) : (
|
||||
<BranchIcon width={16} height={16} />
|
||||
)}
|
||||
|
||||
@@ -25,6 +25,7 @@ import { I18nKey } from "#/i18n/declaration";
|
||||
import RepoIcon from "#/icons/repo.svg?react";
|
||||
import { useHomeStore } from "#/stores/home-store";
|
||||
import { Typography } from "#/ui/typography";
|
||||
import { Spinner } from "#/components/shared/spinner";
|
||||
|
||||
export interface GitRepoDropdownProps {
|
||||
provider: Provider;
|
||||
@@ -318,7 +319,7 @@ export function GitRepoDropdown({
|
||||
<div className="relative">
|
||||
<div className="absolute left-2 top-1/2 transform -translate-y-1/2 z-10">
|
||||
{isLoadingState ? (
|
||||
<div className="animate-spin h-4 w-4 border-2 border-blue-500 border-t-transparent rounded-full" />
|
||||
<Spinner size="sm" className="text-blue-500" />
|
||||
) : (
|
||||
<RepoIcon width={16} height={16} />
|
||||
)}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import { useTranslation } from "react-i18next";
|
||||
import { Spinner } from "@heroui/react";
|
||||
import { Spinner } from "#/components/shared/spinner";
|
||||
import { cn } from "#/utils/utils";
|
||||
|
||||
interface BranchLoadingStateProps {
|
||||
|
||||
+1
-1
@@ -1,5 +1,5 @@
|
||||
import { useTranslation } from "react-i18next";
|
||||
import { Spinner } from "@heroui/react";
|
||||
import { Spinner } from "#/components/shared/spinner";
|
||||
import { cn } from "#/utils/utils";
|
||||
|
||||
export interface RepositoryLoadingStateProps {
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import React from "react";
|
||||
import { cn } from "#/utils/utils";
|
||||
import { Spinner } from "#/components/shared/spinner";
|
||||
|
||||
interface LoadingSpinnerProps {
|
||||
hasSelection: boolean;
|
||||
@@ -17,10 +17,7 @@ export function LoadingSpinner({
|
||||
hasSelection ? "right-11" : "right-6",
|
||||
)}
|
||||
>
|
||||
<div
|
||||
className="animate-spin h-4 w-4 border-2 border-blue-500 border-t-transparent rounded-full"
|
||||
data-testid={testId}
|
||||
/>
|
||||
<Spinner size="sm" testId={testId} className="text-blue-500" />
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
+2
-2
@@ -1,6 +1,6 @@
|
||||
import { useTranslation } from "react-i18next";
|
||||
import { useEffect } from "react";
|
||||
import { Spinner } from "@heroui/react";
|
||||
import { Spinner } from "#/components/shared/spinner";
|
||||
import { MicroagentManagementMicroagentCard } from "./microagent-management-microagent-card";
|
||||
import { MicroagentManagementLearnThisRepo } from "./microagent-management-learn-this-repo";
|
||||
import { useRepositoryMicroagents } from "#/hooks/query/use-repository-microagents";
|
||||
@@ -82,7 +82,7 @@ export function MicroagentManagementRepoMicroagents({
|
||||
if (isLoading) {
|
||||
return (
|
||||
<div className="pb-4 flex justify-center">
|
||||
<Spinner size="sm" data-testid="loading-spinner" />
|
||||
<Spinner size="sm" testId="loading-spinner" />
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
+2
-1
@@ -1,5 +1,6 @@
|
||||
import { useTranslation } from "react-i18next";
|
||||
import { Accordion, AccordionItem, Spinner } from "@heroui/react";
|
||||
import { Accordion, AccordionItem } from "@heroui/react";
|
||||
import { Spinner } from "#/components/shared/spinner";
|
||||
import { MicroagentManagementRepoMicroagents } from "./microagent-management-repo-microagents";
|
||||
import { GitRepository } from "#/types/git";
|
||||
import { TabType } from "#/types/microagent-management";
|
||||
|
||||
+1
-1
@@ -1,6 +1,6 @@
|
||||
import { useEffect, useState, useMemo } from "react";
|
||||
import { useTranslation } from "react-i18next";
|
||||
import { Spinner } from "@heroui/react";
|
||||
import { Spinner } from "#/components/shared/spinner";
|
||||
import { MicroagentManagementSidebarHeader } from "./microagent-management-sidebar-header";
|
||||
import { MicroagentManagementSidebarTabs } from "./microagent-management-sidebar-tabs";
|
||||
import { useGitRepositories } from "#/hooks/query/use-git-repositories";
|
||||
|
||||
+2
-2
@@ -1,5 +1,5 @@
|
||||
import { useTranslation } from "react-i18next";
|
||||
import { Spinner } from "@heroui/react";
|
||||
import { Spinner } from "#/components/shared/spinner";
|
||||
import { useMicroagentManagementStore } from "#/stores/microagent-management-store";
|
||||
import { useRepositoryMicroagentContent } from "#/hooks/query/use-repository-microagent-content";
|
||||
import { I18nKey } from "#/i18n/declaration";
|
||||
@@ -34,7 +34,7 @@ export function MicroagentManagementViewMicroagentContent() {
|
||||
<div className="w-full h-full p-6 bg-[#ffffff1a] rounded-2xl text-white text-sm">
|
||||
{isLoading && (
|
||||
<div className="flex items-center justify-center w-full h-full">
|
||||
<Spinner size="lg" data-testid="loading-microagent-content-spinner" />
|
||||
<Spinner size="lg" testId="loading-microagent-content-spinner" />
|
||||
</div>
|
||||
)}
|
||||
{error && (
|
||||
|
||||
@@ -6,7 +6,7 @@ import { cn } from "#/utils/utils";
|
||||
import MoneyIcon from "#/icons/money.svg?react";
|
||||
import { SettingsInput } from "../settings/settings-input";
|
||||
import { BrandButton } from "../settings/brand-button";
|
||||
import { LoadingSpinner } from "#/components/shared/loading-spinner";
|
||||
import { Spinner } from "#/components/shared/spinner";
|
||||
import { amountIsValid } from "#/utils/amount-is-valid";
|
||||
import { I18nKey } from "#/i18n/declaration";
|
||||
import { PoweredByStripeTag } from "./powered-by-stripe-tag";
|
||||
@@ -54,7 +54,7 @@ export function PaymentForm() {
|
||||
{!isLoading && (
|
||||
<span data-testid="user-balance">${Number(balance).toFixed(2)}</span>
|
||||
)}
|
||||
{isLoading && <LoadingSpinner size="small" />}
|
||||
{isLoading && <Spinner size="sm" />}
|
||||
</div>
|
||||
|
||||
<div className="flex flex-col gap-3">
|
||||
@@ -79,7 +79,7 @@ export function PaymentForm() {
|
||||
>
|
||||
{t(I18nKey.PAYMENT$ADD_CREDIT)}
|
||||
</BrandButton>
|
||||
{isPending && <LoadingSpinner size="small" />}
|
||||
{isPending && <Spinner size="sm" />}
|
||||
<PoweredByStripeTag />
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -3,7 +3,7 @@ import { useTranslation, Trans } from "react-i18next";
|
||||
import { FaTrash, FaEye, FaEyeSlash, FaCopy } from "react-icons/fa6";
|
||||
import { I18nKey } from "#/i18n/declaration";
|
||||
import { BrandButton } from "#/components/features/settings/brand-button";
|
||||
import { LoadingSpinner } from "#/components/shared/loading-spinner";
|
||||
import { Spinner } from "#/components/shared/spinner";
|
||||
import { ApiKey, CreateApiKeyResponse } from "#/api/api-keys";
|
||||
import {
|
||||
displayErrorToast,
|
||||
@@ -62,7 +62,7 @@ function LlmApiKeyManager({
|
||||
isDisabled={refreshLlmApiKey.isPending}
|
||||
>
|
||||
{refreshLlmApiKey.isPending ? (
|
||||
<LoadingSpinner size="small" />
|
||||
<Spinner size="sm" />
|
||||
) : (
|
||||
t(I18nKey.SETTINGS$REFRESH_LLM_API_KEY)
|
||||
)}
|
||||
@@ -146,7 +146,7 @@ function ApiKeysTable({ apiKeys, isLoading, onDeleteKey }: ApiKeysTableProps) {
|
||||
if (isLoading) {
|
||||
return (
|
||||
<div className="flex justify-center p-4">
|
||||
<LoadingSpinner size="large" />
|
||||
<Spinner size="xl" />
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -3,7 +3,7 @@ import { useTranslation } from "react-i18next";
|
||||
import { I18nKey } from "#/i18n/declaration";
|
||||
import { BrandButton } from "#/components/features/settings/brand-button";
|
||||
import { SettingsInput } from "#/components/features/settings/settings-input";
|
||||
import { LoadingSpinner } from "#/components/shared/loading-spinner";
|
||||
import { Spinner } from "#/components/shared/spinner";
|
||||
import { CreateApiKeyResponse } from "#/api/api-keys";
|
||||
import {
|
||||
displayErrorToast,
|
||||
@@ -59,7 +59,7 @@ export function CreateApiKeyModal({
|
||||
isDisabled={createApiKeyMutation.isPending || !newKeyName.trim()}
|
||||
>
|
||||
{createApiKeyMutation.isPending ? (
|
||||
<LoadingSpinner size="small" />
|
||||
<Spinner size="sm" />
|
||||
) : (
|
||||
t(I18nKey.BUTTON$CREATE)
|
||||
)}
|
||||
|
||||
@@ -2,7 +2,7 @@ import React from "react";
|
||||
import { useTranslation } from "react-i18next";
|
||||
import { I18nKey } from "#/i18n/declaration";
|
||||
import { BrandButton } from "#/components/features/settings/brand-button";
|
||||
import { LoadingSpinner } from "#/components/shared/loading-spinner";
|
||||
import { Spinner } from "#/components/shared/spinner";
|
||||
import { ApiKey } from "#/api/api-keys";
|
||||
import {
|
||||
displayErrorToast,
|
||||
@@ -49,7 +49,7 @@ export function DeleteApiKeyModal({
|
||||
isDisabled={deleteApiKeyMutation.isPending}
|
||||
>
|
||||
{deleteApiKeyMutation.isPending ? (
|
||||
<LoadingSpinner size="small" />
|
||||
<Spinner size="sm" />
|
||||
) : (
|
||||
t(I18nKey.BUTTON$DELETE)
|
||||
)}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import { useTranslation } from "react-i18next";
|
||||
import { I18nKey } from "#/i18n/declaration";
|
||||
import { LoadingSpinner } from "#/components/shared/loading-spinner";
|
||||
import { Spinner } from "#/components/shared/spinner";
|
||||
import ProfileIcon from "#/icons/profile.svg?react";
|
||||
import { cn } from "#/utils/utils";
|
||||
import { Avatar } from "./avatar";
|
||||
@@ -33,7 +33,7 @@ export function UserAvatar({ onClick, avatarUrl, isLoading }: UserAvatarProps) {
|
||||
className="text-[#9099AC]"
|
||||
/>
|
||||
)}
|
||||
{isLoading && <LoadingSpinner size="small" />}
|
||||
{isLoading && <Spinner size="sm" testId="loading-spinner" />}
|
||||
</button>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,23 +0,0 @@
|
||||
import LoadingSpinnerOuter from "#/icons/loading-outer.svg?react";
|
||||
import { cn } from "#/utils/utils";
|
||||
|
||||
interface LoadingSpinnerProps {
|
||||
size: "small" | "large";
|
||||
}
|
||||
|
||||
export function LoadingSpinner({ size }: LoadingSpinnerProps) {
|
||||
const sizeStyle =
|
||||
size === "small" ? "w-[25px] h-[25px]" : "w-[50px] h-[50px]";
|
||||
|
||||
return (
|
||||
<div data-testid="loading-spinner" className={cn("relative", sizeStyle)}>
|
||||
<div
|
||||
className={cn(
|
||||
"rounded-full border-4 border-[#525252] absolute",
|
||||
sizeStyle,
|
||||
)}
|
||||
/>
|
||||
<LoadingSpinnerOuter className={cn("absolute animate-spin", sizeStyle)} />
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -1,7 +1,7 @@
|
||||
import { useTranslation } from "react-i18next";
|
||||
import { useAIConfigOptions } from "#/hooks/query/use-ai-config-options";
|
||||
import { I18nKey } from "#/i18n/declaration";
|
||||
import { LoadingSpinner } from "../../loading-spinner";
|
||||
import { Spinner } from "../../spinner";
|
||||
import { ModalBackdrop } from "../modal-backdrop";
|
||||
import { SettingsForm } from "./settings-form";
|
||||
import { Settings } from "#/types/settings";
|
||||
@@ -42,7 +42,7 @@ export function SettingsModal({ onClose, settings }: SettingsModalProps) {
|
||||
|
||||
{aiConfigOptions.isLoading && (
|
||||
<div className="flex justify-center">
|
||||
<LoadingSpinner size="small" />
|
||||
<Spinner size="sm" />
|
||||
</div>
|
||||
)}
|
||||
{aiConfigOptions.data && (
|
||||
|
||||
@@ -0,0 +1,44 @@
|
||||
import { cn } from "#/utils/utils";
|
||||
|
||||
interface SpinnerProps {
|
||||
size?: "sm" | "md" | "lg" | "xl";
|
||||
label?: string;
|
||||
testId?: string;
|
||||
className?: string;
|
||||
}
|
||||
|
||||
const sizeClasses = {
|
||||
sm: "w-4 h-4",
|
||||
md: "w-6 h-6",
|
||||
lg: "w-10 h-10",
|
||||
xl: "w-16 h-16",
|
||||
};
|
||||
|
||||
export function Spinner({
|
||||
size = "md",
|
||||
label,
|
||||
testId = "spinner",
|
||||
className,
|
||||
}: SpinnerProps) {
|
||||
const spinnerElement = (
|
||||
<div
|
||||
data-testid={testId}
|
||||
className={cn(
|
||||
"animate-spin rounded-full border-2 border-current border-t-transparent",
|
||||
sizeClasses[size],
|
||||
className,
|
||||
)}
|
||||
/>
|
||||
);
|
||||
|
||||
if (label) {
|
||||
return (
|
||||
<div className="flex flex-col items-center gap-2">
|
||||
{spinnerElement}
|
||||
<span>{label}</span>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return spinnerElement;
|
||||
}
|
||||
@@ -3,7 +3,7 @@ import { useTranslation } from "react-i18next";
|
||||
import { TaskItem as TaskItemType } from "#/types/v1/core/base/common";
|
||||
import CircleIcon from "#/icons/u-circle.svg?react";
|
||||
import CheckCircleIcon from "#/icons/u-check-circle.svg?react";
|
||||
import LoadingIcon from "#/icons/loading.svg?react";
|
||||
import { Spinner } from "#/components/shared/spinner";
|
||||
import { cn } from "#/utils/utils";
|
||||
import { Typography } from "#/ui/typography";
|
||||
import { I18nKey } from "#/i18n/declaration";
|
||||
@@ -20,7 +20,7 @@ export function TaskItem({ task }: TaskItemProps) {
|
||||
case "todo":
|
||||
return <CircleIcon className="w-4 h-4 text-[#ffffff]" />;
|
||||
case "in_progress":
|
||||
return <LoadingIcon className="w-4 h-4 text-[#ffffff] animate-spin" />;
|
||||
return <Spinner size="sm" className="text-white" />;
|
||||
case "done":
|
||||
return <CheckCircleIcon className="w-4 h-4 text-[#A3A3A3]" />;
|
||||
default:
|
||||
|
||||
@@ -1,4 +0,0 @@
|
||||
<svg width="66" height="66" viewBox="0 0 66 66" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path d="M63 33C63 16.4315 49.5685 3 33 3C16.4315 3 3 16.4315 3 33C3 49.5685 16.4315 63 33 63"
|
||||
stroke="#007AFF" stroke-width="6" stroke-linecap="round" />
|
||||
</svg>
|
||||
|
Before Width: | Height: | Size: 264 B |
@@ -1,3 +0,0 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
|
||||
<path d="M8 0.5C9.74234 0.5 11.4264 1.12774 12.7442 2.26758C14.062 3.40746 14.9259 4.98379 15.1768 6.70801L15.2188 6.99316H13.7061L13.6709 6.78516C13.4431 5.44635 12.7486 4.23161 11.711 3.35547C10.6731 2.47925 9.35827 1.99805 8 1.99805C6.64182 1.99811 5.32782 2.47931 4.29004 3.35547C3.25229 4.23161 2.55792 5.44628 2.33007 6.78516L2.29394 6.99316H0.782227L0.824219 6.70801C1.07515 4.98389 1.9382 3.40745 3.25586 2.26758C4.57357 1.12776 6.25771 0.500069 8 0.5Z" fill="currentColor" stroke="currentColor" stroke-width="0.5"/>
|
||||
</svg>
|
||||
|
Before Width: | Height: | Size: 630 B |
@@ -2,6 +2,7 @@
|
||||
import React, { useState } from "react";
|
||||
import { useSearchParams } from "react-router";
|
||||
import { useIsAuthed } from "#/hooks/query/use-is-authed";
|
||||
import { Spinner } from "#/components/shared/spinner";
|
||||
|
||||
export default function DeviceVerify() {
|
||||
const [searchParams] = useSearchParams();
|
||||
@@ -131,7 +132,7 @@ export default function DeviceVerify() {
|
||||
<div className="min-h-screen flex items-center justify-center bg-background">
|
||||
<div className="max-w-md w-full mx-auto p-6 bg-card rounded-lg shadow-lg">
|
||||
<div className="text-center">
|
||||
<div className="animate-spin rounded-full h-8 w-8 border-b-2 border-primary mx-auto mb-4" />
|
||||
<Spinner size="md" className="text-primary mx-auto mb-4" />
|
||||
<p className="text-muted-foreground">
|
||||
Processing device verification...
|
||||
</p>
|
||||
@@ -251,7 +252,7 @@ export default function DeviceVerify() {
|
||||
return (
|
||||
<div className="min-h-screen flex items-center justify-center bg-background">
|
||||
<div className="text-center">
|
||||
<div className="animate-spin rounded-full h-8 w-8 border-b-2 border-primary mx-auto mb-4" />
|
||||
<Spinner size="md" className="text-primary mx-auto mb-4" />
|
||||
<p className="text-muted-foreground">
|
||||
Processing device verification...
|
||||
</p>
|
||||
|
||||
@@ -6,6 +6,7 @@ import { useGitHubAuthUrl } from "#/hooks/use-github-auth-url";
|
||||
import { useEmailVerification } from "#/hooks/use-email-verification";
|
||||
import { LoginContent } from "#/components/features/auth/login-content";
|
||||
import { EmailVerificationModal } from "#/components/features/waitlist/email-verification-modal";
|
||||
import { Spinner } from "#/components/shared/spinner";
|
||||
|
||||
export default function LoginPage() {
|
||||
const navigate = useNavigate();
|
||||
@@ -46,7 +47,7 @@ export default function LoginPage() {
|
||||
if (isAuthLoading || config.isLoading) {
|
||||
return (
|
||||
<div className="min-h-screen flex items-center justify-center bg-base">
|
||||
<div className="animate-spin rounded-full h-8 w-8 border-b-2 border-white" />
|
||||
<Spinner size="md" className="text-white" />
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -28,7 +28,7 @@ import { LOCAL_STORAGE_KEYS } from "#/utils/local-storage";
|
||||
import { EmailVerificationGuard } from "#/components/features/guards/email-verification-guard";
|
||||
import { MaintenanceBanner } from "#/components/features/maintenance/maintenance-banner";
|
||||
import { cn, isMobileDevice } from "#/utils/utils";
|
||||
import { LoadingSpinner } from "#/components/shared/loading-spinner";
|
||||
import { Spinner } from "#/components/shared/spinner";
|
||||
import { useAppTitle } from "#/hooks/use-app-title";
|
||||
|
||||
export function ErrorBoundary() {
|
||||
@@ -200,7 +200,7 @@ export default function MainApp() {
|
||||
if (shouldRedirectToLogin) {
|
||||
return (
|
||||
<div className="min-h-screen flex items-center justify-center bg-base">
|
||||
<LoadingSpinner size="large" />
|
||||
<Spinner size="xl" />
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@ import { useSharedConversation } from "#/hooks/query/use-shared-conversation";
|
||||
import { useSharedConversationEvents } from "#/hooks/query/use-shared-conversation-events";
|
||||
import { Messages as V1Messages } from "#/components/v1/chat";
|
||||
import { shouldRenderEvent } from "#/components/v1/chat/event-content-helpers/should-render-event";
|
||||
import { LoadingSpinner } from "#/components/shared/loading-spinner";
|
||||
import { Spinner } from "#/components/shared/spinner";
|
||||
import OpenHandsLogo from "#/assets/branding/openhands-logo.svg?react";
|
||||
|
||||
export default function SharedConversation() {
|
||||
@@ -39,7 +39,7 @@ export default function SharedConversation() {
|
||||
if (isLoading) {
|
||||
return (
|
||||
<div className="flex items-center justify-center h-screen bg-neutral-900">
|
||||
<LoadingSpinner size="large" />
|
||||
<Spinner size="xl" />
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Literal
|
||||
from typing import Literal
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
@@ -14,7 +14,6 @@ from openhands.app_server.sandbox.sandbox_models import SandboxStatus
|
||||
from openhands.integrations.service_types import ProviderType
|
||||
from openhands.sdk.conversation.state import ConversationExecutionStatus
|
||||
from openhands.sdk.llm import MetricsSnapshot
|
||||
from openhands.sdk.plugin import PluginSource
|
||||
from openhands.storage.data_models.conversation_metadata import ConversationTrigger
|
||||
|
||||
|
||||
@@ -25,45 +24,6 @@ class AgentType(Enum):
|
||||
PLAN = 'plan'
|
||||
|
||||
|
||||
class PluginSpec(PluginSource):
|
||||
"""Specification for loading a plugin into a conversation.
|
||||
|
||||
Extends SDK's PluginSource with user-provided plugin configuration parameters.
|
||||
Inherits source, ref, and repo_path fields along with their validation.
|
||||
"""
|
||||
|
||||
parameters: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
description='User-provided values for plugin input parameters',
|
||||
)
|
||||
|
||||
@property
|
||||
def display_name(self) -> str:
|
||||
"""Extract a friendly display name from the plugin source.
|
||||
|
||||
Examples:
|
||||
- 'github:owner/repo' -> 'repo'
|
||||
- 'https://github.com/owner/repo.git' -> 'repo.git'
|
||||
- '/local/path' -> 'path'
|
||||
"""
|
||||
return self.source.split('/')[-1] if '/' in self.source else self.source
|
||||
|
||||
def format_params_as_text(self, indent: str = '') -> str | None:
|
||||
"""Format parameters as a readable text block for display.
|
||||
|
||||
Args:
|
||||
indent: Optional prefix to add before each parameter line.
|
||||
|
||||
Returns:
|
||||
Formatted parameters string, or None if no parameters.
|
||||
"""
|
||||
if not self.parameters:
|
||||
return None
|
||||
return '\n'.join(
|
||||
f'{indent}- {key}: {value}' for key, value in self.parameters.items()
|
||||
)
|
||||
|
||||
|
||||
class AppConversationInfo(BaseModel):
|
||||
"""Conversation info which does not contain status."""
|
||||
|
||||
@@ -158,15 +118,6 @@ class AppConversationStartRequest(OpenHandsModel):
|
||||
|
||||
public: bool | None = None
|
||||
|
||||
# Plugin parameters - for loading remote plugins into the conversation
|
||||
plugins: list[PluginSpec] | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
'List of plugins to load for this conversation. Plugins are loaded '
|
||||
'and their skills/MCP config are merged into the agent.'
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class AppConversationUpdateRequest(BaseModel):
|
||||
public: bool
|
||||
@@ -196,8 +147,7 @@ class AppConversationStartTask(OpenHandsModel):
|
||||
|
||||
Because starting an app conversation can be slow (And can involve starting a sandbox),
|
||||
we kick off a background task for it. Once the conversation is started, the app_conversation_id
|
||||
is populated.
|
||||
"""
|
||||
is populated."""
|
||||
|
||||
id: OpenHandsUUID = Field(default_factory=uuid4)
|
||||
created_by_user_id: str | None
|
||||
|
||||
@@ -621,6 +621,7 @@ async def _stream_app_conversation_start(
|
||||
user_context: UserContext,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Stream a json list, item by item."""
|
||||
|
||||
# Because the original dependencies are closed after the method returns, we need
|
||||
# a new dependency context which will continue intil the stream finishes.
|
||||
state = InjectorState()
|
||||
|
||||
@@ -105,8 +105,7 @@ class AppConversationService(ABC):
|
||||
self, conversation_id: UUID, request: AppConversationUpdateRequest
|
||||
) -> AppConversation | None:
|
||||
"""Update an app conversation and return it. Return None if the conversation
|
||||
did not exist.
|
||||
"""
|
||||
did not exist."""
|
||||
|
||||
@abstractmethod
|
||||
async def delete_app_conversation(self, conversation_id: UUID) -> bool:
|
||||
|
||||
@@ -51,8 +51,7 @@ PRE_COMMIT_LOCAL = '.git/hooks/pre-commit.local'
|
||||
class AppConversationServiceBase(AppConversationService, ABC):
|
||||
"""App Conversation service which adds git specific functionality.
|
||||
|
||||
Sets up repositories and installs hooks
|
||||
"""
|
||||
Sets up repositories and installs hooks"""
|
||||
|
||||
init_git_in_empty_workspace: bool
|
||||
user_context: UserContext
|
||||
@@ -485,6 +484,7 @@ class AppConversationServiceBase(AppConversationService, ABC):
|
||||
security_analyzer_str: String value from settings
|
||||
httpx_client: HTTP client for making API requests
|
||||
"""
|
||||
|
||||
if session_api_key is None:
|
||||
return
|
||||
|
||||
|
||||
@@ -32,7 +32,6 @@ from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
AppConversationStartTask,
|
||||
AppConversationStartTaskStatus,
|
||||
AppConversationUpdateRequest,
|
||||
PluginSpec,
|
||||
)
|
||||
from openhands.app_server.app_conversation.app_conversation_service import (
|
||||
AppConversationService,
|
||||
@@ -80,7 +79,6 @@ from openhands.experiments.experiment_manager import ExperimentManagerImpl
|
||||
from openhands.integrations.provider import ProviderType
|
||||
from openhands.sdk import Agent, AgentContext, LocalWorkspace
|
||||
from openhands.sdk.llm import LLM
|
||||
from openhands.sdk.plugin import PluginSource
|
||||
from openhands.sdk.secret import LookupSecret, SecretValue, StaticSecret
|
||||
from openhands.sdk.utils.paging import page_iterator
|
||||
from openhands.sdk.workspace.remote.async_remote_workspace import AsyncRemoteWorkspace
|
||||
@@ -256,7 +254,6 @@ class LiveStatusAppConversationService(AppConversationServiceBase):
|
||||
request.conversation_id,
|
||||
remote_workspace=remote_workspace,
|
||||
selected_repository=request.selected_repository,
|
||||
plugins=request.plugins,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -957,79 +954,6 @@ class LiveStatusAppConversationService(AppConversationServiceBase):
|
||||
return agent.model_copy(update=updates)
|
||||
return agent
|
||||
|
||||
def _construct_initial_message_with_plugin_params(
|
||||
self,
|
||||
initial_message: SendMessageRequest | None,
|
||||
plugins: list[PluginSpec] | None,
|
||||
) -> SendMessageRequest | None:
|
||||
"""Incorporate plugin parameters into the initial message if specified.
|
||||
|
||||
Plugin parameters are formatted and appended to the initial message so the
|
||||
agent has context about the user-provided configuration values.
|
||||
|
||||
Args:
|
||||
initial_message: The original initial message, if any
|
||||
plugins: List of plugin specifications with optional parameters
|
||||
|
||||
Returns:
|
||||
The initial message with plugin parameters incorporated, or the
|
||||
original message if no plugin parameters are specified
|
||||
"""
|
||||
from openhands.agent_server.models import TextContent
|
||||
|
||||
if not plugins:
|
||||
return initial_message
|
||||
|
||||
# Collect formatted parameters from plugins that have them
|
||||
plugins_with_params = [p for p in plugins if p.parameters]
|
||||
if not plugins_with_params:
|
||||
return initial_message
|
||||
|
||||
# Format parameters, grouped by plugin if multiple
|
||||
if len(plugins_with_params) == 1:
|
||||
params_text = plugins_with_params[0].format_params_as_text()
|
||||
plugin_params_message = (
|
||||
f'\n\nPlugin Configuration Parameters:\n{params_text}'
|
||||
)
|
||||
else:
|
||||
# Group by plugin name for clarity
|
||||
formatted_plugins = []
|
||||
for plugin in plugins_with_params:
|
||||
params_text = plugin.format_params_as_text(indent=' ')
|
||||
if params_text:
|
||||
formatted_plugins.append(f'{plugin.display_name}:\n{params_text}')
|
||||
|
||||
plugin_params_message = (
|
||||
'\n\nPlugin Configuration Parameters:\n' + '\n'.join(formatted_plugins)
|
||||
)
|
||||
|
||||
if initial_message is None:
|
||||
# Create a new message with just the plugin parameters
|
||||
return SendMessageRequest(
|
||||
content=[TextContent(text=plugin_params_message.strip())],
|
||||
run=True,
|
||||
)
|
||||
|
||||
# Append plugin parameters to existing message content
|
||||
new_content = list(initial_message.content)
|
||||
if new_content and isinstance(new_content[-1], TextContent):
|
||||
# Append to the last text content
|
||||
last_content = new_content[-1]
|
||||
new_content[-1] = TextContent(
|
||||
text=last_content.text + plugin_params_message,
|
||||
cache_prompt=last_content.cache_prompt,
|
||||
enable_truncation=last_content.enable_truncation,
|
||||
)
|
||||
else:
|
||||
# Add as new text content
|
||||
new_content.append(TextContent(text=plugin_params_message.strip()))
|
||||
|
||||
return SendMessageRequest(
|
||||
role=initial_message.role,
|
||||
content=new_content,
|
||||
run=initial_message.run,
|
||||
)
|
||||
|
||||
async def _finalize_conversation_request(
|
||||
self,
|
||||
agent: Agent,
|
||||
@@ -1042,7 +966,6 @@ class LiveStatusAppConversationService(AppConversationServiceBase):
|
||||
remote_workspace: AsyncRemoteWorkspace | None,
|
||||
selected_repository: str | None,
|
||||
working_dir: str,
|
||||
plugins: list[PluginSpec] | None = None,
|
||||
) -> StartConversationRequest:
|
||||
"""Finalize the conversation request with experiment variants and skills.
|
||||
|
||||
@@ -1057,7 +980,6 @@ class LiveStatusAppConversationService(AppConversationServiceBase):
|
||||
remote_workspace: Optional remote workspace for skills loading
|
||||
selected_repository: Optional repository name
|
||||
working_dir: Working directory path
|
||||
plugins: Optional list of plugin specifications to load
|
||||
|
||||
Returns:
|
||||
Complete StartConversationRequest ready for use
|
||||
@@ -1084,23 +1006,6 @@ class LiveStatusAppConversationService(AppConversationServiceBase):
|
||||
_logger.warning(f'Failed to load skills: {e}', exc_info=True)
|
||||
# Continue without skills - don't fail conversation startup
|
||||
|
||||
# Incorporate plugin parameters into initial message if specified
|
||||
final_initial_message = self._construct_initial_message_with_plugin_params(
|
||||
initial_message, plugins
|
||||
)
|
||||
|
||||
# Convert PluginSpec list to SDK PluginSource list for agent server
|
||||
sdk_plugins: list[PluginSource] | None = None
|
||||
if plugins:
|
||||
sdk_plugins = [
|
||||
PluginSource(
|
||||
source=p.source,
|
||||
ref=p.ref,
|
||||
repo_path=p.repo_path,
|
||||
)
|
||||
for p in plugins
|
||||
]
|
||||
|
||||
# Create and return the final request
|
||||
return StartConversationRequest(
|
||||
conversation_id=conversation_id,
|
||||
@@ -1109,9 +1014,8 @@ class LiveStatusAppConversationService(AppConversationServiceBase):
|
||||
confirmation_policy=self._select_confirmation_policy(
|
||||
bool(user.confirmation_mode), user.security_analyzer
|
||||
),
|
||||
initial_message=final_initial_message,
|
||||
initial_message=initial_message,
|
||||
secrets=secrets,
|
||||
plugins=sdk_plugins,
|
||||
)
|
||||
|
||||
async def _build_start_conversation_request_for_user(
|
||||
@@ -1126,7 +1030,6 @@ class LiveStatusAppConversationService(AppConversationServiceBase):
|
||||
conversation_id: UUID | None = None,
|
||||
remote_workspace: AsyncRemoteWorkspace | None = None,
|
||||
selected_repository: str | None = None,
|
||||
plugins: list[PluginSpec] | None = None,
|
||||
) -> StartConversationRequest:
|
||||
"""Build a complete conversation request for a user.
|
||||
|
||||
@@ -1135,7 +1038,6 @@ class LiveStatusAppConversationService(AppConversationServiceBase):
|
||||
2. Configuring LLM and MCP settings
|
||||
3. Creating an agent with appropriate context
|
||||
4. Finalizing the request with skills and experiment variants
|
||||
5. Passing plugins to the agent server for remote plugin loading
|
||||
"""
|
||||
user = await self.user_context.get_user_info()
|
||||
workspace = LocalWorkspace(working_dir=working_dir)
|
||||
@@ -1168,7 +1070,6 @@ class LiveStatusAppConversationService(AppConversationServiceBase):
|
||||
remote_workspace,
|
||||
selected_repository,
|
||||
working_dir,
|
||||
plugins=plugins,
|
||||
)
|
||||
|
||||
async def update_agent_server_conversation_title(
|
||||
@@ -1223,8 +1124,7 @@ class LiveStatusAppConversationService(AppConversationServiceBase):
|
||||
self, conversation_id: UUID, request: AppConversationUpdateRequest
|
||||
) -> AppConversation | None:
|
||||
"""Update an app conversation and return it. Return None if the conversation
|
||||
did not exist.
|
||||
"""
|
||||
did not exist."""
|
||||
info = await self.app_conversation_info_service.get_app_conversation_info(
|
||||
conversation_id
|
||||
)
|
||||
|
||||
@@ -541,8 +541,7 @@ class SQLAppConversationInfoService(AppConversationInfoService):
|
||||
|
||||
def _fix_timezone(self, value: datetime) -> datetime:
|
||||
"""Sqlite does not stpre timezones - and since we can't update the existing models
|
||||
we assume UTC if the timezone is missing.
|
||||
"""
|
||||
we assume UTC if the timezone is missing."""
|
||||
if not value.tzinfo:
|
||||
value = value.replace(tzinfo=UTC)
|
||||
return value
|
||||
|
||||
@@ -68,8 +68,7 @@ class StoredAppConversationStartTask(Base): # type: ignore
|
||||
class SQLAppConversationStartTaskService(AppConversationStartTaskService):
|
||||
"""SQL implementation of AppConversationStartTaskService focused on db operations.
|
||||
|
||||
This allows storing and retrieving conversation start tasks from the database.
|
||||
"""
|
||||
This allows storing and retrieving conversation start tasks from the database."""
|
||||
|
||||
session: AsyncSession
|
||||
user_id: str | None = None
|
||||
|
||||
@@ -280,6 +280,9 @@ e2b-code-interpreter = { version = "^2.0.0", optional = true }
|
||||
pybase62 = "^1.0.0"
|
||||
|
||||
# V1 dependencies
|
||||
#openhands-agent-server = { git = "https://github.com/OpenHands/agent-sdk.git", subdirectory = "openhands-agent-server", rev = "15f565b8ac38876e40dc05c08e2b04ccaae4a66d" }
|
||||
#openhands-sdk = { git = "https://github.com/OpenHands/agent-sdk.git", subdirectory = "openhands-sdk", rev = "15f565b8ac38876e40dc05c08e2b04ccaae4a66d" }
|
||||
#openhands-tools = { git = "https://github.com/OpenHands/agent-sdk.git", subdirectory = "openhands-tools", rev = "15f565b8ac38876e40dc05c08e2b04ccaae4a66d" }
|
||||
openhands-sdk = "1.10"
|
||||
openhands-agent-server = "1.10"
|
||||
openhands-tools = "1.10"
|
||||
|
||||
@@ -1786,855 +1786,3 @@ class TestLiveStatusAppConversationService:
|
||||
stdio_server = mcp_servers['stdio-server']
|
||||
assert stdio_server['command'] == 'npx'
|
||||
assert stdio_server['env'] == {'TOKEN': 'value'}
|
||||
|
||||
|
||||
class TestPluginHandling:
|
||||
"""Test cases for plugin-related functionality in LiveStatusAppConversationService."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Set up test fixtures."""
|
||||
# Create mock dependencies
|
||||
self.mock_user_context = Mock(spec=UserContext)
|
||||
self.mock_user_auth = Mock()
|
||||
self.mock_user_context.user_auth = self.mock_user_auth
|
||||
self.mock_jwt_service = Mock()
|
||||
self.mock_sandbox_service = Mock()
|
||||
self.mock_sandbox_spec_service = Mock()
|
||||
self.mock_app_conversation_info_service = Mock()
|
||||
self.mock_app_conversation_start_task_service = Mock()
|
||||
self.mock_event_callback_service = Mock()
|
||||
self.mock_event_service = Mock()
|
||||
self.mock_httpx_client = Mock()
|
||||
|
||||
# Create service instance
|
||||
self.service = LiveStatusAppConversationService(
|
||||
init_git_in_empty_workspace=True,
|
||||
user_context=self.mock_user_context,
|
||||
app_conversation_info_service=self.mock_app_conversation_info_service,
|
||||
app_conversation_start_task_service=self.mock_app_conversation_start_task_service,
|
||||
event_callback_service=self.mock_event_callback_service,
|
||||
event_service=self.mock_event_service,
|
||||
sandbox_service=self.mock_sandbox_service,
|
||||
sandbox_spec_service=self.mock_sandbox_spec_service,
|
||||
jwt_service=self.mock_jwt_service,
|
||||
sandbox_startup_timeout=30,
|
||||
sandbox_startup_poll_frequency=1,
|
||||
httpx_client=self.mock_httpx_client,
|
||||
web_url='https://test.example.com',
|
||||
openhands_provider_base_url='https://provider.example.com',
|
||||
access_token_hard_timeout=None,
|
||||
app_mode='test',
|
||||
)
|
||||
|
||||
# Mock user info
|
||||
self.mock_user = Mock()
|
||||
self.mock_user.id = 'test_user_123'
|
||||
self.mock_user.llm_model = 'gpt-4'
|
||||
self.mock_user.llm_base_url = 'https://api.openai.com/v1'
|
||||
self.mock_user.llm_api_key = 'test_api_key'
|
||||
self.mock_user.confirmation_mode = False
|
||||
self.mock_user.search_api_key = None
|
||||
self.mock_user.condenser_max_size = None
|
||||
self.mock_user.mcp_config = None
|
||||
self.mock_user.security_analyzer = None
|
||||
|
||||
# Mock sandbox
|
||||
self.mock_sandbox = Mock(spec=SandboxInfo)
|
||||
self.mock_sandbox.id = uuid4()
|
||||
self.mock_sandbox.status = SandboxStatus.RUNNING
|
||||
|
||||
def test_construct_initial_message_with_plugin_params_no_plugins(self):
|
||||
"""Test _construct_initial_message_with_plugin_params with no plugins returns original message."""
|
||||
from openhands.agent_server.models import SendMessageRequest, TextContent
|
||||
|
||||
# Test with None initial message and None plugins
|
||||
result = self.service._construct_initial_message_with_plugin_params(None, None)
|
||||
assert result is None
|
||||
|
||||
# Test with None initial message and empty plugins list
|
||||
result = self.service._construct_initial_message_with_plugin_params(None, [])
|
||||
assert result is None
|
||||
|
||||
# Test with initial message but None plugins
|
||||
initial_msg = SendMessageRequest(content=[TextContent(text='Hello world')])
|
||||
result = self.service._construct_initial_message_with_plugin_params(
|
||||
initial_msg, None
|
||||
)
|
||||
assert result is initial_msg
|
||||
|
||||
def test_construct_initial_message_with_plugin_params_no_params(self):
|
||||
"""Test _construct_initial_message_with_plugin_params with plugins but no parameters."""
|
||||
from openhands.agent_server.models import SendMessageRequest, TextContent
|
||||
from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
PluginSpec,
|
||||
)
|
||||
|
||||
# Plugin with no parameters
|
||||
plugins = [PluginSpec(source='github:owner/repo')]
|
||||
|
||||
# Test with None initial message
|
||||
result = self.service._construct_initial_message_with_plugin_params(
|
||||
None, plugins
|
||||
)
|
||||
assert result is None
|
||||
|
||||
# Test with initial message
|
||||
initial_msg = SendMessageRequest(content=[TextContent(text='Hello world')])
|
||||
result = self.service._construct_initial_message_with_plugin_params(
|
||||
initial_msg, plugins
|
||||
)
|
||||
assert result is initial_msg
|
||||
|
||||
def test_construct_initial_message_with_plugin_params_creates_new_message(self):
|
||||
"""Test _construct_initial_message_with_plugin_params creates message when no initial message."""
|
||||
from openhands.agent_server.models import TextContent
|
||||
from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
PluginSpec,
|
||||
)
|
||||
|
||||
plugins = [
|
||||
PluginSpec(
|
||||
source='github:owner/repo',
|
||||
parameters={'api_key': 'test123', 'debug': True},
|
||||
)
|
||||
]
|
||||
|
||||
result = self.service._construct_initial_message_with_plugin_params(
|
||||
None, plugins
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert len(result.content) == 1
|
||||
assert isinstance(result.content[0], TextContent)
|
||||
assert 'Plugin Configuration Parameters:' in result.content[0].text
|
||||
assert '- api_key: test123' in result.content[0].text
|
||||
assert '- debug: True' in result.content[0].text
|
||||
assert result.run is True
|
||||
|
||||
def test_construct_initial_message_with_plugin_params_appends_to_message(self):
|
||||
"""Test _construct_initial_message_with_plugin_params appends to existing message."""
|
||||
from openhands.agent_server.models import SendMessageRequest, TextContent
|
||||
from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
PluginSpec,
|
||||
)
|
||||
|
||||
initial_msg = SendMessageRequest(
|
||||
content=[TextContent(text='Please analyze this codebase')],
|
||||
run=False,
|
||||
)
|
||||
plugins = [
|
||||
PluginSpec(
|
||||
source='github:owner/repo',
|
||||
ref='v1.0.0',
|
||||
parameters={'target_dir': '/src', 'verbose': True},
|
||||
)
|
||||
]
|
||||
|
||||
result = self.service._construct_initial_message_with_plugin_params(
|
||||
initial_msg, plugins
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert len(result.content) == 1
|
||||
text = result.content[0].text
|
||||
assert text.startswith('Please analyze this codebase')
|
||||
assert 'Plugin Configuration Parameters:' in text
|
||||
assert '- target_dir: /src' in text
|
||||
assert '- verbose: True' in text
|
||||
assert result.run is False
|
||||
|
||||
def test_construct_initial_message_with_plugin_params_preserves_role(self):
|
||||
"""Test _construct_initial_message_with_plugin_params preserves message role."""
|
||||
from openhands.agent_server.models import SendMessageRequest, TextContent
|
||||
from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
PluginSpec,
|
||||
)
|
||||
|
||||
initial_msg = SendMessageRequest(
|
||||
role='system',
|
||||
content=[TextContent(text='System message')],
|
||||
)
|
||||
plugins = [PluginSpec(source='github:owner/repo', parameters={'key': 'value'})]
|
||||
|
||||
result = self.service._construct_initial_message_with_plugin_params(
|
||||
initial_msg, plugins
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result.role == 'system'
|
||||
|
||||
def test_construct_initial_message_with_plugin_params_empty_content(self):
|
||||
"""Test _construct_initial_message_with_plugin_params handles empty content list."""
|
||||
from openhands.agent_server.models import SendMessageRequest, TextContent
|
||||
from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
PluginSpec,
|
||||
)
|
||||
|
||||
initial_msg = SendMessageRequest(content=[])
|
||||
plugins = [PluginSpec(source='github:owner/repo', parameters={'key': 'value'})]
|
||||
|
||||
result = self.service._construct_initial_message_with_plugin_params(
|
||||
initial_msg, plugins
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert len(result.content) == 1
|
||||
assert isinstance(result.content[0], TextContent)
|
||||
assert 'Plugin Configuration Parameters:' in result.content[0].text
|
||||
|
||||
def test_construct_initial_message_with_multiple_plugins(self):
|
||||
"""Test _construct_initial_message_with_plugin_params handles multiple plugins."""
|
||||
from openhands.agent_server.models import TextContent
|
||||
from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
PluginSpec,
|
||||
)
|
||||
|
||||
plugins = [
|
||||
PluginSpec(
|
||||
source='github:owner/plugin1',
|
||||
parameters={'key1': 'value1'},
|
||||
),
|
||||
PluginSpec(
|
||||
source='github:owner/plugin2',
|
||||
parameters={'key2': 'value2'},
|
||||
),
|
||||
]
|
||||
|
||||
result = self.service._construct_initial_message_with_plugin_params(
|
||||
None, plugins
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert len(result.content) == 1
|
||||
assert isinstance(result.content[0], TextContent)
|
||||
text = result.content[0].text
|
||||
assert 'Plugin Configuration Parameters:' in text
|
||||
# Multiple plugins should show grouped by plugin name
|
||||
assert 'plugin1' in text
|
||||
assert 'plugin2' in text
|
||||
assert 'key1: value1' in text
|
||||
assert 'key2: value2' in text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch(
|
||||
'openhands.app_server.app_conversation.live_status_app_conversation_service.ExperimentManagerImpl'
|
||||
)
|
||||
async def test_finalize_conversation_request_with_plugins(
|
||||
self, mock_experiment_manager
|
||||
):
|
||||
"""Test _finalize_conversation_request passes plugins list to StartConversationRequest."""
|
||||
from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
PluginSpec,
|
||||
)
|
||||
|
||||
# Arrange
|
||||
mock_agent = Mock(spec=Agent)
|
||||
mock_llm = Mock(spec=LLM)
|
||||
mock_llm.model = 'gpt-4'
|
||||
mock_llm.usage_id = 'agent'
|
||||
|
||||
mock_updated_agent = Mock(spec=Agent)
|
||||
mock_updated_agent.llm = mock_llm
|
||||
mock_updated_agent.condenser = None
|
||||
mock_experiment_manager.run_agent_variant_tests__v1.return_value = (
|
||||
mock_updated_agent
|
||||
)
|
||||
|
||||
workspace = LocalWorkspace(working_dir='/test')
|
||||
secrets = {'test': StaticSecret(value='secret')}
|
||||
|
||||
plugins = [
|
||||
PluginSpec(
|
||||
source='github:owner/my-plugin',
|
||||
ref='v1.0.0',
|
||||
parameters={'api_key': 'test123'},
|
||||
)
|
||||
]
|
||||
|
||||
# Act
|
||||
result = await self.service._finalize_conversation_request(
|
||||
mock_agent,
|
||||
None,
|
||||
self.mock_user,
|
||||
workspace,
|
||||
None,
|
||||
secrets,
|
||||
self.mock_sandbox,
|
||||
None,
|
||||
None,
|
||||
'/test/dir',
|
||||
plugins=plugins,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, StartConversationRequest)
|
||||
assert result.plugins is not None
|
||||
assert len(result.plugins) == 1
|
||||
assert result.plugins[0].source == 'github:owner/my-plugin'
|
||||
assert result.plugins[0].ref == 'v1.0.0'
|
||||
# Also verify initial message contains plugin params
|
||||
assert result.initial_message is not None
|
||||
assert (
|
||||
'Plugin Configuration Parameters:' in result.initial_message.content[0].text
|
||||
)
|
||||
assert '- api_key: test123' in result.initial_message.content[0].text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch(
|
||||
'openhands.app_server.app_conversation.live_status_app_conversation_service.ExperimentManagerImpl'
|
||||
)
|
||||
async def test_finalize_conversation_request_without_plugins(
|
||||
self, mock_experiment_manager
|
||||
):
|
||||
"""Test _finalize_conversation_request without plugins sets plugins to None."""
|
||||
# Arrange
|
||||
mock_agent = Mock(spec=Agent)
|
||||
mock_llm = Mock(spec=LLM)
|
||||
mock_llm.model = 'gpt-4'
|
||||
mock_llm.usage_id = 'agent'
|
||||
|
||||
mock_updated_agent = Mock(spec=Agent)
|
||||
mock_updated_agent.llm = mock_llm
|
||||
mock_updated_agent.condenser = None
|
||||
mock_experiment_manager.run_agent_variant_tests__v1.return_value = (
|
||||
mock_updated_agent
|
||||
)
|
||||
|
||||
workspace = LocalWorkspace(working_dir='/test')
|
||||
secrets = {}
|
||||
|
||||
# Act
|
||||
result = await self.service._finalize_conversation_request(
|
||||
mock_agent,
|
||||
None,
|
||||
self.mock_user,
|
||||
workspace,
|
||||
None,
|
||||
secrets,
|
||||
self.mock_sandbox,
|
||||
None,
|
||||
None,
|
||||
'/test/dir',
|
||||
plugins=None,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, StartConversationRequest)
|
||||
assert result.plugins is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch(
|
||||
'openhands.app_server.app_conversation.live_status_app_conversation_service.ExperimentManagerImpl'
|
||||
)
|
||||
async def test_finalize_conversation_request_plugin_without_ref(
|
||||
self, mock_experiment_manager
|
||||
):
|
||||
"""Test _finalize_conversation_request with plugin that has no ref."""
|
||||
from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
PluginSpec,
|
||||
)
|
||||
|
||||
# Arrange
|
||||
mock_agent = Mock(spec=Agent)
|
||||
mock_llm = Mock(spec=LLM)
|
||||
mock_llm.model = 'gpt-4'
|
||||
mock_llm.usage_id = 'agent'
|
||||
|
||||
mock_updated_agent = Mock(spec=Agent)
|
||||
mock_updated_agent.llm = mock_llm
|
||||
mock_updated_agent.condenser = None
|
||||
mock_experiment_manager.run_agent_variant_tests__v1.return_value = (
|
||||
mock_updated_agent
|
||||
)
|
||||
|
||||
workspace = LocalWorkspace(working_dir='/test')
|
||||
secrets = {}
|
||||
|
||||
# Plugin without ref or parameters
|
||||
plugins = [PluginSpec(source='github:owner/my-plugin')]
|
||||
|
||||
# Act
|
||||
result = await self.service._finalize_conversation_request(
|
||||
mock_agent,
|
||||
None,
|
||||
self.mock_user,
|
||||
workspace,
|
||||
None,
|
||||
secrets,
|
||||
self.mock_sandbox,
|
||||
None,
|
||||
None,
|
||||
'/test/dir',
|
||||
plugins=plugins,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, StartConversationRequest)
|
||||
assert result.plugins is not None
|
||||
assert len(result.plugins) == 1
|
||||
assert result.plugins[0].source == 'github:owner/my-plugin'
|
||||
assert result.plugins[0].ref is None
|
||||
# No parameters, so initial message should be None
|
||||
assert result.initial_message is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch(
|
||||
'openhands.app_server.app_conversation.live_status_app_conversation_service.ExperimentManagerImpl'
|
||||
)
|
||||
async def test_finalize_conversation_request_plugin_with_repo_path(
|
||||
self, mock_experiment_manager
|
||||
):
|
||||
"""Test _finalize_conversation_request passes repo_path to PluginSource."""
|
||||
from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
PluginSpec,
|
||||
)
|
||||
|
||||
# Arrange
|
||||
mock_agent = Mock(spec=Agent)
|
||||
mock_llm = Mock(spec=LLM)
|
||||
mock_llm.model = 'gpt-4'
|
||||
mock_llm.usage_id = 'agent'
|
||||
|
||||
mock_updated_agent = Mock(spec=Agent)
|
||||
mock_updated_agent.llm = mock_llm
|
||||
mock_updated_agent.condenser = None
|
||||
mock_experiment_manager.run_agent_variant_tests__v1.return_value = (
|
||||
mock_updated_agent
|
||||
)
|
||||
|
||||
workspace = LocalWorkspace(working_dir='/test')
|
||||
secrets = {}
|
||||
|
||||
# Plugin with repo_path (for marketplace repos containing multiple plugins)
|
||||
plugins = [
|
||||
PluginSpec(
|
||||
source='github:owner/marketplace-repo',
|
||||
ref='main',
|
||||
repo_path='plugins/city-weather',
|
||||
)
|
||||
]
|
||||
|
||||
# Act
|
||||
result = await self.service._finalize_conversation_request(
|
||||
mock_agent,
|
||||
None,
|
||||
self.mock_user,
|
||||
workspace,
|
||||
None,
|
||||
secrets,
|
||||
self.mock_sandbox,
|
||||
None,
|
||||
None,
|
||||
'/test/dir',
|
||||
plugins=plugins,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, StartConversationRequest)
|
||||
assert result.plugins is not None
|
||||
assert len(result.plugins) == 1
|
||||
assert result.plugins[0].source == 'github:owner/marketplace-repo'
|
||||
assert result.plugins[0].ref == 'main'
|
||||
assert result.plugins[0].repo_path == 'plugins/city-weather'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch(
|
||||
'openhands.app_server.app_conversation.live_status_app_conversation_service.ExperimentManagerImpl'
|
||||
)
|
||||
async def test_finalize_conversation_request_multiple_plugins(
|
||||
self, mock_experiment_manager
|
||||
):
|
||||
"""Test _finalize_conversation_request with multiple plugins."""
|
||||
from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
PluginSpec,
|
||||
)
|
||||
|
||||
# Arrange
|
||||
mock_agent = Mock(spec=Agent)
|
||||
mock_llm = Mock(spec=LLM)
|
||||
mock_llm.model = 'gpt-4'
|
||||
mock_llm.usage_id = 'agent'
|
||||
|
||||
mock_updated_agent = Mock(spec=Agent)
|
||||
mock_updated_agent.llm = mock_llm
|
||||
mock_updated_agent.condenser = None
|
||||
mock_experiment_manager.run_agent_variant_tests__v1.return_value = (
|
||||
mock_updated_agent
|
||||
)
|
||||
|
||||
workspace = LocalWorkspace(working_dir='/test')
|
||||
secrets = {}
|
||||
|
||||
# Multiple plugins
|
||||
plugins = [
|
||||
PluginSpec(source='github:owner/security-plugin', ref='v2.0.0'),
|
||||
PluginSpec(
|
||||
source='github:owner/monorepo',
|
||||
repo_path='plugins/logging',
|
||||
),
|
||||
PluginSpec(source='/local/path/to/plugin'),
|
||||
]
|
||||
|
||||
# Act
|
||||
result = await self.service._finalize_conversation_request(
|
||||
mock_agent,
|
||||
None,
|
||||
self.mock_user,
|
||||
workspace,
|
||||
None,
|
||||
secrets,
|
||||
self.mock_sandbox,
|
||||
None,
|
||||
None,
|
||||
'/test/dir',
|
||||
plugins=plugins,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, StartConversationRequest)
|
||||
assert result.plugins is not None
|
||||
assert len(result.plugins) == 3
|
||||
assert result.plugins[0].source == 'github:owner/security-plugin'
|
||||
assert result.plugins[0].ref == 'v2.0.0'
|
||||
assert result.plugins[1].source == 'github:owner/monorepo'
|
||||
assert result.plugins[1].repo_path == 'plugins/logging'
|
||||
assert result.plugins[2].source == '/local/path/to/plugin'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_start_conversation_request_for_user_with_plugins(self):
|
||||
"""Test _build_start_conversation_request_for_user passes plugins to finalize method."""
|
||||
from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
PluginSpec,
|
||||
)
|
||||
|
||||
# Arrange
|
||||
self.mock_user_context.get_user_info.return_value = self.mock_user
|
||||
self.mock_user_context.get_secrets.return_value = {}
|
||||
self.mock_user_context.get_provider_tokens = AsyncMock(return_value=None)
|
||||
self.mock_user_context.get_mcp_api_key.return_value = None
|
||||
|
||||
plugins = [
|
||||
PluginSpec(
|
||||
source='https://github.com/org/plugin.git',
|
||||
ref='main',
|
||||
parameters={'config_file': 'custom.yaml'},
|
||||
)
|
||||
]
|
||||
|
||||
# Mock _finalize_conversation_request to capture the call
|
||||
mock_finalize = AsyncMock(return_value=Mock(spec=StartConversationRequest))
|
||||
self.service._finalize_conversation_request = mock_finalize
|
||||
|
||||
# Act
|
||||
await self.service._build_start_conversation_request_for_user(
|
||||
self.mock_sandbox,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
'/workspace',
|
||||
plugins=plugins,
|
||||
)
|
||||
|
||||
# Assert
|
||||
mock_finalize.assert_called_once()
|
||||
call_kwargs = mock_finalize.call_args.kwargs
|
||||
assert call_kwargs['plugins'] == plugins
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_start_conversation_request_for_user_without_plugins(self):
|
||||
"""Test _build_start_conversation_request_for_user works without plugins."""
|
||||
# Arrange
|
||||
self.mock_user_context.get_user_info.return_value = self.mock_user
|
||||
self.mock_user_context.get_secrets.return_value = {}
|
||||
self.mock_user_context.get_provider_tokens = AsyncMock(return_value=None)
|
||||
self.mock_user_context.get_mcp_api_key.return_value = None
|
||||
|
||||
# Mock _finalize_conversation_request
|
||||
mock_finalize = AsyncMock(return_value=Mock(spec=StartConversationRequest))
|
||||
self.service._finalize_conversation_request = mock_finalize
|
||||
|
||||
# Act
|
||||
await self.service._build_start_conversation_request_for_user(
|
||||
self.mock_sandbox,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
'/workspace',
|
||||
)
|
||||
|
||||
# Assert
|
||||
mock_finalize.assert_called_once()
|
||||
call_kwargs = mock_finalize.call_args.kwargs
|
||||
assert call_kwargs.get('plugins') is None
|
||||
|
||||
|
||||
class TestPluginSpecModel:
|
||||
"""Test cases for the PluginSpec model."""
|
||||
|
||||
def test_plugin_spec_with_all_fields(self):
|
||||
"""Test PluginSpec with all fields provided."""
|
||||
from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
PluginSpec,
|
||||
)
|
||||
|
||||
plugin = PluginSpec(
|
||||
source='github:owner/repo',
|
||||
ref='v1.0.0',
|
||||
repo_path='plugins/my-plugin',
|
||||
parameters={'key1': 'value1', 'key2': 123, 'key3': True},
|
||||
)
|
||||
|
||||
assert plugin.source == 'github:owner/repo'
|
||||
assert plugin.ref == 'v1.0.0'
|
||||
assert plugin.repo_path == 'plugins/my-plugin'
|
||||
assert plugin.parameters == {'key1': 'value1', 'key2': 123, 'key3': True}
|
||||
|
||||
def test_plugin_spec_with_only_source(self):
|
||||
"""Test PluginSpec with only source provided."""
|
||||
from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
PluginSpec,
|
||||
)
|
||||
|
||||
plugin = PluginSpec(source='https://github.com/owner/repo.git')
|
||||
|
||||
assert plugin.source == 'https://github.com/owner/repo.git'
|
||||
assert plugin.ref is None
|
||||
assert plugin.repo_path is None
|
||||
assert plugin.parameters is None
|
||||
|
||||
def test_plugin_spec_serialization(self):
|
||||
"""Test PluginSpec serialization to JSON."""
|
||||
from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
PluginSpec,
|
||||
)
|
||||
|
||||
plugin = PluginSpec(
|
||||
source='github:owner/repo',
|
||||
ref='main',
|
||||
repo_path='plugins/my-plugin',
|
||||
parameters={'debug': True},
|
||||
)
|
||||
|
||||
json_data = plugin.model_dump()
|
||||
assert json_data == {
|
||||
'source': 'github:owner/repo',
|
||||
'ref': 'main',
|
||||
'repo_path': 'plugins/my-plugin',
|
||||
'parameters': {'debug': True},
|
||||
}
|
||||
|
||||
def test_plugin_spec_deserialization(self):
|
||||
"""Test PluginSpec deserialization from dict."""
|
||||
from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
PluginSpec,
|
||||
)
|
||||
|
||||
data = {
|
||||
'source': 'github:owner/repo',
|
||||
'ref': 'v2.0.0',
|
||||
'repo_path': 'plugins/weather',
|
||||
'parameters': {'timeout': 30},
|
||||
}
|
||||
|
||||
plugin = PluginSpec.model_validate(data)
|
||||
|
||||
assert plugin.source == 'github:owner/repo'
|
||||
assert plugin.ref == 'v2.0.0'
|
||||
assert plugin.repo_path == 'plugins/weather'
|
||||
assert plugin.parameters == {'timeout': 30}
|
||||
|
||||
def test_plugin_spec_display_name_github_format(self):
|
||||
"""Test display_name extracts repo name from github:owner/repo format."""
|
||||
from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
PluginSpec,
|
||||
)
|
||||
|
||||
plugin = PluginSpec(source='github:owner/my-plugin')
|
||||
assert plugin.display_name == 'my-plugin'
|
||||
|
||||
def test_plugin_spec_display_name_git_url(self):
|
||||
"""Test display_name extracts repo name from git URL."""
|
||||
from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
PluginSpec,
|
||||
)
|
||||
|
||||
plugin = PluginSpec(source='https://github.com/owner/repo.git')
|
||||
assert plugin.display_name == 'repo.git'
|
||||
|
||||
def test_plugin_spec_display_name_local_path(self):
|
||||
"""Test display_name extracts directory name from local path."""
|
||||
from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
PluginSpec,
|
||||
)
|
||||
|
||||
plugin = PluginSpec(source='/local/path/to/plugin')
|
||||
assert plugin.display_name == 'plugin'
|
||||
|
||||
def test_plugin_spec_display_name_no_slash(self):
|
||||
"""Test display_name returns source as-is when no slash present."""
|
||||
from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
PluginSpec,
|
||||
)
|
||||
|
||||
plugin = PluginSpec(source='local-plugin')
|
||||
assert plugin.display_name == 'local-plugin'
|
||||
|
||||
def test_plugin_spec_format_params_as_text(self):
|
||||
"""Test format_params_as_text formats parameters as text."""
|
||||
from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
PluginSpec,
|
||||
)
|
||||
|
||||
plugin = PluginSpec(
|
||||
source='github:owner/repo',
|
||||
parameters={'key1': 'value1', 'key2': 123},
|
||||
)
|
||||
|
||||
result = plugin.format_params_as_text()
|
||||
assert result == '- key1: value1\n- key2: 123'
|
||||
|
||||
def test_plugin_spec_format_params_as_text_with_indent(self):
|
||||
"""Test format_params_as_text with custom indent."""
|
||||
from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
PluginSpec,
|
||||
)
|
||||
|
||||
plugin = PluginSpec(
|
||||
source='github:owner/repo',
|
||||
parameters={'debug': True},
|
||||
)
|
||||
|
||||
result = plugin.format_params_as_text(indent=' ')
|
||||
assert result == ' - debug: True'
|
||||
|
||||
def test_plugin_spec_format_params_as_text_no_params(self):
|
||||
"""Test format_params_as_text returns None when no parameters."""
|
||||
from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
PluginSpec,
|
||||
)
|
||||
|
||||
plugin = PluginSpec(source='github:owner/repo')
|
||||
assert plugin.format_params_as_text() is None
|
||||
|
||||
def test_plugin_spec_inherits_repo_path_validation(self):
|
||||
"""Test PluginSpec inherits validation from SDK's PluginSource."""
|
||||
import pytest
|
||||
|
||||
from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
PluginSpec,
|
||||
)
|
||||
|
||||
# Should reject absolute paths
|
||||
with pytest.raises(ValueError, match='must be relative'):
|
||||
PluginSpec(source='github:owner/repo', repo_path='/absolute/path')
|
||||
|
||||
# Should reject parent traversal
|
||||
with pytest.raises(ValueError, match="cannot contain '..'"):
|
||||
PluginSpec(source='github:owner/repo', repo_path='../parent/path')
|
||||
|
||||
|
||||
class TestAppConversationStartRequestWithPlugins:
|
||||
"""Test cases for AppConversationStartRequest with plugins field."""
|
||||
|
||||
def test_start_request_with_plugins(self):
|
||||
"""Test AppConversationStartRequest with plugins field."""
|
||||
from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
AppConversationStartRequest,
|
||||
PluginSpec,
|
||||
)
|
||||
|
||||
plugins = [
|
||||
PluginSpec(
|
||||
source='github:owner/my-plugin',
|
||||
ref='v1.0.0',
|
||||
parameters={'api_key': 'test'},
|
||||
)
|
||||
]
|
||||
|
||||
request = AppConversationStartRequest(
|
||||
title='Test conversation',
|
||||
plugins=plugins,
|
||||
)
|
||||
|
||||
assert request.plugins is not None
|
||||
assert len(request.plugins) == 1
|
||||
assert request.plugins[0].source == 'github:owner/my-plugin'
|
||||
assert request.plugins[0].ref == 'v1.0.0'
|
||||
assert request.plugins[0].parameters == {'api_key': 'test'}
|
||||
|
||||
def test_start_request_without_plugins(self):
|
||||
"""Test AppConversationStartRequest without plugins field (backwards compatible)."""
|
||||
from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
AppConversationStartRequest,
|
||||
)
|
||||
|
||||
request = AppConversationStartRequest(
|
||||
title='Test conversation',
|
||||
)
|
||||
|
||||
assert request.plugins is None
|
||||
|
||||
def test_start_request_serialization_with_plugins(self):
|
||||
"""Test AppConversationStartRequest serialization includes plugins."""
|
||||
from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
AppConversationStartRequest,
|
||||
PluginSpec,
|
||||
)
|
||||
|
||||
plugins = [PluginSpec(source='github:owner/repo')]
|
||||
request = AppConversationStartRequest(plugins=plugins)
|
||||
|
||||
json_data = request.model_dump()
|
||||
|
||||
assert 'plugins' in json_data
|
||||
assert len(json_data['plugins']) == 1
|
||||
assert json_data['plugins'][0]['source'] == 'github:owner/repo'
|
||||
|
||||
def test_start_request_deserialization_with_plugins(self):
|
||||
"""Test AppConversationStartRequest deserialization from JSON with plugins."""
|
||||
from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
AppConversationStartRequest,
|
||||
)
|
||||
|
||||
data = {
|
||||
'title': 'Test',
|
||||
'plugins': [
|
||||
{
|
||||
'source': 'github:owner/plugin',
|
||||
'ref': 'main',
|
||||
'parameters': {'key': 'value'},
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
request = AppConversationStartRequest.model_validate(data)
|
||||
|
||||
assert request.plugins is not None
|
||||
assert len(request.plugins) == 1
|
||||
assert request.plugins[0].source == 'github:owner/plugin'
|
||||
assert request.plugins[0].ref == 'main'
|
||||
assert request.plugins[0].parameters == {'key': 'value'}
|
||||
|
||||
def test_start_request_with_multiple_plugins(self):
|
||||
"""Test AppConversationStartRequest with multiple plugins."""
|
||||
from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
AppConversationStartRequest,
|
||||
PluginSpec,
|
||||
)
|
||||
|
||||
plugins = [
|
||||
PluginSpec(source='github:owner/plugin1', ref='v1.0.0'),
|
||||
PluginSpec(source='github:owner/plugin2', repo_path='plugins/sub'),
|
||||
PluginSpec(source='/local/path'),
|
||||
]
|
||||
|
||||
request = AppConversationStartRequest(
|
||||
title='Test conversation',
|
||||
plugins=plugins,
|
||||
)
|
||||
|
||||
assert request.plugins is not None
|
||||
assert len(request.plugins) == 3
|
||||
assert request.plugins[0].source == 'github:owner/plugin1'
|
||||
assert request.plugins[1].repo_path == 'plugins/sub'
|
||||
assert request.plugins[2].source == '/local/path'
|
||||
|
||||
Reference in New Issue
Block a user