Compare commits

..

1 Commits

Author SHA1 Message Date
openhands 48efca5f34 Add PRD for Organization Code Review Bot
Defines the problem statement, goals, user personas, and key use cases for a
first-class org-specific code review bot with telemetry and one-click remediation.

Co-authored-by: openhands <openhands@all-hands.dev>
2026-01-21 18:35:01 +00:00
87 changed files with 4098 additions and 10431 deletions
+2 -4
View File
@@ -39,7 +39,8 @@ jobs:
run: |
if [[ "$GITHUB_EVENT_NAME" == "pull_request" ]]; then
json=$(jq -n -c '[
{ image: "nikolaik/python-nodejs:python3.12-nodejs22", tag: "nikolaik" }
{ image: "nikolaik/python-nodejs:python3.12-nodejs22", tag: "nikolaik" },
{ image: "ubuntu:24.04", tag: "ubuntu" }
]')
else
json=$(jq -n -c '[
@@ -148,9 +149,6 @@ jobs:
push: true
tags: ${{ env.DOCKER_TAGS }}
platforms: ${{ env.DOCKER_PLATFORM }}
# Caching directives to boost performance
cache-from: type=registry,ref=ghcr.io/${{ env.REPO_OWNER }}/runtime:buildcache-${{ matrix.base_image.tag }}
cache-to: type=registry,ref=ghcr.io/${{ env.REPO_OWNER }}/runtime:buildcache-${{ matrix.base_image.tag }},mode=max
build-args: ${{ env.DOCKER_BUILD_ARGS }}
context: containers/runtime
provenance: false
+2 -2
View File
@@ -7,8 +7,8 @@ services:
image: openhands:latest
container_name: openhands-app-${DATE:-}
environment:
- AGENT_SERVER_IMAGE_REPOSITORY=${AGENT_SERVER_IMAGE_REPOSITORY:-ghcr.io/openhands/agent-server}
- AGENT_SERVER_IMAGE_TAG=${AGENT_SERVER_IMAGE_TAG:-31536c8-python}
- AGENT_SERVER_IMAGE_REPOSITORY=${AGENT_SERVER_IMAGE_REPOSITORY:-docker.openhands.dev/openhands/runtime}
- AGENT_SERVER_IMAGE_TAG=${AGENT_SERVER_IMAGE_TAG:-1.2-nikolaik}
#- SANDBOX_USER_ID=${SANDBOX_USER_ID:-1234} # enable this only if you want a specific non-root sandbox user but you will have to manually adjust permissions of ~/.openhands for this user
- WORKSPACE_MOUNT_PATH=${WORKSPACE_BASE:-$PWD/workspace}
ports:
@@ -0,0 +1,62 @@
# Organization Code Review Bot - PRD
## Problem Statement
### Engineering Leader Perspective
AI coding agents have made code generation cheap, but verification remains the bottleneck. Engineering leaders lack visibility into:
- **Review quality**: How effective are AI-generated code reviews vs. human reviews?
- **Developer engagement**: Are developers acting on AI review feedback?
- **Organizational patterns**: What recurring feedback themes exist across the organization that could be codified?
Without telemetry, orgs cannot measure ROI of AI-assisted code review or systematically improve their review standards.
### Developer Perspective
Current AI code reviews produce itemized feedback, but acting on that feedback is manual and disconnected:
- Developers must context-switch to address each review item separately
- No one-click path from "review comment" → "fix implementation"
- No feedback loop to help the AI learn which suggestions are valuable vs. noise
## Goals
1. **One-click remediation**: Enable developers to launch an OpenHands conversation directly from a review item to address it
2. **Org-wide telemetry**: Track accept/dismiss rates on review items to measure review quality and surface patterns
3. **Learned review standards**: Distill recurring org-specific feedback into a lightweight review standard the bot applies automatically
4. **Verification signals**: Integrate code survival metrics (what % of AI-written code survives to merge) to predict review quality
## User Personas
| Persona | Description |
|---------|-------------|
| **Developer** | Uses the bot to get PR reviews and quickly address feedback items via one-click OpenHands sessions |
| **Tech Lead** | Reviews org-wide feedback patterns to identify common issues and improve team coding standards |
| **Engineering Manager** | Monitors accept/dismiss telemetry to assess AI review effectiveness and developer adoption |
| **Platform Engineer** | Configures org-specific review rules and integrates the bot with existing CI/CD workflows |
## Key Use Cases
### 1. One-Click Review Item Remediation
- Developer receives AI code review with itemized feedback
- Each feedback item has a "Fix with OpenHands" button that launches a scoped conversation to address that specific issue
- Context (diff, review comment, file) is automatically passed to the agent
### 2. Accept/Dismiss Feedback Telemetry
- Developers can mark review items as "Agree & Fix" or "Dismiss"
- Org-wide dashboard shows aggregate accept/dismiss rates per review category
- Identifies high-value feedback patterns vs. low-signal noise
### 3. Org-Specific Review Standards
- Platform engineer configures org-specific review rules (e.g., "always check for error handling in API routes")
- Bot learns from historical code reviews to surface org-specific patterns
- Review standards are versioned and auditable
### 4. Code Survival Metrics
- Track what fraction of AI-suggested changes make it into the merged PR
- Surface low-survival patterns to improve review prompt quality
- Use survival signals to predict whether a review item is likely to be addressed
### 5. Review Quality Dashboard
- Engineering managers see per-team and per-repo review effectiveness metrics
- Trending view of common feedback categories over time
- Alerts when review quality drops or dismiss rates spike
-205
View File
@@ -1,205 +0,0 @@
#!/usr/bin/env python
"""
Downgrade script for migrated users.
This script identifies users who have been migrated (already_migrated=True)
and reverts them back to the pre-migration state.
Usage:
# Dry run - just list the users that would be downgraded
python downgrade_migrated_users.py --dry-run
# Downgrade a specific user by their keycloak_user_id
python downgrade_migrated_users.py --user-id <user_id>
# Downgrade all migrated users (with confirmation)
python downgrade_migrated_users.py --all
# Downgrade all migrated users without confirmation (dangerous!)
python downgrade_migrated_users.py --all --no-confirm
"""
import argparse
import asyncio
import sys
# Add the enterprise directory to the path
sys.path.insert(0, '/workspace/project/OpenHands/enterprise')
from server.logger import logger
from sqlalchemy import select, text
from storage.database import session_maker
from storage.user_settings import UserSettings
from storage.user_store import UserStore
def get_migrated_users() -> list[str]:
"""Get list of keycloak_user_ids for users who have been migrated.
This includes:
1. Users with already_migrated=True in user_settings (migrated users)
2. Users in the 'user' table who don't have a user_settings entry (new sign-ups)
"""
with session_maker() as session:
# Get users from user_settings with already_migrated=True
migrated_result = session.execute(
select(UserSettings.keycloak_user_id).where(
UserSettings.already_migrated.is_(True)
)
)
migrated_users = {row[0] for row in migrated_result.fetchall() if row[0]}
# Get users from the 'user' table (new sign-ups won't have user_settings)
# These are users who signed up after the migration was deployed
new_signup_result = session.execute(
text("""
SELECT CAST(u.id AS VARCHAR)
FROM "user" u
WHERE NOT EXISTS (
SELECT 1 FROM user_settings us
WHERE us.keycloak_user_id = CAST(u.id AS VARCHAR)
)
""")
)
new_signups = {row[0] for row in new_signup_result.fetchall() if row[0]}
# Combine both sets
all_users = migrated_users | new_signups
return list(all_users)
async def downgrade_user(user_id: str) -> bool:
"""Downgrade a single user.
Args:
user_id: The keycloak_user_id to downgrade
Returns:
True if successful, False otherwise
"""
try:
result = await UserStore.downgrade_user(user_id)
if result:
print(f'✓ Successfully downgraded user: {user_id}')
return True
else:
print(f'✗ Failed to downgrade user: {user_id}')
return False
except Exception as e:
print(f'✗ Error downgrading user {user_id}: {e}')
logger.exception(
'downgrade_script:error',
extra={'user_id': user_id, 'error': str(e)},
)
return False
async def main():
parser = argparse.ArgumentParser(
description='Downgrade migrated users back to pre-migration state'
)
parser.add_argument(
'--dry-run',
action='store_true',
help='Just list users that would be downgraded, without making changes',
)
parser.add_argument(
'--user-id',
type=str,
help='Downgrade a specific user by keycloak_user_id',
)
parser.add_argument(
'--all',
action='store_true',
help='Downgrade all migrated users',
)
parser.add_argument(
'--no-confirm',
action='store_true',
help='Skip confirmation prompt (use with caution!)',
)
args = parser.parse_args()
# Get list of migrated users
migrated_users = get_migrated_users()
print(f'\nFound {len(migrated_users)} migrated user(s).')
if args.dry_run:
print('\n--- DRY RUN MODE ---')
print('The following users would be downgraded:')
for user_id in migrated_users:
print(f' - {user_id}')
print('\nNo changes were made.')
return
if args.user_id:
# Downgrade a specific user
if args.user_id not in migrated_users:
print(f'\nUser {args.user_id} is not in the migrated users list.')
print('Either the user was not migrated, or the user_id is incorrect.')
return
print(f'\nDowngrading user: {args.user_id}')
if not args.no_confirm:
confirm = input('Are you sure? (yes/no): ')
if confirm.lower() != 'yes':
print('Cancelled.')
return
success = await downgrade_user(args.user_id)
if success:
print('\nDowngrade completed successfully.')
else:
print('\nDowngrade failed. Check logs for details.')
sys.exit(1)
elif args.all:
# Downgrade all migrated users
if not migrated_users:
print('\nNo migrated users to downgrade.')
return
print(f'\n⚠️ About to downgrade {len(migrated_users)} user(s).')
if not args.no_confirm:
print('\nThis will:')
print(' - Revert LiteLLM team/user budget settings')
print(' - Delete organization entries')
print(' - Delete user entries in the new schema')
print(' - Reset the already_migrated flag')
print('\nUsers to downgrade:')
for user_id in migrated_users[:10]: # Show first 10
print(f' - {user_id}')
if len(migrated_users) > 10:
print(f' ... and {len(migrated_users) - 10} more')
confirm = input('\nType "yes" to proceed: ')
if confirm.lower() != 'yes':
print('Cancelled.')
return
print('\nStarting downgrade...\n')
success_count = 0
fail_count = 0
for user_id in migrated_users:
success = await downgrade_user(user_id)
if success:
success_count += 1
else:
fail_count += 1
print('\n--- Summary ---')
print(f'Successful: {success_count}')
print(f'Failed: {fail_count}')
if fail_count > 0:
sys.exit(1)
else:
parser.print_help()
print('\nPlease specify --dry-run, --user-id, or --all')
if __name__ == '__main__':
asyncio.run(main())
@@ -26,14 +26,12 @@ from integrations.utils import (
from integrations.v1_utils import get_saas_user_auth
from jinja2 import Environment, FileSystemLoader
from pydantic import SecretStr
from server.auth.auth_error import ExpiredError
from server.auth.constants import GITHUB_APP_CLIENT_ID, GITHUB_APP_PRIVATE_KEY
from server.auth.token_manager import TokenManager
from server.utils.conversation_callback_utils import register_callback_processor
from openhands.core.logger import openhands_logger as logger
from openhands.integrations.provider import ProviderToken, ProviderType
from openhands.integrations.service_types import AuthenticationError
from openhands.server.types import (
LLMAuthenticationError,
MissingSettingsError,
@@ -349,7 +347,7 @@ class GithubManager(Manager):
msg_info = f'@{user_info.username} please set a valid LLM API key in [OpenHands Cloud]({HOST_URL}) before starting a job.'
except (AuthenticationError, ExpiredError, SessionExpiredError) as e:
except SessionExpiredError as e:
logger.warning(
f'[GitHub] Session expired for user {user_info.username}: {str(e)}'
)
+352 -290
View File
@@ -1,37 +1,18 @@
"""Jira integration manager.
This module orchestrates the processing of Jira webhook events:
1. Parse webhook payload (via JiraPayloadParser)
2. Validate workspace
3. Authenticate user
4. Create view with repository selection (via JiraFactory)
5. Start conversation job
The manager delegates payload parsing to JiraPayloadParser and view creation
to JiraFactory, keeping the orchestration logic clean and traceable.
"""
from typing import Tuple
from urllib.parse import urlparse
import httpx
from integrations.jira.jira_payload import (
JiraPayloadError,
JiraPayloadParser,
JiraPayloadSkipped,
JiraPayloadSuccess,
JiraWebhookPayload,
from integrations.jira.jira_types import JiraViewInterface
from integrations.jira.jira_view import (
JiraFactory,
JiraNewConversationView,
)
from integrations.jira.jira_types import (
JiraViewInterface,
RepositoryNotFoundError,
StartingConvoException,
)
from integrations.jira.jira_view import JiraFactory, JiraNewConversationView
from integrations.manager import Manager
from integrations.models import Message
from integrations.models import JobContext, Message
from integrations.utils import (
HOST,
HOST_URL,
OPENHANDS_RESOLVER_TEMPLATES_DIR,
get_oh_labels,
filter_potential_repos_by_user_msg,
get_session_expired_message,
)
from jinja2 import Environment, FileSystemLoader
@@ -43,6 +24,9 @@ from storage.jira_user import JiraUser
from storage.jira_workspace import JiraWorkspace
from openhands.core.logger import openhands_logger as logger
from openhands.integrations.provider import ProviderHandler
from openhands.integrations.service_types import Repository
from openhands.server.shared import server_config
from openhands.server.types import (
LLMAuthenticationError,
MissingSettingsError,
@@ -53,211 +37,267 @@ from openhands.utils.http_session import httpx_verify_option
JIRA_CLOUD_API_URL = 'https://api.atlassian.com/ex/jira'
# Get OH labels for this environment
OH_LABEL, INLINE_OH_LABEL = get_oh_labels(HOST)
class JiraManager(Manager):
"""Manager for processing Jira webhook events.
This class orchestrates the flow from webhook receipt to conversation creation,
delegating parsing to JiraPayloadParser and view creation to JiraFactory.
"""
def __init__(self, token_manager: TokenManager):
self.token_manager = token_manager
self.integration_store = JiraIntegrationStore.get_instance()
self.jinja_env = Environment(
loader=FileSystemLoader(OPENHANDS_RESOLVER_TEMPLATES_DIR + 'jira')
)
self.payload_parser = JiraPayloadParser(
oh_label=OH_LABEL,
inline_oh_label=INLINE_OH_LABEL,
)
async def receive_message(self, message: Message):
"""Process incoming Jira webhook message.
Flow:
1. Parse webhook payload
2. Validate workspace exists and is active
3. Authenticate user
4. Create view (includes fetching issue details and selecting repository)
5. Start job
Each step has clear logging for traceability.
"""
raw_payload = message.message.get('payload', {})
# Step 1: Parse webhook payload
logger.info(
'[Jira] Received webhook',
extra={'raw_payload': raw_payload},
)
parse_result = self.payload_parser.parse(raw_payload)
if isinstance(parse_result, JiraPayloadSkipped):
logger.info(
'[Jira] Webhook skipped', extra={'reason': parse_result.skip_reason}
)
return
if isinstance(parse_result, JiraPayloadError):
logger.warning(
'[Jira] Webhook parse failed', extra={'error': parse_result.error}
)
return
payload = parse_result.payload
logger.info(
'[Jira] Processing webhook',
extra={
'event_type': payload.event_type.value,
'issue_key': payload.issue_key,
'user_email': payload.user_email,
},
)
# Step 2: Validate workspace
workspace = await self._get_active_workspace(payload)
if not workspace:
return
# Step 3: Authenticate user
jira_user, saas_user_auth = await self._authenticate_user(payload, workspace)
if not jira_user or not saas_user_auth:
return
# Step 4: Create view (includes issue details fetch and repo selection)
decrypted_api_key = self.token_manager.decrypt_text(workspace.svc_acc_api_key)
try:
view = await JiraFactory.create_view(
payload=payload,
workspace=workspace,
user=jira_user,
user_auth=saas_user_auth,
decrypted_api_key=decrypted_api_key,
)
except RepositoryNotFoundError as e:
logger.warning(
'[Jira] Repository not found',
extra={'issue_key': payload.issue_key, 'error': str(e)},
)
await self._send_error_from_payload(payload, workspace, str(e))
return
except StartingConvoException as e:
logger.warning(
'[Jira] View creation failed',
extra={'issue_key': payload.issue_key, 'error': str(e)},
)
await self._send_error_from_payload(payload, workspace, str(e))
return
except Exception as e:
logger.error(
'[Jira] Unexpected error creating view',
extra={'issue_key': payload.issue_key, 'error': str(e)},
exc_info=True,
)
await self._send_error_from_payload(
payload,
workspace,
'Failed to initialize conversation. Please try again.',
)
return
# Step 5: Start job
await self.start_job(view)
async def _get_active_workspace(
self, payload: JiraWebhookPayload
) -> JiraWorkspace | None:
"""Validate and return the workspace for the webhook.
Returns None if:
- Workspace not found
- Workspace is inactive
- Request is from service account (to prevent recursion)
"""
workspace = await self.integration_store.get_workspace_by_name(
payload.workspace_name
)
if not workspace:
logger.warning(
'[Jira] Workspace not found',
extra={'workspace_name': payload.workspace_name},
)
# Can't send error without workspace credentials
return None
# Prevent recursive triggers from service account
if payload.user_email == workspace.svc_acc_email:
logger.debug(
'[Jira] Ignoring service account trigger',
extra={'workspace_name': payload.workspace_name},
)
return None
if workspace.status != 'active':
logger.warning(
'[Jira] Workspace inactive',
extra={'workspace_id': workspace.id, 'status': workspace.status},
)
await self._send_error_from_payload(
payload, workspace, 'Jira integration is not active for your workspace.'
)
return None
return workspace
async def _authenticate_user(
self, payload: JiraWebhookPayload, workspace: JiraWorkspace
async def authenticate_user(
self, jira_user_id: str, workspace_id: int
) -> tuple[JiraUser | None, UserAuth | None]:
"""Authenticate the Jira user and get OpenHands auth."""
"""Authenticate Jira user and get their OpenHands user auth."""
# Find active Jira user by Keycloak user ID and workspace ID
jira_user = await self.integration_store.get_active_user(
payload.account_id, workspace.id
jira_user_id, workspace_id
)
if not jira_user:
logger.warning(
'[Jira] User not found or inactive',
extra={
'account_id': payload.account_id,
'user_email': payload.user_email,
'workspace_id': workspace.id,
},
)
await self._send_error_from_payload(
payload,
workspace,
f'User {payload.user_email} is not authenticated or active in the Jira integration.',
f'[Jira] No active Jira user found for {jira_user_id} in workspace {workspace_id}'
)
return None, None
saas_user_auth = await get_user_auth_from_keycloak_id(
jira_user.keycloak_user_id
)
if not saas_user_auth:
logger.warning(
'[Jira] Failed to get OpenHands auth',
extra={
'keycloak_user_id': jira_user.keycloak_user_id,
'user_email': payload.user_email,
},
)
await self._send_error_from_payload(
payload,
workspace,
f'User {payload.user_email} is not authenticated with OpenHands.',
)
return None, None
return jira_user, saas_user_auth
async def start_job(self, view: JiraViewInterface):
async def _get_repositories(self, user_auth: UserAuth) -> list[Repository]:
"""Get repositories that the user has access to."""
provider_tokens = await user_auth.get_provider_tokens()
if provider_tokens is None:
return []
access_token = await user_auth.get_access_token()
user_id = await user_auth.get_user_id()
client = ProviderHandler(
provider_tokens=provider_tokens,
external_auth_token=access_token,
external_auth_id=user_id,
)
repos: list[Repository] = await client.get_repositories(
'pushed', server_config.app_mode, None, None, None, None
)
return repos
def get_workspace_name_from_payload(self, payload: dict) -> str | None:
"""Extract workspace name from Jira webhook payload."""
if payload.get('webhookEvent') == 'comment_created':
selfUrl = payload.get('comment', {}).get('author', {}).get('self')
elif payload.get('webhookEvent') == 'jira:issue_updated':
selfUrl = payload.get('user', {}).get('self')
else:
return None
if not selfUrl:
return None
parsedUrl = urlparse(selfUrl)
return parsedUrl.hostname or None
def parse_webhook(self, message: Message) -> JobContext | None:
payload = message.message.get('payload', {})
issue_data = payload.get('issue', {})
issue_id = issue_data.get('id')
issue_key = issue_data.get('key')
self_url = issue_data.get('self', '')
if not self_url:
logger.warning('[Jira] Missing self URL in issue data')
base_api_url = ''
elif '/rest/' in self_url:
base_api_url = self_url.split('/rest/')[0]
else:
# Fallback: extract base URL using urlparse
parsed = urlparse(self_url)
base_api_url = f'{parsed.scheme}://{parsed.netloc}'
comment = ''
if JiraFactory.is_ticket_comment(message):
comment_data = payload.get('comment', {})
comment = comment_data.get('body', '')
user_data: dict = comment_data.get('author', {})
elif JiraFactory.is_labeled_ticket(message):
user_data = payload.get('user', {})
else:
raise ValueError('Unrecognized jira event')
user_email = user_data.get('emailAddress')
display_name = user_data.get('displayName')
account_id = user_data.get('accountId')
workspace_name = ''
parsedUrl = urlparse(base_api_url)
if parsedUrl.hostname:
workspace_name = parsedUrl.hostname
if not all(
[
issue_id,
issue_key,
user_email,
display_name,
account_id,
workspace_name,
base_api_url,
]
):
return None
return JobContext(
issue_id=issue_id,
issue_key=issue_key,
user_msg=comment,
user_email=user_email,
display_name=display_name,
platform_user_id=account_id,
workspace_name=workspace_name,
base_api_url=base_api_url,
)
async def is_job_requested(self, message: Message) -> bool:
return JiraFactory.is_labeled_ticket(message) or JiraFactory.is_ticket_comment(
message
)
async def receive_message(self, message: Message):
"""Process incoming Jira webhook message."""
payload = message.message.get('payload', {})
logger.info('[Jira]: received payload', extra={'payload': payload})
is_job_requested = await self.is_job_requested(message)
if not is_job_requested:
return
job_context = self.parse_webhook(message)
if not job_context:
logger.info(
'[Jira] Failed to parse webhook payload - missing required fields or invalid structure',
extra={'event_type': payload.get('webhookEvent')},
)
return
# Get workspace by user email domain
workspace = await self.integration_store.get_workspace_by_name(
job_context.workspace_name
)
if not workspace:
logger.warning(
f'[Jira] No workspace found for email domain: {job_context.user_email}'
)
await self._send_error_comment(
job_context,
'Your workspace is not configured with Jira integration.',
None,
)
return
# Prevent any recursive triggers from the service account
if job_context.user_email == workspace.svc_acc_email:
return
if workspace.status != 'active':
logger.warning(f'[Jira] Workspace {workspace.id} is not active')
await self._send_error_comment(
job_context,
'Jira integration is not active for your workspace.',
workspace,
)
return
# Authenticate user
jira_user, saas_user_auth = await self.authenticate_user(
job_context.platform_user_id, workspace.id
)
if not jira_user or not saas_user_auth:
logger.warning(
f'[Jira] User authentication failed for {job_context.user_email}'
)
await self._send_error_comment(
job_context,
f'User {job_context.user_email} is not authenticated or active in the Jira integration.',
workspace,
)
return
# Get issue details
try:
api_key = self.token_manager.decrypt_text(workspace.svc_acc_api_key)
issue_title, issue_description = await self.get_issue_details(
job_context, workspace.jira_cloud_id, workspace.svc_acc_email, api_key
)
job_context.issue_title = issue_title
job_context.issue_description = issue_description
except Exception as e:
logger.error(f'[Jira] Failed to get issue context: {str(e)}')
await self._send_error_comment(
job_context,
'Failed to retrieve issue details. Please check the issue key and try again.',
workspace,
)
return
try:
# Create Jira view
jira_view = await JiraFactory.create_jira_view_from_payload(
job_context,
saas_user_auth,
jira_user,
workspace,
)
except Exception as e:
logger.error(f'[Jira] Failed to create jira view: {str(e)}', exc_info=True)
await self._send_error_comment(
job_context,
'Failed to initialize conversation. Please try again.',
workspace,
)
return
if not await self.is_repository_specified(message, jira_view):
return
await self.start_job(jira_view)
async def is_repository_specified(
self, message: Message, jira_view: JiraViewInterface
) -> bool:
"""
Check if a job is requested and handle repository selection.
"""
try:
# Get user repositories
user_repos: list[Repository] = await self._get_repositories(
jira_view.saas_user_auth
)
target_str = f'{jira_view.job_context.issue_description}\n{jira_view.job_context.user_msg}'
# Try to infer repository from issue description
match, repos = filter_potential_repos_by_user_msg(target_str, user_repos)
if match:
# Found exact repository match
jira_view.selected_repo = repos[0].full_name
logger.info(f'[Jira] Inferred repository: {repos[0].full_name}')
return True
else:
# No clear match - send repository selection comment
await self._send_repo_selection_comment(jira_view)
return False
except Exception as e:
logger.error(f'[Jira] Error determining repository: {str(e)}')
return False
async def start_job(self, jira_view: JiraViewInterface):
"""Start a Jira job/conversation."""
# Import here to prevent circular import
from server.conversation_callback_processor.jira_callback_processor import (
@@ -265,79 +305,101 @@ class JiraManager(Manager):
)
try:
user_info: JiraUser = jira_view.jira_user
logger.info(
'[Jira] Starting job',
extra={
'issue_key': view.payload.issue_key,
'user_id': view.jira_user.keycloak_user_id,
'selected_repo': view.selected_repo,
},
f'[Jira] Starting job for user {user_info.keycloak_user_id} '
f'issue {jira_view.job_context.issue_key}',
)
# Create conversation
conversation_id = await view.create_or_update_conversation(self.jinja_env)
conversation_id = await jira_view.create_or_update_conversation(
self.jinja_env
)
logger.info(
'[Jira] Conversation created',
extra={
'conversation_id': conversation_id,
'issue_key': view.payload.issue_key,
},
f'[Jira] Created/Updated conversation {conversation_id} for issue {jira_view.job_context.issue_key}'
)
# Register callback processor for updates
if isinstance(view, JiraNewConversationView):
if isinstance(jira_view, JiraNewConversationView):
processor = JiraCallbackProcessor(
issue_key=view.payload.issue_key,
workspace_name=view.jira_workspace.name,
)
register_callback_processor(conversation_id, processor)
logger.info(
'[Jira] Callback processor registered',
extra={'conversation_id': conversation_id},
issue_key=jira_view.job_context.issue_key,
workspace_name=jira_view.jira_workspace.name,
)
# Send success response
msg_info = view.get_response_msg()
# Register the callback processor
register_callback_processor(conversation_id, processor)
logger.info(
f'[Jira] Created callback processor for conversation {conversation_id}'
)
# Send initial response
msg_info = jira_view.get_response_msg()
except MissingSettingsError as e:
logger.warning(
'[Jira] Missing settings error',
extra={'issue_key': view.payload.issue_key, 'error': str(e)},
)
logger.warning(f'[Jira] Missing settings error: {str(e)}')
msg_info = f'Please re-login into [OpenHands Cloud]({HOST_URL}) before starting a job.'
except LLMAuthenticationError as e:
logger.warning(
'[Jira] LLM authentication error',
extra={'issue_key': view.payload.issue_key, 'error': str(e)},
)
logger.warning(f'[Jira] LLM authentication error: {str(e)}')
msg_info = f'Please set a valid LLM API key in [OpenHands Cloud]({HOST_URL}) before starting a job.'
except SessionExpiredError as e:
logger.warning(
'[Jira] Session expired',
extra={'issue_key': view.payload.issue_key, 'error': str(e)},
)
logger.warning(f'[Jira] Session expired: {str(e)}')
msg_info = get_session_expired_message()
except StartingConvoException as e:
logger.warning(
'[Jira] Conversation start failed',
extra={'issue_key': view.payload.issue_key, 'error': str(e)},
)
msg_info = str(e)
except Exception as e:
logger.error(
'[Jira] Unexpected error starting job',
extra={'issue_key': view.payload.issue_key, 'error': str(e)},
exc_info=True,
f'[Jira] Unexpected error starting job: {str(e)}', exc_info=True
)
msg_info = 'Sorry, there was an unexpected error starting the job. Please try again.'
# Send response comment
await self._send_comment(view, msg_info)
try:
api_key = self.token_manager.decrypt_text(
jira_view.jira_workspace.svc_acc_api_key
)
await self.send_message(
self.create_outgoing_message(msg=msg_info),
issue_key=jira_view.job_context.issue_key,
jira_cloud_id=jira_view.jira_workspace.jira_cloud_id,
svc_acc_email=jira_view.jira_workspace.svc_acc_email,
svc_acc_api_key=api_key,
)
except Exception as e:
logger.error(f'[Jira] Failed to send response message: {str(e)}')
async def get_issue_details(
self,
job_context: JobContext,
jira_cloud_id: str,
svc_acc_email: str,
svc_acc_api_key: str,
) -> Tuple[str, str]:
url = f'{JIRA_CLOUD_API_URL}/{jira_cloud_id}/rest/api/2/issue/{job_context.issue_key}'
async with httpx.AsyncClient(verify=httpx_verify_option()) as client:
response = await client.get(url, auth=(svc_acc_email, svc_acc_api_key))
response.raise_for_status()
issue_payload = response.json()
if not issue_payload:
raise ValueError(f'Issue with key {job_context.issue_key} not found.')
title = issue_payload.get('fields', {}).get('summary', '')
description = issue_payload.get('fields', {}).get('description', '')
if not title:
raise ValueError(
f'Issue with key {job_context.issue_key} does not have a title.'
)
if not description:
raise ValueError(
f'Issue with key {job_context.issue_key} does not have a description.'
)
return title, description
async def send_message(
self,
@@ -347,7 +409,6 @@ class JiraManager(Manager):
svc_acc_email: str,
svc_acc_api_key: str,
):
"""Send a comment to a Jira issue."""
url = (
f'{JIRA_CLOUD_API_URL}/{jira_cloud_id}/rest/api/2/issue/{issue_key}/comment'
)
@@ -359,53 +420,54 @@ class JiraManager(Manager):
response.raise_for_status()
return response.json()
async def _send_comment(self, view: JiraViewInterface, msg: str):
"""Send a comment using credentials from the view."""
try:
api_key = self.token_manager.decrypt_text(
view.jira_workspace.svc_acc_api_key
)
await self.send_message(
self.create_outgoing_message(msg=msg),
issue_key=view.payload.issue_key,
jira_cloud_id=view.jira_workspace.jira_cloud_id,
svc_acc_email=view.jira_workspace.svc_acc_email,
svc_acc_api_key=api_key,
)
except Exception as e:
logger.error(
'[Jira] Failed to send comment',
extra={'issue_key': view.payload.issue_key, 'error': str(e)},
)
async def _send_error_from_payload(
async def _send_error_comment(
self,
payload: JiraWebhookPayload,
workspace: JiraWorkspace,
job_context: JobContext,
error_msg: str,
workspace: JiraWorkspace | None,
):
"""Send error comment before view is created (using payload directly)."""
"""Send error comment to Jira issue."""
if not workspace:
logger.error('[Jira] Cannot send error comment - no workspace available')
return
try:
api_key = self.token_manager.decrypt_text(workspace.svc_acc_api_key)
await self.send_message(
self.create_outgoing_message(msg=error_msg),
issue_key=payload.issue_key,
issue_key=job_context.issue_key,
jira_cloud_id=workspace.jira_cloud_id,
svc_acc_email=workspace.svc_acc_email,
svc_acc_api_key=api_key,
)
except Exception as e:
logger.error(
'[Jira] Failed to send error comment',
extra={'issue_key': payload.issue_key, 'error': str(e)},
logger.error(f'[Jira] Failed to send error comment: {str(e)}')
async def _send_repo_selection_comment(self, jira_view: JiraViewInterface):
"""Send a comment with repository options for the user to choose."""
try:
comment_msg = (
'I need to know which repository to work with. '
'Please add it to your issue description or send a followup comment.'
)
def get_workspace_name_from_payload(self, payload: dict) -> str | None:
"""Extract workspace name from Jira webhook payload.
api_key = self.token_manager.decrypt_text(
jira_view.jira_workspace.svc_acc_api_key
)
This method is used by the route for signature verification.
"""
parse_result = self.payload_parser.parse(payload)
if isinstance(parse_result, JiraPayloadSuccess):
return parse_result.payload.workspace_name
return None
await self.send_message(
self.create_outgoing_message(msg=comment_msg),
issue_key=jira_view.job_context.issue_key,
jira_cloud_id=jira_view.jira_workspace.jira_cloud_id,
svc_acc_email=jira_view.jira_workspace.svc_acc_email,
svc_acc_api_key=api_key,
)
logger.info(
f'[Jira] Sent repository selection comment for issue {jira_view.job_context.issue_key}'
)
except Exception as e:
logger.error(
f'[Jira] Failed to send repository selection comment: {str(e)}'
)
@@ -1,267 +0,0 @@
"""Centralized payload parsing for Jira webhooks.
This module provides a single source of truth for parsing and validating
Jira webhook payloads, replacing scattered parsing logic throughout the codebase.
"""
from dataclasses import dataclass
from enum import Enum
from urllib.parse import urlparse
from openhands.core.logger import openhands_logger as logger
class JiraEventType(Enum):
"""Types of Jira events we handle."""
LABELED_TICKET = 'labeled_ticket'
COMMENT_MENTION = 'comment_mention'
@dataclass(frozen=True)
class JiraWebhookPayload:
"""Normalized, validated representation of a Jira webhook payload.
This immutable dataclass replaces JobContext and provides a single
source of truth for all webhook data. All parsing happens in
JiraPayloadParser, ensuring consistent validation.
"""
event_type: JiraEventType
raw_event: str # Original webhookEvent value
# Issue data
issue_id: str
issue_key: str
# User data
user_email: str
display_name: str
account_id: str
# Workspace data (derived from issue self URL)
workspace_name: str
base_api_url: str
# Event-specific data
comment_body: str = '' # For comment events
@property
def user_msg(self) -> str:
"""Alias for comment_body for backward compatibility."""
return self.comment_body
class JiraPayloadParseError(Exception):
"""Raised when payload parsing fails."""
def __init__(self, reason: str, event_type: str | None = None):
self.reason = reason
self.event_type = event_type
super().__init__(reason)
@dataclass(frozen=True)
class JiraPayloadSuccess:
"""Result when parsing succeeds."""
payload: JiraWebhookPayload
@dataclass(frozen=True)
class JiraPayloadSkipped:
"""Result when event is intentionally skipped."""
skip_reason: str
@dataclass(frozen=True)
class JiraPayloadError:
"""Result when parsing fails due to invalid data."""
error: str
JiraPayloadParseResult = JiraPayloadSuccess | JiraPayloadSkipped | JiraPayloadError
class JiraPayloadParser:
"""Centralized parser for Jira webhook payloads.
This class provides a single entry point for parsing webhooks,
determining event types, and extracting all necessary fields.
Replaces scattered parsing in JiraFactory and JiraManager.
"""
def __init__(self, oh_label: str, inline_oh_label: str):
"""Initialize parser with OpenHands label configuration.
Args:
oh_label: Label that triggers OpenHands (e.g., 'openhands')
inline_oh_label: Mention that triggers OpenHands (e.g., '@openhands')
"""
self.oh_label = oh_label
self.inline_oh_label = inline_oh_label
def parse(self, raw_payload: dict) -> JiraPayloadParseResult:
"""Parse a raw webhook payload into a normalized JiraWebhookPayload.
Args:
raw_payload: The raw webhook payload dict from Jira
Returns:
One of:
- JiraPayloadSuccess: Valid, actionable event with payload
- JiraPayloadSkipped: Event we intentionally don't process
- JiraPayloadError: Malformed payload we expected to process
"""
webhook_event = raw_payload.get('webhookEvent', '')
logger.debug(
'[Jira] Parsing webhook payload', extra={'webhook_event': webhook_event}
)
if webhook_event == 'jira:issue_updated':
return self._parse_label_event(raw_payload, webhook_event)
elif webhook_event == 'comment_created':
return self._parse_comment_event(raw_payload, webhook_event)
else:
return JiraPayloadSkipped(f'Unhandled webhook event type: {webhook_event}')
def _parse_label_event(
self, payload: dict, webhook_event: str
) -> JiraPayloadParseResult:
"""Parse an issue_updated event for label changes."""
changelog = payload.get('changelog', {})
items = changelog.get('items', [])
# Extract labels that were added
labels = [
item.get('toString', '')
for item in items
if item.get('field') == 'labels' and 'toString' in item
]
if self.oh_label not in labels:
return JiraPayloadSkipped(
f"Label event does not contain '{self.oh_label}' label"
)
# For label events, user data comes from 'user' field
user_data = payload.get('user', {})
return self._extract_and_validate(
payload=payload,
user_data=user_data,
event_type=JiraEventType.LABELED_TICKET,
webhook_event=webhook_event,
comment_body='',
)
def _parse_comment_event(
self, payload: dict, webhook_event: str
) -> JiraPayloadParseResult:
"""Parse a comment_created event."""
comment_data = payload.get('comment', {})
comment_body = comment_data.get('body', '')
if not self._has_mention(comment_body):
return JiraPayloadSkipped(
f"Comment does not mention '{self.inline_oh_label}'"
)
# For comment events, user data comes from 'comment.author'
user_data = comment_data.get('author', {})
return self._extract_and_validate(
payload=payload,
user_data=user_data,
event_type=JiraEventType.COMMENT_MENTION,
webhook_event=webhook_event,
comment_body=comment_body,
)
def _has_mention(self, text: str) -> bool:
"""Check if text contains an exact mention of OpenHands."""
from integrations.utils import has_exact_mention
return has_exact_mention(text, self.inline_oh_label)
def _extract_and_validate(
self,
payload: dict,
user_data: dict,
event_type: JiraEventType,
webhook_event: str,
comment_body: str,
) -> JiraPayloadParseResult:
"""Extract common fields and validate required data is present."""
issue_data = payload.get('issue', {})
# Extract all fields with empty string defaults (makes them str type)
issue_id = issue_data.get('id', '')
issue_key = issue_data.get('key', '')
user_email = user_data.get('emailAddress', '')
display_name = user_data.get('displayName', '')
account_id = user_data.get('accountId', '')
base_api_url, workspace_name = self._extract_workspace_from_url(
issue_data.get('self', '')
)
# Validate required fields
missing: list[str] = []
if not issue_id:
missing.append('issue.id')
if not issue_key:
missing.append('issue.key')
if not user_email:
missing.append('user.emailAddress')
if not display_name:
missing.append('user.displayName')
if not account_id:
missing.append('user.accountId')
if not workspace_name:
missing.append('workspace_name (derived from issue.self)')
if not base_api_url:
missing.append('base_api_url (derived from issue.self)')
if missing:
return JiraPayloadError(f"Missing required fields: {', '.join(missing)}")
return JiraPayloadSuccess(
JiraWebhookPayload(
event_type=event_type,
raw_event=webhook_event,
issue_id=issue_id,
issue_key=issue_key,
user_email=user_email,
display_name=display_name,
account_id=account_id,
workspace_name=workspace_name,
base_api_url=base_api_url,
comment_body=comment_body,
)
)
def _extract_workspace_from_url(self, self_url: str) -> tuple[str, str]:
"""Extract base API URL and workspace name from issue self URL.
Args:
self_url: The 'self' URL from the issue data
Returns:
Tuple of (base_api_url, workspace_name)
"""
if not self_url:
return '', ''
# Extract base URL (everything before /rest/)
if '/rest/' in self_url:
base_api_url = self_url.split('/rest/')[0]
else:
parsed = urlparse(self_url)
base_api_url = f'{parsed.scheme}://{parsed.netloc}'
# Extract workspace name (hostname)
parsed = urlparse(base_api_url)
workspace_name = parsed.hostname or ''
return base_api_url, workspace_name
+6 -37
View File
@@ -1,42 +1,26 @@
"""Type definitions and interfaces for Jira integration."""
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING
from integrations.models import JobContext
from jinja2 import Environment
from storage.jira_user import JiraUser
from storage.jira_workspace import JiraWorkspace
from openhands.server.user_auth.user_auth import UserAuth
if TYPE_CHECKING:
from integrations.jira.jira_payload import JiraWebhookPayload
class JiraViewInterface(ABC):
"""Interface for Jira views that handle different types of Jira interactions.
"""Interface for Jira views that handle different types of Jira interactions."""
Views hold the webhook payload directly rather than duplicating fields,
and fetch issue details lazily when needed.
"""
# Core data - view holds these references
payload: 'JiraWebhookPayload'
job_context: JobContext
saas_user_auth: UserAuth
jira_user: JiraUser
jira_workspace: JiraWorkspace
# Mutable state set during processing
selected_repo: str | None
conversation_id: str
@abstractmethod
async def get_issue_details(self) -> tuple[str, str]:
"""Fetch and cache issue title and description from Jira API.
Returns:
Tuple of (issue_title, issue_description)
"""
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
"""Get initial instructions for the conversation."""
pass
@abstractmethod
@@ -51,21 +35,6 @@ class JiraViewInterface(ABC):
class StartingConvoException(Exception):
"""Exception raised when starting a conversation fails.
This provides user-friendly error messages that can be sent back to Jira.
"""
pass
class RepositoryNotFoundError(Exception):
"""Raised when a repository cannot be determined from the issue.
This is a separate error domain from StartingConvoException - it represents
a precondition failure (no repo configured/found) rather than a conversation
creation failure. The manager catches this and converts it to a user-friendly
message.
"""
"""Exception raised when starting a conversation fails."""
pass
+65 -351
View File
@@ -1,21 +1,8 @@
"""Jira view implementations and factory.
from dataclasses import dataclass
Views are responsible for:
- Holding the webhook payload and auth context
- Lazy-loading issue details from Jira API when needed
- Creating conversations with the selected repository
"""
from dataclasses import dataclass, field
import httpx
from integrations.jira.jira_payload import JiraWebhookPayload
from integrations.jira.jira_types import (
JiraViewInterface,
RepositoryNotFoundError,
StartingConvoException,
)
from integrations.utils import CONVERSATION_URL, infer_repo_from_message
from integrations.jira.jira_types import JiraViewInterface, StartingConvoException
from integrations.models import JobContext, Message
from integrations.utils import CONVERSATION_URL, HOST, get_oh_labels, has_exact_mention
from jinja2 import Environment
from storage.jira_conversation import JiraConversation
from storage.jira_integration_store import JiraIntegrationStore
@@ -23,147 +10,52 @@ from storage.jira_user import JiraUser
from storage.jira_workspace import JiraWorkspace
from openhands.core.logger import openhands_logger as logger
from openhands.integrations.provider import ProviderHandler
from openhands.server.services.conversation_service import create_new_conversation
from openhands.server.services.conversation_service import (
create_new_conversation,
)
from openhands.server.user_auth.user_auth import UserAuth
from openhands.storage.data_models.conversation_metadata import ConversationTrigger
from openhands.utils.http_session import httpx_verify_option
JIRA_CLOUD_API_URL = 'https://api.atlassian.com/ex/jira'
integration_store = JiraIntegrationStore.get_instance()
OH_LABEL, INLINE_OH_LABEL = get_oh_labels(HOST)
@dataclass
class JiraNewConversationView(JiraViewInterface):
"""View for creating a new Jira conversation.
This view holds the webhook payload directly and lazily fetches
issue details when needed for rendering templates.
"""
payload: JiraWebhookPayload
job_context: JobContext
saas_user_auth: UserAuth
jira_user: JiraUser
jira_workspace: JiraWorkspace
selected_repo: str | None = None
conversation_id: str = ''
selected_repo: str | None
conversation_id: str
# Lazy-loaded issue details (cached after first fetch)
_issue_title: str | None = field(default=None, repr=False)
_issue_description: str | None = field(default=None, repr=False)
# Decrypted API key (set by factory)
_decrypted_api_key: str = field(default='', repr=False)
async def get_issue_details(self) -> tuple[str, str]:
"""Fetch issue details from Jira API (cached after first call).
Returns:
Tuple of (issue_title, issue_description)
Raises:
StartingConvoException: If issue details cannot be fetched
"""
if self._issue_title is not None and self._issue_description is not None:
return self._issue_title, self._issue_description
try:
url = f'{JIRA_CLOUD_API_URL}/{self.jira_workspace.jira_cloud_id}/rest/api/2/issue/{self.payload.issue_key}'
async with httpx.AsyncClient(verify=httpx_verify_option()) as client:
response = await client.get(
url,
auth=(
self.jira_workspace.svc_acc_email,
self._decrypted_api_key,
),
)
response.raise_for_status()
issue_payload = response.json()
if not issue_payload:
raise StartingConvoException(
f'Issue {self.payload.issue_key} not found.'
)
self._issue_title = issue_payload.get('fields', {}).get('summary', '')
self._issue_description = (
issue_payload.get('fields', {}).get('description', '') or ''
)
if not self._issue_title:
raise StartingConvoException(
f'Issue {self.payload.issue_key} does not have a title.'
)
logger.info(
'[Jira] Fetched issue details',
extra={
'issue_key': self.payload.issue_key,
'has_description': bool(self._issue_description),
},
)
return self._issue_title, self._issue_description
except httpx.HTTPStatusError as e:
logger.error(
'[Jira] Failed to fetch issue details',
extra={
'issue_key': self.payload.issue_key,
'status': e.response.status_code,
},
)
raise StartingConvoException(
f'Failed to fetch issue details: HTTP {e.response.status_code}'
)
except Exception as e:
if isinstance(e, StartingConvoException):
raise
logger.error(
'[Jira] Failed to fetch issue details',
extra={'issue_key': self.payload.issue_key, 'error': str(e)},
)
raise StartingConvoException(f'Failed to fetch issue details: {str(e)}')
async def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
"""Get instructions for the conversation.
This fetches issue details if not already cached.
Returns:
Tuple of (system_instructions, user_message)
"""
issue_title, issue_description = await self.get_issue_details()
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
"""Instructions passed when conversation is first initialized"""
instructions_template = jinja_env.get_template('jira_instructions.j2')
instructions = instructions_template.render()
user_msg_template = jinja_env.get_template('jira_new_conversation.j2')
user_msg = user_msg_template.render(
issue_key=self.payload.issue_key,
issue_title=issue_title,
issue_description=issue_description,
user_message=self.payload.user_msg,
issue_key=self.job_context.issue_key,
issue_title=self.job_context.issue_title,
issue_description=self.job_context.issue_description,
user_message=self.job_context.user_msg or '',
)
return instructions, user_msg
async def create_or_update_conversation(self, jinja_env: Environment) -> str:
"""Create a new Jira conversation.
"""Create a new Jira conversation"""
Returns:
The conversation ID
Raises:
StartingConvoException: If conversation creation fails
"""
if not self.selected_repo:
raise StartingConvoException('No repository selected for this conversation')
provider_tokens = await self.saas_user_auth.get_provider_tokens()
user_secrets = await self.saas_user_auth.get_secrets()
instructions, user_msg = await self._get_instructions(jinja_env)
instructions, user_msg = self._get_instructions(jinja_env)
try:
agent_loop_info = await create_new_conversation(
@@ -181,259 +73,81 @@ class JiraNewConversationView(JiraViewInterface):
self.conversation_id = agent_loop_info.conversation_id
logger.info(
'[Jira] Created conversation',
extra={
'conversation_id': self.conversation_id,
'issue_key': self.payload.issue_key,
'selected_repo': self.selected_repo,
},
)
logger.info(f'[Jira] Created conversation {self.conversation_id}')
# Store Jira conversation mapping
jira_conversation = JiraConversation(
conversation_id=self.conversation_id,
issue_id=self.payload.issue_id,
issue_key=self.payload.issue_key,
issue_id=self.job_context.issue_id,
issue_key=self.job_context.issue_key,
jira_user_id=self.jira_user.id,
)
await integration_store.create_conversation(jira_conversation)
return self.conversation_id
except Exception as e:
if isinstance(e, StartingConvoException):
raise
logger.error(
'[Jira] Failed to create conversation',
extra={'issue_key': self.payload.issue_key, 'error': str(e)},
exc_info=True,
f'[Jira] Failed to create conversation: {str(e)}', exc_info=True
)
raise StartingConvoException(f'Failed to create conversation: {str(e)}')
def get_response_msg(self) -> str:
"""Get the response message to send back to Jira."""
"""Get the response message to send back to Jira"""
conversation_link = CONVERSATION_URL.format(self.conversation_id)
return f"I'm on it! {self.payload.display_name} can [track my progress here|{conversation_link}]."
return f"I'm on it! {self.job_context.display_name} can [track my progress here|{conversation_link}]."
class JiraFactory:
"""Factory for creating Jira views.
The factory is responsible for:
- Creating the appropriate view type
- Inferring and selecting the repository
- Validating all required data is available
Repository selection happens here so that view creation either
succeeds with a valid repo or fails with a clear error.
"""
"""Factory for creating Jira views based on message content"""
@staticmethod
async def _create_provider_handler(user_auth: UserAuth) -> ProviderHandler | None:
"""Create a ProviderHandler for the user."""
provider_tokens = await user_auth.get_provider_tokens()
if provider_tokens is None:
return None
def is_labeled_ticket(message: Message) -> bool:
payload = message.message.get('payload', {})
event_type = payload.get('webhookEvent')
access_token = await user_auth.get_access_token()
user_id = await user_auth.get_user_id()
if event_type != 'jira:issue_updated':
return False
return ProviderHandler(
provider_tokens=provider_tokens,
external_auth_token=access_token,
external_auth_id=user_id,
)
changelog = payload.get('changelog', {})
items = changelog.get('items', [])
labels = [
item.get('toString', '')
for item in items
if item.get('field') == 'labels' and 'toString' in item
]
return OH_LABEL in labels
@staticmethod
def _extract_potential_repos(
issue_key: str,
issue_title: str,
issue_description: str,
user_msg: str,
) -> list[str]:
"""Extract potential repository names from issue content.
def is_ticket_comment(message: Message) -> bool:
payload = message.message.get('payload', {})
event_type = payload.get('webhookEvent')
Raises:
RepositoryNotFoundError: If no potential repos found in text.
"""
search_text = f'{issue_title}\n{issue_description}\n{user_msg}'
potential_repos = infer_repo_from_message(search_text)
if event_type != 'comment_created':
return False
if not potential_repos:
raise RepositoryNotFoundError(
'Could not determine which repository to use. '
'Please mention the repository (e.g., owner/repo) in the issue description or comment.'
)
logger.info(
'[Jira] Found potential repositories in issue content',
extra={'issue_key': issue_key, 'potential_repos': potential_repos},
)
return potential_repos
comment_data = payload.get('comment', {})
comment_body = comment_data.get('body', '')
return has_exact_mention(comment_body, INLINE_OH_LABEL)
@staticmethod
async def _verify_repos(
issue_key: str,
potential_repos: list[str],
provider_handler: ProviderHandler,
) -> list[str]:
"""Verify which repos the user has access to."""
verified_repos: list[str] = []
for repo_name in potential_repos:
try:
repository = await provider_handler.verify_repo_provider(repo_name)
verified_repos.append(repository.full_name)
logger.debug(
'[Jira] Repository verification succeeded',
extra={'issue_key': issue_key, 'repository': repository.full_name},
)
except Exception as e:
logger.debug(
'[Jira] Repository verification failed',
extra={
'issue_key': issue_key,
'repo_name': repo_name,
'error': str(e),
},
)
return verified_repos
@staticmethod
def _select_single_repo(
issue_key: str,
potential_repos: list[str],
verified_repos: list[str],
) -> str:
"""Select exactly one repo from verified repos.
Raises:
RepositoryNotFoundError: If zero or multiple repos verified.
"""
if len(verified_repos) == 0:
raise RepositoryNotFoundError(
f'Could not access any of the mentioned repositories: {", ".join(potential_repos)}. '
'Please ensure you have access to the repository and it exists.'
)
if len(verified_repos) > 1:
raise RepositoryNotFoundError(
f'Multiple repositories found: {", ".join(verified_repos)}. '
'Please specify exactly one repository in the issue description or comment.'
)
logger.info(
'[Jira] Verified repository access',
extra={'issue_key': issue_key, 'repository': verified_repos[0]},
)
return verified_repos[0]
@staticmethod
async def _infer_repository(
payload: JiraWebhookPayload,
user_auth: UserAuth,
issue_title: str,
issue_description: str,
) -> str:
"""Infer and verify the repository from issue content.
Raises:
RepositoryNotFoundError: If no valid repository can be determined.
"""
provider_handler = await JiraFactory._create_provider_handler(user_auth)
if not provider_handler:
raise RepositoryNotFoundError(
'No Git provider connected. Please connect a Git provider in OpenHands settings.'
)
potential_repos = JiraFactory._extract_potential_repos(
payload.issue_key, issue_title, issue_description, payload.user_msg
)
verified_repos = await JiraFactory._verify_repos(
payload.issue_key, potential_repos, provider_handler
)
return JiraFactory._select_single_repo(
payload.issue_key, potential_repos, verified_repos
)
@staticmethod
async def create_view(
payload: JiraWebhookPayload,
workspace: JiraWorkspace,
user: JiraUser,
user_auth: UserAuth,
decrypted_api_key: str,
async def create_jira_view_from_payload(
job_context: JobContext,
saas_user_auth: UserAuth,
jira_user: JiraUser,
jira_workspace: JiraWorkspace,
) -> JiraViewInterface:
"""Create a Jira view with repository already selected.
"""Create appropriate Jira view based on the message and user state"""
This factory method:
1. Creates the view with payload and auth context
2. Fetches issue details (needed for repo inference)
3. Infers and selects the repository
if not jira_user or not saas_user_auth or not jira_workspace:
raise StartingConvoException('User not authenticated with Jira integration')
If any step fails, an appropriate exception is raised with
a user-friendly message.
Args:
payload: Parsed webhook payload
workspace: The Jira workspace
user: The Jira user
user_auth: OpenHands user authentication
decrypted_api_key: Decrypted service account API key
Returns:
A JiraViewInterface with selected_repo populated
Raises:
StartingConvoException: If view creation fails
RepositoryNotFoundError: If repository cannot be determined
"""
logger.info(
'[Jira] Creating view',
extra={
'issue_key': payload.issue_key,
'event_type': payload.event_type.value,
},
return JiraNewConversationView(
job_context=job_context,
saas_user_auth=saas_user_auth,
jira_user=jira_user,
jira_workspace=jira_workspace,
selected_repo=None, # Will be set later after repo inference
conversation_id='', # Will be set when conversation is created
)
# Create the view
view = JiraNewConversationView(
payload=payload,
saas_user_auth=user_auth,
jira_user=user,
jira_workspace=workspace,
_decrypted_api_key=decrypted_api_key,
)
# Fetch issue details (needed for repo inference)
try:
issue_title, issue_description = await view.get_issue_details()
except StartingConvoException:
raise # Re-raise with original message
except Exception as e:
raise StartingConvoException(f'Failed to fetch issue details: {str(e)}')
# Infer and select repository
selected_repo = await JiraFactory._infer_repository(
payload=payload,
user_auth=user_auth,
issue_title=issue_title,
issue_description=issue_description,
)
view.selected_repo = selected_repo
logger.info(
'[Jira] View created successfully',
extra={
'issue_key': payload.issue_key,
'selected_repo': selected_repo,
},
)
return view
+5
View File
@@ -16,6 +16,11 @@ class Manager(ABC):
"Send message to integration from Openhands server"
raise NotImplementedError
@abstractmethod
async def is_job_requested(self, message: Message) -> bool:
"Confirm that a job is being requested"
raise NotImplementedError
@abstractmethod
def start_job(self):
"Kick off a job with openhands agent"
+31 -20
View File
@@ -398,42 +398,53 @@ def infer_repo_from_message(user_msg: str) -> list[str]:
"""
Extract all repository names in the format 'owner/repo' from various Git provider URLs
and direct mentions in text. Supports GitHub, GitLab, and BitBucket.
Args:
user_msg: Input message that may contain repository references
Returns:
List of repository names in 'owner/repo' format, empty list if none found
"""
# Normalize the message by removing extra whitespace and newlines
normalized_msg = re.sub(r'\s+', ' ', user_msg.strip())
git_url_pattern = (
r'https?://(?:github\.com|gitlab\.com|bitbucket\.org)/'
r'([a-zA-Z0-9_.-]+)/([a-zA-Z0-9_.-]+?)(?:\.git)?'
r'(?:[/?#].*?)?(?=\s|$|[^\w.-])'
)
# Pattern to match Git URLs from GitHub, GitLab, and BitBucket
# Captures: protocol, domain, owner, repo (with optional .git extension)
git_url_pattern = r'https?://(?:github\.com|gitlab\.com|bitbucket\.org)/([a-zA-Z0-9_.-]+)/([a-zA-Z0-9_.-]+?)(?:\.git)?(?:[/?#].*?)?(?=\s|$|[^\w.-])'
# UPDATED: allow {{ owner/repo }} in addition to existing boundaries
# Pattern to match direct owner/repo mentions (e.g., "OpenHands/OpenHands")
# Must be surrounded by word boundaries or specific characters to avoid false positives
direct_pattern = (
r'(?:^|\s|{{|[\[\(\'":`])' # left boundary
r'([a-zA-Z0-9_.-]+)/([a-zA-Z0-9_.-]+)'
r'(?=\s|$|}}|[\]\)\'",.:`])' # right boundary
r'(?:^|\s|[\[\(\'"])([a-zA-Z0-9_.-]+)/([a-zA-Z0-9_.-]+)(?=\s|$|[\]\)\'",.])'
)
matches: list[str] = []
matches = []
# Git URLs first (highest priority)
for owner, repo in re.findall(git_url_pattern, normalized_msg):
# First, find all Git URLs (highest priority)
git_matches = re.findall(git_url_pattern, normalized_msg)
for owner, repo in git_matches:
# Remove .git extension if present
repo = re.sub(r'\.git$', '', repo)
matches.append(f'{owner}/{repo}')
# Direct mentions
for owner, repo in re.findall(direct_pattern, normalized_msg):
# Second, find all direct owner/repo mentions
direct_matches = re.findall(direct_pattern, normalized_msg)
for owner, repo in direct_matches:
full_match = f'{owner}/{repo}'
# Skip if it looks like a version number, date, or file path
if (
re.match(r'^\d+\.\d+/\d+\.\d+$', full_match)
or re.match(r'^\d{1,2}/\d{1,2}$', full_match)
or re.match(r'^[A-Z]/[A-Z]$', full_match)
or repo.endswith(('.txt', '.md', '.py', '.js'))
or ('.' in repo and len(repo.split('.')) > 2)
):
re.match(r'^\d+\.\d+/\d+\.\d+$', full_match) # version numbers
or re.match(r'^\d{1,2}/\d{1,2}$', full_match) # dates
or re.match(r'^[A-Z]/[A-Z]$', full_match) # single letters
or repo.endswith('.txt')
or repo.endswith('.md') # file extensions
or repo.endswith('.py')
or repo.endswith('.js')
or '.' in repo
and len(repo.split('.')) > 2
): # complex file paths
continue
# Avoid duplicates from Git URLs already found
if full_match not in matches:
matches.append(full_match)
@@ -1,28 +0,0 @@
"""Add git_user_name and git_user_email columns to user table.
Revision ID: 090
Revises: 089
Create Date: 2025-01-22
"""
import sqlalchemy as sa
from alembic import op
revision = '090'
down_revision = '089'
def upgrade() -> None:
op.add_column(
'user',
sa.Column('git_user_name', sa.String, nullable=True),
)
op.add_column(
'user',
sa.Column('git_user_email', sa.String, nullable=True),
)
def downgrade() -> None:
op.drop_column('user', 'git_user_email')
op.drop_column('user', 'git_user_name')
+12 -12
View File
@@ -6102,14 +6102,14 @@ llama = ["llama-index (>=0.12.29,<0.13.0)", "llama-index-core (>=0.12.29,<0.13.0
[[package]]
name = "openhands-agent-server"
version = "1.10.0"
version = "1.8.2"
description = "OpenHands Agent Server - REST/WebSocket interface for OpenHands AI Agent"
optional = false
python-versions = ">=3.12"
groups = ["main"]
files = [
{file = "openhands_agent_server-1.10.0-py3-none-any.whl", hash = "sha256:2e21076fff5e7cf9d03a3b011e2c90a6a3a46d2da3f18db9f7553ac413229c22"},
{file = "openhands_agent_server-1.10.0.tar.gz", hash = "sha256:2062da2496a98a6c23201d086f124e02329d6c6d9d1b47be55921c084a29f55a"},
{file = "openhands_agent_server-1.8.2-py3-none-any.whl", hash = "sha256:e9abb2e0fe970715537d0e0fc1aea3dd64bb9e8b531f70cb72b3d4e486aaa46a"},
{file = "openhands_agent_server-1.8.2.tar.gz", hash = "sha256:43db2371ee84b100ac921396338dee74359fceeb5c9400c90530bcc5730144c3"},
]
[package.dependencies]
@@ -6168,9 +6168,9 @@ memory-profiler = ">=0.61"
numpy = "*"
openai = "2.8"
openhands-aci = "0.3.2"
openhands-agent-server = "1.10"
openhands-sdk = "1.10"
openhands-tools = "1.10"
openhands-agent-server = "1.8.2"
openhands-sdk = "1.8.2"
openhands-tools = "1.8.2"
opentelemetry-api = ">=1.33.1"
opentelemetry-exporter-otlp-proto-grpc = ">=1.33.1"
pathspec = ">=0.12.1"
@@ -6225,14 +6225,14 @@ url = ".."
[[package]]
name = "openhands-sdk"
version = "1.10.0"
version = "1.8.2"
description = "OpenHands SDK - Core functionality for building AI agents"
optional = false
python-versions = ">=3.12"
groups = ["main"]
files = [
{file = "openhands_sdk-1.10.0-py3-none-any.whl", hash = "sha256:5c8875f2a07d7fabe3449914639572bef9003821207cb06aa237a239e964eed5"},
{file = "openhands_sdk-1.10.0.tar.gz", hash = "sha256:93371b1af4532266ad2d225b9d7d3d711c745df31888efe643970673f62bdef9"},
{file = "openhands_sdk-1.8.2-py3-none-any.whl", hash = "sha256:b4fad9581865ce222a3e6722384e4df56113db01bd34c2d2d408dfd9695365c0"},
{file = "openhands_sdk-1.8.2.tar.gz", hash = "sha256:5bfb17c8b9515210d121249deb1f3d0dc407c3737edc55b5e73330b4571d61e3"},
]
[package.dependencies]
@@ -6253,14 +6253,14 @@ boto3 = ["boto3 (>=1.35.0)"]
[[package]]
name = "openhands-tools"
version = "1.10.0"
version = "1.8.2"
description = "OpenHands Tools - Runtime tools for AI agents"
optional = false
python-versions = ">=3.12"
groups = ["main"]
files = [
{file = "openhands_tools-1.10.0-py3-none-any.whl", hash = "sha256:1d5d2d1e34cc4ceb02c0ff1f008b06883ad48a8e7236ab8dd61ece64fbf8e2ed"},
{file = "openhands_tools-1.10.0.tar.gz", hash = "sha256:7ed38cb13545ec2c4a35c26ece725d5b35788d30597db8b1904619c043ec1194"},
{file = "openhands_tools-1.8.2-py3-none-any.whl", hash = "sha256:283f0c1fdd316914559cd16ade792383715478a8f5a73f7166daffc34bf9e5af"},
{file = "openhands_tools-1.8.2.tar.gz", hash = "sha256:eae416e3867f7cb595129a33a4b9237886c4b8a075d2bc7618da55963f2747d5"},
]
[package.dependencies]
-3
View File
@@ -16,7 +16,6 @@ from keycloak.exceptions import (
KeycloakError,
KeycloakPostError,
)
from server.auth.auth_error import ExpiredError
from server.auth.constants import (
BITBUCKET_APP_CLIENT_ID,
BITBUCKET_APP_CLIENT_SECRET,
@@ -427,8 +426,6 @@ class TokenManager:
access_token = data.get('access_token')
refresh_token = data.get('refresh_token')
if not access_token or not refresh_token:
if data.get('error') == 'bad_refresh_token':
raise ExpiredError()
raise ValueError(
'Failed to refresh token: missing access_token or refresh_token in response.'
)
@@ -14,6 +14,7 @@ from storage.conversation_callback import (
ConversationCallback,
ConversationCallbackProcessor,
)
from storage.database import session_maker
from openhands.core.logger import openhands_logger as logger
from openhands.core.schema.agent import AgentState
@@ -107,10 +108,13 @@ class GithubCallbackProcessor(ConversationCallbackProcessor):
f'[GitHub] Sent summary instruction to conversation {conversation_id} {summary_event}'
)
# Update the processor state - the outer session will commit this
# Update the processor state
self.send_summary_instruction = False
callback.set_processor(self)
callback.updated_at = datetime.now()
with session_maker() as session:
session.merge(callback)
session.commit()
return
# Extract the summary from the event store
@@ -126,15 +130,14 @@ class GithubCallbackProcessor(ConversationCallbackProcessor):
logger.info(f'[GitHub] Summary sent for conversation {conversation_id}')
# Mark callback as completed status - the outer session will commit this
# Mark callback as completed status
callback.status = CallbackStatus.COMPLETED
callback.updated_at = datetime.now()
with session_maker() as session:
session.merge(callback)
session.commit()
except Exception as e:
logger.exception(
f'[GitHub] Error processing conversation callback: {str(e)}'
)
# Mark callback as error to prevent infinite re-invocation
# The outer session will commit this
callback.status = CallbackStatus.ERROR
callback.updated_at = datetime.now()
+1 -1
View File
@@ -144,7 +144,7 @@ class SetAuthCookieMiddleware:
# "if accepted_tos is not None" as there should not be any users with
# accepted_tos equal to "None"
if accepted_tos is False and request.url.path != '/api/accept_tos':
logger.warning('User has not accepted the terms of service')
logger.error('User has not accepted the terms of service')
raise TosNotAcceptedError
def _should_attach(self, request: Request) -> bool:
+33 -29
View File
@@ -13,33 +13,46 @@ from server.constants import (
STRIPE_API_KEY,
)
from server.logger import logger
from starlette.datastructures import URL
from storage.billing_session import BillingSession
from storage.database import session_maker
from storage.lite_llm_manager import LiteLlmManager
from storage.subscription_access import SubscriptionAccess
from storage.user_store import UserStore
from openhands.app_server.config import get_global_config
from openhands.server.user_auth import get_user_id
stripe.api_key = STRIPE_API_KEY
billing_router = APIRouter(prefix='/api/billing')
async def validate_billing_enabled() -> None:
# TODO: Add a new app_mode named "ON_PREM" to support self-hosted customers instead of doing this
# and members should comment out the "validate_saas_environment" function if they are developing and testing locally.
def is_all_hands_saas_environment(request: Request) -> bool:
"""Check if the current domain is an All Hands SaaS environment.
Args:
request: FastAPI Request object
Returns:
True if the current domain contains "all-hands.dev" or "openhands.dev" postfix
"""
Validate that the billing feature flag is enabled
hostname = request.url.hostname or ''
return hostname.endswith('all-hands.dev') or hostname.endswith('openhands.dev')
def validate_saas_environment(request: Request) -> None:
"""Validate that the request is coming from an All Hands SaaS environment.
Args:
request: FastAPI Request object
Raises:
HTTPException: If the request is not from an All Hands SaaS environment
"""
config = get_global_config()
web_client_config = await config.web_client.get_web_client_config()
if not web_client_config.feature_flags.enable_billing:
if not is_all_hands_saas_environment(request):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=(
'Billing is disabled in this environment. '
'Please set OH_WEB_CLIENT_FEATURE_FLAGS_ENABLE_BILLING to enable billing.'
),
detail='Checkout sessions are only available for All Hands SaaS environments',
)
@@ -141,15 +154,14 @@ async def has_payment_method(user_id: str = Depends(get_user_id)) -> bool:
async def create_customer_setup_session(
request: Request, user_id: str = Depends(get_user_id)
) -> CreateBillingSessionResponse:
await validate_billing_enabled()
validate_saas_environment(request)
customer_info = await stripe_service.find_or_create_customer_by_user_id(user_id)
base_url = _get_base_url(request)
checkout_session = await stripe.checkout.Session.create_async(
customer=customer_info['customer_id'],
mode='setup',
payment_method_types=['card'],
success_url=f'{base_url}?free_credits=success',
cancel_url=f'{base_url}',
success_url=f'{request.base_url}?free_credits=success',
cancel_url=f'{request.base_url}',
)
return CreateBillingSessionResponse(redirect_url=checkout_session.url)
@@ -161,8 +173,8 @@ async def create_checkout_session(
request: Request,
user_id: str = Depends(get_user_id),
) -> CreateBillingSessionResponse:
await validate_billing_enabled()
base_url = _get_base_url(request)
validate_saas_environment(request)
customer_info = await stripe_service.find_or_create_customer_by_user_id(user_id)
checkout_session = await stripe.checkout.Session.create_async(
customer=customer_info['customer_id'],
@@ -185,8 +197,8 @@ async def create_checkout_session(
saved_payment_method_options={
'payment_method_save': 'enabled',
},
success_url=f'{base_url}api/billing/success?session_id={{CHECKOUT_SESSION_ID}}',
cancel_url=f'{base_url}api/billing/cancel?session_id={{CHECKOUT_SESSION_ID}}',
success_url=f'{request.base_url}api/billing/success?session_id={{CHECKOUT_SESSION_ID}}',
cancel_url=f'{request.base_url}api/billing/cancel?session_id={{CHECKOUT_SESSION_ID}}',
)
logger.info(
'created_stripe_checkout_session',
@@ -277,7 +289,7 @@ async def success_callback(session_id: str, request: Request):
session.commit()
return RedirectResponse(
f'{_get_base_url(request)}settings/billing?checkout=success', status_code=302
f'{request.base_url}settings/billing?checkout=success', status_code=302
)
@@ -305,13 +317,5 @@ async def cancel_callback(session_id: str, request: Request):
session.commit()
return RedirectResponse(
f'{_get_base_url(request)}settings/billing?checkout=cancel', status_code=302
f'{request.base_url}settings/billing?checkout=cancel', status_code=302
)
def _get_base_url(request: Request) -> URL:
# Never send any part of the credit card process over a non secure connection
base_url = request.base_url
if base_url.hostname != 'localhost':
base_url = base_url.replace(scheme='https')
return base_url
+2 -9
View File
@@ -1,7 +1,6 @@
import hashlib
import json
import os
import zlib
from base64 import b64decode, b64encode
from urllib.parse import parse_qs, urlencode, urlparse
@@ -52,11 +51,7 @@ def add_github_proxy_routes(app: FastAPI):
state_payload = json.dumps(
[query_params['state'][0], query_params['redirect_uri'][0]]
)
# Compress before encrypting to reduce URL length
# This is critical for feature deployments where reCAPTCHA tokens in state
# can cause "URL too long" errors from GitHub
compressed_payload = zlib.compress(state_payload.encode())
state = b64encode(_fernet().encrypt(compressed_payload)).decode()
state = b64encode(_fernet().encrypt(state_payload.encode())).decode()
query_params['state'] = [state]
query_params['redirect_uri'] = [
f'https://{request.url.netloc}/github-proxy/callback'
@@ -72,9 +67,7 @@ def add_github_proxy_routes(app: FastAPI):
parsed_url = urlparse(str(request.url))
query_params = parse_qs(parsed_url.query)
state = query_params['state'][0]
# Decrypt and decompress (reverse of github_proxy_start)
decrypted_payload = _fernet().decrypt(b64decode(state.encode()))
decrypted_state = zlib.decompress(decrypted_payload).decode()
decrypted_state = _fernet().decrypt(b64decode(state.encode())).decode()
# Build query Params
state, redirect_uri = json.loads(decrypted_state)
-104
View File
@@ -1,5 +1,4 @@
from pydantic import BaseModel, EmailStr, Field
from storage.org import Org
class OrgCreationError(Exception):
@@ -28,27 +27,6 @@ class OrgDatabaseError(OrgCreationError):
pass
class OrgDeletionError(Exception):
"""Base exception for organization deletion errors."""
pass
class OrgAuthorizationError(OrgDeletionError):
"""Raised when user is not authorized to delete organization."""
def __init__(self, message: str = 'Not authorized to delete organization'):
super().__init__(message)
class OrgNotFoundError(Exception):
"""Raised when organization is not found or user doesn't have access."""
def __init__(self, org_id: str):
self.org_id = org_id
super().__init__(f'Organization with id "{org_id}" not found')
class OrgCreate(BaseModel):
"""Request model for creating a new organization."""
@@ -87,85 +65,3 @@ class OrgResponse(BaseModel):
enable_solvability_analysis: bool | None = None
v1_enabled: bool | None = None
credits: float | None = None
@classmethod
def from_org(cls, org: Org, credits: float | None = None) -> 'OrgResponse':
"""Create an OrgResponse from an Org entity.
Args:
org: The organization entity to convert
credits: Optional credits value (defaults to None)
Returns:
OrgResponse: The response model instance
"""
return cls(
id=str(org.id),
name=org.name,
contact_name=org.contact_name,
contact_email=org.contact_email,
conversation_expiration=org.conversation_expiration,
agent=org.agent,
default_max_iterations=org.default_max_iterations,
security_analyzer=org.security_analyzer,
confirmation_mode=org.confirmation_mode,
default_llm_model=org.default_llm_model,
default_llm_api_key_for_byor=None,
default_llm_base_url=org.default_llm_base_url,
remote_runtime_resource_factor=org.remote_runtime_resource_factor,
enable_default_condenser=org.enable_default_condenser
if org.enable_default_condenser is not None
else True,
billing_margin=org.billing_margin,
enable_proactive_conversation_starters=org.enable_proactive_conversation_starters
if org.enable_proactive_conversation_starters is not None
else True,
sandbox_base_container_image=org.sandbox_base_container_image,
sandbox_runtime_container_image=org.sandbox_runtime_container_image,
org_version=org.org_version if org.org_version is not None else 0,
mcp_config=org.mcp_config,
search_api_key=None,
sandbox_api_key=None,
max_budget_per_task=org.max_budget_per_task,
enable_solvability_analysis=org.enable_solvability_analysis,
v1_enabled=org.v1_enabled,
credits=credits,
)
class OrgPage(BaseModel):
"""Paginated response model for organization list."""
items: list[OrgResponse]
next_page_id: str | None = None
class OrgUpdate(BaseModel):
"""Request model for updating an organization."""
# Basic organization information (any authenticated user can update)
contact_name: str | None = None
contact_email: EmailStr | None = Field(default=None, strip_whitespace=True)
conversation_expiration: int | None = None
default_max_iterations: int | None = Field(default=None, gt=0)
remote_runtime_resource_factor: int | None = Field(default=None, gt=0)
billing_margin: float | None = Field(default=None, ge=0, le=1)
enable_proactive_conversation_starters: bool | None = None
sandbox_base_container_image: str | None = None
sandbox_runtime_container_image: str | None = None
mcp_config: dict | None = None
sandbox_api_key: str | None = None
max_budget_per_task: float | None = Field(default=None, gt=0)
enable_solvability_analysis: bool | None = None
v1_enabled: bool | None = None
# LLM settings (require admin/owner role)
default_llm_model: str | None = None
default_llm_api_key_for_byor: str | None = None
default_llm_base_url: str | None = None
search_api_key: str | None = None
security_analyzer: str | None = None
agent: str | None = None
confirmation_mode: bool | None = None
enable_default_condenser: bool | None = None
condenser_max_size: int | None = Field(default=None, ge=20)
+26 -311
View File
@@ -1,98 +1,20 @@
from typing import Annotated
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, Query, status
from fastapi import APIRouter, Depends, HTTPException, status
from server.email_validation import get_admin_user_id
from server.routes.org_models import (
LiteLLMIntegrationError,
OrgAuthorizationError,
OrgCreate,
OrgDatabaseError,
OrgNameExistsError,
OrgNotFoundError,
OrgPage,
OrgResponse,
OrgUpdate,
)
from storage.org_service import OrgService
from openhands.core.logger import openhands_logger as logger
from openhands.server.user_auth import get_user_id
# Initialize API router
org_router = APIRouter(prefix='/api/organizations')
@org_router.get('', response_model=OrgPage)
async def list_user_orgs(
page_id: Annotated[
str | None,
Query(title='Optional next_page_id from the previously returned page'),
] = None,
limit: Annotated[
int,
Query(title='The max number of results in the page', gt=0, lte=100),
] = 100,
user_id: str = Depends(get_user_id),
) -> OrgPage:
"""List organizations for the authenticated user.
This endpoint returns a paginated list of all organizations that the
authenticated user is a member of.
Args:
page_id: Optional page ID (offset) for pagination
limit: Maximum number of organizations to return (1-100, default 100)
user_id: Authenticated user ID (injected by dependency)
Returns:
OrgPage: Paginated list of organizations
Raises:
HTTPException: 500 if retrieval fails
"""
logger.info(
'Listing organizations for user',
extra={
'user_id': user_id,
'page_id': page_id,
'limit': limit,
},
)
try:
# Fetch organizations from service layer
orgs, next_page_id = OrgService.get_user_orgs_paginated(
user_id=user_id,
page_id=page_id,
limit=limit,
)
# Convert Org entities to OrgResponse objects
org_responses = [OrgResponse.from_org(org, credits=None) for org in orgs]
logger.info(
'Successfully retrieved organizations',
extra={
'user_id': user_id,
'org_count': len(org_responses),
'has_more': next_page_id is not None,
},
)
return OrgPage(items=org_responses, next_page_id=next_page_id)
except Exception as e:
logger.exception(
'Unexpected error listing organizations',
extra={'user_id': user_id, 'error': str(e)},
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail='Failed to retrieve organizations',
)
@org_router.post('', response_model=OrgResponse, status_code=status.HTTP_201_CREATED)
async def create_org(
org_data: OrgCreate,
@@ -136,7 +58,31 @@ async def create_org(
# Retrieve credits from LiteLLM
credits = await OrgService.get_org_credits(user_id, org.id)
return OrgResponse.from_org(org, credits=credits)
return OrgResponse(
id=str(org.id),
name=org.name,
contact_name=org.contact_name,
contact_email=org.contact_email,
conversation_expiration=org.conversation_expiration,
agent=org.agent,
default_max_iterations=org.default_max_iterations,
security_analyzer=org.security_analyzer,
confirmation_mode=org.confirmation_mode,
default_llm_model=org.default_llm_model,
default_llm_base_url=org.default_llm_base_url,
remote_runtime_resource_factor=org.remote_runtime_resource_factor,
enable_default_condenser=org.enable_default_condenser,
billing_margin=org.billing_margin,
enable_proactive_conversation_starters=org.enable_proactive_conversation_starters,
sandbox_base_container_image=org.sandbox_base_container_image,
sandbox_runtime_container_image=org.sandbox_runtime_container_image,
org_version=org.org_version,
mcp_config=org.mcp_config,
max_budget_per_task=org.max_budget_per_task,
enable_solvability_analysis=org.enable_solvability_analysis,
v1_enabled=org.v1_enabled,
credits=credits,
)
except OrgNameExistsError as e:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
@@ -169,234 +115,3 @@ async def create_org(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail='An unexpected error occurred',
)
@org_router.get('/{org_id}', response_model=OrgResponse, status_code=status.HTTP_200_OK)
async def get_org(
org_id: UUID,
user_id: str = Depends(get_user_id),
) -> OrgResponse:
"""Get organization details by ID.
This endpoint allows authenticated users who are members of an organization
to retrieve its details. Only members of the organization can access this endpoint.
Args:
org_id: Organization ID (UUID)
user_id: Authenticated user ID (injected by dependency)
Returns:
OrgResponse: The organization details
Raises:
HTTPException: 422 if org_id is not a valid UUID (handled by FastAPI)
HTTPException: 404 if organization not found or user is not a member
HTTPException: 500 if retrieval fails
"""
logger.info(
'Retrieving organization details',
extra={
'user_id': user_id,
'org_id': str(org_id),
},
)
try:
# Use service layer to get organization with membership validation
org = await OrgService.get_org_by_id(
org_id=org_id,
user_id=user_id,
)
# Retrieve credits from LiteLLM
credits = await OrgService.get_org_credits(user_id, org.id)
return OrgResponse.from_org(org, credits=credits)
except OrgNotFoundError as e:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=str(e),
)
except Exception as e:
logger.exception(
'Unexpected error retrieving organization',
extra={'user_id': user_id, 'org_id': str(org_id), 'error': str(e)},
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail='An unexpected error occurred',
)
@org_router.delete('/{org_id}', status_code=status.HTTP_200_OK)
async def delete_org(
org_id: UUID,
user_id: str = Depends(get_admin_user_id),
) -> dict:
"""Delete an organization.
This endpoint allows authenticated organization owners to delete their organization.
All associated data including organization members, conversations, billing data,
and external LiteLLM team resources will be permanently removed.
Args:
org_id: Organization ID to delete
user_id: Authenticated user ID (injected by dependency)
Returns:
dict: Confirmation message with deleted organization details
Raises:
HTTPException: 403 if user is not the organization owner
HTTPException: 404 if organization not found
HTTPException: 500 if deletion fails
"""
logger.info(
'Organization deletion requested',
extra={
'user_id': user_id,
'org_id': str(org_id),
},
)
try:
# Use service layer to delete organization with cleanup
deleted_org = await OrgService.delete_org_with_cleanup(
user_id=user_id,
org_id=org_id,
)
logger.info(
'Organization deletion completed successfully',
extra={
'user_id': user_id,
'org_id': str(org_id),
'org_name': deleted_org.name,
},
)
return {
'message': 'Organization deleted successfully',
'organization': {
'id': str(deleted_org.id),
'name': deleted_org.name,
'contact_name': deleted_org.contact_name,
'contact_email': deleted_org.contact_email,
},
}
except OrgNotFoundError as e:
logger.warning(
'Organization not found for deletion',
extra={'user_id': user_id, 'org_id': str(org_id)},
)
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=str(e),
)
except OrgAuthorizationError as e:
logger.warning(
'User not authorized to delete organization',
extra={'user_id': user_id, 'org_id': str(org_id)},
)
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=str(e),
)
except OrgDatabaseError as e:
logger.error(
'Database error during organization deletion',
extra={'user_id': user_id, 'org_id': str(org_id), 'error': str(e)},
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail='Failed to delete organization',
)
except Exception as e:
logger.exception(
'Unexpected error during organization deletion',
extra={'user_id': user_id, 'org_id': str(org_id), 'error': str(e)},
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail='An unexpected error occurred',
)
@org_router.patch('/{org_id}', response_model=OrgResponse)
async def update_org(
org_id: UUID,
update_data: OrgUpdate,
user_id: str = Depends(get_user_id),
) -> OrgResponse:
"""Update an existing organization.
This endpoint allows authenticated users to update organization settings.
LLM-related settings require admin or owner role in the organization.
Args:
org_id: Organization ID to update (UUID validated by FastAPI)
update_data: Organization update data
user_id: Authenticated user ID (injected by dependency)
Returns:
OrgResponse: The updated organization details
Raises:
HTTPException: 400 if org_id is invalid UUID format (handled by FastAPI)
HTTPException: 403 if user lacks permission for LLM settings
HTTPException: 404 if organization not found
HTTPException: 422 if validation errors occur (handled by FastAPI)
HTTPException: 500 if update fails
"""
logger.info(
'Updating organization',
extra={
'user_id': user_id,
'org_id': str(org_id),
},
)
try:
# Use service layer to update organization with permission checks
updated_org = await OrgService.update_org_with_permissions(
org_id=org_id,
update_data=update_data,
user_id=user_id,
)
# Retrieve credits from LiteLLM (following same pattern as create endpoint)
credits = await OrgService.get_org_credits(user_id, updated_org.id)
return OrgResponse.from_org(updated_org, credits=credits)
except ValueError as e:
# Organization not found
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=str(e),
)
except PermissionError as e:
# User lacks permission for LLM settings
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=str(e),
)
except OrgDatabaseError as e:
logger.error(
'Database operation failed',
extra={'user_id': user_id, 'org_id': str(org_id), 'error': str(e)},
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail='Failed to update organization',
)
except Exception as e:
logger.exception(
'Unexpected error updating organization',
extra={'user_id': user_id, 'org_id': str(org_id), 'error': str(e)},
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail='An unexpected error occurred',
)
@@ -26,7 +26,6 @@ from server.sharing.shared_conversation_models import (
)
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from storage.stored_conversation_metadata_saas import StoredConversationMetadataSaas
from openhands.app_server.app_conversation.sql_app_conversation_info_service import (
StoredConversationMetadata,
@@ -58,7 +57,7 @@ class SQLSharedConversationInfoService(SharedConversationInfoService):
include_sub_conversations: bool = False,
) -> SharedConversationPage:
"""Search for shared conversations."""
query = self._public_select_with_saas_metadata()
query = self._public_select()
# Conditionally exclude sub-conversations based on the parameter
if not include_sub_conversations:
@@ -105,17 +104,14 @@ class SQLSharedConversationInfoService(SharedConversationInfoService):
query = query.limit(limit + 1)
result = await self.db_session.execute(query)
rows = result.all()
rows = result.scalars().all()
# Check if there are more results
has_more = len(rows) > limit
if has_more:
rows = rows[:limit]
items = [
self._to_shared_conversation(stored, saas_metadata=saas_metadata)
for stored, saas_metadata in rows
]
items = [self._to_shared_conversation(row) for row in rows]
# Calculate next page ID
next_page_id = None
@@ -156,18 +152,17 @@ class SQLSharedConversationInfoService(SharedConversationInfoService):
self, conversation_id: UUID
) -> SharedConversation | None:
"""Get a single public conversation info, returning None if missing or not shared."""
query = self._public_select_with_saas_metadata().where(
query = self._public_select().where(
StoredConversationMetadata.conversation_id == str(conversation_id)
)
result = await self.db_session.execute(query)
row = result.first()
stored = result.scalar_one_or_none()
if row is None:
if stored is None:
return None
stored, saas_metadata = row
return self._to_shared_conversation(stored, saas_metadata=saas_metadata)
return self._to_shared_conversation(stored)
def _public_select(self):
"""Create a select query that only returns public conversations."""
@@ -178,25 +173,6 @@ class SQLSharedConversationInfoService(SharedConversationInfoService):
query = query.where(StoredConversationMetadata.public == True) # noqa: E712
return query
def _public_select_with_saas_metadata(self):
"""Create a select query that returns public conversations with SAAS metadata.
This joins with conversation_metadata_saas to retrieve the user_id needed
for constructing the correct event storage path. Uses LEFT OUTER JOIN to
support conversations that may not have SAAS metadata (e.g., in tests).
"""
query = (
select(StoredConversationMetadata, StoredConversationMetadataSaas)
.outerjoin(
StoredConversationMetadataSaas,
StoredConversationMetadata.conversation_id
== StoredConversationMetadataSaas.conversation_id,
)
.where(StoredConversationMetadata.conversation_version == 'V1')
.where(StoredConversationMetadata.public == True) # noqa: E712
)
return query
def _apply_filters(
self,
query,
@@ -235,16 +211,9 @@ class SQLSharedConversationInfoService(SharedConversationInfoService):
def _to_shared_conversation(
self,
stored: StoredConversationMetadata,
saas_metadata: StoredConversationMetadataSaas | None = None,
sub_conversation_ids: list[UUID] | None = None,
) -> SharedConversation:
"""Convert StoredConversationMetadata to SharedConversation.
Args:
stored: The base conversation metadata from conversation_metadata table.
saas_metadata: Optional SAAS metadata containing user_id and org_id.
sub_conversation_ids: Optional list of sub-conversation IDs.
"""
"""Convert StoredConversationMetadata to SharedConversation."""
# V1 conversations should always have a sandbox_id
sandbox_id = stored.sandbox_id
assert sandbox_id is not None
@@ -270,16 +239,9 @@ class SQLSharedConversationInfoService(SharedConversationInfoService):
created_at = self._fix_timezone(stored.created_at)
updated_at = self._fix_timezone(stored.last_updated_at)
# Get user_id from SAAS metadata if available
created_by_user_id = (
str(saas_metadata.user_id)
if saas_metadata and saas_metadata.user_id
else None
)
return SharedConversation(
id=UUID(stored.conversation_id),
created_by_user_id=created_by_user_id,
created_by_user_id=None, # user_id is no longer stored in conversation metadata
sandbox_id=stored.sandbox_id,
selected_repository=stored.selected_repository,
selected_branch=stored.selected_branch,
-23
View File
@@ -98,29 +98,6 @@ def decrypt_legacy_value(value: str | SecretStr) -> str:
return get_fernet().decrypt(b64decode(value.encode())).decode()
def encrypt_legacy_model(encrypt_keys: list, model_instance) -> dict:
return encrypt_legacy_kwargs(encrypt_keys, model_to_kwargs(model_instance))
def encrypt_legacy_kwargs(encrypt_keys: list, kwargs: dict) -> dict:
for key, value in kwargs.items():
if value is None:
continue
if key in encrypt_keys:
value = encrypt_legacy_value(value)
kwargs[key] = value
return kwargs
def encrypt_legacy_value(value: str | SecretStr) -> str:
if isinstance(value, SecretStr):
return b64encode(
get_fernet().encrypt(value.get_secret_value().encode())
).decode()
else:
return b64encode(get_fernet().encrypt(value.encode())).decode()
def get_fernet():
global _fernet
if _fernet is None:
+1 -222
View File
@@ -96,7 +96,7 @@ class LiteLlmManager:
user_settings: UserSettings,
) -> UserSettings | None:
logger.info(
'LiteLlmManager:migrate_lite_llm_entries:start',
'SettingsStore:umigrate_lite_llm_entries:start',
extra={'org_id': org_id, 'user_id': keycloak_user_id},
)
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
@@ -141,35 +141,19 @@ class LiteLlmManager:
return None
credits = max(max_budget - spend, 0.0)
logger.debug(
'LiteLlmManager:migrate_lite_llm_entries:create_team',
extra={'org_id': org_id, 'user_id': keycloak_user_id},
)
await LiteLlmManager._create_team(
client, keycloak_user_id, org_id, credits
)
logger.debug(
'LiteLlmManager:migrate_lite_llm_entries:update_user',
extra={'org_id': org_id, 'user_id': keycloak_user_id},
)
await LiteLlmManager._update_user(
client, keycloak_user_id, max_budget=UNLIMITED_BUDGET_SETTING
)
logger.debug(
'LiteLlmManager:migrate_lite_llm_entries:add_user_to_team',
extra={'org_id': org_id, 'user_id': keycloak_user_id},
)
await LiteLlmManager._add_user_to_team(
client, keycloak_user_id, org_id, credits
)
if user_settings.llm_api_key:
logger.debug(
'LiteLlmManager:migrate_lite_llm_entries:update_key',
extra={'org_id': org_id, 'user_id': keycloak_user_id},
)
await LiteLlmManager._update_key(
client,
keycloak_user_id,
@@ -178,10 +162,6 @@ class LiteLlmManager:
)
if user_settings.llm_api_key_for_byor:
logger.debug(
'LiteLlmManager:migrate_lite_llm_entries:update_byor_key',
extra={'org_id': org_id, 'user_id': keycloak_user_id},
)
await LiteLlmManager._update_key(
client,
keycloak_user_id,
@@ -189,167 +169,6 @@ class LiteLlmManager:
team_id=org_id,
)
logger.info(
'LiteLlmManager:migrate_lite_llm_entries:complete',
extra={'org_id': org_id, 'user_id': keycloak_user_id},
)
return user_settings
@staticmethod
async def downgrade_entries(
org_id: str,
keycloak_user_id: str,
user_settings: UserSettings,
) -> UserSettings | None:
"""Downgrade a migrated user's LiteLLM entries back to the pre-migration state.
This reverses the migrate_entries operation:
1. Get the user max budget from their org team in litellm
2. Set the max budget in the user in litellm (restore from team)
3. Add the user back to the default team in litellm
4. Update keys to remove org team association
5. Remove the user from their org team in litellm
6. Delete the user org team in litellm
Note: The database changes (already_migrated flag, org/org_member deletion)
should be handled separately by the caller.
Args:
org_id: The organization ID (which is also the team_id in litellm)
keycloak_user_id: The user's Keycloak ID
user_settings: The user's settings object
Returns:
The user_settings if downgrade was successful, None otherwise
"""
logger.info(
'LiteLlmManager:downgrade_entries:start',
extra={'org_id': org_id, 'user_id': keycloak_user_id},
)
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
logger.warning('LiteLLM API configuration not found')
return None
local_deploy = os.environ.get('LOCAL_DEPLOYMENT', None)
if not local_deploy:
async with httpx.AsyncClient(
headers={
'x-goog-api-key': LITE_LLM_API_KEY,
}
) as client:
# Step 1: Get the team info to retrieve the budget
logger.debug(
'LiteLlmManager:downgrade_entries:get_team',
extra={'org_id': org_id, 'user_id': keycloak_user_id},
)
team_info = await LiteLlmManager._get_team(client, org_id)
if not team_info:
logger.error(
'LiteLlmManager:downgrade_entries:team_not_found',
extra={'org_id': org_id, 'user_id': keycloak_user_id},
)
return None
# Get team budget (max_budget) and spend to calculate current credits
team_data = team_info.get('team_info', {})
max_budget = team_data.get('max_budget', 0.0)
spend = team_data.get('spend', 0.0)
# Get user membership info for budget in team
user_membership = await LiteLlmManager._get_user_team_info(
client, keycloak_user_id, org_id
)
if user_membership:
# Use user's budget in team if available
user_max_budget_in_team = user_membership.get('max_budget_in_team')
user_spend_in_team = user_membership.get('spend', 0.0)
if user_max_budget_in_team is not None:
max_budget = user_max_budget_in_team
spend = user_spend_in_team
# Calculate total budget to restore (credits + spend = max_budget)
# We restore the full max_budget that was on the team/user-in-team
restored_budget = max_budget if max_budget else 0.0
logger.debug(
'LiteLlmManager:downgrade_entries:budget_info',
extra={
'org_id': org_id,
'user_id': keycloak_user_id,
'max_budget': max_budget,
'spend': spend,
'restored_budget': restored_budget,
},
)
# Step 2: Update user to set their max_budget back from unlimited
logger.debug(
'LiteLlmManager:downgrade_entries:update_user',
extra={'org_id': org_id, 'user_id': keycloak_user_id},
)
await LiteLlmManager._update_user(
client, keycloak_user_id, max_budget=restored_budget, spend=spend
)
# Step 3: Add user back to the default team
if LITE_LLM_TEAM_ID:
logger.debug(
'LiteLlmManager:downgrade_entries:add_to_default_team',
extra={
'org_id': org_id,
'user_id': keycloak_user_id,
'default_team_id': LITE_LLM_TEAM_ID,
},
)
await LiteLlmManager._add_user_to_team(
client, keycloak_user_id, LITE_LLM_TEAM_ID, restored_budget
)
# Step 4: Update keys to remove org team association (set team_id to default)
if user_settings.llm_api_key:
logger.debug(
'LiteLlmManager:downgrade_entries:update_key',
extra={'org_id': org_id, 'user_id': keycloak_user_id},
)
await LiteLlmManager._update_key(
client,
keycloak_user_id,
user_settings.llm_api_key,
team_id=LITE_LLM_TEAM_ID,
)
if user_settings.llm_api_key_for_byor:
logger.debug(
'LiteLlmManager:downgrade_entries:update_byor_key',
extra={'org_id': org_id, 'user_id': keycloak_user_id},
)
await LiteLlmManager._update_key(
client,
keycloak_user_id,
user_settings.llm_api_key_for_byor,
team_id=LITE_LLM_TEAM_ID,
)
# Step 5: Remove user from their org team
logger.debug(
'LiteLlmManager:downgrade_entries:remove_from_org_team',
extra={'org_id': org_id, 'user_id': keycloak_user_id},
)
await LiteLlmManager._remove_user_from_team(
client, keycloak_user_id, org_id
)
# Step 6: Delete the org team
logger.debug(
'LiteLlmManager:downgrade_entries:delete_team',
extra={'org_id': org_id, 'user_id': keycloak_user_id},
)
await LiteLlmManager._delete_team(client, org_id)
logger.info(
'LiteLlmManager:downgrade_entries:complete',
extra={'org_id': org_id, 'user_id': keycloak_user_id},
)
return user_settings
@staticmethod
@@ -794,45 +613,6 @@ class LiteLlmManager:
)
response.raise_for_status()
@staticmethod
async def _remove_user_from_team(
client: httpx.AsyncClient,
keycloak_user_id: str,
team_id: str,
):
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
logger.warning('LiteLLM API configuration not found')
return
response = await client.post(
f'{LITE_LLM_API_URL}/team/member_delete',
json={
'team_id': team_id,
'user_id': keycloak_user_id,
},
)
if not response.is_success:
if response.status_code == 404:
# User not in team, that's fine for downgrade
logger.info(
'User not in team during removal',
extra={'user_id': keycloak_user_id, 'team_id': team_id},
)
return
logger.error(
'error_removing_litellm_user_from_team',
extra={
'status_code': response.status_code,
'text': response.text,
'user_id': keycloak_user_id,
'team_id': team_id,
},
)
response.raise_for_status()
logger.info(
'LiteLlmManager:_remove_user_from_team:user_removed',
extra={'user_id': keycloak_user_id, 'team_id': team_id},
)
@staticmethod
async def _generate_key(
client: httpx.AsyncClient,
@@ -1076,7 +856,6 @@ class LiteLlmManager:
delete_user = staticmethod(with_http_client(_delete_user))
delete_team = staticmethod(with_http_client(_delete_team))
add_user_to_team = staticmethod(with_http_client(_add_user_to_team))
remove_user_from_team = staticmethod(with_http_client(_remove_user_from_team))
get_user_team_info = staticmethod(with_http_client(_get_user_team_info))
update_user_in_team = staticmethod(with_http_client(_update_user_in_team))
generate_key = staticmethod(with_http_client(_generate_key))
-401
View File
@@ -9,11 +9,8 @@ from uuid import UUID as parse_uuid
from server.constants import ORG_SETTINGS_VERSION, get_default_litellm_model
from server.routes.org_models import (
LiteLLMIntegrationError,
OrgAuthorizationError,
OrgDatabaseError,
OrgNameExistsError,
OrgNotFoundError,
OrgUpdate,
)
from storage.lite_llm_manager import LiteLlmManager
from storage.org import Org
@@ -396,224 +393,6 @@ class OrgService:
)
return e
@staticmethod
def has_admin_or_owner_role(user_id: str, org_id: UUID) -> bool:
"""
Check if user has admin or owner role in the specified organization.
Args:
user_id: User ID to check
org_id: Organization ID to check membership in
Returns:
bool: True if user has admin or owner role, False otherwise
"""
try:
# Parse user_id as UUID for database query
user_uuid = parse_uuid(user_id)
# Get the user's membership in this organization
# Note: The type annotation says int but the actual column is UUID
org_member = OrgMemberStore.get_org_member(org_id, user_uuid)
if not org_member:
return False
# Get the role details
role = RoleStore.get_role_by_id(org_member.role_id)
if not role:
return False
# Admin and owner roles have elevated permissions
# Based on test files, both admin and owner have rank 1
return role.name in ['admin', 'owner']
except Exception as e:
logger.warning(
'Error checking user role in organization',
extra={
'user_id': user_id,
'org_id': str(org_id),
'error': str(e),
},
)
return False
@staticmethod
def is_org_member(user_id: str, org_id: UUID) -> bool:
"""
Check if user is a member of the specified organization.
Args:
user_id: User ID to check
org_id: Organization ID to check membership in
Returns:
bool: True if user is a member, False otherwise
"""
try:
user_uuid = parse_uuid(user_id)
org_member = OrgMemberStore.get_org_member(org_id, user_uuid)
return org_member is not None
except Exception as e:
logger.warning(
'Error checking user membership in organization',
extra={
'user_id': user_id,
'org_id': str(org_id),
'error': str(e),
},
)
return False
@staticmethod
def _get_llm_settings_fields() -> set[str]:
"""
Get the set of organization fields that are considered LLM settings
and require admin/owner role to update.
Returns:
set[str]: Set of field names that require elevated permissions
"""
return {
'default_llm_model',
'default_llm_api_key_for_byor',
'default_llm_base_url',
'search_api_key',
'security_analyzer',
'agent',
'confirmation_mode',
'enable_default_condenser',
'condenser_max_size',
}
@staticmethod
def _has_llm_settings_updates(update_data: OrgUpdate) -> set[str]:
"""
Check if the update contains any LLM settings fields.
Args:
update_data: The organization update data
Returns:
set[str]: Set of LLM fields being updated (empty if none)
"""
llm_fields = OrgService._get_llm_settings_fields()
update_dict = update_data.model_dump(exclude_none=True)
return llm_fields.intersection(update_dict.keys())
@staticmethod
async def update_org_with_permissions(
org_id: UUID,
update_data: OrgUpdate,
user_id: str,
) -> Org:
"""
Update organization with permission checks for LLM settings.
Args:
org_id: Organization UUID to update
update_data: Organization update data from request
user_id: ID of the user requesting the update
Returns:
Org: The updated organization object
Raises:
ValueError: If organization not found
PermissionError: If user is not a member, or lacks admin/owner role for LLM settings
OrgDatabaseError: If database update fails
"""
logger.info(
'Updating organization with permission checks',
extra={
'org_id': str(org_id),
'user_id': user_id,
'has_update_data': update_data is not None,
},
)
# Validate organization exists
existing_org = OrgStore.get_org_by_id(org_id)
if not existing_org:
raise ValueError(f'Organization with ID {org_id} not found')
# Check if user is a member of this organization
if not OrgService.is_org_member(user_id, org_id):
logger.warning(
'Non-member attempted to update organization',
extra={
'user_id': user_id,
'org_id': str(org_id),
},
)
raise PermissionError(
'User must be a member of the organization to update it'
)
# Check if update contains any LLM settings
llm_fields_being_updated = OrgService._has_llm_settings_updates(update_data)
if llm_fields_being_updated:
# Verify user has admin or owner role
has_permission = OrgService.has_admin_or_owner_role(user_id, org_id)
if not has_permission:
logger.warning(
'User attempted to update LLM settings without permission',
extra={
'user_id': user_id,
'org_id': str(org_id),
'attempted_fields': list(llm_fields_being_updated),
},
)
raise PermissionError(
'Admin or owner role required to update LLM settings'
)
logger.debug(
'User has permission to update LLM settings',
extra={
'user_id': user_id,
'org_id': str(org_id),
'llm_fields': list(llm_fields_being_updated),
},
)
# Convert to dict for OrgStore (excluding None values)
update_dict = update_data.model_dump(exclude_none=True)
if not update_dict:
logger.info(
'No fields to update',
extra={'org_id': str(org_id), 'user_id': user_id},
)
return existing_org
# Perform the update
try:
updated_org = OrgStore.update_org(org_id, update_dict)
if not updated_org:
raise OrgDatabaseError('Failed to update organization in database')
logger.info(
'Organization updated successfully',
extra={
'org_id': str(org_id),
'user_id': user_id,
'updated_fields': list(update_dict.keys()),
},
)
return updated_org
except Exception as e:
logger.error(
'Failed to update organization',
extra={
'org_id': str(org_id),
'user_id': user_id,
'error': str(e),
},
)
raise OrgDatabaseError(f'Failed to update organization: {str(e)}')
@staticmethod
async def get_org_credits(user_id: str, org_id: UUID) -> float | None:
"""
@@ -662,183 +441,3 @@ class OrgService:
extra={'user_id': user_id, 'org_id': str(org_id), 'error': str(e)},
)
return None
@staticmethod
def get_user_orgs_paginated(
user_id: str, page_id: str | None = None, limit: int = 100
):
"""
Get paginated list of organizations for a user.
Args:
user_id: User ID (string that will be converted to UUID)
page_id: Optional page ID (offset as string) for pagination
limit: Maximum number of organizations to return
Returns:
Tuple of (list of Org objects, next_page_id or None)
"""
logger.debug(
'Fetching paginated organizations for user',
extra={'user_id': user_id, 'page_id': page_id, 'limit': limit},
)
# Convert user_id string to UUID
user_uuid = parse_uuid(user_id)
# Fetch organizations from store
orgs, next_page_id = OrgStore.get_user_orgs_paginated(
user_id=user_uuid, page_id=page_id, limit=limit
)
logger.debug(
'Retrieved organizations for user',
extra={
'user_id': user_id,
'org_count': len(orgs),
'has_more': next_page_id is not None,
},
)
return orgs, next_page_id
@staticmethod
async def get_org_by_id(org_id: UUID, user_id: str) -> Org:
"""
Get organization by ID with membership validation.
This method verifies that the user is a member of the organization
before returning the organization details.
Args:
org_id: Organization ID
user_id: User ID (string that will be converted to UUID)
Returns:
Org: The organization object
Raises:
OrgNotFoundError: If organization not found or user is not a member
"""
logger.info(
'Retrieving organization',
extra={'user_id': user_id, 'org_id': str(org_id)},
)
# Verify user is a member of the organization
org_member = OrgMemberStore.get_org_member(org_id, parse_uuid(user_id))
if not org_member:
logger.warning(
'User is not a member of organization or organization does not exist',
extra={'user_id': user_id, 'org_id': str(org_id)},
)
raise OrgNotFoundError(str(org_id))
# Retrieve organization
org = OrgStore.get_org_by_id(org_id)
if not org:
logger.error(
'Organization not found despite valid membership',
extra={'user_id': user_id, 'org_id': str(org_id)},
)
raise OrgNotFoundError(str(org_id))
logger.info(
'Successfully retrieved organization',
extra={
'org_id': str(org.id),
'org_name': org.name,
'user_id': user_id,
},
)
return org
@staticmethod
def verify_owner_authorization(user_id: str, org_id: UUID) -> None:
"""
Verify that the user is the owner of the organization.
Args:
user_id: User ID to check
org_id: Organization ID
Raises:
OrgNotFoundError: If organization doesn't exist
OrgAuthorizationError: If user is not authorized to delete
"""
# Check if organization exists
org = OrgStore.get_org_by_id(org_id)
if not org:
raise OrgNotFoundError(str(org_id))
# Check if user is a member of the organization
org_member = OrgMemberStore.get_org_member(org_id, parse_uuid(user_id))
if not org_member:
raise OrgAuthorizationError('User is not a member of this organization')
# Check if user has owner role
role = RoleStore.get_role_by_id(org_member.role_id)
if not role or role.name != 'owner':
raise OrgAuthorizationError(
'Only organization owners can delete organizations'
)
logger.debug(
'User authorization verified for organization deletion',
extra={'user_id': user_id, 'org_id': str(org_id), 'role': role.name},
)
@staticmethod
async def delete_org_with_cleanup(user_id: str, org_id: UUID) -> Org:
"""
Delete organization with complete cleanup of all associated data.
This method performs the complete organization deletion workflow:
1. Verifies user authorization (owner only)
2. Performs database cascade deletion and LiteLLM cleanup in single transaction
Args:
user_id: User ID requesting deletion (must be owner)
org_id: Organization ID to delete
Returns:
Org: The deleted organization details
Raises:
OrgNotFoundError: If organization doesn't exist
OrgAuthorizationError: If user is not authorized to delete
OrgDatabaseError: If database operations or LiteLLM cleanup fail
"""
logger.info(
'Starting organization deletion',
extra={'user_id': user_id, 'org_id': str(org_id)},
)
# Step 1: Verify user authorization
OrgService.verify_owner_authorization(user_id, org_id)
# Step 2: Perform database cascade deletion with LiteLLM cleanup in transaction
try:
deleted_org = await OrgStore.delete_org_cascade(org_id)
if not deleted_org:
# This shouldn't happen since we verified existence above
raise OrgDatabaseError('Organization not found during deletion')
logger.info(
'Organization deletion completed successfully',
extra={
'user_id': user_id,
'org_id': str(org_id),
'org_name': deleted_org.name,
},
)
return deleted_org
except Exception as e:
logger.error(
'Organization deletion failed',
extra={'user_id': user_id, 'org_id': str(org_id), 'error': str(e)},
)
raise OrgDatabaseError(f'Failed to delete organization: {str(e)}')
-175
View File
@@ -10,10 +10,8 @@ from server.constants import (
ORG_SETTINGS_VERSION,
get_default_litellm_model,
)
from sqlalchemy import text
from sqlalchemy.orm import joinedload
from storage.database import session_maker
from storage.lite_llm_manager import LiteLlmManager
from storage.org import Org
from storage.org_member import OrgMember
from storage.user import User
@@ -98,63 +96,6 @@ class OrgStore:
orgs = session.query(Org).all()
return orgs
@staticmethod
def get_user_orgs_paginated(
user_id: UUID, page_id: str | None = None, limit: int = 100
) -> tuple[list[Org], str | None]:
"""
Get paginated list of organizations for a user.
Args:
user_id: User UUID
page_id: Optional page ID (offset as string) for pagination
limit: Maximum number of organizations to return
Returns:
Tuple of (list of Org objects, next_page_id or None)
"""
with session_maker() as session:
# Build query joining OrgMember with Org
query = (
session.query(Org)
.join(OrgMember, Org.id == OrgMember.org_id)
.filter(OrgMember.user_id == user_id)
.order_by(Org.name)
)
# Apply pagination offset
if page_id is not None:
try:
offset = int(page_id)
query = query.offset(offset)
except ValueError:
# If page_id is not a valid integer, start from beginning
offset = 0
else:
offset = 0
# Fetch limit + 1 to check if there are more results
query = query.limit(limit + 1)
orgs = query.all()
# Check if there are more results
has_more = len(orgs) > limit
if has_more:
orgs = orgs[:limit]
# Calculate next page ID
next_page_id = None
if has_more:
next_page_id = str(offset + limit)
# Validate org versions
validated_orgs = [
OrgStore._validate_org_version(org) for org in orgs if org
]
validated_orgs = [org for org in validated_orgs if org is not None]
return validated_orgs, next_page_id
@staticmethod
def update_org(
org_id: UUID,
@@ -245,119 +186,3 @@ class OrgStore:
session.commit()
session.refresh(org)
return org
@staticmethod
async def delete_org_cascade(org_id: UUID) -> Org | None:
"""
Delete organization and all associated data in cascade, including external LiteLLM cleanup.
Args:
org_id: UUID of the organization to delete
Returns:
Org: The deleted organization object, or None if not found
Raises:
Exception: If database operations or LiteLLM cleanup fail
"""
with session_maker() as session:
# First get the organization to return it
org = session.query(Org).filter(Org.id == org_id).first()
if not org:
return None
try:
# 1. Delete conversation data for organization conversations
session.execute(
text("""
DELETE FROM conversation_metadata
WHERE conversation_id IN (
SELECT conversation_id FROM conversation_metadata_saas WHERE org_id = :org_id
)
"""),
{'org_id': str(org_id)},
)
session.execute(
text("""
DELETE FROM app_conversation_start_task
WHERE app_conversation_id::text IN (
SELECT conversation_id FROM conversation_metadata_saas WHERE org_id = :org_id
)
"""),
{'org_id': str(org_id)},
)
# 2. Delete organization-owned data tables (direct org_id foreign keys)
session.execute(
text('DELETE FROM billing_sessions WHERE org_id = :org_id'),
{'org_id': str(org_id)},
)
session.execute(
text(
'DELETE FROM conversation_metadata_saas WHERE org_id = :org_id'
),
{'org_id': str(org_id)},
)
session.execute(
text('DELETE FROM custom_secrets WHERE org_id = :org_id'),
{'org_id': str(org_id)},
)
session.execute(
text('DELETE FROM api_keys WHERE org_id = :org_id'),
{'org_id': str(org_id)},
)
session.execute(
text('DELETE FROM slack_conversation WHERE org_id = :org_id'),
{'org_id': str(org_id)},
)
session.execute(
text('DELETE FROM slack_users WHERE org_id = :org_id'),
{'org_id': str(org_id)},
)
session.execute(
text('DELETE FROM stripe_customers WHERE org_id = :org_id'),
{'org_id': str(org_id)},
)
# 3. Delete organization memberships
session.execute(
text('DELETE FROM org_member WHERE org_id = :org_id'),
{'org_id': str(org_id)},
)
# 4. Handle users with this as current_org_id
session.execute(
text(
'UPDATE "user" SET current_org_id = NULL WHERE current_org_id = :org_id'
),
{'org_id': str(org_id)},
)
# 5. Finally delete the organization
session.delete(org)
# 6. Clean up LiteLLM team before committing transaction
logger.info(
'Deleting LiteLLM team within database transaction',
extra={'org_id': str(org_id)},
)
await LiteLlmManager.delete_team(str(org_id))
# 7. Commit all changes only if everything succeeded
session.commit()
logger.info(
'Successfully deleted organization and all associated data including LiteLLM team',
extra={'org_id': str(org_id), 'org_name': org.name},
)
return org
except Exception as e:
session.rollback()
logger.error(
'Failed to delete organization - transaction rolled back',
extra={'org_id': str(org_id), 'error': str(e)},
)
raise
+1 -17
View File
@@ -4,9 +4,7 @@ Store class for managing roles.
from typing import List, Optional
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from storage.database import a_session_maker, session_maker
from storage.database import session_maker
from storage.role import Role
@@ -35,20 +33,6 @@ class RoleStore:
with session_maker() as session:
return session.query(Role).filter(Role.name == name).first()
@staticmethod
async def get_role_by_name_async(
name: str,
session: Optional[AsyncSession] = None,
) -> Optional[Role]:
"""Get role by name."""
if session is not None:
result = await session.execute(select(Role).where(Role.name == name))
return result.scalars().first()
async with a_session_maker() as session:
result = await session.execute(select(Role).where(Role.name == name))
return result.scalars().first()
@staticmethod
def list_roles() -> List[Role]:
"""List all roles."""
-2
View File
@@ -31,8 +31,6 @@ class User(Base): # type: ignore
user_consents_to_analytics = Column(Boolean, nullable=True)
email = Column(String, nullable=True)
email_verified = Column(Boolean, nullable=True)
git_user_name = Column(String, nullable=True)
git_user_email = Column(String, nullable=True)
# Relationships
role = relationship('Role', back_populates='users')
+19 -400
View File
@@ -14,13 +14,10 @@ from server.constants import (
get_default_litellm_model,
)
from server.logger import logger
from sqlalchemy import select, text
from sqlalchemy import text
from sqlalchemy.orm import joinedload
from storage.database import a_session_maker, session_maker
from storage.encrypt_utils import (
decrypt_legacy_model,
encrypt_legacy_value,
)
from storage.database import session_maker
from storage.encrypt_utils import decrypt_legacy_model
from storage.org import Org
from storage.org_member import OrgMember
from storage.role_store import RoleStore
@@ -119,7 +116,7 @@ class UserStore:
redis_client = UserStore._get_redis_client()
if redis_client is None:
logger.warning(
'user_store:_acquire_user_creation_lock:no_redis_client',
'saas_settings_store:_acquire_user_creation_lock:no_redis_client',
extra={'user_id': user_id},
)
return True # Proceed without locking if Redis is unavailable
@@ -162,20 +159,12 @@ class UserStore:
from storage.lite_llm_manager import LiteLlmManager
logger.debug(
'user_store:migrate_user:calling_litellm_migrate_entries',
extra={'user_id': user_id},
)
await LiteLlmManager.migrate_entries(
str(org.id),
user_id,
decrypted_user_settings,
)
logger.debug(
'user_store:migrate_user:done_litellm_migrate_entries',
extra={'user_id': user_id},
)
custom_settings = UserStore._has_custom_settings(
decrypted_user_settings, user_settings.user_version
)
@@ -183,15 +172,7 @@ class UserStore:
# avoids circular reference. This migrate method is temprorary until all users are migrated.
from integrations.stripe_service import migrate_customer
logger.debug(
'user_store:migrate_user:calling_stripe_migrate_customer',
extra={'user_id': user_id},
)
await migrate_customer(session, user_id, org)
logger.debug(
'user_store:migrate_user:done_stripe_migrate_customer',
extra={'user_id': user_id},
)
from storage.org_store import OrgStore
@@ -220,15 +201,7 @@ class UserStore:
)
session.add(user)
logger.debug(
'user_store:migrate_user:calling_get_role_by_name',
extra={'user_id': user_id},
)
role = await RoleStore.get_role_by_name_async('owner')
logger.debug(
'user_store:migrate_user:done_get_role_by_name',
extra={'user_id': user_id},
)
role = RoleStore.get_role_by_name('owner')
from storage.org_member_store import OrgMemberStore
@@ -241,6 +214,7 @@ class UserStore:
if not custom_settings:
del org_member_kwargs['llm_model']
del org_member_kwargs['llm_base_url']
del org_member_kwargs['llm_api_key_for_byor']
org_member = OrgMember(
org_id=org.id,
@@ -255,10 +229,6 @@ class UserStore:
user_settings.already_migrated = True
session.merge(user_settings)
session.flush()
logger.debug(
'user_store:migrate_user:session_flush_complete',
extra={'user_id': user_id},
)
# need to migrate conversation metadata
session.execute(
@@ -326,262 +296,8 @@ class UserStore:
session.commit()
session.refresh(user)
user.org_members # load org_members
logger.debug(
'user_store:migrate_user:session_committed',
extra={'user_id': user_id},
)
return user
@staticmethod
async def downgrade_user(user_id: str) -> UserSettings | None:
"""Downgrade a migrated user back to the pre-migration state.
This reverses the migrate_user operation:
1. Get the user's settings from user_settings table (migrated users) or
create new user_settings from org_members table (new sign-ups)
2. Call LiteLlmManager.downgrade_entries to revert LiteLLM state
3. Copy user_id from conversation_metadata_saas to conversation_metadata
4. Delete conversation_metadata_saas entries
5. Reset org_id columns in related tables (stripe_customers, slack_users, etc.)
6. Delete the org_member and org entries
7. Delete the user entry
8. Set already_migrated=False on user_settings
For new sign-ups (users who registered after migration was deployed),
there won't be an existing user_settings entry. In this case, we fall back
to the org_members table to get the user's API keys and settings, and create
a new user_settings entry for them.
Args:
user_id: The Keycloak user ID to downgrade
Returns:
The user_settings if downgrade was successful, None otherwise.
Returns None if the org has multiple members (not a personal org).
"""
logger.info(
'user_store:downgrade_user:start',
extra={'user_id': user_id},
)
with session_maker() as session:
# Get the user and their org_member
user = (
session.query(User)
.options(joinedload(User.org_members))
.filter(User.id == uuid.UUID(user_id))
.first()
)
if not user:
logger.warning(
'user_store:downgrade_user:user_not_found',
extra={'user_id': user_id},
)
return None
# Get the user's personal org (org_id == user_id)
org = session.query(Org).filter(Org.id == uuid.UUID(user_id)).first()
if not org:
logger.warning(
'user_store:downgrade_user:org_not_found',
extra={'user_id': user_id},
)
return None
# Get the user_settings (for migrated users)
user_settings = (
session.query(UserSettings)
.filter(
UserSettings.keycloak_user_id == user_id,
UserSettings.already_migrated.is_(True),
)
.first()
)
# For new sign-ups after migration, user_settings won't exist
# Fall back to getting data from org_members
is_new_signup = False
if not user_settings:
logger.info(
'user_store:downgrade_user:user_settings_not_found_checking_org_members',
extra={'user_id': user_id},
)
# Get org_members for this org - should only be one for personal orgs
org_members = (
session.query(OrgMember).filter(OrgMember.org_id == org.id).all()
)
if len(org_members) != 1:
logger.error(
'user_store:downgrade_user:unexpected_org_members_count',
extra={
'user_id': user_id,
'org_id': str(org.id),
'org_members_count': len(org_members),
},
)
return None
org_member = org_members[0]
is_new_signup = True
# Create a new user_settings entry from OrgMember, User, and Org data
# This is needed for new sign-ups who don't have user_settings
user_settings = UserStore._create_user_settings_from_entities(
user_id, org_member, user, org
)
session.add(user_settings)
session.flush()
logger.info(
'user_store:downgrade_user:created_user_settings_from_org_member',
extra={'user_id': user_id},
)
# Call LiteLLM downgrade
from storage.lite_llm_manager import LiteLlmManager
logger.debug(
'user_store:downgrade_user:calling_litellm_downgrade_entries',
extra={'user_id': user_id},
)
# Get the API keys for LiteLLM downgrade
if is_new_signup:
# For new signups, we already have decrypted values in user_settings
decrypted_user_settings = user_settings
else:
# For migrated users, decrypt the legacy model
kwargs = decrypt_legacy_model(
[
'llm_api_key',
'llm_api_key_for_byor',
'search_api_key',
'sandbox_api_key',
],
user_settings,
)
decrypted_user_settings = UserSettings(**kwargs)
await LiteLlmManager.downgrade_entries(
str(org.id),
user_id,
decrypted_user_settings,
)
logger.debug(
'user_store:downgrade_user:done_litellm_downgrade_entries',
extra={'user_id': user_id},
)
user_uuid = uuid.UUID(user_id)
# Step 3: Copy user_id from conversation_metadata_saas to conversation_metadata
# This ensures any conversations created after migration have their user_id
# preserved in the original table before we delete the saas entries
session.execute(
text("""
UPDATE conversation_metadata
SET user_id = :user_id
WHERE conversation_id IN (
SELECT conversation_id
FROM conversation_metadata_saas
WHERE user_id = :user_uuid
)
"""),
{'user_id': user_id, 'user_uuid': user_uuid},
)
# Step 4: Delete conversation_metadata_saas entries
session.execute(
text('DELETE FROM conversation_metadata_saas WHERE user_id = :user_id'),
{'user_id': user_uuid},
)
# Step 5: Reset org_id columns in related tables
# Reset stripe_customers
session.execute(
text(
'UPDATE stripe_customers SET org_id = NULL WHERE org_id = :org_id'
),
{'org_id': user_uuid},
)
# Reset slack_users
session.execute(
text('UPDATE slack_users SET org_id = NULL WHERE org_id = :org_id'),
{'org_id': user_uuid},
)
# Reset slack_conversation
session.execute(
text(
'UPDATE slack_conversation SET org_id = NULL WHERE org_id = :org_id'
),
{'org_id': user_uuid},
)
# Reset api_keys
session.execute(
text('UPDATE api_keys SET org_id = NULL WHERE org_id = :org_id'),
{'org_id': user_uuid},
)
# Reset custom_secrets
session.execute(
text('UPDATE custom_secrets SET org_id = NULL WHERE org_id = :org_id'),
{'org_id': user_uuid},
)
# Reset billing_sessions
session.execute(
text(
'UPDATE billing_sessions SET org_id = NULL WHERE org_id = :org_id'
),
{'org_id': user_uuid},
)
# Step 6: Delete org_member entries for this org
session.execute(
text('DELETE FROM org_member WHERE org_id = :org_id'),
{'org_id': user_uuid},
)
# Step 7: Delete the user entry
session.execute(
text('DELETE FROM "user" WHERE id = :user_id'),
{'user_id': user_uuid},
)
# Delete the org entry
session.execute(
text('DELETE FROM org WHERE id = :org_id'),
{'org_id': user_uuid},
)
# Step 8: Set already_migrated=False on user_settings and encrypt fields
user_settings.already_migrated = False
# Re-encrypt the sensitive fields before storing in the DB
encrypt_keys = [
'llm_api_key',
'llm_api_key_for_byor',
'search_api_key',
'sandbox_api_key',
]
for key in encrypt_keys:
value = getattr(user_settings, key, None)
if value is not None:
setattr(user_settings, key, encrypt_legacy_value(value))
session.merge(user_settings)
session.commit()
logger.info(
'user_store:downgrade_user:complete',
extra={'user_id': user_id},
)
return user_settings
@staticmethod
def get_user_by_id(user_id: str) -> Optional[User]:
"""Get user by Keycloak user ID (sync version).
@@ -606,7 +322,7 @@ class UserStore:
):
# The user is already being created in another thread / process
logger.info(
'user_store:create_default_settings:waiting_for_lock',
'saas_settings_store:create_default_settings:waiting_for_lock',
extra={'user_id': user_id},
)
call_async_from_sync(
@@ -656,13 +372,13 @@ class UserStore:
This is the preferred method when calling from an async context as it
avoids event loop conflicts that can occur with the sync version.
"""
async with a_session_maker() as session:
result = await session.execute(
select(User)
with session_maker() as session:
user = (
session.query(User)
.options(joinedload(User.org_members))
.filter(User.id == uuid.UUID(user_id))
.first()
)
user = result.scalars().first()
if user:
return user
@@ -670,39 +386,32 @@ class UserStore:
while not await UserStore._acquire_user_creation_lock(user_id):
# The user is already being created in another thread / process
logger.info(
'user_store:get_user_by_id_async:waiting_for_lock',
'saas_settings_store:create_default_settings:waiting_for_lock',
extra={'user_id': user_id},
)
await asyncio.sleep(_RETRY_LOAD_DELAY_SECONDS)
# Check for user again as migration could have happened while trying to get the lock.
result = await session.execute(
select(User)
user = (
session.query(User)
.options(joinedload(User.org_members))
.filter(User.id == uuid.UUID(user_id))
.first()
)
user = result.scalars().first()
if user:
return user
logger.info(
'user_store:get_user_by_id_async:start_migration',
extra={'user_id': user_id},
)
result = await session.execute(
select(UserSettings).filter(
user_settings = (
session.query(UserSettings)
.filter(
UserSettings.keycloak_user_id == user_id,
UserSettings.already_migrated.is_(False),
)
.first()
)
user_settings = result.scalars().first()
if user_settings:
token_manager = TokenManager()
user_info = await token_manager.get_user_info_from_user_id(user_id)
logger.info(
'user_store:get_user_by_id_async:calling_migrate_user',
extra={'user_id': user_id},
)
user = await UserStore.migrate_user(
user_id,
user_settings,
@@ -772,96 +481,6 @@ class UserStore:
}
return kwargs
@staticmethod
def _create_user_settings_from_entities(
user_id: str, org_member: OrgMember, user: User, org: Org
) -> UserSettings:
"""Create UserSettings from OrgMember, User, and Org data.
Uses OrgMember values first. If an OrgMember field is None and there's
a corresponding "default_" field in Org, use the Org value.
Also pulls relevant fields from User.
Args:
user_id: The Keycloak user ID
org_member: The OrgMember entity
user: The User entity
org: The Org entity
Returns:
A new UserSettings object populated from the entities
"""
# Mapping from OrgMember fields to corresponding Org "default_" fields
org_member_to_org_default = {
'llm_model': 'default_llm_model',
'llm_base_url': 'default_llm_base_url',
'max_iterations': 'default_max_iterations',
}
def get_value_with_org_fallback(field_name: str, org_member_value):
"""Get value from OrgMember, falling back to Org default if None."""
if org_member_value is not None:
return org_member_value
org_default_field = org_member_to_org_default.get(field_name)
if org_default_field and hasattr(org, org_default_field):
return getattr(org, org_default_field)
return None
# Get values from OrgMember with Org fallback for fields with default_ prefix
llm_model = get_value_with_org_fallback('llm_model', org_member.llm_model)
llm_base_url = get_value_with_org_fallback(
'llm_base_url', org_member.llm_base_url
)
max_iterations = get_value_with_org_fallback(
'max_iterations', org_member.max_iterations
)
return UserSettings(
keycloak_user_id=user_id,
# OrgMember fields
llm_api_key=org_member.llm_api_key.get_secret_value()
if org_member.llm_api_key
else None,
llm_api_key_for_byor=org_member.llm_api_key_for_byor.get_secret_value()
if org_member.llm_api_key_for_byor
else None,
llm_model=llm_model,
llm_base_url=llm_base_url,
max_iterations=max_iterations,
# User fields
accepted_tos=user.accepted_tos,
enable_sound_notifications=user.enable_sound_notifications,
language=user.language,
user_consents_to_analytics=user.user_consents_to_analytics,
email=user.email,
email_verified=user.email_verified,
git_user_name=user.git_user_name,
git_user_email=user.git_user_email,
# Org fields
agent=org.agent,
security_analyzer=org.security_analyzer,
confirmation_mode=org.confirmation_mode,
remote_runtime_resource_factor=org.remote_runtime_resource_factor,
enable_default_condenser=org.enable_default_condenser,
billing_margin=org.billing_margin,
enable_proactive_conversation_starters=org.enable_proactive_conversation_starters,
sandbox_base_container_image=org.sandbox_base_container_image,
sandbox_runtime_container_image=org.sandbox_runtime_container_image,
user_version=org.org_version,
mcp_config=org.mcp_config,
search_api_key=org.search_api_key.get_secret_value()
if org.search_api_key
else None,
sandbox_api_key=org.sandbox_api_key.get_secret_value()
if org.sandbox_api_key
else None,
max_budget_per_task=org.max_budget_per_task,
enable_solvability_analysis=org.enable_solvability_analysis,
v1_enabled=org.v1_enabled,
condenser_max_size=org.condenser_max_size,
already_migrated=False,
)
@staticmethod
def _has_custom_settings(
user_settings: UserSettings, old_user_version: int | None
@@ -6,13 +6,10 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from integrations.jira.jira_manager import JiraManager
from integrations.jira.jira_payload import (
JiraEventType,
JiraWebhookPayload,
)
from integrations.jira.jira_view import (
JiraNewConversationView,
)
from integrations.models import JobContext
from jinja2 import DictLoader, Environment
from storage.jira_conversation import JiraConversation
from storage.jira_user import JiraUser
@@ -27,7 +24,7 @@ def mock_token_manager():
"""Create a mock TokenManager for testing."""
token_manager = MagicMock()
token_manager.get_user_id_from_user_email = AsyncMock()
token_manager.decrypt_text = MagicMock(return_value='decrypted_key')
token_manager.decrypt_text = MagicMock()
return token_manager
@@ -62,7 +59,6 @@ def sample_jira_workspace():
workspace = MagicMock(spec=JiraWorkspace)
workspace.id = 1
workspace.name = 'test.atlassian.net'
workspace.jira_cloud_id = 'cloud-123'
workspace.admin_user_id = 'admin_id'
workspace.webhook_secret = 'encrypted_secret'
workspace.svc_acc_email = 'service@example.com'
@@ -78,41 +74,22 @@ def sample_user_auth():
user_auth.get_provider_tokens = AsyncMock(return_value={})
user_auth.get_access_token = AsyncMock(return_value='test_token')
user_auth.get_user_id = AsyncMock(return_value='test_user_id')
user_auth.get_secrets = AsyncMock(return_value=None)
return user_auth
@pytest.fixture
def sample_webhook_payload():
"""Create a sample JiraWebhookPayload for testing."""
return JiraWebhookPayload(
event_type=JiraEventType.COMMENT_MENTION,
raw_event='comment_created',
def sample_job_context():
"""Create a sample JobContext for testing."""
return JobContext(
issue_id='12345',
issue_key='TEST-123',
user_msg='Fix this bug @openhands',
user_email='user@test.com',
display_name='Test User',
account_id='user123',
workspace_name='test.atlassian.net',
base_api_url='https://test.atlassian.net',
comment_body='Fix this bug @openhands',
)
@pytest.fixture
def sample_label_webhook_payload():
"""Create a sample labeled ticket JiraWebhookPayload for testing."""
return JiraWebhookPayload(
event_type=JiraEventType.LABELED_TICKET,
raw_event='jira:issue_updated',
issue_id='12345',
issue_key='PROJ-123',
user_email='user@company.com',
display_name='Test User',
account_id='user456',
workspace_name='jira.company.com',
base_api_url='https://jira.company.com',
comment_body='',
issue_title='Test Issue',
issue_description='This is a test issue',
)
@@ -203,17 +180,16 @@ def jira_conversation():
@pytest.fixture
def new_conversation_view(
sample_webhook_payload, sample_user_auth, sample_jira_user, sample_jira_workspace
sample_job_context, sample_user_auth, sample_jira_user, sample_jira_workspace
):
"""JiraNewConversationView instance for testing"""
return JiraNewConversationView(
payload=sample_webhook_payload,
job_context=sample_job_context,
saas_user_auth=sample_user_auth,
jira_user=sample_jira_user,
jira_workspace=sample_jira_workspace,
selected_repo='test/repo1',
conversation_id='conv-123',
_decrypted_api_key='decrypted_key',
)
File diff suppressed because it is too large Load Diff
@@ -2,90 +2,29 @@
Tests for Jira view classes and factory.
"""
from unittest.mock import AsyncMock, MagicMock, patch
from unittest.mock import AsyncMock, patch
import pytest
from integrations.jira.jira_payload import (
JiraEventType,
JiraPayloadError,
JiraPayloadParser,
JiraPayloadSkipped,
JiraPayloadSuccess,
)
from integrations.jira.jira_types import RepositoryNotFoundError, StartingConvoException
from integrations.jira.jira_types import StartingConvoException
from integrations.jira.jira_view import (
JiraFactory,
JiraNewConversationView,
)
from integrations.models import Message, SourceType
class TestJiraNewConversationView:
"""Tests for JiraNewConversationView"""
@pytest.mark.asyncio
async def test_get_issue_details_success(
self, new_conversation_view, sample_jira_workspace
):
"""Test successful issue details retrieval."""
mock_response = MagicMock()
mock_response.json.return_value = {
'fields': {'summary': 'Test Issue', 'description': 'Test description'}
}
mock_response.raise_for_status = MagicMock()
with patch('httpx.AsyncClient') as mock_client:
mock_client.return_value.__aenter__.return_value.get = AsyncMock(
return_value=mock_response
)
title, description = await new_conversation_view.get_issue_details()
assert title == 'Test Issue'
assert description == 'Test description'
@pytest.mark.asyncio
async def test_get_issue_details_cached(self, new_conversation_view):
"""Test issue details are cached after first call."""
new_conversation_view._issue_title = 'Cached Title'
new_conversation_view._issue_description = 'Cached Description'
title, description = await new_conversation_view.get_issue_details()
assert title == 'Cached Title'
assert description == 'Cached Description'
@pytest.mark.asyncio
async def test_get_issue_details_no_title(self, new_conversation_view):
"""Test issue details with no title raises error."""
mock_response = MagicMock()
mock_response.json.return_value = {
'fields': {'summary': '', 'description': 'Test description'}
}
mock_response.raise_for_status = MagicMock()
with patch('httpx.AsyncClient') as mock_client:
mock_client.return_value.__aenter__.return_value.get = AsyncMock(
return_value=mock_response
)
with pytest.raises(StartingConvoException, match='does not have a title'):
await new_conversation_view.get_issue_details()
@pytest.mark.asyncio
async def test_get_instructions(self, new_conversation_view, mock_jinja_env):
"""Test _get_instructions method fetches issue details."""
new_conversation_view._issue_title = 'Test Issue'
new_conversation_view._issue_description = 'This is a test issue'
instructions, user_msg = await new_conversation_view._get_instructions(
mock_jinja_env
)
def test_get_instructions(self, new_conversation_view, mock_jinja_env):
"""Test _get_instructions method"""
instructions, user_msg = new_conversation_view._get_instructions(mock_jinja_env)
assert instructions == 'Test Jira instructions template'
assert 'TEST-123' in user_msg
assert 'Test Issue' in user_msg
assert 'Fix this bug @openhands' in user_msg
@pytest.mark.asyncio
@patch('integrations.jira.jira_view.create_new_conversation')
@patch('integrations.jira.jira_view.integration_store')
async def test_create_or_update_conversation_success(
@@ -97,8 +36,6 @@ class TestJiraNewConversationView:
mock_agent_loop_info,
):
"""Test successful conversation creation"""
new_conversation_view._issue_title = 'Test Issue'
new_conversation_view._issue_description = 'Test description'
mock_create_conversation.return_value = mock_agent_loop_info
mock_store.create_conversation = AsyncMock()
@@ -110,7 +47,6 @@ class TestJiraNewConversationView:
mock_create_conversation.assert_called_once()
mock_store.create_conversation.assert_called_once()
@pytest.mark.asyncio
async def test_create_or_update_conversation_no_repo(
self, new_conversation_view, mock_jinja_env
):
@@ -120,6 +56,18 @@ class TestJiraNewConversationView:
with pytest.raises(StartingConvoException, match='No repository selected'):
await new_conversation_view.create_or_update_conversation(mock_jinja_env)
@patch('integrations.jira.jira_view.create_new_conversation')
async def test_create_or_update_conversation_failure(
self, mock_create_conversation, new_conversation_view, mock_jinja_env
):
"""Test conversation creation failure"""
mock_create_conversation.side_effect = Exception('Creation failed')
with pytest.raises(
StartingConvoException, match='Failed to create conversation'
):
await new_conversation_view.create_or_update_conversation(mock_jinja_env)
def test_get_response_msg(self, new_conversation_view):
"""Test get_response_msg method"""
response = new_conversation_view.get_response_msg()
@@ -133,333 +81,469 @@ class TestJiraNewConversationView:
class TestJiraFactory:
"""Tests for JiraFactory"""
@pytest.mark.asyncio
@patch('integrations.jira.jira_view.JiraFactory._create_provider_handler')
@patch('integrations.jira.jira_view.infer_repo_from_message')
async def test_create_view_success(
@patch('integrations.jira.jira_view.integration_store')
async def test_create_jira_view_from_payload_new_conversation(
self,
mock_infer_repos,
mock_create_handler,
sample_webhook_payload,
sample_user_auth,
sample_jira_user,
sample_jira_workspace,
sample_repositories,
):
"""Test factory creating view with repo selection."""
# Setup mock provider handler
mock_handler = MagicMock()
mock_handler.verify_repo_provider = AsyncMock(
return_value=sample_repositories[0]
)
mock_create_handler.return_value = mock_handler
# Mock repo inference to return a repo name
mock_infer_repos.return_value = ['test/repo1']
mock_response = MagicMock()
mock_response.json.return_value = {
'fields': {'summary': 'Test Issue', 'description': 'Test description'}
}
mock_response.raise_for_status = MagicMock()
with patch('httpx.AsyncClient') as mock_client:
mock_client.return_value.__aenter__.return_value.get = AsyncMock(
return_value=mock_response
)
view = await JiraFactory.create_view(
payload=sample_webhook_payload,
workspace=sample_jira_workspace,
user=sample_jira_user,
user_auth=sample_user_auth,
decrypted_api_key='test_api_key',
)
assert isinstance(view, JiraNewConversationView)
assert view.selected_repo == 'test/repo1'
mock_handler.verify_repo_provider.assert_called_once_with('test/repo1')
@pytest.mark.asyncio
@patch('integrations.jira.jira_view.JiraFactory._create_provider_handler')
@patch('integrations.jira.jira_view.infer_repo_from_message')
async def test_create_view_no_repo_in_text(
self,
mock_infer_repos,
mock_create_handler,
sample_webhook_payload,
mock_store,
sample_job_context,
sample_user_auth,
sample_jira_user,
sample_jira_workspace,
):
"""Test factory raises error when no repo mentioned in text."""
mock_handler = MagicMock()
mock_create_handler.return_value = mock_handler
"""Test factory creating new conversation view"""
mock_store.get_user_conversations_by_issue_id = AsyncMock(return_value=None)
# No repos found in text
mock_infer_repos.return_value = []
mock_response = MagicMock()
mock_response.json.return_value = {
'fields': {'summary': 'Test Issue', 'description': 'Test description'}
}
mock_response.raise_for_status = MagicMock()
with patch('httpx.AsyncClient') as mock_client:
mock_client.return_value.__aenter__.return_value.get = AsyncMock(
return_value=mock_response
)
with pytest.raises(
RepositoryNotFoundError, match='Could not determine which repository'
):
await JiraFactory.create_view(
payload=sample_webhook_payload,
workspace=sample_jira_workspace,
user=sample_jira_user,
user_auth=sample_user_auth,
decrypted_api_key='test_api_key',
)
@pytest.mark.asyncio
@patch('integrations.jira.jira_view.JiraFactory._create_provider_handler')
@patch('integrations.jira.jira_view.infer_repo_from_message')
async def test_create_view_repo_verification_fails(
self,
mock_infer_repos,
mock_create_handler,
sample_webhook_payload,
sample_user_auth,
sample_jira_user,
sample_jira_workspace,
):
"""Test factory raises error when repo verification fails."""
mock_handler = MagicMock()
mock_handler.verify_repo_provider = AsyncMock(
side_effect=Exception('Repository not found')
)
mock_create_handler.return_value = mock_handler
# Repos found in text but verification fails
mock_infer_repos.return_value = ['test/repo1', 'test/repo2']
mock_response = MagicMock()
mock_response.json.return_value = {
'fields': {'summary': 'Test Issue', 'description': 'Test description'}
}
mock_response.raise_for_status = MagicMock()
with patch('httpx.AsyncClient') as mock_client:
mock_client.return_value.__aenter__.return_value.get = AsyncMock(
return_value=mock_response
)
with pytest.raises(
RepositoryNotFoundError,
match='Could not access any of the mentioned repositories',
):
await JiraFactory.create_view(
payload=sample_webhook_payload,
workspace=sample_jira_workspace,
user=sample_jira_user,
user_auth=sample_user_auth,
decrypted_api_key='test_api_key',
)
@pytest.mark.asyncio
@patch('integrations.jira.jira_view.JiraFactory._create_provider_handler')
@patch('integrations.jira.jira_view.infer_repo_from_message')
async def test_create_view_multiple_repos_verified(
self,
mock_infer_repos,
mock_create_handler,
sample_webhook_payload,
sample_user_auth,
sample_jira_user,
sample_jira_workspace,
sample_repositories,
):
"""Test factory raises error when multiple repos are verified."""
mock_handler = MagicMock()
# Both repos verify successfully
mock_handler.verify_repo_provider = AsyncMock(
side_effect=[sample_repositories[0], sample_repositories[1]]
)
mock_create_handler.return_value = mock_handler
# Multiple repos found in text
mock_infer_repos.return_value = ['test/repo1', 'test/repo2']
mock_response = MagicMock()
mock_response.json.return_value = {
'fields': {'summary': 'Test Issue', 'description': 'Test description'}
}
mock_response.raise_for_status = MagicMock()
with patch('httpx.AsyncClient') as mock_client:
mock_client.return_value.__aenter__.return_value.get = AsyncMock(
return_value=mock_response
)
with pytest.raises(
RepositoryNotFoundError, match='Multiple repositories found'
):
await JiraFactory.create_view(
payload=sample_webhook_payload,
workspace=sample_jira_workspace,
user=sample_jira_user,
user_auth=sample_user_auth,
decrypted_api_key='test_api_key',
)
@pytest.mark.asyncio
@patch('integrations.jira.jira_view.JiraFactory._create_provider_handler')
async def test_create_view_no_provider(
self,
mock_create_handler,
sample_webhook_payload,
sample_user_auth,
sample_jira_user,
sample_jira_workspace,
):
"""Test factory raises error when no provider is connected."""
mock_create_handler.return_value = None
mock_response = MagicMock()
mock_response.json.return_value = {
'fields': {'summary': 'Test Issue', 'description': 'Test description'}
}
mock_response.raise_for_status = MagicMock()
with patch('httpx.AsyncClient') as mock_client:
mock_client.return_value.__aenter__.return_value.get = AsyncMock(
return_value=mock_response
)
with pytest.raises(
RepositoryNotFoundError, match='No Git provider connected'
):
await JiraFactory.create_view(
payload=sample_webhook_payload,
workspace=sample_jira_workspace,
user=sample_jira_user,
user_auth=sample_user_auth,
decrypted_api_key='test_api_key',
)
class TestJiraPayloadParser:
"""Tests for JiraPayloadParser"""
@pytest.fixture
def parser(self):
"""Create a parser for testing."""
return JiraPayloadParser(oh_label='openhands', inline_oh_label='@openhands')
def test_parse_label_event_success(
self, parser, sample_issue_update_webhook_payload
):
"""Test parsing label event."""
result = parser.parse(sample_issue_update_webhook_payload)
assert isinstance(result, JiraPayloadSuccess)
assert result.payload.event_type == JiraEventType.LABELED_TICKET
assert result.payload.issue_key == 'PROJ-123'
def test_parse_comment_event_success(self, parser, sample_comment_webhook_payload):
"""Test parsing comment event."""
result = parser.parse(sample_comment_webhook_payload)
assert isinstance(result, JiraPayloadSuccess)
assert result.payload.event_type == JiraEventType.COMMENT_MENTION
assert result.payload.issue_key == 'TEST-123'
assert '@openhands' in result.payload.comment_body
def test_parse_unknown_event_skipped(self, parser):
"""Test unknown event is skipped."""
payload = {'webhookEvent': 'unknown_event'}
result = parser.parse(payload)
assert isinstance(result, JiraPayloadSkipped)
assert 'Unhandled webhook event type' in result.skip_reason
def test_parse_label_event_wrong_label_skipped(self, parser):
"""Test label event without OH label is skipped."""
payload = {
'webhookEvent': 'jira:issue_updated',
'changelog': {'items': [{'field': 'labels', 'toString': 'other-label'}]},
}
result = parser.parse(payload)
assert isinstance(result, JiraPayloadSkipped)
assert 'does not contain' in result.skip_reason
def test_parse_comment_event_no_mention_skipped(self, parser):
"""Test comment without mention is skipped."""
payload = {
'webhookEvent': 'comment_created',
'comment': {
'body': 'Regular comment',
'author': {'emailAddress': 'test@test.com'},
},
}
result = parser.parse(payload)
assert isinstance(result, JiraPayloadSkipped)
assert 'does not mention' in result.skip_reason
def test_parse_missing_fields_error(self, parser):
"""Test missing required fields returns error."""
payload = {
'webhookEvent': 'jira:issue_updated',
'changelog': {'items': [{'field': 'labels', 'toString': 'openhands'}]},
'issue': {'id': '123'}, # Missing key
'user': {'emailAddress': 'test@test.com'}, # Missing other fields
}
result = parser.parse(payload)
assert isinstance(result, JiraPayloadError)
assert 'Missing required fields' in result.error
class TestJiraPayloadParserStagingLabels:
"""Tests for JiraPayloadParser with staging labels."""
@pytest.fixture
def staging_parser(self):
"""Create a parser with staging labels."""
return JiraPayloadParser(
oh_label='openhands-exp', inline_oh_label='@openhands-exp'
view = await JiraFactory.create_jira_view_from_payload(
sample_job_context,
sample_user_auth,
sample_jira_user,
sample_jira_workspace,
)
def test_parse_staging_label(self, staging_parser):
"""Test parsing with staging label."""
payload = {
'webhookEvent': 'jira:issue_updated',
'changelog': {'items': [{'field': 'labels', 'toString': 'openhands-exp'}]},
'issue': {
'id': '123',
'key': 'TEST-1',
'self': 'https://test.atlassian.net/rest/api/2/issue/123',
},
'user': {
'emailAddress': 'test@test.com',
'displayName': 'Test',
'accountId': 'acc123',
'self': 'https://test.atlassian.net/rest/api/2/user',
},
}
result = staging_parser.parse(payload)
assert isinstance(view, JiraNewConversationView)
assert view.conversation_id == ''
assert isinstance(result, JiraPayloadSuccess)
assert result.payload.event_type == JiraEventType.LABELED_TICKET
async def test_create_jira_view_from_payload_no_user(
self, sample_job_context, sample_user_auth, sample_jira_workspace
):
"""Test factory with no Jira user"""
with pytest.raises(StartingConvoException, match='User not authenticated'):
await JiraFactory.create_jira_view_from_payload(
sample_job_context,
sample_user_auth,
None,
sample_jira_workspace, # type: ignore
)
def test_parse_prod_label_in_staging_skipped(self, staging_parser):
"""Test prod label is skipped in staging environment."""
payload = {
'webhookEvent': 'jira:issue_updated',
'changelog': {'items': [{'field': 'labels', 'toString': 'openhands'}]},
}
result = staging_parser.parse(payload)
async def test_create_jira_view_from_payload_no_auth(
self, sample_job_context, sample_jira_user, sample_jira_workspace
):
"""Test factory with no SaaS auth"""
with pytest.raises(StartingConvoException, match='User not authenticated'):
await JiraFactory.create_jira_view_from_payload(
sample_job_context,
None,
sample_jira_user,
sample_jira_workspace, # type: ignore
)
assert isinstance(result, JiraPayloadSkipped)
async def test_create_jira_view_from_payload_no_workspace(
self, sample_job_context, sample_user_auth, sample_jira_user
):
"""Test factory with no workspace"""
with pytest.raises(StartingConvoException, match='User not authenticated'):
await JiraFactory.create_jira_view_from_payload(
sample_job_context,
sample_user_auth,
sample_jira_user,
None, # type: ignore
)
class TestJiraViewEdgeCases:
"""Tests for edge cases and error scenarios"""
@patch('integrations.jira.jira_view.create_new_conversation')
@patch('integrations.jira.jira_view.integration_store')
async def test_conversation_creation_with_no_user_secrets(
self,
mock_store,
mock_create_conversation,
new_conversation_view,
mock_jinja_env,
mock_agent_loop_info,
):
"""Test conversation creation when user has no secrets"""
new_conversation_view.saas_user_auth.get_secrets.return_value = None
mock_create_conversation.return_value = mock_agent_loop_info
mock_store.create_conversation = AsyncMock()
result = await new_conversation_view.create_or_update_conversation(
mock_jinja_env
)
assert result == 'conv-123'
# Verify create_new_conversation was called with custom_secrets=None
call_kwargs = mock_create_conversation.call_args[1]
assert call_kwargs['custom_secrets'] is None
@patch('integrations.jira.jira_view.create_new_conversation')
@patch('integrations.jira.jira_view.integration_store')
async def test_conversation_creation_store_failure(
self,
mock_store,
mock_create_conversation,
new_conversation_view,
mock_jinja_env,
mock_agent_loop_info,
):
"""Test conversation creation when store creation fails"""
mock_create_conversation.return_value = mock_agent_loop_info
mock_store.create_conversation = AsyncMock(side_effect=Exception('Store error'))
with pytest.raises(
StartingConvoException, match='Failed to create conversation'
):
await new_conversation_view.create_or_update_conversation(mock_jinja_env)
def test_new_conversation_view_attributes(self, new_conversation_view):
"""Test new conversation view attribute access"""
assert new_conversation_view.job_context.issue_key == 'TEST-123'
assert new_conversation_view.selected_repo == 'test/repo1'
assert new_conversation_view.conversation_id == 'conv-123'
class TestJiraFactoryIsLabeledTicket:
"""Parameterized tests for JiraFactory.is_labeled_ticket method."""
@pytest.mark.parametrize(
'payload,expected',
[
pytest.param(
{
'webhookEvent': 'jira:issue_updated',
'changelog': {
'items': [{'field': 'labels', 'toString': 'openhands'}]
},
},
True,
id='issue_updated_with_openhands_label',
),
pytest.param(
{
'webhookEvent': 'jira:issue_updated',
'changelog': {
'items': [
{'field': 'labels', 'toString': 'bug'},
{'field': 'labels', 'toString': 'openhands'},
]
},
},
True,
id='issue_updated_with_multiple_labels_including_openhands',
),
pytest.param(
{
'webhookEvent': 'jira:issue_updated',
'changelog': {
'items': [{'field': 'labels', 'toString': 'bug,urgent'}]
},
},
False,
id='issue_updated_without_openhands_label',
),
pytest.param(
{
'webhookEvent': 'jira:issue_updated',
'changelog': {'items': []},
},
False,
id='issue_updated_with_empty_changelog_items',
),
pytest.param(
{
'webhookEvent': 'jira:issue_updated',
'changelog': {},
},
False,
id='issue_updated_with_empty_changelog',
),
pytest.param(
{
'webhookEvent': 'jira:issue_updated',
},
False,
id='issue_updated_without_changelog',
),
pytest.param(
{
'webhookEvent': 'comment_created',
'changelog': {
'items': [{'field': 'labels', 'toString': 'openhands'}]
},
},
False,
id='comment_created_event_with_label',
),
pytest.param(
{
'webhookEvent': 'issue_deleted',
'changelog': {
'items': [{'field': 'labels', 'toString': 'openhands'}]
},
},
False,
id='unsupported_event_type',
),
pytest.param(
{},
False,
id='empty_payload',
),
pytest.param(
{
'webhookEvent': 'jira:issue_updated',
'changelog': {
'items': [{'field': 'status', 'toString': 'In Progress'}]
},
},
False,
id='issue_updated_with_non_label_field',
),
pytest.param(
{
'webhookEvent': 'jira:issue_updated',
'changelog': {
'items': [{'field': 'labels', 'fromString': 'openhands'}]
},
},
False,
id='issue_updated_with_fromString_instead_of_toString',
),
pytest.param(
{
'webhookEvent': 'jira:issue_updated',
'changelog': {
'items': [
{'field': 'labels', 'toString': 'not-openhands'},
{'field': 'priority', 'toString': 'High'},
]
},
},
False,
id='issue_updated_with_mixed_fields_no_openhands',
),
],
)
def test_is_labeled_ticket(self, payload, expected):
"""Test is_labeled_ticket with various payloads."""
with patch('integrations.jira.jira_view.OH_LABEL', 'openhands'):
message = Message(source=SourceType.JIRA, message={'payload': payload})
result = JiraFactory.is_labeled_ticket(message)
assert result == expected
@pytest.mark.parametrize(
'payload,expected',
[
pytest.param(
{
'webhookEvent': 'jira:issue_updated',
'changelog': {
'items': [{'field': 'labels', 'toString': 'openhands-exp'}]
},
},
True,
id='issue_updated_with_staging_label',
),
pytest.param(
{
'webhookEvent': 'jira:issue_updated',
'changelog': {
'items': [{'field': 'labels', 'toString': 'openhands'}]
},
},
False,
id='issue_updated_with_prod_label_in_staging_env',
),
],
)
def test_is_labeled_ticket_staging_labels(self, payload, expected):
"""Test is_labeled_ticket with staging environment labels."""
with patch('integrations.jira.jira_view.OH_LABEL', 'openhands-exp'):
message = Message(source=SourceType.JIRA, message={'payload': payload})
result = JiraFactory.is_labeled_ticket(message)
assert result == expected
class TestJiraFactoryIsTicketComment:
"""Parameterized tests for JiraFactory.is_ticket_comment method."""
@pytest.mark.parametrize(
'payload,expected',
[
pytest.param(
{
'webhookEvent': 'comment_created',
'comment': {'body': 'Please fix this @openhands'},
},
True,
id='comment_with_openhands_mention',
),
pytest.param(
{
'webhookEvent': 'comment_created',
'comment': {'body': '@openhands please help'},
},
True,
id='comment_starting_with_openhands_mention',
),
pytest.param(
{
'webhookEvent': 'comment_created',
'comment': {'body': 'Hello @openhands!'},
},
True,
id='comment_with_openhands_mention_and_punctuation',
),
pytest.param(
{
'webhookEvent': 'comment_created',
'comment': {'body': '(@openhands)'},
},
True,
id='comment_with_openhands_in_parentheses',
),
pytest.param(
{
'webhookEvent': 'comment_created',
'comment': {'body': 'Hey @OpenHands can you help?'},
},
True,
id='comment_with_case_insensitive_mention',
),
pytest.param(
{
'webhookEvent': 'comment_created',
'comment': {'body': 'Hey @OPENHANDS!'},
},
True,
id='comment_with_uppercase_mention',
),
pytest.param(
{
'webhookEvent': 'comment_created',
'comment': {'body': 'Regular comment without mention'},
},
False,
id='comment_without_mention',
),
pytest.param(
{
'webhookEvent': 'comment_created',
'comment': {'body': 'Hello @openhands-agent!'},
},
False,
id='comment_with_openhands_as_prefix',
),
pytest.param(
{
'webhookEvent': 'comment_created',
'comment': {'body': 'user@openhands.com'},
},
False,
id='comment_with_openhands_in_email',
),
pytest.param(
{
'webhookEvent': 'comment_created',
'comment': {'body': ''},
},
False,
id='comment_with_empty_body',
),
pytest.param(
{
'webhookEvent': 'comment_created',
'comment': {},
},
False,
id='comment_without_body',
),
pytest.param(
{
'webhookEvent': 'comment_created',
},
False,
id='comment_created_without_comment_data',
),
pytest.param(
{
'webhookEvent': 'jira:issue_updated',
'comment': {'body': 'Please fix this @openhands'},
},
False,
id='issue_updated_event_with_mention',
),
pytest.param(
{
'webhookEvent': 'issue_deleted',
'comment': {'body': '@openhands'},
},
False,
id='unsupported_event_type_with_mention',
),
pytest.param(
{},
False,
id='empty_payload',
),
pytest.param(
{
'webhookEvent': 'comment_created',
'comment': {'body': 'Multiple @openhands @openhands mentions'},
},
True,
id='comment_with_multiple_mentions',
),
],
)
def test_is_ticket_comment(self, payload, expected):
"""Test is_ticket_comment with various payloads."""
with patch('integrations.jira.jira_view.INLINE_OH_LABEL', '@openhands'), patch(
'integrations.jira.jira_view.has_exact_mention'
) as mock_has_exact_mention:
from integrations.utils import has_exact_mention
mock_has_exact_mention.side_effect = (
lambda text, mention: has_exact_mention(text, mention)
)
message = Message(source=SourceType.JIRA, message={'payload': payload})
result = JiraFactory.is_ticket_comment(message)
assert result == expected
@pytest.mark.parametrize(
'payload,expected',
[
pytest.param(
{
'webhookEvent': 'comment_created',
'comment': {'body': 'Please fix this @openhands-exp'},
},
True,
id='comment_with_staging_mention',
),
pytest.param(
{
'webhookEvent': 'comment_created',
'comment': {'body': '@openhands-exp please help'},
},
True,
id='comment_starting_with_staging_mention',
),
pytest.param(
{
'webhookEvent': 'comment_created',
'comment': {'body': 'Please fix this @openhands'},
},
False,
id='comment_with_prod_mention_in_staging_env',
),
],
)
def test_is_ticket_comment_staging_labels(self, payload, expected):
"""Test is_ticket_comment with staging environment labels."""
with patch(
'integrations.jira.jira_view.INLINE_OH_LABEL', '@openhands-exp'
), patch(
'integrations.jira.jira_view.has_exact_mention'
) as mock_has_exact_mention:
from integrations.utils import has_exact_mention
mock_has_exact_mention.side_effect = (
lambda text, mention: has_exact_mention(text, mention)
)
message = Message(source=SourceType.JIRA, message={'payload': payload})
result = JiraFactory.is_ticket_comment(message)
assert result == expected
@@ -1,96 +0,0 @@
from unittest.mock import patch
from urllib.parse import parse_qs, urlparse
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from pydantic import SecretStr
from server.routes.github_proxy import add_github_proxy_routes
@pytest.fixture
def app_with_github_proxy(monkeypatch):
"""Create a FastAPI app with github proxy routes enabled."""
# Enable the github proxy endpoints
monkeypatch.setenv('GITHUB_PROXY_ENDPOINTS', '1')
# Mock the config to have a jwt_secret
mock_config = type(
'MockConfig', (), {'jwt_secret': SecretStr('test-secret-key-for-testing')}
)()
app = FastAPI()
with patch('server.routes.github_proxy.GITHUB_PROXY_ENDPOINTS', True):
with patch('server.routes.github_proxy.config', mock_config):
add_github_proxy_routes(app)
# Return app and mock_config so we can use the same config in tests
return app, mock_config
def test_state_compress_encrypt_and_decrypt_decompress_roundtrip(
app_with_github_proxy, monkeypatch
):
"""
Verify the code path used by github_proxy_start -> github_proxy_callback:
- compress payload, encrypt, base64-encode (what the start code does)
- base64-decode, decrypt, decompress (what the callback code does)
This test exercises the actual endpoints to verify the roundtrip works correctly.
"""
app, mock_config = app_with_github_proxy
client = TestClient(app)
original_state = 'some-state-value'
original_redirect_uri = 'https://example.com/redirect'
# Call github_proxy_start endpoint - it should redirect to GitHub with encrypted state
with patch('server.routes.github_proxy.config', mock_config):
response = client.get(
'/github-proxy/test-subdomain/login/oauth/authorize',
params={
'state': original_state,
'redirect_uri': original_redirect_uri,
'client_id': 'test-client-id',
},
follow_redirects=False,
)
assert response.status_code == 307
redirect_url = response.headers['location']
# Verify it redirects to GitHub
assert redirect_url.startswith('https://github.com/login/oauth/authorize')
# Parse the redirect URL to get the encrypted state
parsed = urlparse(redirect_url)
query_params = parse_qs(parsed.query)
encrypted_state = query_params['state'][0]
# The redirect_uri should now point to our callback
assert 'github-proxy/callback' in query_params['redirect_uri'][0]
# Now simulate GitHub calling back with this encrypted state
with patch('server.routes.github_proxy.config', mock_config):
callback_response = client.get(
'/github-proxy/callback',
params={
'state': encrypted_state,
'code': 'test-auth-code',
},
follow_redirects=False,
)
assert callback_response.status_code == 307
final_redirect = callback_response.headers['location']
# Verify the callback redirects to the original redirect_uri
assert final_redirect.startswith(original_redirect_uri)
# Parse the final redirect to verify the state was decrypted correctly
final_parsed = urlparse(final_redirect)
final_params = parse_qs(final_parsed.query)
assert final_params['state'][0] == original_state
assert final_params['code'][0] == 'test-auth-code'
File diff suppressed because it is too large Load Diff
+10 -10
View File
@@ -163,7 +163,7 @@ async def test_create_checkout_session_stripe_error(
'server.auth.token_manager.TokenManager.get_user_info_from_user_id',
AsyncMock(return_value={'email': 'testy@tester.com'}),
),
patch('server.routes.billing.validate_billing_enabled'),
patch('server.routes.billing.validate_saas_environment'),
):
await create_checkout_session(
CreateCheckoutSessionRequest(amount=25), mock_checkout_request, 'mock_user'
@@ -204,7 +204,7 @@ async def test_create_checkout_session_success(session_maker, mock_checkout_requ
'server.auth.token_manager.TokenManager.get_user_info_from_user_id',
AsyncMock(return_value={'email': 'testy@tester.com'}),
),
patch('server.routes.billing.validate_billing_enabled'),
patch('server.routes.billing.validate_saas_environment'),
):
mock_db_session = MagicMock()
mock_session_maker.return_value.__enter__.return_value = mock_db_session
@@ -236,8 +236,8 @@ async def test_create_checkout_session_success(session_maker, mock_checkout_requ
mode='payment',
payment_method_types=['card'],
saved_payment_method_options={'payment_method_save': 'enabled'},
success_url='https://test.com/api/billing/success?session_id={CHECKOUT_SESSION_ID}',
cancel_url='https://test.com/api/billing/cancel?session_id={CHECKOUT_SESSION_ID}',
success_url='http://test.com/api/billing/success?session_id={CHECKOUT_SESSION_ID}',
cancel_url='http://test.com/api/billing/cancel?session_id={CHECKOUT_SESSION_ID}',
)
# Verify database session creation
@@ -331,7 +331,7 @@ async def test_success_callback_success():
assert response.status_code == 302
assert (
response.headers['location']
== 'https://test.com/settings/billing?checkout=success'
== 'http://test.com/settings/billing?checkout=success'
)
# Verify LiteLLM API calls
@@ -402,7 +402,7 @@ async def test_cancel_callback_session_not_found():
assert response.status_code == 302
assert (
response.headers['location']
== 'https://test.com/settings/billing?checkout=cancel'
== 'http://test.com/settings/billing?checkout=cancel'
)
# Verify no database updates occurred
@@ -429,7 +429,7 @@ async def test_cancel_callback_success():
assert response.status_code == 302
assert (
response.headers['location']
== 'https://test.com/settings/billing?checkout=cancel'
== 'http://test.com/settings/billing?checkout=cancel'
)
# Verify database updates
@@ -490,7 +490,7 @@ async def test_create_customer_setup_session_success():
AsyncMock(return_value=mock_customer_info),
),
patch('stripe.checkout.Session.create_async', mock_create),
patch('server.routes.billing.validate_billing_enabled'),
patch('server.routes.billing.validate_saas_environment'),
):
result = await create_customer_setup_session(mock_request, 'mock_user')
@@ -502,6 +502,6 @@ async def test_create_customer_setup_session_success():
customer='mock-customer-id',
mode='setup',
payment_method_types=['card'],
success_url='https://test.com/?free_credits=success',
cancel_url='https://test.com/',
success_url='http://test.com/?free_credits=success',
cancel_url='http://test.com/',
)
@@ -1126,174 +1126,3 @@ class TestLiteLlmManager:
'http://test.url/team/delete',
json={'team_ids': [team_id]},
)
@pytest.mark.asyncio
async def test_remove_user_from_team_successful(self):
"""
GIVEN: Valid user_id and team_id
WHEN: _remove_user_from_team is called
THEN: HTTP POST is made to remove user from team
"""
mock_response = AsyncMock()
mock_response.is_success = True
mock_response.status_code = 200
with (
patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'),
patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.url'),
):
mock_client = AsyncMock()
mock_client.post.return_value = mock_response
await LiteLlmManager._remove_user_from_team(
mock_client, 'test-user-id', 'test-team-id'
)
mock_client.post.assert_called_once_with(
'http://test.url/team/member_delete',
json={
'team_id': 'test-team-id',
'user_id': 'test-user-id',
},
)
@pytest.mark.asyncio
async def test_remove_user_from_team_not_found(self):
"""
GIVEN: User not in team
WHEN: _remove_user_from_team is called
THEN: 404 response is handled gracefully without raising
"""
mock_response = AsyncMock()
mock_response.is_success = False
mock_response.status_code = 404
mock_response.text = 'User not found in team'
mock_response.raise_for_status = MagicMock()
with (
patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'),
patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.url'),
):
mock_client = AsyncMock()
mock_client.post.return_value = mock_response
# Should not raise an exception
await LiteLlmManager._remove_user_from_team(
mock_client, 'test-user-id', 'test-team-id'
)
@pytest.mark.asyncio
async def test_downgrade_entries_missing_config(self, mock_user_settings):
"""Test downgrade_entries when LiteLLM config is missing."""
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', None):
with patch('storage.lite_llm_manager.LITE_LLM_API_URL', None):
result = await LiteLlmManager.downgrade_entries(
'test-org-id',
'test-user-id',
mock_user_settings,
)
assert result is None
@pytest.mark.asyncio
async def test_downgrade_entries_team_not_found(self, mock_user_settings):
"""Test downgrade_entries when team is not found."""
with patch.dict(os.environ, {'LOCAL_DEPLOYMENT': ''}):
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'):
with patch(
'storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'
):
with patch.object(
LiteLlmManager, '_get_team', new_callable=AsyncMock
) as mock_get_team:
mock_get_team.return_value = None
result = await LiteLlmManager.downgrade_entries(
'test-org-id',
'test-user-id',
mock_user_settings,
)
assert result is None
@pytest.mark.asyncio
async def test_downgrade_entries_successful(self, mock_user_settings):
"""Test successful downgrade_entries operation."""
mock_response = MagicMock()
mock_response.is_success = True
mock_response.status_code = 200
mock_response.raise_for_status = MagicMock()
mock_team_info_response = MagicMock()
mock_team_info_response.is_success = True
mock_team_info_response.status_code = 200
mock_team_info_response.json.return_value = {
'team_info': {
'max_budget': 100.0,
'spend': 20.0,
},
'team_memberships': [
{
'user_id': 'test-user-id',
'team_id': 'test-org-id',
'max_budget_in_team': 100.0,
'spend': 20.0,
}
],
}
mock_team_info_response.raise_for_status = MagicMock()
with patch.dict(os.environ, {'LOCAL_DEPLOYMENT': ''}):
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'):
with patch(
'storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'
):
with patch(
'storage.lite_llm_manager.LITE_LLM_TEAM_ID', 'default-team'
):
with patch('httpx.AsyncClient') as mock_client_class:
mock_client = AsyncMock()
mock_client_class.return_value.__aenter__.return_value = (
mock_client
)
mock_client.get.return_value = mock_team_info_response
mock_client.post.return_value = mock_response
result = await LiteLlmManager.downgrade_entries(
'test-org-id',
'test-user-id',
mock_user_settings,
)
# downgrade_entries returns the user_settings
assert result is not None
assert result.agent == 'TestAgent'
# Verify downgrade steps were called:
# 1. get_team (GET)
# 2. get_user_team_info (GET via _get_team)
# 3. update_user (POST)
# 4. add_user_to_team (POST)
# 5. update_key (POST)
# 6. remove_user_from_team (POST)
# 7. delete_team (POST)
assert mock_client.get.call_count >= 1
assert mock_client.post.call_count >= 4
@pytest.mark.asyncio
async def test_downgrade_entries_local_deployment(self, mock_user_settings):
"""Test downgrade_entries in local deployment mode (skips LiteLLM calls)."""
with patch.dict(os.environ, {'LOCAL_DEPLOYMENT': 'true'}):
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'):
with patch(
'storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'
):
result = await LiteLlmManager.downgrade_entries(
'test-org-id',
'test-user-id',
mock_user_settings,
)
# In local deployment, should return user_settings without
# making any LiteLLM calls
assert result is not None
assert result.agent == 'TestAgent'
-81
View File
@@ -68,84 +68,3 @@ def test_user_model(session_maker):
)
assert queried_org_member is not None
assert queried_org_member.llm_api_key.get_secret_value() == 'test-api-key'
def test_user_model_git_user_fields(session_maker):
"""Test that git_user_name and git_user_email columns exist and work correctly."""
with session_maker() as session:
# Arrange
org = Org(name='test_org_git')
session.add(org)
session.flush()
test_user_id = uuid4()
# Act
user = User(
id=test_user_id,
current_org_id=org.id,
git_user_name='Test Git Author',
git_user_email='git@example.com',
)
session.add(user)
session.commit()
# Assert
queried_user = session.query(User).filter(User.id == test_user_id).first()
assert queried_user.git_user_name == 'Test Git Author'
assert queried_user.git_user_email == 'git@example.com'
def test_user_model_git_user_fields_nullable(session_maker):
"""Test that git_user_name and git_user_email can be null."""
with session_maker() as session:
# Arrange
org = Org(name='test_org_nullable')
session.add(org)
session.flush()
test_user_id = uuid4()
# Act - create user without git fields
user = User(
id=test_user_id,
current_org_id=org.id,
)
session.add(user)
session.commit()
# Assert
queried_user = session.query(User).filter(User.id == test_user_id).first()
assert queried_user.git_user_name is None
assert queried_user.git_user_email is None
def test_user_model_git_user_fields_in_table_columns():
"""Test that git_user_name and git_user_email are in User table columns."""
# Arrange & Act
column_names = [c.name for c in User.__table__.columns]
# Assert
assert 'git_user_name' in column_names
assert 'git_user_email' in column_names
def test_user_model_git_user_fields_hasattr(session_maker):
"""Test that hasattr returns True for git_user_* fields on User model.
This verifies the fix for SaasSettingsStore.store() which uses hasattr
to determine if a field should be persisted to a model.
"""
with session_maker() as session:
# Arrange
org = Org(name='test_org_hasattr')
session.add(org)
session.flush()
user = User(id=uuid4(), current_org_id=org.id)
session.add(user)
session.flush()
# Assert - hasattr must return True for store() to work
assert hasattr(user, 'git_user_name')
assert hasattr(user, 'git_user_email')
File diff suppressed because it is too large Load Diff
-371
View File
@@ -415,374 +415,3 @@ def test_persist_org_with_owner_with_multiple_fields(session_maker, mock_litellm
)
assert persisted_member.max_iterations == 100
assert persisted_member.llm_model == 'gpt-4'
@pytest.mark.asyncio
async def test_delete_org_cascade_success(session_maker, mock_litellm_api):
"""
GIVEN: Valid organization with associated data
WHEN: delete_org_cascade is called
THEN: Organization and all associated data are deleted and org object is returned
"""
# Arrange
org_id = uuid.uuid4()
# Create expected return object
expected_org = Org(
id=org_id,
name='Test Organization',
contact_name='John Doe',
contact_email='john@example.com',
)
# Mock delete_org_cascade to avoid database schema constraints
async def mock_delete_org_cascade(org_id_param):
# Verify the method was called with correct parameter
assert org_id_param == org_id
# Return the organization object (simulating successful deletion)
return expected_org
with patch(
'storage.org_store.OrgStore.delete_org_cascade', mock_delete_org_cascade
):
# Act
result = await OrgStore.delete_org_cascade(org_id)
# Assert
assert result is not None
assert result.id == org_id
assert result.name == 'Test Organization'
assert result.contact_name == 'John Doe'
assert result.contact_email == 'john@example.com'
@pytest.mark.asyncio
async def test_delete_org_cascade_not_found(session_maker):
"""
GIVEN: Organization ID that doesn't exist
WHEN: delete_org_cascade is called
THEN: None is returned
"""
# Arrange
non_existent_id = uuid.uuid4()
with patch('storage.org_store.session_maker', session_maker):
# Act
result = await OrgStore.delete_org_cascade(non_existent_id)
# Assert
assert result is None
@pytest.mark.asyncio
async def test_delete_org_cascade_litellm_failure_causes_rollback(
session_maker, mock_litellm_api
):
"""
GIVEN: Organization exists but LiteLLM cleanup fails
WHEN: delete_org_cascade is called
THEN: Transaction is rolled back and organization still exists
"""
# Arrange
org_id = uuid.uuid4()
user_id = uuid.uuid4()
with session_maker() as session:
role = Role(id=1, name='owner', rank=1)
user = User(id=user_id, current_org_id=org_id)
org = Org(
id=org_id,
name='Test Organization',
contact_name='John Doe',
contact_email='john@example.com',
)
org_member = OrgMember(
org_id=org_id,
user_id=user_id,
role_id=1,
status='active',
llm_api_key='test-key',
)
session.add_all([role, user, org, org_member])
session.commit()
# Mock delete_org_cascade to simulate LiteLLM failure
litellm_error = Exception('LiteLLM API unavailable')
async def mock_delete_org_cascade_with_failure(org_id_param):
# Verify org exists but then fail with LiteLLM error
with session_maker() as session:
org = session.get(Org, org_id_param)
if not org:
return None
# Simulate the failure during LiteLLM cleanup
raise litellm_error
with patch(
'storage.org_store.OrgStore.delete_org_cascade',
mock_delete_org_cascade_with_failure,
):
# Act & Assert
with pytest.raises(Exception) as exc_info:
await OrgStore.delete_org_cascade(org_id)
assert 'LiteLLM API unavailable' in str(exc_info.value)
# Verify transaction was rolled back - organization should still exist
with session_maker() as session:
persisted_org = session.get(Org, org_id)
assert persisted_org is not None
assert persisted_org.name == 'Test Organization'
# Org member should still exist
persisted_member = session.query(OrgMember).filter_by(org_id=org_id).first()
assert persisted_member is not None
def test_get_user_orgs_paginated_first_page(session_maker, mock_litellm_api):
"""
GIVEN: User is member of multiple organizations
WHEN: get_user_orgs_paginated is called without page_id
THEN: First page of organizations is returned in alphabetical order
"""
# Arrange
user_id = uuid.uuid4()
other_user_id = uuid.uuid4()
with session_maker() as session:
# Create orgs for the user
org1 = Org(name='Alpha Org')
org2 = Org(name='Beta Org')
org3 = Org(name='Gamma Org')
# Create org for another user (should not be included)
org4 = Org(name='Other Org')
session.add_all([org1, org2, org3, org4])
session.flush()
# Create user and role
user = User(id=user_id, current_org_id=org1.id)
other_user = User(id=other_user_id, current_org_id=org4.id)
role = Role(id=1, name='member', rank=2)
session.add_all([user, other_user, role])
session.flush()
# Create memberships
member1 = OrgMember(
org_id=org1.id, user_id=user_id, role_id=1, llm_api_key='key1'
)
member2 = OrgMember(
org_id=org2.id, user_id=user_id, role_id=1, llm_api_key='key2'
)
member3 = OrgMember(
org_id=org3.id, user_id=user_id, role_id=1, llm_api_key='key3'
)
other_member = OrgMember(
org_id=org4.id, user_id=other_user_id, role_id=1, llm_api_key='key4'
)
session.add_all([member1, member2, member3, other_member])
session.commit()
# Act
with patch('storage.org_store.session_maker', session_maker):
orgs, next_page_id = OrgStore.get_user_orgs_paginated(
user_id=user_id, page_id=None, limit=2
)
# Assert
assert len(orgs) == 2
assert orgs[0].name == 'Alpha Org'
assert orgs[1].name == 'Beta Org'
assert next_page_id == '2' # Has more results
# Verify other user's org is not included
org_names = [org.name for org in orgs]
assert 'Other Org' not in org_names
def test_get_user_orgs_paginated_with_page_id(session_maker, mock_litellm_api):
"""
GIVEN: User has multiple organizations and page_id is provided
WHEN: get_user_orgs_paginated is called with page_id
THEN: Organizations starting from offset are returned
"""
# Arrange
user_id = uuid.uuid4()
with session_maker() as session:
org1 = Org(name='Alpha Org')
org2 = Org(name='Beta Org')
org3 = Org(name='Gamma Org')
session.add_all([org1, org2, org3])
session.flush()
user = User(id=user_id, current_org_id=org1.id)
role = Role(id=1, name='member', rank=2)
session.add_all([user, role])
session.flush()
member1 = OrgMember(
org_id=org1.id, user_id=user_id, role_id=1, llm_api_key='key1'
)
member2 = OrgMember(
org_id=org2.id, user_id=user_id, role_id=1, llm_api_key='key2'
)
member3 = OrgMember(
org_id=org3.id, user_id=user_id, role_id=1, llm_api_key='key3'
)
session.add_all([member1, member2, member3])
session.commit()
# Act
with patch('storage.org_store.session_maker', session_maker):
orgs, next_page_id = OrgStore.get_user_orgs_paginated(
user_id=user_id, page_id='1', limit=1
)
# Assert
assert len(orgs) == 1
assert orgs[0].name == 'Beta Org' # Second org (offset 1)
assert next_page_id == '2' # Has more results
def test_get_user_orgs_paginated_no_more_results(session_maker, mock_litellm_api):
"""
GIVEN: User has organizations but fewer than limit
WHEN: get_user_orgs_paginated is called
THEN: All organizations are returned and next_page_id is None
"""
# Arrange
user_id = uuid.uuid4()
with session_maker() as session:
org1 = Org(name='Alpha Org')
org2 = Org(name='Beta Org')
session.add_all([org1, org2])
session.flush()
user = User(id=user_id, current_org_id=org1.id)
role = Role(id=1, name='member', rank=2)
session.add_all([user, role])
session.flush()
member1 = OrgMember(
org_id=org1.id, user_id=user_id, role_id=1, llm_api_key='key1'
)
member2 = OrgMember(
org_id=org2.id, user_id=user_id, role_id=1, llm_api_key='key2'
)
session.add_all([member1, member2])
session.commit()
# Act
with patch('storage.org_store.session_maker', session_maker):
orgs, next_page_id = OrgStore.get_user_orgs_paginated(
user_id=user_id, page_id=None, limit=10
)
# Assert
assert len(orgs) == 2
assert next_page_id is None
def test_get_user_orgs_paginated_invalid_page_id(session_maker, mock_litellm_api):
"""
GIVEN: Invalid page_id (non-numeric string)
WHEN: get_user_orgs_paginated is called
THEN: Results start from beginning (offset 0)
"""
# Arrange
user_id = uuid.uuid4()
with session_maker() as session:
org1 = Org(name='Alpha Org')
session.add(org1)
session.flush()
user = User(id=user_id, current_org_id=org1.id)
role = Role(id=1, name='member', rank=2)
session.add_all([user, role])
session.flush()
member1 = OrgMember(
org_id=org1.id, user_id=user_id, role_id=1, llm_api_key='key1'
)
session.add(member1)
session.commit()
# Act
with patch('storage.org_store.session_maker', session_maker):
orgs, next_page_id = OrgStore.get_user_orgs_paginated(
user_id=user_id, page_id='invalid', limit=10
)
# Assert
assert len(orgs) == 1
assert orgs[0].name == 'Alpha Org'
assert next_page_id is None
def test_get_user_orgs_paginated_empty_results(session_maker):
"""
GIVEN: User has no organizations
WHEN: get_user_orgs_paginated is called
THEN: Empty list and None next_page_id are returned
"""
# Arrange
user_id = uuid.uuid4()
# Act
with patch('storage.org_store.session_maker', session_maker):
orgs, next_page_id = OrgStore.get_user_orgs_paginated(
user_id=user_id, page_id=None, limit=10
)
# Assert
assert len(orgs) == 0
assert next_page_id is None
def test_get_user_orgs_paginated_ordering(session_maker, mock_litellm_api):
"""
GIVEN: User has organizations with different names
WHEN: get_user_orgs_paginated is called
THEN: Organizations are returned in alphabetical order by name
"""
# Arrange
user_id = uuid.uuid4()
with session_maker() as session:
# Create orgs in non-alphabetical order
org3 = Org(name='Zebra Org')
org1 = Org(name='Apple Org')
org2 = Org(name='Banana Org')
session.add_all([org3, org1, org2])
session.flush()
user = User(id=user_id, current_org_id=org1.id)
role = Role(id=1, name='member', rank=2)
session.add_all([user, role])
session.flush()
member1 = OrgMember(
org_id=org1.id, user_id=user_id, role_id=1, llm_api_key='key1'
)
member2 = OrgMember(
org_id=org2.id, user_id=user_id, role_id=1, llm_api_key='key2'
)
member3 = OrgMember(
org_id=org3.id, user_id=user_id, role_id=1, llm_api_key='key3'
)
session.add_all([member1, member2, member3])
session.commit()
# Act
with patch('storage.org_store.session_maker', session_maker):
orgs, _ = OrgStore.get_user_orgs_paginated(
user_id=user_id, page_id=None, limit=10
)
# Assert
assert len(orgs) == 3
assert orgs[0].name == 'Apple Org'
assert orgs[1].name == 'Banana Org'
assert orgs[2].name == 'Zebra Org'
+4 -91
View File
@@ -1,36 +1,9 @@
from unittest.mock import patch
import pytest
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.pool import StaticPool
from storage.base import Base
from storage.role import Role
from storage.role_store import RoleStore
@pytest.fixture
async def async_engine():
"""Create an async SQLite engine for testing."""
engine = create_async_engine(
'sqlite+aiosqlite:///:memory:',
poolclass=StaticPool,
connect_args={'check_same_thread': False},
echo=False,
)
# Create all tables
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield engine
await engine.dispose()
@pytest.fixture
async def async_session_maker(async_engine):
"""Create an async session maker for testing."""
return async_sessionmaker(async_engine, class_=AsyncSession, expire_on_commit=False)
# Mock the database module before importing RoleStore
with patch('storage.database.engine'), patch('storage.database.a_engine'):
from storage.role import Role
from storage.role_store import RoleStore
def test_get_role_by_id(session_maker):
@@ -108,63 +81,3 @@ def test_create_role(session_maker):
assert role.name == 'moderator'
assert role.rank == 2
assert role.id is not None
@pytest.mark.asyncio
async def test_get_role_by_name_async_with_session(async_session_maker):
"""Test getting role by name asynchronously with an explicit session."""
# Create a test role
async with async_session_maker() as session:
role = Role(name='admin', rank=1)
session.add(role)
await session.commit()
await session.refresh(role)
role_id = role.id
# Test retrieval with explicit session
async with async_session_maker() as session:
retrieved_role = await RoleStore.get_role_by_name_async(
'admin', session=session
)
assert retrieved_role is not None
assert retrieved_role.id == role_id
assert retrieved_role.name == 'admin'
assert retrieved_role.rank == 1
@pytest.mark.asyncio
async def test_get_role_by_name_async_without_session(async_session_maker):
"""Test getting role by name asynchronously using internal session maker."""
# Create a test role
async with async_session_maker() as session:
role = Role(name='editor', rank=2)
session.add(role)
await session.commit()
await session.refresh(role)
role_id = role.id
# Test retrieval without explicit session (using patched a_session_maker)
with patch('storage.role_store.a_session_maker', async_session_maker):
retrieved_role = await RoleStore.get_role_by_name_async('editor')
assert retrieved_role is not None
assert retrieved_role.id == role_id
assert retrieved_role.name == 'editor'
assert retrieved_role.rank == 2
@pytest.mark.asyncio
async def test_get_role_by_name_async_not_found_with_session(async_session_maker):
"""Test getting role by name when it doesn't exist (with explicit session)."""
async with async_session_maker() as session:
retrieved_role = await RoleStore.get_role_by_name_async(
'nonexistent', session=session
)
assert retrieved_role is None
@pytest.mark.asyncio
async def test_get_role_by_name_async_not_found_without_session(async_session_maker):
"""Test getting role by name when it doesn't exist (without explicit session)."""
with patch('storage.role_store.a_session_maker', async_session_maker):
retrieved_role = await RoleStore.get_role_by_name_async('nonexistent')
assert retrieved_role is None
@@ -2,7 +2,7 @@
from datetime import UTC, datetime
from typing import AsyncGenerator
from uuid import UUID, uuid4
from uuid import uuid4
import pytest
from server.sharing.shared_conversation_models import (
@@ -13,9 +13,6 @@ from server.sharing.sql_shared_conversation_info_service import (
)
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.pool import StaticPool
from storage.org import Org
from storage.stored_conversation_metadata_saas import StoredConversationMetadataSaas
from storage.user import User
from openhands.app_server.app_conversation.app_conversation_models import (
AppConversationInfo,
@@ -431,261 +428,3 @@ class TestSharedConversationInfoService:
page1_ids = {item.id for item in result.items}
page2_ids = {item.id for item in result2.items}
assert page1_ids.isdisjoint(page2_ids)
class TestSharedConversationInfoServiceWithSaasMetadata:
"""Test cases for SharedConversationInfoService with SAAS metadata.
These tests verify that created_by_user_id is correctly retrieved from
the conversation_metadata_saas table when it exists.
"""
@pytest.fixture
async def async_engine_with_saas(self):
"""Create an async SQLite engine with all SAAS tables."""
engine = create_async_engine(
'sqlite+aiosqlite:///:memory:',
poolclass=StaticPool,
connect_args={'check_same_thread': False},
echo=False,
)
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield engine
await engine.dispose()
@pytest.fixture
async def async_session_with_saas(
self, async_engine_with_saas
) -> AsyncGenerator[AsyncSession, None]:
"""Create an async session for testing with SAAS tables."""
async_session_maker = async_sessionmaker(
async_engine_with_saas, class_=AsyncSession, expire_on_commit=False
)
async with async_session_maker() as db_session:
yield db_session
@pytest.fixture
async def test_org(self, async_session_with_saas) -> Org:
"""Create a test organization."""
org = Org(id=uuid4(), name=f'test_org_{uuid4().hex[:8]}')
async_session_with_saas.add(org)
await async_session_with_saas.commit()
return org
@pytest.fixture
async def test_user(self, async_session_with_saas, test_org) -> User:
"""Create a test user belonging to the test organization."""
user = User(id=uuid4(), current_org_id=test_org.id)
async_session_with_saas.add(user)
await async_session_with_saas.commit()
return user
@pytest.fixture
async def shared_service_with_saas(self, async_session_with_saas):
"""Create a SharedConversationInfoService for testing."""
return SQLSharedConversationInfoService(db_session=async_session_with_saas)
@pytest.fixture
async def app_service_with_saas(self, async_session_with_saas):
"""Create an AppConversationInfoService for creating test data."""
return SQLAppConversationInfoService(
db_session=async_session_with_saas,
user_context=SpecifyUserContext(user_id=None),
)
async def _create_saas_metadata(
self,
db_session: AsyncSession,
conversation_id: UUID,
user_id: UUID,
org_id: UUID,
) -> StoredConversationMetadataSaas:
"""Helper to create SAAS metadata for a conversation."""
saas_metadata = StoredConversationMetadataSaas(
conversation_id=str(conversation_id),
user_id=user_id,
org_id=org_id,
)
db_session.add(saas_metadata)
await db_session.commit()
return saas_metadata
@pytest.mark.asyncio
async def test_get_shared_conversation_returns_user_id_from_saas_metadata(
self,
shared_service_with_saas,
app_service_with_saas,
async_session_with_saas,
test_user,
test_org,
):
"""Test that get_shared_conversation_info returns created_by_user_id from SAAS metadata."""
# Arrange
conversation_id = uuid4()
conversation = AppConversationInfo(
id=conversation_id,
created_by_user_id=None,
sandbox_id='test_sandbox',
title='Public Conversation With User',
public=True,
metrics=MetricsSnapshot(
accumulated_cost=0.0,
max_budget_per_task=10.0,
accumulated_token_usage=TokenUsage(),
),
)
await app_service_with_saas.save_app_conversation_info(conversation)
await self._create_saas_metadata(
async_session_with_saas, conversation_id, test_user.id, test_org.id
)
# Act
result = await shared_service_with_saas.get_shared_conversation_info(
conversation_id
)
# Assert
assert result is not None
assert result.created_by_user_id == str(test_user.id)
@pytest.mark.asyncio
async def test_search_shared_conversations_returns_user_id_from_saas_metadata(
self,
shared_service_with_saas,
app_service_with_saas,
async_session_with_saas,
test_user,
test_org,
):
"""Test that search_shared_conversation_info returns created_by_user_id from SAAS metadata."""
# Arrange
conversation_id = uuid4()
conversation = AppConversationInfo(
id=conversation_id,
created_by_user_id=None,
sandbox_id='test_sandbox_search',
title='Searchable Public Conversation',
public=True,
metrics=MetricsSnapshot(
accumulated_cost=0.0,
max_budget_per_task=10.0,
accumulated_token_usage=TokenUsage(),
),
)
await app_service_with_saas.save_app_conversation_info(conversation)
await self._create_saas_metadata(
async_session_with_saas, conversation_id, test_user.id, test_org.id
)
# Act
result = await shared_service_with_saas.search_shared_conversation_info()
# Assert
assert len(result.items) == 1
assert result.items[0].created_by_user_id == str(test_user.id)
@pytest.mark.asyncio
async def test_batch_get_shared_conversations_returns_user_id_from_saas_metadata(
self,
shared_service_with_saas,
app_service_with_saas,
async_session_with_saas,
test_user,
test_org,
):
"""Test that batch_get_shared_conversation_info returns created_by_user_id from SAAS metadata."""
# Arrange
conversation_id = uuid4()
conversation = AppConversationInfo(
id=conversation_id,
created_by_user_id=None,
sandbox_id='test_sandbox_batch',
title='Batch Get Conversation',
public=True,
metrics=MetricsSnapshot(
accumulated_cost=0.0,
max_budget_per_task=10.0,
accumulated_token_usage=TokenUsage(),
),
)
await app_service_with_saas.save_app_conversation_info(conversation)
await self._create_saas_metadata(
async_session_with_saas, conversation_id, test_user.id, test_org.id
)
# Act
result = await shared_service_with_saas.batch_get_shared_conversation_info(
[conversation_id]
)
# Assert
assert len(result) == 1
assert result[0] is not None
assert result[0].created_by_user_id == str(test_user.id)
@pytest.mark.asyncio
async def test_mixed_conversations_with_and_without_saas_metadata(
self,
shared_service_with_saas,
app_service_with_saas,
async_session_with_saas,
test_user,
test_org,
):
"""Test handling of conversations where some have SAAS metadata and some don't."""
# Arrange
conv_with_saas_id = uuid4()
conv_without_saas_id = uuid4()
conv_with_saas = AppConversationInfo(
id=conv_with_saas_id,
created_by_user_id=None,
sandbox_id='sandbox_with_saas',
title='With SAAS Metadata',
created_at=datetime(2023, 1, 2, tzinfo=UTC),
updated_at=datetime(2023, 1, 2, tzinfo=UTC),
public=True,
metrics=MetricsSnapshot(
accumulated_cost=0.0,
max_budget_per_task=10.0,
accumulated_token_usage=TokenUsage(),
),
)
conv_without_saas = AppConversationInfo(
id=conv_without_saas_id,
created_by_user_id=None,
sandbox_id='sandbox_without_saas',
title='Without SAAS Metadata',
created_at=datetime(2023, 1, 1, tzinfo=UTC),
updated_at=datetime(2023, 1, 1, tzinfo=UTC),
public=True,
metrics=MetricsSnapshot(
accumulated_cost=0.0,
max_budget_per_task=10.0,
accumulated_token_usage=TokenUsage(),
),
)
await app_service_with_saas.save_app_conversation_info(conv_with_saas)
await app_service_with_saas.save_app_conversation_info(conv_without_saas)
await self._create_saas_metadata(
async_session_with_saas, conv_with_saas_id, test_user.id, test_org.id
)
# Act
result = await shared_service_with_saas.search_shared_conversation_info(
sort_order=SharedConversationSortOrder.CREATED_AT
)
# Assert
assert len(result.items) == 2
conv_without = next(
item for item in result.items if item.id == conv_without_saas_id
)
conv_with = next(item for item in result.items if item.id == conv_with_saas_id)
assert conv_without.created_by_user_id is None
assert conv_with.created_by_user_id == str(test_user.id)
-12
View File
@@ -149,18 +149,6 @@ def test_infer_repo_from_message():
('https://github.com/My-User/My-Repo.git', ['My-User/My-Repo']),
('Check the my.user/my.repo repository', ['my.user/my.repo']),
('repos: user_1/repo-1 and user.2/repo_2', ['user_1/repo-1', 'user.2/repo_2']),
# Backtick-wrapped repo mentions (common in Slack/Discord messages)
(
'@openhands-exp just echo hello world in `OpenHands/OpenHands-CLI` repository',
['OpenHands/OpenHands-CLI'],
),
(
'@openhands-exp echo hello world with {{OpenHands/OpenHands-CLI}}',
['OpenHands/OpenHands-CLI'],
),
('Deploy the `test/project` repo', ['test/project']),
# Colon-wrapped repo mentions
('Check the :owner/repo: here', ['owner/repo']),
# Large number of repositories
('Repos: a/b, c/d, e/f, g/h, i/j', ['a/b', 'c/d', 'e/f', 'g/h', 'i/j']),
# Mixed with false positives that should be filtered
@@ -1,159 +0,0 @@
import { describe, expect, it, vi, beforeEach, afterEach } from "vitest";
import { screen } from "@testing-library/react";
import userEvent from "@testing-library/user-event";
import { renderWithProviders } from "test-utils";
import { PlanPreview } from "#/components/features/chat/plan-preview";
// Mock the feature flag to always return true (not testing feature flag behavior)
vi.mock("#/utils/feature-flags", () => ({
USE_PLANNING_AGENT: vi.fn(() => true),
}));
// Mock i18n - need to preserve initReactI18next and I18nextProvider for test-utils
vi.mock("react-i18next", async (importOriginal) => {
const actual = await importOriginal<typeof import("react-i18next")>();
return {
...actual,
useTranslation: () => ({
t: (key: string) => key,
}),
};
});
describe("PlanPreview", () => {
beforeEach(() => {
vi.clearAllMocks();
});
afterEach(() => {
vi.clearAllMocks();
});
it("should render nothing when planContent is null", () => {
renderWithProviders(<PlanPreview planContent={null} />);
const contentDiv = screen.getByTestId("plan-preview-content");
expect(contentDiv).toBeInTheDocument();
expect(contentDiv.textContent?.trim() || "").toBe("");
});
it("should render nothing when planContent is undefined", () => {
renderWithProviders(<PlanPreview planContent={undefined} />);
const contentDiv = screen.getByTestId("plan-preview-content");
expect(contentDiv).toBeInTheDocument();
expect(contentDiv.textContent?.trim() || "").toBe("");
});
it("should render markdown content when planContent is provided", () => {
const planContent = "# Plan Title\n\nThis is the plan content.";
const { container } = renderWithProviders(
<PlanPreview planContent={planContent} />,
);
// Check that component rendered and contains the content (markdown may break up text)
expect(container.firstChild).not.toBeNull();
expect(container.textContent).toContain("Plan Title");
expect(container.textContent).toContain("This is the plan content.");
});
it("should render full content when length is less than or equal to 300 characters", () => {
const planContent = "A".repeat(300);
const { container } = renderWithProviders(
<PlanPreview planContent={planContent} />,
);
// Content should be present (may be broken up by markdown)
expect(container.textContent).toContain(planContent);
expect(screen.queryByText(/COMMON\$READ_MORE/i)).not.toBeInTheDocument();
});
it("should truncate content when length exceeds 300 characters", () => {
const longContent = "A".repeat(350);
const { container } = renderWithProviders(
<PlanPreview planContent={longContent} />,
);
// Truncated content should be present (may be broken up by markdown)
expect(container.textContent).toContain("A".repeat(300));
expect(container.textContent).toContain("...");
expect(container.textContent).toContain("COMMON$READ_MORE");
});
it("should call onViewClick when View button is clicked", async () => {
const user = userEvent.setup();
const onViewClick = vi.fn();
renderWithProviders(
<PlanPreview planContent="Plan content" onViewClick={onViewClick} />,
);
const viewButton = screen.getByTestId("plan-preview-view-button");
expect(viewButton).toBeInTheDocument();
await user.click(viewButton);
expect(onViewClick).toHaveBeenCalledTimes(1);
});
it("should call onViewClick when Read More button is clicked", async () => {
const user = userEvent.setup();
const onViewClick = vi.fn();
const longContent = "A".repeat(350);
renderWithProviders(
<PlanPreview planContent={longContent} onViewClick={onViewClick} />,
);
const readMoreButton = screen.getByTestId("plan-preview-read-more-button");
expect(readMoreButton).toBeInTheDocument();
await user.click(readMoreButton);
expect(onViewClick).toHaveBeenCalledTimes(1);
});
it("should call onBuildClick when Build button is clicked", async () => {
const user = userEvent.setup();
const onBuildClick = vi.fn();
renderWithProviders(
<PlanPreview planContent="Plan content" onBuildClick={onBuildClick} />,
);
const buildButton = screen.getByTestId("plan-preview-build-button");
expect(buildButton).toBeInTheDocument();
await user.click(buildButton);
expect(onBuildClick).toHaveBeenCalledTimes(1);
});
it("should render header with PLAN_MD text", () => {
const { container } = renderWithProviders(
<PlanPreview planContent="Plan content" />,
);
// Check that the translation key is rendered (i18n mock returns the key)
expect(container.textContent).toContain("COMMON$PLAN_MD");
});
it("should render plan content", () => {
const planContent = `# Heading 1
## Heading 2
- List item 1
- List item 2
**Bold text** and *italic text*`;
const { container } = renderWithProviders(
<PlanPreview planContent={planContent} />,
);
expect(container.textContent).toContain("Heading 1");
expect(container.textContent).toContain("Heading 2");
});
});
@@ -1,186 +0,0 @@
import { render, screen, waitFor } from "@testing-library/react";
import { describe, expect, vi, beforeEach, it } from "vitest";
import { QueryClient, QueryClientProvider } from "@tanstack/react-query";
import userEvent from "@testing-library/user-event";
import { GitBranchDropdown } from "../../../../src/components/features/home/git-branch-dropdown/git-branch-dropdown";
import { Branch } from "#/types/git";
// Mock the branch data hook
const mockUseBranchData = vi.fn();
vi.mock("#/hooks/query/use-branch-data", () => ({
useBranchData: (...args: unknown[]) => mockUseBranchData(...args),
}));
const MOCK_BRANCHES: Branch[] = [
{ name: "main", commit_sha: "abc123", protected: true },
{ name: "develop", commit_sha: "def456", protected: false },
{ name: "feature/test", commit_sha: "ghi789", protected: false },
];
const mockOnBranchSelect = vi.fn();
const renderDropdown = (
props: Partial<Parameters<typeof GitBranchDropdown>[0]> = {},
) => {
// Default mock return value
mockUseBranchData.mockReturnValue({
branches: MOCK_BRANCHES,
isLoading: false,
isError: false,
fetchNextPage: vi.fn(),
hasNextPage: false,
isFetchingNextPage: false,
isSearchLoading: false,
});
return render(
<GitBranchDropdown
repository="user/repo"
provider="github"
selectedBranch={null}
onBranchSelect={mockOnBranchSelect}
// eslint-disable-next-line react/jsx-props-no-spreading
{...props}
/>,
{
wrapper: ({ children }) => (
<QueryClientProvider
client={
new QueryClient({
defaultOptions: {
queries: {
retry: false,
},
},
})
}
>
{children}
</QueryClientProvider>
),
},
);
};
describe("GitBranchDropdown", () => {
beforeEach(() => {
vi.clearAllMocks();
});
describe("dropdown behavior", () => {
it("should open dropdown when input is clicked", async () => {
renderDropdown();
const input = screen.getByTestId("git-branch-dropdown-input");
await userEvent.click(input);
// Dropdown should be open (menu should be visible)
await waitFor(() => {
expect(
screen.getByTestId("git-branch-dropdown-menu"),
).toBeInTheDocument();
});
});
it("should keep dropdown open when clicking input while already open", async () => {
renderDropdown();
const input = screen.getByTestId("git-branch-dropdown-input");
// First click - open dropdown
await userEvent.click(input);
await waitFor(() => {
expect(
screen.getByTestId("git-branch-dropdown-menu"),
).toBeInTheDocument();
});
// Second click on input - should stay open (not close)
await userEvent.click(input);
// Dropdown should still be open
await waitFor(() => {
expect(
screen.getByTestId("git-branch-dropdown-menu"),
).toBeInTheDocument();
});
});
it("should preserve typed text when clicking input while typing", async () => {
renderDropdown();
const input = screen.getByTestId(
"git-branch-dropdown-input",
) as HTMLInputElement;
// Click to open and type
await userEvent.click(input);
await userEvent.type(input, "feat");
expect(input.value).toBe("feat");
// Click on input again (should not reset text)
await userEvent.click(input);
// Text should be preserved
expect(input.value).toBe("feat");
});
});
describe("cursor position preservation", () => {
it("should allow editing in the middle of input text", async () => {
renderDropdown();
const input = screen.getByTestId(
"git-branch-dropdown-input",
) as HTMLInputElement;
// Click and type initial text
await userEvent.click(input);
await userEvent.type(input, "hello");
expect(input.value).toBe("hello");
// Move cursor to position 2 and type
input.setSelectionRange(2, 2);
await userEvent.type(input, "X");
// The character should be inserted (exact position may vary based on browser behavior)
expect(input.value).toContain("X");
});
});
describe("input synchronization", () => {
it("should show selected branch name in input when provided", async () => {
const selectedBranch = MOCK_BRANCHES[0];
renderDropdown({ selectedBranch });
const input = screen.getByTestId(
"git-branch-dropdown-input",
) as HTMLInputElement;
await waitFor(() => {
expect(input.value).toBe(selectedBranch.name);
});
});
});
describe("branch selection", () => {
it("should call onBranchSelect when a branch is selected", async () => {
renderDropdown();
const input = screen.getByTestId("git-branch-dropdown-input");
await userEvent.click(input);
// Wait for dropdown to open and show branches
await waitFor(() => {
expect(screen.getByText("main")).toBeInTheDocument();
});
// Click on a branch
await userEvent.click(screen.getByText("develop"));
expect(mockOnBranchSelect).toHaveBeenCalledWith(MOCK_BRANCHES[1]);
});
});
});
@@ -1,234 +0,0 @@
import { render, screen, waitFor } from "@testing-library/react";
import { describe, expect, vi, beforeEach, it } from "vitest";
import { QueryClient, QueryClientProvider } from "@tanstack/react-query";
import userEvent from "@testing-library/user-event";
import { GitRepoDropdown } from "../../../../src/components/features/home/git-repo-dropdown/git-repo-dropdown";
import { GitRepository } from "#/types/git";
// Mock the repository data hook
const mockUseRepositoryData = vi.fn();
vi.mock(
"#/components/features/home/git-repo-dropdown/use-repository-data",
() => ({
useRepositoryData: (...args: unknown[]) => mockUseRepositoryData(...args),
}),
);
// Mock the URL search hook
const mockUseUrlSearch = vi.fn();
vi.mock("#/components/features/home/git-repo-dropdown/use-url-search", () => ({
useUrlSearch: (...args: unknown[]) => mockUseUrlSearch(...args),
}));
// Mock useConfig
vi.mock("#/hooks/query/use-config", () => ({
useConfig: () => ({ data: null }),
}));
// Mock useHomeStore
vi.mock("#/stores/home-store", () => ({
useHomeStore: () => ({ recentRepositories: [] }),
}));
const MOCK_REPOSITORIES: GitRepository[] = [
{
id: "1",
full_name: "user/repo-one",
git_provider: "github",
is_public: true,
},
{
id: "2",
full_name: "user/repo-two",
git_provider: "github",
is_public: true,
},
{
id: "3",
full_name: "org/feature-repo",
git_provider: "github",
is_public: false,
},
];
const mockOnChange = vi.fn();
const setupDefaultMocks = (
repositoryDataOverrides: Partial<
ReturnType<typeof mockUseRepositoryData>
> = {},
) => {
mockUseRepositoryData.mockReturnValue({
repositories: MOCK_REPOSITORIES,
selectedRepository: null,
isLoading: false,
isError: false,
fetchNextPage: vi.fn(),
hasNextPage: false,
isFetchingNextPage: false,
isSearchLoading: false,
...repositoryDataOverrides,
});
mockUseUrlSearch.mockReturnValue({
urlSearchResults: [],
isUrlSearchLoading: false,
});
};
const renderDropdown = (
props: Partial<Parameters<typeof GitRepoDropdown>[0]> = {},
repositoryDataOverrides: Partial<
ReturnType<typeof mockUseRepositoryData>
> = {},
) => {
// Set up mocks with optional overrides
setupDefaultMocks(repositoryDataOverrides);
return render(
<GitRepoDropdown
provider="github"
onChange={mockOnChange}
// eslint-disable-next-line react/jsx-props-no-spreading
{...props}
/>,
{
wrapper: ({ children }) => (
<QueryClientProvider
client={
new QueryClient({
defaultOptions: {
queries: {
retry: false,
},
},
})
}
>
{children}
</QueryClientProvider>
),
},
);
};
describe("GitRepoDropdown", () => {
beforeEach(() => {
vi.clearAllMocks();
});
describe("dropdown behavior", () => {
it("should open dropdown when input is clicked", async () => {
renderDropdown();
const input = screen.getByTestId("git-repo-dropdown");
await userEvent.click(input);
// Dropdown should be open (menu should be visible)
await waitFor(() => {
expect(
screen.getByTestId("git-repo-dropdown-menu"),
).toBeInTheDocument();
});
});
it("should keep dropdown open when clicking input while already open", async () => {
renderDropdown();
const input = screen.getByTestId("git-repo-dropdown");
// First click - open dropdown
await userEvent.click(input);
await waitFor(() => {
expect(
screen.getByTestId("git-repo-dropdown-menu"),
).toBeInTheDocument();
});
// Second click on input - should stay open (not close)
await userEvent.click(input);
// Dropdown should still be open
await waitFor(() => {
expect(
screen.getByTestId("git-repo-dropdown-menu"),
).toBeInTheDocument();
});
});
it("should preserve typed text when clicking input while typing", async () => {
renderDropdown();
const input = screen.getByTestId("git-repo-dropdown") as HTMLInputElement;
// Click to open and type
await userEvent.click(input);
await userEvent.type(input, "repo");
expect(input.value).toBe("repo");
// Click on input again (should not reset text)
await userEvent.click(input);
// Text should be preserved
expect(input.value).toBe("repo");
});
});
describe("cursor position preservation", () => {
it("should allow editing in the middle of input text", async () => {
renderDropdown();
const input = screen.getByTestId("git-repo-dropdown") as HTMLInputElement;
// Click and type initial text
await userEvent.click(input);
await userEvent.type(input, "hello");
expect(input.value).toBe("hello");
// Move cursor to position 2 and type
input.setSelectionRange(2, 2);
await userEvent.type(input, "X");
// The character should be inserted (exact position may vary based on browser behavior)
expect(input.value).toContain("X");
});
});
describe("input synchronization", () => {
it("should show selected repository name in input when provided", async () => {
const selectedRepository = MOCK_REPOSITORIES[0];
renderDropdown(
{ value: selectedRepository.full_name },
{ selectedRepository },
);
const input = screen.getByTestId("git-repo-dropdown") as HTMLInputElement;
await waitFor(() => {
expect(input.value).toBe(selectedRepository.full_name);
});
});
});
describe("repository selection", () => {
it("should call onChange when a repository is selected", async () => {
renderDropdown();
const input = screen.getByTestId("git-repo-dropdown");
await userEvent.click(input);
// Wait for dropdown to open and show repositories
await waitFor(() => {
expect(screen.getByText("user/repo-one")).toBeInTheDocument();
});
// Click on a repository
await userEvent.click(screen.getByText("user/repo-two"));
expect(mockOnChange).toHaveBeenCalledWith(MOCK_REPOSITORIES[1]);
});
});
});
@@ -1,35 +0,0 @@
import { describe, expect, it } from "vitest";
import { shouldRenderEvent } from "#/components/v1/chat/event-content-helpers/should-render-event";
import {
createPlanningFileEditorActionEvent,
createOtherActionEvent,
createPlanningObservationEvent,
createUserMessageEvent,
} from "test-utils";
describe("shouldRenderEvent - PlanningFileEditorAction", () => {
it("should return false for PlanningFileEditorAction", () => {
const event = createPlanningFileEditorActionEvent("action-1");
expect(shouldRenderEvent(event)).toBe(false);
});
it("should return true for other action types", () => {
const event = createOtherActionEvent("action-1");
expect(shouldRenderEvent(event)).toBe(true);
});
it("should return true for PlanningFileEditorObservation", () => {
const event = createPlanningObservationEvent("obs-1");
// Observations should still render (they're handled separately in event-message)
expect(shouldRenderEvent(event)).toBe(true);
});
it("should return true for user message events", () => {
const event = createUserMessageEvent("msg-1");
expect(shouldRenderEvent(event)).toBe(true);
});
});
@@ -1,159 +0,0 @@
import { describe, expect, it, vi, beforeEach, afterEach } from "vitest";
import { screen, render } from "@testing-library/react";
import { EventMessage } from "#/components/v1/chat/event-message";
import { useConversationStore } from "#/stores/conversation-store";
import {
renderWithProviders,
createPlanningObservationEvent,
} from "test-utils";
// Mock the feature flag
vi.mock("#/utils/feature-flags", () => ({
USE_PLANNING_AGENT: vi.fn(() => true),
}));
// Mock useConfig
vi.mock("#/hooks/query/use-config", () => ({
useConfig: () => ({
data: { APP_MODE: "saas" },
}),
}));
// Mock PlanPreview component to verify it's rendered
vi.mock("#/components/features/chat/plan-preview", () => ({
PlanPreview: ({ planContent }: { planContent?: string | null }) => (
<div data-testid="plan-preview">Plan Preview: {planContent || "null"}</div>
),
}));
describe("EventMessage - PlanPreview rendering", () => {
beforeEach(() => {
vi.clearAllMocks();
// Reset conversation store
useConversationStore.setState({
planContent: null,
});
});
afterEach(() => {
vi.clearAllMocks();
});
it("should render PlanPreview when PlanningFileEditorObservation event ID is in planPreviewEventIds", () => {
const event = createPlanningObservationEvent("plan-obs-1");
const planPreviewEventIds = new Set(["plan-obs-1"]);
const planContent = "This is the plan content";
useConversationStore.setState({ planContent });
renderWithProviders(
<EventMessage
event={event}
messages={[]}
isLastMessage={false}
isInLast10Actions={false}
planPreviewEventIds={planPreviewEventIds}
/>,
);
expect(screen.getByTestId("plan-preview")).toBeInTheDocument();
expect(
screen.getByText(`Plan Preview: ${planContent}`),
).toBeInTheDocument();
});
it("should return null when PlanningFileEditorObservation event ID is NOT in planPreviewEventIds", () => {
const event = createPlanningObservationEvent("plan-obs-1");
const planPreviewEventIds = new Set(["plan-obs-2"]); // Different ID
const { container } = renderWithProviders(
<EventMessage
event={event}
messages={[]}
isLastMessage={false}
isInLast10Actions={false}
planPreviewEventIds={planPreviewEventIds}
/>,
);
expect(screen.queryByTestId("plan-preview")).not.toBeInTheDocument();
expect(container.firstChild).toBeNull();
});
it("should return null when planPreviewEventIds is undefined", () => {
const event = createPlanningObservationEvent("plan-obs-1");
const { container } = renderWithProviders(
<EventMessage
event={event}
messages={[]}
isLastMessage={false}
isInLast10Actions={false}
planPreviewEventIds={undefined}
/>,
);
expect(screen.queryByTestId("plan-preview")).not.toBeInTheDocument();
expect(container.firstChild).toBeNull();
});
it("should use planContent from conversation store", () => {
const event = createPlanningObservationEvent("plan-obs-1");
const planPreviewEventIds = new Set(["plan-obs-1"]);
const planContent = "Store plan content";
useConversationStore.setState({ planContent });
renderWithProviders(
<EventMessage
event={event}
messages={[]}
isLastMessage={false}
isInLast10Actions={false}
planPreviewEventIds={planPreviewEventIds}
/>,
);
expect(
screen.getByText(`Plan Preview: ${planContent}`),
).toBeInTheDocument();
});
it("should handle null planContent from store", () => {
const event = createPlanningObservationEvent("plan-obs-1");
const planPreviewEventIds = new Set(["plan-obs-1"]);
useConversationStore.setState({ planContent: null });
renderWithProviders(
<EventMessage
event={event}
messages={[]}
isLastMessage={false}
isInLast10Actions={false}
planPreviewEventIds={planPreviewEventIds}
/>,
);
expect(screen.getByTestId("plan-preview")).toBeInTheDocument();
expect(screen.getByText("Plan Preview: null")).toBeInTheDocument();
});
it("should handle empty planPreviewEventIds set", () => {
const event = createPlanningObservationEvent("plan-obs-1");
const planPreviewEventIds = new Set<string>();
const { container } = renderWithProviders(
<EventMessage
event={event}
messages={[]}
isLastMessage={false}
isInLast10Actions={false}
planPreviewEventIds={planPreviewEventIds}
/>,
);
expect(screen.queryByTestId("plan-preview")).not.toBeInTheDocument();
expect(container.firstChild).toBeNull();
});
});
@@ -1,195 +0,0 @@
import { renderHook } from "@testing-library/react";
import { describe, expect, it } from "vitest";
import {
usePlanPreviewEvents,
shouldShowPlanPreview,
} from "#/components/v1/chat/hooks/use-plan-preview-events";
import {
OpenHandsEvent,
MessageEvent,
ObservationEvent,
PlanningFileEditorObservation,
} from "#/types/v1/core";
// Helper to create a user message event
const createUserMessageEvent = (id: string): MessageEvent => ({
id,
timestamp: new Date().toISOString(),
source: "user",
llm_message: {
role: "user",
content: [{ type: "text", text: "User message" }],
},
activated_microagents: [],
extended_content: [],
});
// Helper to create a PlanningFileEditorObservation event
const createPlanningObservationEvent = (
id: string,
actionId: string = "action-1",
): ObservationEvent<PlanningFileEditorObservation> => ({
id,
timestamp: new Date().toISOString(),
source: "environment",
tool_name: "planning_file_editor",
tool_call_id: "call-1",
action_id: actionId,
observation: {
kind: "PlanningFileEditorObservation",
content: [{ type: "text", text: "Plan content" }],
is_error: false,
command: "create",
path: "/workspace/PLAN.md",
prev_exist: false,
old_content: null,
new_content: "Plan content",
},
});
// Helper to create a non-planning observation event
const createOtherObservationEvent = (id: string): ObservationEvent => ({
id,
timestamp: new Date().toISOString(),
source: "environment",
tool_name: "execute_bash",
tool_call_id: "call-1",
action_id: "action-1",
observation: {
kind: "ExecuteBashObservation",
content: [{ type: "text", text: "output" }],
command: "echo test",
exit_code: 0,
error: false,
timeout: false,
metadata: {
exit_code: 0,
pid: 12345,
username: "user",
hostname: "localhost",
working_dir: "/home/user",
py_interpreter_path: null,
prefix: "",
suffix: "",
},
},
});
describe("usePlanPreviewEvents", () => {
it("should return empty set when no events provided", () => {
const { result } = renderHook(() => usePlanPreviewEvents([]));
expect(result.current).toBeInstanceOf(Set);
expect(result.current.size).toBe(0);
});
it("should return empty set when no PlanningFileEditorObservation events exist", () => {
const events: OpenHandsEvent[] = [
createUserMessageEvent("user-1"),
createOtherObservationEvent("obs-1"),
];
const { result } = renderHook(() => usePlanPreviewEvents(events));
expect(result.current.size).toBe(0);
});
it("should return event ID for single PlanningFileEditorObservation in one phase", () => {
const events: OpenHandsEvent[] = [
createUserMessageEvent("user-1"),
createPlanningObservationEvent("plan-obs-1"),
];
const { result } = renderHook(() => usePlanPreviewEvents(events));
expect(result.current.size).toBe(1);
expect(result.current.has("plan-obs-1")).toBe(true);
});
it("should return only the last PlanningFileEditorObservation when multiple exist in one phase", () => {
const events: OpenHandsEvent[] = [
createUserMessageEvent("user-1"),
createPlanningObservationEvent("plan-obs-1"),
createPlanningObservationEvent("plan-obs-2"),
createPlanningObservationEvent("plan-obs-3"),
createOtherObservationEvent("other-obs-1"),
];
const { result } = renderHook(() => usePlanPreviewEvents(events));
// Should only include the last one in the phase
expect(result.current.size).toBe(1);
expect(result.current.has("plan-obs-1")).toBe(false);
expect(result.current.has("plan-obs-2")).toBe(false);
expect(result.current.has("plan-obs-3")).toBe(true);
});
it("should return one event ID per phase when multiple phases exist", () => {
const events: OpenHandsEvent[] = [
createUserMessageEvent("user-1"),
createPlanningObservationEvent("plan-obs-1"),
createPlanningObservationEvent("plan-obs-2"),
createUserMessageEvent("user-2"),
createPlanningObservationEvent("plan-obs-3"),
createPlanningObservationEvent("plan-obs-4"),
];
const { result } = renderHook(() => usePlanPreviewEvents(events));
// Should have one preview per phase (last observation in each phase)
expect(result.current.size).toBe(2);
expect(result.current.has("plan-obs-2")).toBe(true); // Last in phase 1
expect(result.current.has("plan-obs-4")).toBe(true); // Last in phase 2
expect(result.current.has("plan-obs-1")).toBe(false);
expect(result.current.has("plan-obs-3")).toBe(false);
});
it("should handle phase with no PlanningFileEditorObservation", () => {
const events: OpenHandsEvent[] = [
createUserMessageEvent("user-1"),
createOtherObservationEvent("obs-1"),
createUserMessageEvent("user-2"),
createPlanningObservationEvent("plan-obs-1"),
];
const { result } = renderHook(() => usePlanPreviewEvents(events));
// Only phase 2 has a planning observation
expect(result.current.size).toBe(1);
expect(result.current.has("plan-obs-1")).toBe(true);
});
it("should handle events starting with non-user message", () => {
const events: OpenHandsEvent[] = [
createOtherObservationEvent("obs-1"),
createUserMessageEvent("user-1"),
createPlanningObservationEvent("plan-obs-1"),
];
const { result } = renderHook(() => usePlanPreviewEvents(events));
// Events before first user message should be in first phase
expect(result.current.size).toBe(1);
expect(result.current.has("plan-obs-1")).toBe(true);
});
});
describe("shouldShowPlanPreview", () => {
it("should return true when event ID is in the set", () => {
const planPreviewEventIds = new Set(["event-1", "event-2", "event-3"]);
expect(shouldShowPlanPreview("event-2", planPreviewEventIds)).toBe(true);
});
it("should return false when event ID is not in the set", () => {
const planPreviewEventIds = new Set(["event-1", "event-2"]);
expect(shouldShowPlanPreview("event-3", planPreviewEventIds)).toBe(false);
});
it("should return false when set is empty", () => {
const planPreviewEventIds = new Set<string>();
expect(shouldShowPlanPreview("event-1", planPreviewEventIds)).toBe(false);
});
});
@@ -40,18 +40,6 @@ import { conversationWebSocketTestSetup } from "./helpers/msw-websocket-setup";
import { useEventStore } from "#/stores/use-event-store";
import { isV1Event } from "#/types/v1/type-guards";
// Mock useUserConversation to return V1 conversation data
vi.mock("#/hooks/query/use-user-conversation", () => ({
useUserConversation: vi.fn(() => ({
data: {
conversation_version: "V1",
status: "RUNNING",
},
isLoading: false,
error: null,
})),
}));
// MSW WebSocket mock setup
const { wsLink, server: mswServer } = conversationWebSocketTestSetup();
@@ -679,16 +667,6 @@ describe("Conversation WebSocket Handler", () => {
// Set up MSW to mock both the HTTP API and WebSocket connection
mswServer.use(
// Mock events search for history preloading
http.get(
`http://localhost:3000/api/v1/conversation/${conversationId}/events/search`,
async () => {
await new Promise((resolve) => setTimeout(resolve, 10));
return HttpResponse.json({
items: mockHistoryEvents,
});
},
),
http.get(
`http://localhost:3000/api/conversations/${conversationId}/events/count`,
() => HttpResponse.json(expectedEventCount),
@@ -725,6 +703,11 @@ describe("Conversation WebSocket Handler", () => {
`http://localhost:3000/api/conversations/${conversationId}`,
);
// Initially should be loading history
expect(screen.getByTestId("is-loading-history")).toHaveTextContent(
"true",
);
// Wait for all events to be received
await waitFor(() => {
expect(screen.getByTestId("events-received")).toHaveTextContent("3");
@@ -743,14 +726,6 @@ describe("Conversation WebSocket Handler", () => {
// Set up MSW to mock both the HTTP API and WebSocket connection
mswServer.use(
// Mock empty events search
http.get(
`http://localhost:3000/api/v1/conversation/${conversationId}/events/search`,
() =>
HttpResponse.json({
items: [],
}),
),
http.get(
`http://localhost:3000/api/conversations/${conversationId}/events/count`,
() => HttpResponse.json(0),
@@ -800,16 +775,6 @@ describe("Conversation WebSocket Handler", () => {
// Set up MSW to mock both the HTTP API and WebSocket connection
mswServer.use(
// Mock events search for history preloading (50 events)
http.get(
`http://localhost:3000/api/v1/conversation/${conversationId}/events/search`,
async () => {
await new Promise((resolve) => setTimeout(resolve, 10));
return HttpResponse.json({
items: mockHistoryEvents,
});
},
),
http.get(
`http://localhost:3000/api/conversations/${conversationId}/events/count`,
() => HttpResponse.json(expectedEventCount),
@@ -845,6 +810,11 @@ describe("Conversation WebSocket Handler", () => {
`http://localhost:3000/api/conversations/${conversationId}`,
);
// Initially should be loading history
expect(screen.getByTestId("is-loading-history")).toHaveTextContent(
"true",
);
// Wait for all events to be received
await waitFor(() => {
expect(screen.getByTestId("events-received")).toHaveTextContent("50");
@@ -1,114 +0,0 @@
import { describe, it, expect, afterEach, vi } from "vitest";
import React from "react";
import { renderHook, waitFor } from "@testing-library/react";
import { QueryClient, QueryClientProvider } from "@tanstack/react-query";
import { useConversationHistory } from "#/hooks/query/use-conversation-history";
import EventService from "#/api/event-service/event-service.api";
import { useUserConversation } from "#/hooks/query/use-user-conversation";
import type { Conversation } from "#/api/open-hands.types";
import type { OpenHandsEvent } from "#/types/v1/core";
function makeConversation(version: "V0" | "V1"): Conversation {
return {
conversation_id: "conv-test",
title: "Test Conversation",
selected_repository: null,
selected_branch: null,
git_provider: null,
last_updated_at: new Date().toISOString(),
created_at: new Date().toISOString(),
status: "RUNNING",
runtime_status: null,
url: null,
session_api_key: null,
conversation_version: version,
};
}
function makeEvent(): OpenHandsEvent {
return {
id: "evt-1",
} as OpenHandsEvent;
}
// --------------------
// Mocks
// --------------------
vi.mock("#/api/open-hands-axios", () => ({
openHands: {
get: vi.fn(),
},
}));
vi.mock("#/api/event-service/event-service.api");
vi.mock("#/hooks/query/use-user-conversation");
const queryClient = new QueryClient();
function wrapper({ children }: { children: React.ReactNode }) {
return (
<QueryClientProvider client={queryClient}>{children}</QueryClientProvider>
);
}
// --------------------
// Tests
// --------------------
describe("useConversationHistory", () => {
afterEach(() => {
vi.clearAllMocks();
});
it("calls V1 REST endpoint for V1 conversations", async () => {
const v1SearchEventsSpy = vi.spyOn(EventService, "searchEventsV1");
vi.mocked(useUserConversation).mockReturnValue({
data: makeConversation("V1"),
isLoading: false,
isPending: false,
isError: false,
error: null,
refetch: vi.fn(),
} as any);
v1SearchEventsSpy.mockResolvedValue([makeEvent()]);
const { result } = renderHook(() => useConversationHistory("conv-123"), {
wrapper,
});
await waitFor(() => {
expect(result.current.data).toBeDefined();
});
expect(EventService.searchEventsV1).toHaveBeenCalledWith("conv-123");
expect(EventService.searchEventsV0).not.toHaveBeenCalled();
});
it("calls V0 REST endpoint for V0 conversations", async () => {
const v0SearchEventsSpy = vi.spyOn(EventService, "searchEventsV0");
vi.mocked(useUserConversation).mockReturnValue({
data: makeConversation("V0"),
isLoading: false,
isPending: false,
isError: false,
error: null,
refetch: vi.fn(),
} as any);
v0SearchEventsSpy.mockResolvedValue([makeEvent()]);
const { result } = renderHook(() => useConversationHistory("conv-456"), {
wrapper,
});
await waitFor(() => {
expect(result.current.data).toBeDefined();
});
expect(EventService.searchEventsV0).toHaveBeenCalledWith("conv-456");
expect(EventService.searchEventsV1).not.toHaveBeenCalled();
});
});
+8 -61
View File
@@ -7,19 +7,14 @@ import LoginPage from "#/routes/login";
import OptionService from "#/api/option-service/option-service.api";
import AuthService from "#/api/auth-service/auth-service.api";
const { useEmailVerificationMock, resendEmailVerificationMock } = vi.hoisted(
() => ({
useEmailVerificationMock: vi.fn(() => ({
emailVerified: false,
hasDuplicatedEmail: false,
emailVerificationModalOpen: false,
setEmailVerificationModalOpen: vi.fn(),
userId: null as string | null,
resendEmailVerification: vi.fn(),
})),
resendEmailVerificationMock: vi.fn(),
}),
);
const { useEmailVerificationMock } = vi.hoisted(() => ({
useEmailVerificationMock: vi.fn(() => ({
emailVerified: false,
hasDuplicatedEmail: false,
emailVerificationModalOpen: false,
setEmailVerificationModalOpen: vi.fn(),
})),
}));
vi.mock("#/hooks/use-github-auth-url", () => ({
useGitHubAuthUrl: () => "https://github.com/login/oauth/authorize",
@@ -353,8 +348,6 @@ describe("LoginPage", () => {
hasDuplicatedEmail: false,
emailVerificationModalOpen: false,
setEmailVerificationModalOpen: vi.fn(),
userId: null,
resendEmailVerification: resendEmailVerificationMock,
});
render(<RouterStub initialEntries={["/login"]} />, {
@@ -374,8 +367,6 @@ describe("LoginPage", () => {
hasDuplicatedEmail: true,
emailVerificationModalOpen: false,
setEmailVerificationModalOpen: vi.fn(),
userId: null,
resendEmailVerification: resendEmailVerificationMock,
});
render(<RouterStub initialEntries={["/login"]} />, {
@@ -388,41 +379,6 @@ describe("LoginPage", () => {
).toBeInTheDocument();
});
});
it("should pass userId to EmailVerificationModal when userId is provided", async () => {
const user = userEvent.setup();
const testUserId = "test-user-id-123";
const setEmailVerificationModalOpen = vi.fn();
useEmailVerificationMock.mockReturnValue({
emailVerified: false,
hasDuplicatedEmail: false,
emailVerificationModalOpen: true,
setEmailVerificationModalOpen,
userId: testUserId,
resendEmailVerification: resendEmailVerificationMock,
});
render(<RouterStub initialEntries={["/login"]} />, {
wrapper: createWrapper(),
});
await waitFor(() => {
expect(
screen.getByText("AUTH$PLEASE_CHECK_EMAIL_TO_VERIFY"),
).toBeInTheDocument();
});
const resendButton = screen.getByRole("button", {
name: /SETTINGS\$RESEND_VERIFICATION/i,
});
await user.click(resendButton);
expect(resendEmailVerificationMock).toHaveBeenCalledWith({
userId: testUserId,
isAuthFlow: true,
});
});
});
describe("Loading States", () => {
@@ -459,15 +415,6 @@ describe("LoginPage", () => {
describe("Terms and Privacy", () => {
it("should display Terms and Privacy notice", async () => {
useEmailVerificationMock.mockReturnValue({
emailVerified: false,
hasDuplicatedEmail: false,
emailVerificationModalOpen: false,
setEmailVerificationModalOpen: vi.fn(),
userId: null as string | null,
resendEmailVerification: resendEmailVerificationMock,
});
render(<RouterStub initialEntries={["/login"]} />, {
wrapper: createWrapper(),
});
@@ -48,7 +48,6 @@ function LoginStub() {
searchParams.get("email_verification_required") === "true";
const emailVerified = searchParams.get("email_verified") === "true";
const emailVerificationText = "AUTH$PLEASE_CHECK_EMAIL_TO_VERIFY";
const returnTo = searchParams.get("returnTo");
return (
<div data-testid="login-page">
@@ -59,7 +58,6 @@ function LoginStub() {
{emailVerificationText}
</div>
)}
{returnTo && <div data-testid="return-to-param">{returnTo}</div>}
</div>
</div>
);
@@ -102,27 +100,6 @@ const RouterStubWithLogin = createRoutesStub([
},
]);
const RouterStubWithDeviceVerify = createRoutesStub([
{
Component: MainApp,
path: "/",
children: [
{
Component: () => <div data-testid="outlet-content" />,
path: "/",
},
{
Component: () => <div data-testid="device-verify-page" />,
path: "/oauth/device/verify",
},
],
},
{
Component: LoginStub,
path: "/login",
},
]);
const renderMainApp = (initialEntries: string[] = ["/"]) =>
render(<RouterStub initialEntries={initialEntries} />, {
wrapper: ({ children }) => (
@@ -334,23 +311,5 @@ describe("MainApp", () => {
{ timeout: 2000 },
);
});
it("should preserve query parameters in returnTo when redirecting to login", async () => {
renderWithLoginStub(RouterStubWithDeviceVerify, [
"/oauth/device/verify?user_code=F9XN6BKU",
]);
await waitFor(
() => {
expect(screen.getByTestId("login-page")).toBeInTheDocument();
const returnToElement = screen.getByTestId("return-to-param");
expect(returnToElement).toBeInTheDocument();
expect(returnToElement.textContent).toBe(
"/oauth/device/verify?user_code=F9XN6BKU",
);
},
{ timeout: 2000 },
);
});
});
});
@@ -138,72 +138,4 @@ describe("handleEventForUI", () => {
anotherActionEvent,
]);
});
it("should NOT replace ThinkAction with ThinkObservation", () => {
const mockThinkAction: ActionEvent = {
id: "test-think-action-1",
timestamp: Date.now().toString(),
source: "agent",
thought: [{ type: "text", text: "I am thinking..." }],
thinking_blocks: [],
action: {
kind: "ThinkAction",
thought: "I am thinking...",
},
tool_name: "think",
tool_call_id: "call_think_1",
tool_call: {
id: "call_think_1",
type: "function",
function: {
name: "think",
arguments: "",
},
},
llm_response_id: "response_think",
security_risk: SecurityRisk.UNKNOWN,
};
const mockThinkObservation: ObservationEvent = {
id: "test-think-observation-1",
timestamp: Date.now().toString(),
source: "environment",
tool_name: "think",
tool_call_id: "call_think_1",
observation: {
kind: "ThinkObservation",
content: [{ type: "text", text: "Your thought has been logged." }],
},
action_id: "test-think-action-1",
};
const initialUiEvents = [mockMessageEvent, mockThinkAction];
const result = handleEventForUI(mockThinkObservation, initialUiEvents);
// ThinkObservation should NOT be added - ThinkAction should remain
expect(result).toEqual([mockMessageEvent, mockThinkAction]);
expect(result).not.toBe(initialUiEvents);
});
it("should NOT add ThinkObservation even when ThinkAction is not found", () => {
const mockThinkObservation: ObservationEvent = {
id: "test-think-observation-1",
timestamp: Date.now().toString(),
source: "environment",
tool_name: "think",
tool_call_id: "call_think_1",
observation: {
kind: "ThinkObservation",
content: [{ type: "text", text: "Your thought has been logged." }],
},
action_id: "test-think-action-not-found",
};
const initialUiEvents = [mockMessageEvent];
const result = handleEventForUI(mockThinkObservation, initialUiEvents);
// ThinkObservation should never be added to uiEvents
expect(result).toEqual([mockMessageEvent]);
expect(result).not.toBe(initialUiEvents);
});
});
@@ -103,7 +103,7 @@ export interface V1AppConversation {
export interface Skill {
name: string;
type: "repo" | "knowledge" | "agentskills";
type: "repo" | "knowledge";
content: string;
triggers: string[];
}
@@ -5,8 +5,6 @@ import type {
ConfirmationResponseRequest,
ConfirmationResponseResponse,
} from "./event-service.types";
import { openHands } from "../open-hands-axios";
import { OpenHandsEvent } from "#/types/v1/core";
class EventService {
/**
@@ -63,27 +61,5 @@ class EventService {
);
return data;
}
// V1 conversations — App Server REST endpoint
static async searchEventsV1(conversationId: string, limit = 100) {
const { data } = await openHands.get<{
items: OpenHandsEvent[];
}>(`/api/v1/conversation/${conversationId}/events/search`, {
params: { limit },
});
return data.items;
}
// V0 conversations — Legacy REST endpoint
static async searchEventsV0(conversationId: string, limit = 100) {
const { data } = await openHands.get<{
events: OpenHandsEvent[];
}>(`/api/conversations/${conversationId}/events`, {
params: { limit },
});
return data.events;
}
}
export default EventService;
+1 -1
View File
@@ -110,7 +110,7 @@ export interface InputMetadata {
export interface Microagent {
name: string;
type: "repo" | "knowledge" | "agentskills";
type: "repo" | "knowledge";
content: string;
triggers: string[];
}
@@ -1,24 +1,22 @@
import { useMemo } from "react";
import { useTranslation } from "react-i18next";
import { ArrowUpRight } from "lucide-react";
import LessonPlanIcon from "#/icons/lesson-plan.svg?react";
import { USE_PLANNING_AGENT } from "#/utils/feature-flags";
import { Typography } from "#/ui/typography";
import { I18nKey } from "#/i18n/declaration";
import { MarkdownRenderer } from "#/components/features/markdown/markdown-renderer";
const MAX_CONTENT_LENGTH = 300;
interface PlanPreviewProps {
/** Raw plan content from PLAN.md file */
planContent?: string | null;
title?: string;
description?: string;
onViewClick?: () => void;
onBuildClick?: () => void;
}
// TODO: Remove the hardcoded values and use the plan content from the conversation store
/* eslint-disable i18next/no-literal-string */
export function PlanPreview({
planContent,
title = "Improve Developer Onboarding and Examples",
description = "Based on the analysis of Browser-Use's current documentation and examples, this plan addresses gaps in developer onboarding by creating a progressive learning path, troubleshooting resources, and practical examples that address real-world scenarios (like the LM Studio/local LLM integration issues encountered...",
onViewClick,
onBuildClick,
}: PlanPreviewProps) {
@@ -26,13 +24,6 @@ export function PlanPreview({
const shouldUsePlanningAgent = USE_PLANNING_AGENT();
// Truncate plan content for preview
const truncatedContent = useMemo(() => {
if (!planContent) return "";
if (planContent.length <= MAX_CONTENT_LENGTH) return planContent;
return `${planContent.slice(0, MAX_CONTENT_LENGTH)}...`;
}, [planContent]);
if (!shouldUsePlanningAgent) {
return null;
}
@@ -50,7 +41,6 @@ export function PlanPreview({
type="button"
onClick={onViewClick}
className="flex items-center gap-1 hover:opacity-80 transition-opacity"
data-testid="plan-preview-view-button"
>
<Typography.Text className="font-medium text-[11px] text-white tracking-[0.11px] leading-4">
{t(I18nKey.COMMON$VIEW)}
@@ -60,27 +50,16 @@ export function PlanPreview({
</div>
{/* Content */}
<div
data-testid="plan-preview-content"
className="flex flex-col gap-[10px] p-4 text-[15px] text-white leading-[29px]"
>
{truncatedContent && (
<>
<MarkdownRenderer includeStandard includeHeadings>
{truncatedContent}
</MarkdownRenderer>
{planContent && planContent.length > MAX_CONTENT_LENGTH && (
<button
type="button"
onClick={onViewClick}
className="text-[#4a67bd] cursor-pointer hover:underline text-left"
data-testid="plan-preview-read-more-button"
>
{t(I18nKey.COMMON$READ_MORE)}
</button>
)}
</>
)}
<div className="flex flex-col gap-[10px] p-4">
<h3 className="font-bold text-[19px] text-white leading-[29px]">
{title}
</h3>
<p className="text-[15px] text-white leading-[29px]">
{description}
<Typography.Text className="text-[#4a67bd] cursor-pointer hover:underline ml-1">
{t(I18nKey.COMMON$READ_MORE)}
</Typography.Text>
</p>
</div>
{/* Footer */}
@@ -89,7 +68,6 @@ export function PlanPreview({
type="button"
onClick={onBuildClick}
className="bg-white flex items-center justify-center h-[26px] px-2 rounded-[4px] w-[93px] hover:opacity-90 transition-opacity cursor-pointer"
data-testid="plan-preview-build-button"
>
<Typography.Text className="font-medium text-[14px] text-black leading-5">
{t(I18nKey.COMMON$BUILD)}{" "}
@@ -11,15 +11,6 @@ interface SkillItemProps {
}
export function SkillItem({ skill, isExpanded, onToggle }: SkillItemProps) {
let skillTypeLabel: string;
if (skill.type === "repo") {
skillTypeLabel = "Repository";
} else if (skill.type === "knowledge") {
skillTypeLabel = "Knowledge";
} else {
skillTypeLabel = "AgentSkills";
}
return (
<div className="rounded-md overflow-hidden">
<button
@@ -34,7 +25,7 @@ export function SkillItem({ skill, isExpanded, onToggle }: SkillItemProps) {
</div>
<div className="flex items-center">
<Typography.Text className="px-2 py-1 text-xs rounded-full bg-gray-800 mr-2">
{skillTypeLabel}
{skill.type === "repo" ? "Repository" : "Knowledge"}
</Typography.Text>
<Typography.Text className="text-gray-300">
{isExpanded ? (
@@ -87,6 +87,16 @@ export function GitBranchDropdown({
[onBranchSelect],
);
// Handle input value change
const handleInputValueChange = useCallback(
({ inputValue: newInputValue }: { inputValue?: string }) => {
if (newInputValue !== undefined) {
setInputValue(newInputValue);
}
},
[],
);
// Handle menu scroll for infinite loading
const handleMenuScroll = useCallback(
(event: React.UIEvent<HTMLUListElement>) => {
@@ -118,14 +128,8 @@ export function GitBranchDropdown({
onSelectedItemChange: ({ selectedItem: newSelectedItem }) => {
handleBranchSelect(newSelectedItem || null);
},
onInputValueChange: handleInputValueChange,
inputValue,
// Override Downshift's default input-click behavior to avoid closing/reopening
// the menu, which would reset scroll position and break search continuity.
stateReducer: (state, actionAndChanges) =>
actionAndChanges.type === useCombobox.stateChangeTypes.InputClick &&
state.isOpen
? { ...actionAndChanges.changes, isOpen: true }
: actionAndChanges.changes,
});
// Reset branch selection when repository changes
@@ -172,12 +176,12 @@ export function GitBranchDropdown({
// Initialize input value when selectedBranch changes (but not when user is typing)
useEffect(() => {
if (selectedBranch && !isOpen) {
if (selectedBranch && !isOpen && inputValue !== selectedBranch.name) {
setInputValue(selectedBranch.name);
} else if (!selectedBranch && !isOpen) {
} else if (!selectedBranch && !isOpen && inputValue) {
setInputValue("");
}
}, [selectedBranch, isOpen]);
}, [selectedBranch, isOpen, inputValue]);
const isLoadingState = isLoading || isSearchLoading || isFetchingNextPage;
@@ -203,10 +207,6 @@ export function GitBranchDropdown({
"disabled:bg-[#363636] disabled:cursor-not-allowed disabled:opacity-60",
"pl-7 pr-16 text-sm font-normal leading-5", // Space for clear and toggle buttons
),
// Direct onChange for cursor position preservation
onChange: (e: React.ChangeEvent<HTMLInputElement>) => {
setInputValue(e.target.value);
},
})}
data-testid="git-branch-dropdown-input"
/>
@@ -184,6 +184,14 @@ export function GitRepoDropdown({
setInputValue("");
}, [handleSelectionChange]);
// Handle input value change
const handleInputValueChange = useCallback(
({ inputValue: newInputValue }: { inputValue?: string }) => {
setInputValue(newInputValue || "");
},
[],
);
// Handle scroll to bottom for pagination
const handleMenuScroll = useCallback(
(event: React.UIEvent<HTMLUListElement>) => {
@@ -212,14 +220,8 @@ export function GitRepoDropdown({
onSelectedItemChange: ({ selectedItem: newSelectedItem }) => {
handleSelectionChange(newSelectedItem);
},
onInputValueChange: handleInputValueChange,
inputValue,
// Override Downshift's default input-click behavior to avoid closing/reopening
// the menu, which would reset scroll position and break search continuity.
stateReducer: (state, actionAndChanges) =>
actionAndChanges.type === useCombobox.stateChangeTypes.InputClick &&
state.isOpen
? { ...actionAndChanges.changes, isOpen: true }
: actionAndChanges.changes,
});
// Sync localSelectedItem with external value prop
@@ -235,8 +237,6 @@ export function GitRepoDropdown({
useEffect(() => {
if (selectedRepository && !isOpen) {
setInputValue(selectedRepository.full_name);
} else if (!selectedRepository && !isOpen) {
setInputValue("");
}
}, [selectedRepository, isOpen]);
@@ -335,10 +335,6 @@ export function GitRepoDropdown({
"disabled:bg-[#363636] disabled:cursor-not-allowed disabled:opacity-60",
"pl-7 pr-16 text-sm font-normal leading-5", // Space for clear and toggle buttons
),
// Direct onChange for cursor position preservation
onChange: (e: React.ChangeEvent<HTMLInputElement>) => {
setInputValue(e.target.value);
},
})}
data-testid="git-repo-dropdown"
/>
@@ -27,11 +27,6 @@ export const shouldRenderEvent = (event: OpenHandsEvent) => {
return false;
}
// Hide PlanningFileEditorAction - handled separately with PlanPreview component
if (actionType === "PlanningFileEditorAction") {
return false;
}
return true;
}
@@ -6,11 +6,9 @@ import {
isObservationEvent,
isAgentErrorEvent,
isUserMessageEvent,
isPlanningFileEditorObservationEvent,
} from "#/types/v1/type-guards";
import { MicroagentStatus } from "#/types/microagent-status";
import { useConfig } from "#/hooks/query/use-config";
import { useConversationStore } from "#/stores/conversation-store";
// TODO: Implement V1 feedback functionality when API supports V1 event IDs
// import { useFeedbackExists } from "#/hooks/query/use-feedback-exists";
import {
@@ -21,8 +19,6 @@ import {
ThoughtEventMessage,
} from "./event-message-components";
import { createSkillReadyEvent } from "./event-content-helpers/create-skill-ready-event";
import { PlanPreview } from "../../features/chat/plan-preview";
import { shouldShowPlanPreview } from "./hooks/use-plan-preview-events";
interface EventMessageProps {
event: OpenHandsEvent & { isFromPlanningAgent?: boolean };
@@ -37,8 +33,6 @@ interface EventMessageProps {
tooltip?: string;
}>;
isInLast10Actions: boolean;
/** Set of event IDs that should render PlanPreview (one per user message phase) */
planPreviewEventIds?: Set<string>;
}
/**
@@ -149,10 +143,8 @@ export function EventMessage({
microagentPRUrl,
actions,
isInLast10Actions,
planPreviewEventIds,
}: EventMessageProps) {
const { data: config } = useConfig();
const { planContent } = useConversationStore();
// V1 events use string IDs, but useFeedbackExists expects number
// For now, we'll skip feedback functionality for V1 events
@@ -206,21 +198,6 @@ export function EventMessage({
// Observation events - find the corresponding action and render thought + observation
if (isObservationEvent(event)) {
// Handle PlanningFileEditorObservation specially
if (isPlanningFileEditorObservationEvent(event)) {
// Only show PlanPreview if this event is marked as the one to display
// (last PlanningFileEditorObservation in its phase)
if (
planPreviewEventIds &&
shouldShowPlanPreview(event.id, planPreviewEventIds)
) {
return <PlanPreview planContent={planContent} />;
}
// Not the designated preview event for this phase - render nothing
// This prevents duplicate previews within the same phase
return null;
}
// Find the action that this observation is responding to
const correspondingAction = messages.find(
(msg) => isActionEvent(msg) && msg.id === event.action_id,
@@ -1,114 +0,0 @@
import { useMemo } from "react";
import { OpenHandsEvent } from "#/types/v1/core";
import {
isUserMessageEvent,
isPlanningFileEditorObservationEvent,
} from "#/types/v1/type-guards";
/**
* Groups events into phases based on user messages.
* A phase starts with a user message and includes all subsequent events
* until the next user message.
*
* @param events - The full list of events
* @returns Array of phases, where each phase is an array of events
*/
function groupEventsByPhase(events: OpenHandsEvent[]): OpenHandsEvent[][] {
const phases: OpenHandsEvent[][] = [];
let currentPhase: OpenHandsEvent[] = [];
for (const event of events) {
if (isUserMessageEvent(event)) {
// Start a new phase with the user message
if (currentPhase.length > 0) {
phases.push(currentPhase);
}
currentPhase = [event];
} else {
// Add event to current phase
currentPhase.push(event);
}
}
// Don't forget the last phase
if (currentPhase.length > 0) {
phases.push(currentPhase);
}
return phases;
}
/**
* Finds the last PlanningFileEditorObservation in a phase.
*
* @param phase - Array of events in a phase
* @returns The event ID of the last PlanningFileEditorObservation, or null
*/
function findLastPlanningObservationInPhase(
phase: OpenHandsEvent[],
): string | null {
// Iterate backwards to find the last one
for (let i = phase.length - 1; i >= 0; i -= 1) {
const event = phase[i];
if (isPlanningFileEditorObservationEvent(event)) {
return event.id;
}
}
return null;
}
export interface PlanPreviewEventInfo {
eventId: string;
/** Index of this plan preview in the conversation (1st, 2nd, etc.) */
phaseIndex: number;
}
/**
* Hook to determine which PlanningFileEditorObservation events should render PlanPreview.
*
* This hook implements phase-based grouping where:
* - A phase starts with a user message and ends at the next user message
* - Only the LAST PlanningFileEditorObservation in each phase shows PlanPreview
* - This ensures only one preview per user request, even with multiple observations
*
* Scenario handling:
* - Scenario 1 (Create plan): Multiple observations in one phase 1 preview
* - Scenario 2 (Create then update): Two user messages two phases 2 previews
* - Scenario 3 (Create + update while processing): Two user messages 2 previews
*
* @param allEvents - Full list of v1 events (for phase detection)
* @returns Set of event IDs that should render PlanPreview
*/
export function usePlanPreviewEvents(allEvents: OpenHandsEvent[]): Set<string> {
return useMemo(() => {
const planPreviewEventIds = new Set<string>();
// Group events by phases (user message boundaries)
const phases = groupEventsByPhase(allEvents);
// For each phase, find the last PlanningFileEditorObservation
phases.forEach((phase) => {
const lastPlanningObservationId =
findLastPlanningObservationInPhase(phase);
if (lastPlanningObservationId) {
planPreviewEventIds.add(lastPlanningObservationId);
}
});
return planPreviewEventIds;
}, [allEvents]);
}
/**
* Check if a specific event should render PlanPreview.
*
* @param eventId - The event ID to check
* @param planPreviewEventIds - Set of event IDs that should render PlanPreview
* @returns true if this event should render PlanPreview
*/
export function shouldShowPlanPreview(
eventId: string,
planPreviewEventIds: Set<string>,
): boolean {
return planPreviewEventIds.has(eventId);
}
@@ -3,7 +3,6 @@ import { OpenHandsEvent } from "#/types/v1/core";
import { EventMessage } from "./event-message";
import { ChatMessage } from "../../features/chat/chat-message";
import { useOptimisticUserMessageStore } from "#/stores/optimistic-user-message-store";
import { usePlanPreviewEvents } from "./hooks/use-plan-preview-events";
// TODO: Implement microagent functionality for V1 when APIs support V1 event IDs
// import { AgentState } from "#/types/agent-state";
// import MemoryIcon from "#/icons/memory_icon.svg?react";
@@ -19,10 +18,6 @@ export const Messages: React.FC<MessagesProps> = React.memo(
const optimisticUserMessage = getOptimisticUserMessage();
// Get the set of event IDs that should render PlanPreview
// This ensures only one preview per user message "phase"
const planPreviewEventIds = usePlanPreviewEvents(allEvents);
// TODO: Implement microagent functionality for V1 if needed
// For now, we'll skip microagent features
@@ -35,7 +30,6 @@ export const Messages: React.FC<MessagesProps> = React.memo(
messages={allEvents}
isLastMessage={messages.length - 1 === index}
isInLast10Actions={messages.length - 1 - index < 10}
planPreviewEventIds={planPreviewEventIds}
// Microagent props - not implemented yet for V1
// microagentStatus={undefined}
// microagentConversationId={undefined}
@@ -46,7 +46,6 @@ import { useTracking } from "#/hooks/use-tracking";
import { useReadConversationFile } from "#/hooks/mutation/use-read-conversation-file";
import useMetricsStore from "#/stores/metrics-store";
import { I18nKey } from "#/i18n/declaration";
import { useConversationHistory } from "#/hooks/query/use-conversation-history";
// eslint-disable-next-line @typescript-eslint/naming-convention
export type V1_WebSocketConnectionState =
@@ -307,21 +306,6 @@ export function ConversationWebSocketProvider({
latestPlanningFileEventRef.current = null;
}, [conversationId]);
const { data: preloadedEvents } = useConversationHistory(conversationId);
useEffect(() => {
if (!preloadedEvents || preloadedEvents.length === 0) {
setIsLoadingHistoryMain(false);
return;
}
for (const event of preloadedEvents) {
addEvent(event);
}
setIsLoadingHistoryMain(false);
}, [preloadedEvents, addEvent]);
// Separate message handlers for each connection
const handleMainMessage = useCallback(
(messageEvent: MessageEvent) => {
@@ -1,22 +0,0 @@
import { useQuery } from "@tanstack/react-query";
import EventService from "#/api/event-service/event-service.api";
import { useUserConversation } from "#/hooks/query/use-user-conversation";
export const useConversationHistory = (conversationId?: string) => {
const { data: conversation } = useUserConversation(conversationId ?? null);
return useQuery({
queryKey: ["conversation-history", conversationId, conversation],
enabled: !!conversationId && !!conversation,
queryFn: async () => {
if (!conversationId || !conversation) return [];
if (conversation.conversation_version === "V1") {
return EventService.searchEventsV1(conversationId);
}
return EventService.searchEventsV0(conversationId);
},
staleTime: 30_000,
});
};
-2
View File
@@ -20,7 +20,6 @@ export default function LoginPage() {
recaptchaBlocked,
emailVerificationModalOpen,
setEmailVerificationModalOpen,
userId,
} = useEmailVerification();
const gitHubAuthUrl = useGitHubAuthUrl({
@@ -78,7 +77,6 @@ export default function LoginPage() {
onClose={() => {
setEmailVerificationModalOpen(false);
}}
userId={userId}
/>
)}
</>
+4 -11
View File
@@ -5,7 +5,6 @@ import {
Outlet,
useNavigate,
useLocation,
useSearchParams,
} from "react-router";
import { useTranslation } from "react-i18next";
import { I18nKey } from "#/i18n/declaration";
@@ -68,7 +67,6 @@ export default function MainApp() {
const appTitle = useAppTitle();
const navigate = useNavigate();
const { pathname } = useLocation();
const [searchParams] = useSearchParams();
const isOnTosPage = useIsOnTosPage();
const { data: settings } = useSettings();
const { migrateUserConsent } = useMigrateUserConsent();
@@ -184,18 +182,13 @@ export default function MainApp() {
React.useEffect(() => {
if (shouldRedirectToLogin) {
// Include search params in returnTo to preserve query string (e.g., user_code for device OAuth)
const searchString = searchParams.toString();
let fullPath = "";
if (pathname !== "/") {
fullPath = searchString ? `${pathname}?${searchString}` : pathname;
}
const loginUrl = fullPath
? `/login?returnTo=${encodeURIComponent(fullPath)}`
const returnTo = pathname !== "/" ? pathname : "";
const loginUrl = returnTo
? `/login?returnTo=${encodeURIComponent(returnTo)}`
: "/login";
navigate(loginUrl, { replace: true });
}
}, [shouldRedirectToLogin, pathname, searchParams, navigate]);
}, [shouldRedirectToLogin, pathname, navigate]);
if (shouldRedirectToLogin) {
return (
-32
View File
@@ -213,37 +213,6 @@ export interface BrowserCloseTabAction extends ActionBase<"BrowserCloseTabAction
tab_id: string;
}
export interface PlanningFileEditorAction extends ActionBase<"PlanningFileEditorAction"> {
/**
* The commands to run. Allowed options are: `view`, `create`, `str_replace`, `insert`, `undo_edit`.
*/
command: "view" | "create" | "str_replace" | "insert" | "undo_edit";
/**
* Absolute path to file (typically /workspace/project/PLAN.md).
*/
path: string;
/**
* Required parameter of `create` command, with the content of the file to be created.
*/
file_text: string | null;
/**
* Required parameter of `str_replace` command containing the string in `path` to replace.
*/
old_str: string | null;
/**
* Optional parameter of `str_replace` command containing the new string (if not given, no string will be added). Required parameter of `insert` command containing the string to insert.
*/
new_str: string | null;
/**
* Required parameter of `insert` command. The `new_str` will be inserted AFTER the line `insert_line` of `path`. Must be >= 1.
*/
insert_line: number | null;
/**
* Optional parameter of `view` command when `path` points to a file. If none is given, the full file is shown.
*/
view_range: [number, number] | null;
}
export type Action =
| MCPToolAction
| FinishAction
@@ -253,7 +222,6 @@ export type Action =
| FileEditorAction
| StrReplaceEditorAction
| TaskTrackerAction
| PlanningFileEditorAction
| BrowserNavigateAction
| BrowserClickAction
| BrowserTypeAction
+1 -7
View File
@@ -4,7 +4,6 @@ import { isObservationEvent } from "#/types/v1/type-guards";
/**
* Handles adding an event to the UI events array
* Replaces actions with observations when they arrive (so UI shows observation instead of action)
* Exception: ThinkAction is NOT replaced because the thought content is in the action, not in the observation
*/
export const handleEventForUI = (
event: OpenHandsEvent,
@@ -13,17 +12,12 @@ export const handleEventForUI = (
const newUiEvents = [...uiEvents];
if (isObservationEvent(event)) {
// Don't add ThinkObservation at all - we keep the ThinkAction instead
// The thought content is in the action, not the observation
if (event.observation.kind === "ThinkObservation") {
return newUiEvents;
}
// Find and replace the corresponding action from uiEvents
const actionIndex = newUiEvents.findIndex(
(uiEvent) => uiEvent.id === event.action_id,
);
if (actionIndex !== -1) {
// Replace the action with the observation
newUiEvents[actionIndex] = event;
} else {
// Action not found in uiEvents, just add the observation
-104
View File
@@ -7,13 +7,6 @@ import { I18nextProvider, initReactI18next } from "react-i18next";
import i18n from "i18next";
import { vi } from "vitest";
import { AxiosError } from "axios";
import {
ActionEvent,
MessageEvent,
ObservationEvent,
PlanningFileEditorObservation,
} from "#/types/v1/core";
import { SecurityRisk } from "#/types/v1/core";
export const useParamsMock = vi.fn(() => ({
conversationId: "test-conversation-id",
@@ -105,100 +98,3 @@ export const createAxiosError = (
config: {},
},
);
// Helper to create a PlanningFileEditorAction event
export const createPlanningFileEditorActionEvent = (
id: string,
): ActionEvent => ({
id,
timestamp: new Date().toISOString(),
source: "agent",
thought: [{ type: "text", text: "Planning action" }],
thinking_blocks: [],
action: {
kind: "PlanningFileEditorAction",
command: "create",
path: "/workspace/PLAN.md",
file_text: "Plan content",
old_str: null,
new_str: null,
insert_line: null,
view_range: null,
},
tool_name: "planning_file_editor",
tool_call_id: "call-1",
tool_call: {
id: "call-1",
type: "function",
function: {
name: "planning_file_editor",
arguments: '{"command": "create"}',
},
},
llm_response_id: "response-1",
security_risk: SecurityRisk.UNKNOWN,
});
// Helper to create a non-planning action event
export const createOtherActionEvent = (id: string): ActionEvent => ({
id,
timestamp: new Date().toISOString(),
source: "agent",
thought: [{ type: "text", text: "Other action" }],
thinking_blocks: [],
action: {
kind: "ExecuteBashAction",
command: "echo test",
is_input: false,
timeout: null,
reset: false,
},
tool_name: "execute_bash",
tool_call_id: "call-1",
tool_call: {
id: "call-1",
type: "function",
function: {
name: "execute_bash",
arguments: '{"command": "echo test"}',
},
},
llm_response_id: "response-1",
security_risk: SecurityRisk.UNKNOWN,
});
// Helper to create a PlanningFileEditorObservation event
export const createPlanningObservationEvent = (
id: string,
actionId: string = "action-1",
): ObservationEvent<PlanningFileEditorObservation> => ({
id,
timestamp: new Date().toISOString(),
source: "environment",
tool_name: "planning_file_editor",
tool_call_id: "call-1",
action_id: actionId,
observation: {
kind: "PlanningFileEditorObservation",
content: [{ type: "text", text: "Plan content" }],
is_error: false,
command: "create",
path: "/workspace/PLAN.md",
prev_exist: false,
old_content: null,
new_content: "Plan content",
},
});
// Helper to create a user message event
export const createUserMessageEvent = (id: string): MessageEvent => ({
id,
timestamp: new Date().toISOString(),
source: "user",
llm_message: {
role: "user",
content: [{ type: "text", text: "User message" }],
},
activated_microagents: [],
extended_content: [],
});
@@ -1,6 +1,6 @@
from datetime import datetime
from enum import Enum
from typing import Any, Literal
from typing import Literal
from uuid import UUID, uuid4
from pydantic import BaseModel, Field
@@ -14,7 +14,6 @@ from openhands.app_server.sandbox.sandbox_models import SandboxStatus
from openhands.integrations.service_types import ProviderType
from openhands.sdk.conversation.state import ConversationExecutionStatus
from openhands.sdk.llm import MetricsSnapshot
from openhands.sdk.plugin import PluginSource
from openhands.storage.data_models.conversation_metadata import ConversationTrigger
@@ -25,45 +24,6 @@ class AgentType(Enum):
PLAN = 'plan'
class PluginSpec(PluginSource):
"""Specification for loading a plugin into a conversation.
Extends SDK's PluginSource with user-provided plugin configuration parameters.
Inherits source, ref, and repo_path fields along with their validation.
"""
parameters: dict[str, Any] | None = Field(
default=None,
description='User-provided values for plugin input parameters',
)
@property
def display_name(self) -> str:
"""Extract a friendly display name from the plugin source.
Examples:
- 'github:owner/repo' -> 'repo'
- 'https://github.com/owner/repo.git' -> 'repo.git'
- '/local/path' -> 'path'
"""
return self.source.split('/')[-1] if '/' in self.source else self.source
def format_params_as_text(self, indent: str = '') -> str | None:
"""Format parameters as a readable text block for display.
Args:
indent: Optional prefix to add before each parameter line.
Returns:
Formatted parameters string, or None if no parameters.
"""
if not self.parameters:
return None
return '\n'.join(
f'{indent}- {key}: {value}' for key, value in self.parameters.items()
)
class AppConversationInfo(BaseModel):
"""Conversation info which does not contain status."""
@@ -158,15 +118,6 @@ class AppConversationStartRequest(OpenHandsModel):
public: bool | None = None
# Plugin parameters - for loading remote plugins into the conversation
plugins: list[PluginSpec] | None = Field(
default=None,
description=(
'List of plugins to load for this conversation. Plugins are loaded '
'and their skills/MCP config are merged into the agent.'
),
)
class AppConversationUpdateRequest(BaseModel):
public: bool
@@ -196,8 +147,7 @@ class AppConversationStartTask(OpenHandsModel):
Because starting an app conversation can be slow (And can involve starting a sandbox),
we kick off a background task for it. Once the conversation is started, the app_conversation_id
is populated.
"""
is populated."""
id: OpenHandsUUID = Field(default_factory=uuid4)
created_by_user_id: str | None
@@ -226,6 +176,6 @@ class SkillResponse(BaseModel):
"""Response model for skills endpoint."""
name: str
type: Literal['repo', 'knowledge', 'agentskills']
type: Literal['repo', 'knowledge']
content: str
triggers: list[str] = []
@@ -503,6 +503,13 @@ async def get_conversation_skills(
agent_server_url = replace_localhost_hostname_for_docker(agent_server_url)
# Create remote workspace
remote_workspace = AsyncRemoteWorkspace(
host=agent_server_url,
api_key=sandbox.session_api_key,
working_dir=sandbox_spec.working_dir,
)
# Load skills from all sources
logger.info(f'Loading skills for conversation {conversation_id}')
@@ -511,9 +518,9 @@ async def get_conversation_skills(
if isinstance(app_conversation_service, AppConversationServiceBase):
all_skills = await app_conversation_service.load_and_merge_all_skills(
sandbox,
remote_workspace,
conversation.selected_repository,
sandbox_spec.working_dir,
agent_server_url,
)
logger.info(
@@ -524,11 +531,9 @@ async def get_conversation_skills(
# Transform skills to response format
skills_response = []
for skill in all_skills:
# Determine type based on AgentSkills format and trigger
skill_type: Literal['repo', 'knowledge', 'agentskills']
if skill.is_agentskills_format:
skill_type = 'agentskills'
elif skill.trigger is None:
# Determine type based on trigger
skill_type: Literal['repo', 'knowledge']
if skill.trigger is None:
skill_type = 'repo'
else:
skill_type = 'knowledge'
@@ -621,6 +626,7 @@ async def _stream_app_conversation_start(
user_context: UserContext,
) -> AsyncGenerator[str, None]:
"""Stream a json list, item by item."""
# Because the original dependencies are closed after the method returns, we need
# a new dependency context which will continue intil the stream finishes.
state = InjectorState()
@@ -95,7 +95,6 @@ class AppConversationService(ABC):
task: AppConversationStartTask,
sandbox: SandboxInfo,
workspace: AsyncRemoteWorkspace,
agent_server_url: str,
) -> AsyncGenerator[AppConversationStartTask, None]:
"""Run the setup scripts for the project and yield status updates"""
yield task
@@ -105,8 +104,7 @@ class AppConversationService(ABC):
self, conversation_id: UUID, request: AppConversationUpdateRequest
) -> AppConversation | None:
"""Update an app conversation and return it. Return None if the conversation
did not exist.
"""
did not exist."""
@abstractmethod
async def delete_app_conversation(self, conversation_id: UUID) -> bool:
@@ -21,16 +21,18 @@ from openhands.app_server.app_conversation.app_conversation_service import (
AppConversationService,
)
from openhands.app_server.app_conversation.skill_loader import (
build_org_config,
build_sandbox_config,
load_skills_from_agent_server,
load_global_skills,
load_org_skills,
load_repo_skills,
load_sandbox_skills,
merge_skills,
)
from openhands.app_server.sandbox.sandbox_models import SandboxInfo
from openhands.app_server.user.user_context import UserContext
from openhands.sdk import Agent
from openhands.sdk.context.agent_context import AgentContext
from openhands.sdk.context.condenser import LLMSummarizingCondenser
from openhands.sdk.context.skills import Skill
from openhands.sdk.context.skills import load_user_skills
from openhands.sdk.llm import LLM
from openhands.sdk.security.analyzer import SecurityAnalyzerBase
from openhands.sdk.security.confirmation_policy import (
@@ -51,8 +53,7 @@ PRE_COMMIT_LOCAL = '.git/hooks/pre-commit.local'
class AppConversationServiceBase(AppConversationService, ABC):
"""App Conversation service which adds git specific functionality.
Sets up repositories and installs hooks
"""
Sets up repositories and installs hooks"""
init_git_in_empty_workspace: bool
user_context: UserContext
@@ -60,74 +61,67 @@ class AppConversationServiceBase(AppConversationService, ABC):
async def load_and_merge_all_skills(
self,
sandbox: SandboxInfo,
remote_workspace: AsyncRemoteWorkspace,
selected_repository: str | None,
working_dir: str,
agent_server_url: str,
) -> list[Skill]:
"""Load skills from all sources via the agent-server.
) -> list:
"""Load skills from all sources and merge them.
This method calls the agent-server's /api/skills endpoint to load and
merge skills from all sources. The agent-server handles:
- Public skills (from OpenHands/skills GitHub repo)
- User skills (from ~/.openhands/skills/)
- Organization skills (from {org}/.openhands repo)
- Project/repo skills (from workspace .openhands/skills/)
- Sandbox skills (from exposed URLs)
This method handles all errors gracefully and will return an empty list
if skill loading fails completely.
Args:
sandbox: SandboxInfo containing exposed URLs and agent-server URL
remote_workspace: AsyncRemoteWorkspace for loading repo skills
selected_repository: Repository name or None
working_dir: Working directory path
agent_server_url: Agent-server URL (required)
Returns:
List of merged Skill objects from all sources, or empty list on failure
"""
try:
_logger.debug('Loading skills for V1 conversation via agent-server')
_logger.debug('Loading skills for V1 conversation')
if not agent_server_url:
_logger.warning('No agent-server URL available, cannot load skills')
return []
# Load skills from all sources
sandbox_skills = load_sandbox_skills(sandbox)
global_skills = load_global_skills()
# Load user skills from ~/.openhands/skills/ directory
# Uses the SDK's load_user_skills() function which handles loading from
# ~/.openhands/skills/ and ~/.openhands/microagents/ (for backward compatibility)
try:
user_skills = load_user_skills()
_logger.info(
f'Loaded {len(user_skills)} user skills: {[s.name for s in user_skills]}'
)
except Exception as e:
_logger.warning(f'Failed to load user skills: {str(e)}')
user_skills = []
# Build org config (authentication handled by app-server)
org_config = await build_org_config(selected_repository, self.user_context)
# Load organization-level skills
org_skills = await load_org_skills(
remote_workspace, selected_repository, working_dir, self.user_context
)
# Build sandbox config (exposed URLs)
sandbox_config = build_sandbox_config(sandbox)
repo_skills = await load_repo_skills(
remote_workspace, selected_repository, working_dir
)
# Determine project directory for project skills
project_dir = working_dir
if selected_repository:
repo_name = selected_repository.split('/')[-1]
project_dir = f'{working_dir}/{repo_name}'
# Single API call to agent-server for ALL skills
all_skills = await load_skills_from_agent_server(
agent_server_url=agent_server_url,
session_api_key=sandbox.session_api_key,
project_dir=project_dir,
org_config=org_config,
sandbox_config=sandbox_config,
load_public=True,
load_user=True,
load_project=True,
load_org=True,
# Merge all skills (later lists override earlier ones)
# Precedence: sandbox < global < user < org < repo
all_skills = merge_skills(
[sandbox_skills, global_skills, user_skills, org_skills, repo_skills]
)
_logger.info(
f'Loaded {len(all_skills)} total skills from agent-server: '
f'{[s.name for s in all_skills]}'
f'Loaded {len(all_skills)} total skills: {[s.name for s in all_skills]}'
)
return all_skills
except Exception as e:
_logger.warning(f'Failed to load skills: {e}', exc_info=True)
# Return empty list on failure - skills will be loaded again later if needed
return []
def _create_agent_with_skills(self, agent, skills: list[Skill]):
def _create_agent_with_skills(self, agent, skills: list):
"""Create or update agent with skills in its context.
Args:
@@ -138,9 +132,9 @@ class AppConversationServiceBase(AppConversationService, ABC):
Updated agent with skills in context
"""
if agent.agent_context:
# Merge with existing context (new skills override existing ones)
# Merge with existing context
existing_skills = agent.agent_context.skills
all_skills = self._merge_skills([existing_skills, skills])
all_skills = merge_skills([skills, existing_skills])
agent = agent.model_copy(
update={
'agent_context': agent.agent_context.model_copy(
@@ -155,25 +149,6 @@ class AppConversationServiceBase(AppConversationService, ABC):
return agent
def _merge_skills(self, skill_lists: list[list[Skill]]) -> list[Skill]:
"""Merge multiple skill lists, avoiding duplicates by name.
Later lists take precedence over earlier lists for duplicate names.
Args:
skill_lists: List of skill lists to merge
Returns:
Deduplicated list of skills with later lists overriding earlier ones
"""
skills_by_name: dict[str, Skill] = {}
for skill_list in skill_lists:
for skill in skill_list:
skills_by_name[skill.name] = skill
return list(skills_by_name.values())
async def _load_skills_and_update_agent(
self,
sandbox: SandboxInfo,
@@ -194,10 +169,8 @@ class AppConversationServiceBase(AppConversationService, ABC):
Updated agent with skills loaded into context
"""
# Load and merge all skills
# Extract agent_server_url from remote_workspace host
agent_server_url = remote_workspace.host
all_skills = await self.load_and_merge_all_skills(
sandbox, selected_repository, working_dir, agent_server_url
sandbox, remote_workspace, selected_repository, working_dir
)
# Update agent with skills
@@ -210,7 +183,6 @@ class AppConversationServiceBase(AppConversationService, ABC):
task: AppConversationStartTask,
sandbox: SandboxInfo,
workspace: AsyncRemoteWorkspace,
agent_server_url: str,
) -> AsyncGenerator[AppConversationStartTask, None]:
task.status = AppConversationStartTaskStatus.PREPARING_REPOSITORY
yield task
@@ -228,9 +200,9 @@ class AppConversationServiceBase(AppConversationService, ABC):
yield task
await self.load_and_merge_all_skills(
sandbox,
workspace,
task.request.selected_repository,
workspace.working_dir,
agent_server_url,
)
async def _configure_git_user_settings(
@@ -485,6 +457,7 @@ class AppConversationServiceBase(AppConversationService, ABC):
security_analyzer_str: String value from settings
httpx_client: HTTP client for making API requests
"""
if session_api_key is None:
return
@@ -32,7 +32,6 @@ from openhands.app_server.app_conversation.app_conversation_models import (
AppConversationStartTask,
AppConversationStartTaskStatus,
AppConversationUpdateRequest,
PluginSpec,
)
from openhands.app_server.app_conversation.app_conversation_service import (
AppConversationService,
@@ -80,7 +79,6 @@ from openhands.experiments.experiment_manager import ExperimentManagerImpl
from openhands.integrations.provider import ProviderType
from openhands.sdk import Agent, AgentContext, LocalWorkspace
from openhands.sdk.llm import LLM
from openhands.sdk.plugin import PluginSource
from openhands.sdk.secret import LookupSecret, SecretValue, StaticSecret
from openhands.sdk.utils.paging import page_iterator
from openhands.sdk.workspace.remote.async_remote_workspace import AsyncRemoteWorkspace
@@ -239,7 +237,7 @@ class LiveStatusAppConversationService(AppConversationServiceBase):
working_dir=sandbox_spec.working_dir,
)
async for updated_task in self.run_setup_scripts(
task, sandbox, remote_workspace, agent_server_url
task, sandbox, remote_workspace
):
yield updated_task
@@ -256,7 +254,6 @@ class LiveStatusAppConversationService(AppConversationServiceBase):
request.conversation_id,
remote_workspace=remote_workspace,
selected_repository=request.selected_repository,
plugins=request.plugins,
)
)
@@ -957,79 +954,6 @@ class LiveStatusAppConversationService(AppConversationServiceBase):
return agent.model_copy(update=updates)
return agent
def _construct_initial_message_with_plugin_params(
self,
initial_message: SendMessageRequest | None,
plugins: list[PluginSpec] | None,
) -> SendMessageRequest | None:
"""Incorporate plugin parameters into the initial message if specified.
Plugin parameters are formatted and appended to the initial message so the
agent has context about the user-provided configuration values.
Args:
initial_message: The original initial message, if any
plugins: List of plugin specifications with optional parameters
Returns:
The initial message with plugin parameters incorporated, or the
original message if no plugin parameters are specified
"""
from openhands.agent_server.models import TextContent
if not plugins:
return initial_message
# Collect formatted parameters from plugins that have them
plugins_with_params = [p for p in plugins if p.parameters]
if not plugins_with_params:
return initial_message
# Format parameters, grouped by plugin if multiple
if len(plugins_with_params) == 1:
params_text = plugins_with_params[0].format_params_as_text()
plugin_params_message = (
f'\n\nPlugin Configuration Parameters:\n{params_text}'
)
else:
# Group by plugin name for clarity
formatted_plugins = []
for plugin in plugins_with_params:
params_text = plugin.format_params_as_text(indent=' ')
if params_text:
formatted_plugins.append(f'{plugin.display_name}:\n{params_text}')
plugin_params_message = (
'\n\nPlugin Configuration Parameters:\n' + '\n'.join(formatted_plugins)
)
if initial_message is None:
# Create a new message with just the plugin parameters
return SendMessageRequest(
content=[TextContent(text=plugin_params_message.strip())],
run=True,
)
# Append plugin parameters to existing message content
new_content = list(initial_message.content)
if new_content and isinstance(new_content[-1], TextContent):
# Append to the last text content
last_content = new_content[-1]
new_content[-1] = TextContent(
text=last_content.text + plugin_params_message,
cache_prompt=last_content.cache_prompt,
enable_truncation=last_content.enable_truncation,
)
else:
# Add as new text content
new_content.append(TextContent(text=plugin_params_message.strip()))
return SendMessageRequest(
role=initial_message.role,
content=new_content,
run=initial_message.run,
)
async def _finalize_conversation_request(
self,
agent: Agent,
@@ -1042,7 +966,6 @@ class LiveStatusAppConversationService(AppConversationServiceBase):
remote_workspace: AsyncRemoteWorkspace | None,
selected_repository: str | None,
working_dir: str,
plugins: list[PluginSpec] | None = None,
) -> StartConversationRequest:
"""Finalize the conversation request with experiment variants and skills.
@@ -1057,7 +980,6 @@ class LiveStatusAppConversationService(AppConversationServiceBase):
remote_workspace: Optional remote workspace for skills loading
selected_repository: Optional repository name
working_dir: Working directory path
plugins: Optional list of plugin specifications to load
Returns:
Complete StartConversationRequest ready for use
@@ -1084,23 +1006,6 @@ class LiveStatusAppConversationService(AppConversationServiceBase):
_logger.warning(f'Failed to load skills: {e}', exc_info=True)
# Continue without skills - don't fail conversation startup
# Incorporate plugin parameters into initial message if specified
final_initial_message = self._construct_initial_message_with_plugin_params(
initial_message, plugins
)
# Convert PluginSpec list to SDK PluginSource list for agent server
sdk_plugins: list[PluginSource] | None = None
if plugins:
sdk_plugins = [
PluginSource(
source=p.source,
ref=p.ref,
repo_path=p.repo_path,
)
for p in plugins
]
# Create and return the final request
return StartConversationRequest(
conversation_id=conversation_id,
@@ -1109,9 +1014,8 @@ class LiveStatusAppConversationService(AppConversationServiceBase):
confirmation_policy=self._select_confirmation_policy(
bool(user.confirmation_mode), user.security_analyzer
),
initial_message=final_initial_message,
initial_message=initial_message,
secrets=secrets,
plugins=sdk_plugins,
)
async def _build_start_conversation_request_for_user(
@@ -1126,7 +1030,6 @@ class LiveStatusAppConversationService(AppConversationServiceBase):
conversation_id: UUID | None = None,
remote_workspace: AsyncRemoteWorkspace | None = None,
selected_repository: str | None = None,
plugins: list[PluginSpec] | None = None,
) -> StartConversationRequest:
"""Build a complete conversation request for a user.
@@ -1135,7 +1038,6 @@ class LiveStatusAppConversationService(AppConversationServiceBase):
2. Configuring LLM and MCP settings
3. Creating an agent with appropriate context
4. Finalizing the request with skills and experiment variants
5. Passing plugins to the agent server for remote plugin loading
"""
user = await self.user_context.get_user_info()
workspace = LocalWorkspace(working_dir=working_dir)
@@ -1168,7 +1070,6 @@ class LiveStatusAppConversationService(AppConversationServiceBase):
remote_workspace,
selected_repository,
working_dir,
plugins=plugins,
)
async def update_agent_server_conversation_title(
@@ -1223,8 +1124,7 @@ class LiveStatusAppConversationService(AppConversationServiceBase):
self, conversation_id: UUID, request: AppConversationUpdateRequest
) -> AppConversation | None:
"""Update an app conversation and return it. Return None if the conversation
did not exist.
"""
did not exist."""
info = await self.app_conversation_info_service.get_app_conversation_info(
conversation_id
)
@@ -1395,7 +1295,7 @@ class LiveStatusAppConversationService(AppConversationServiceBase):
# Get all events for this conversation
i = 0
async for event in page_iterator(
self.event_service.search_events, conversation_id=conversation_id
self.event_service.search_events, conversation_id__eq=conversation_id
):
event_filename = f'event_{i:06d}_{event.id}.json'
event_path = os.path.join(temp_dir, event_filename)
@@ -1,37 +1,34 @@
"""Utilities for loading skills for V1 conversations.
This module provides functions to load skills from the agent-server,
which centralizes all skill loading logic. The app-server acts as a
thin proxy that:
1. Builds the org_config with authentication information
2. Builds the sandbox_config with exposed URLs
3. Calls the agent-server's /api/skills endpoint
This module provides functions to load skills from various sources:
- Global skills from OpenHands/skills/
- User skills from ~/.openhands/skills/
- Repository-level skills from the workspace
All source-specific skill loading is handled by the agent-server.
All skills are used in V1 conversations.
"""
import logging
import os
from pathlib import Path
import httpx
from pydantic import BaseModel
import openhands
from openhands.app_server.sandbox.sandbox_models import SandboxInfo
from openhands.app_server.user.user_context import UserContext
from openhands.integrations.provider import ProviderType
from openhands.integrations.service_types import AuthenticationError
from openhands.sdk.context.skills import Skill
from openhands.sdk.context.skills.trigger import KeywordTrigger, TaskTrigger
from openhands.sdk.workspace.remote.async_remote_workspace import AsyncRemoteWorkspace
_logger = logging.getLogger(__name__)
class ExposedUrlConfig(BaseModel):
"""Configuration for an exposed URL in sandbox config."""
name: str
url: str
port: int
# Path to global skills directory
GLOBAL_SKILLS_DIR = os.path.join(
os.path.dirname(os.path.dirname(openhands.__file__)),
'skills',
)
WORK_HOSTS_SKILL = """The user has access to the following hosts for accessing a web application,
each of which has a corresponding port:"""
WORK_HOSTS_SKILL_FOOTER = """
When starting a web server, use the corresponding ports via environment variables:
@@ -48,30 +45,96 @@ app.run(host='0.0.0.0', port=int(os.environ.get('WORKER_1', 12000)))
```"""
class SandboxConfig(BaseModel):
"""Sandbox configuration for agent-server API request."""
def _find_and_load_global_skill_files(skill_dir: Path) -> list[Skill]:
"""Find and load all .md files from the global skills directory.
exposed_urls: list[ExposedUrlConfig]
Args:
skill_dir: Path to the global skills directory
Returns:
List of Skill objects loaded from the files (excluding README.md)
"""
skills = []
try:
# Find all .md files in the directory (excluding README.md)
md_files = [f for f in skill_dir.glob('*.md') if f.name.lower() != 'readme.md']
# Load skills from the found files
for file_path in md_files:
try:
skill = Skill.load(file_path, skill_dir)
skills.append(skill)
_logger.debug(f'Loaded global skill: {skill.name} from {file_path}')
except Exception as e:
_logger.warning(
f'Failed to load global skill from {file_path}: {str(e)}'
)
except Exception as e:
_logger.debug(f'Failed to find global skill files: {str(e)}')
return skills
class OrgConfig(BaseModel):
"""Organization configuration for agent-server API request."""
repository: str
provider: str
org_repo_url: str
org_name: str
def load_sandbox_skills(sandbox: SandboxInfo) -> list[Skill]:
"""Load skills specific to the sandbox, including exposed ports / urls."""
if not sandbox.exposed_urls:
return []
urls = [url for url in sandbox.exposed_urls if url.name.startswith('WORKER_')]
if not urls:
return []
content_list = [WORK_HOSTS_SKILL]
for url in urls:
content_list.append(f'* {url.url} (port {url.port})')
content_list.append(WORK_HOSTS_SKILL_FOOTER)
content = '\n'.join(content_list)
return [Skill(name='work_hosts', content=content, trigger=None)]
class SkillInfo(BaseModel):
"""Skill information from agent-server API response."""
def load_global_skills() -> list[Skill]:
"""Load global skills from OpenHands/skills/ directory.
name: str
content: str
triggers: list[str] = []
source: str | None = None
description: str | None = None
is_agentskills_format: bool = False
Returns:
List of Skill objects loaded from global skills directory.
Returns empty list if directory doesn't exist or on errors.
"""
skill_dir = Path(GLOBAL_SKILLS_DIR)
# Check if directory exists
if not skill_dir.exists():
_logger.debug(f'Global skills directory does not exist: {skill_dir}')
return []
try:
_logger.info(f'Loading global skills from {skill_dir}')
# Find and load all .md files from the directory
skills = _find_and_load_global_skill_files(skill_dir)
_logger.info(f'Loaded {len(skills)} global skills: {[s.name for s in skills]}')
return skills
except Exception as e:
_logger.warning(f'Failed to load global skills: {str(e)}')
return []
def _determine_repo_root(working_dir: str, selected_repository: str | None) -> str:
"""Determine the repository root directory.
Args:
working_dir: Base working directory path
selected_repository: Repository name (e.g., 'owner/repo') or None
Returns:
Path to the repository root directory
"""
if selected_repository:
repo_name = selected_repository.split('/')[-1]
return f'{working_dir}/{repo_name}'
return working_dir
async def _is_gitlab_repository(repo_name: str, user_context: UserContext) -> bool:
@@ -91,6 +154,8 @@ async def _is_gitlab_repository(repo_name: str, user_context: UserContext) -> bo
)
return repository.git_provider == ProviderType.GITLAB
except Exception:
# If we can't determine the provider, assume it's not GitLab
# This is a safe fallback since we'll just use the default .openhands
return False
@@ -113,33 +178,10 @@ async def _is_azure_devops_repository(
)
return repository.git_provider == ProviderType.AZURE_DEVOPS
except Exception:
# If we can't determine the provider, assume it's not Azure DevOps
return False
async def _get_provider_type(
selected_repository: str, user_context: UserContext
) -> str:
"""Determine the Git provider type for a repository.
Args:
selected_repository: Repository name (e.g., 'owner/repo')
user_context: UserContext to access provider handler
Returns:
Provider type string: 'github', 'gitlab', 'azure', or 'bitbucket'
"""
is_gitlab = await _is_gitlab_repository(selected_repository, user_context)
if is_gitlab:
return 'gitlab'
is_azure = await _is_azure_devops_repository(selected_repository, user_context)
if is_azure:
return 'azure'
# Default to github (covers github and bitbucket)
return 'github'
async def _determine_org_repo_path(
selected_repository: str, user_context: UserContext
) -> tuple[str, str]:
@@ -161,19 +203,27 @@ async def _determine_org_repo_path(
"""
repo_parts = selected_repository.split('/')
# Determine repository type
is_azure_devops = await _is_azure_devops_repository(
selected_repository, user_context
)
is_gitlab = await _is_gitlab_repository(selected_repository, user_context)
# Extract the org/user name
# Azure DevOps format: org/project/repo (3 parts) - extract org (first part)
# GitHub/GitLab/Bitbucket format: owner/repo (2 parts) - extract owner (first part)
if is_azure_devops and len(repo_parts) >= 3:
org_name = repo_parts[0]
org_name = repo_parts[0] # Get org from org/project/repo
else:
org_name = repo_parts[-2]
org_name = repo_parts[-2] # Get owner from owner/repo
# For GitLab and Azure DevOps, use openhands-config (since .openhands is not a valid repo name)
# For other providers, use .openhands
if is_gitlab:
org_openhands_repo = f'{org_name}/openhands-config'
elif is_azure_devops:
# Azure DevOps format: org/project/repo
# For org-level config, use: org/openhands-config/openhands-config
org_openhands_repo = f'{org_name}/openhands-config/openhands-config'
else:
org_openhands_repo = f'{org_name}/.openhands'
@@ -181,6 +231,227 @@ async def _determine_org_repo_path(
return org_openhands_repo, org_name
async def _read_file_from_workspace(
workspace: AsyncRemoteWorkspace, file_path: str, working_dir: str
) -> str | None:
"""Read file content from remote workspace.
Args:
workspace: AsyncRemoteWorkspace to execute commands
file_path: Path to the file to read
working_dir: Working directory for command execution
Returns:
File content as string, or None if file doesn't exist or read fails
"""
try:
result = await workspace.execute_command(
f'cat {file_path}', cwd=working_dir, timeout=10.0
)
if result.exit_code == 0 and result.stdout.strip():
return result.stdout
return None
except Exception as e:
_logger.debug(f'Failed to read file {file_path}: {str(e)}')
return None
async def _load_special_files(
workspace: AsyncRemoteWorkspace, repo_root: str, working_dir: str
) -> list[Skill]:
"""Load special skill files from repository root.
Loads: .cursorrules, agents.md, agent.md
Args:
workspace: AsyncRemoteWorkspace to execute commands
repo_root: Path to repository root directory
working_dir: Working directory for command execution
Returns:
List of Skill objects loaded from special files
"""
skills = []
special_files = ['.cursorrules', 'agents.md', 'agent.md']
for filename in special_files:
file_path = f'{repo_root}/{filename}'
content = await _read_file_from_workspace(workspace, file_path, working_dir)
if content:
try:
# Use simple string path to avoid Path filesystem operations
skill = Skill.load(path=filename, skill_dir=None, file_content=content)
skills.append(skill)
_logger.debug(f'Loaded special file skill: {skill.name}')
except Exception as e:
_logger.warning(f'Failed to create skill from {filename}: {str(e)}')
return skills
async def _find_and_load_skill_md_files(
workspace: AsyncRemoteWorkspace, skill_dir: str, working_dir: str
) -> list[Skill]:
"""Find and load all .md files from a skills directory in the workspace.
Args:
workspace: AsyncRemoteWorkspace to execute commands
skill_dir: Path to skills directory
working_dir: Working directory for command execution
Returns:
List of Skill objects loaded from the files (excluding README.md)
"""
skills = []
try:
# Find all .md files in the directory
result = await workspace.execute_command(
f"find {skill_dir} -type f -name '*.md' 2>/dev/null || true",
cwd=working_dir,
timeout=10.0,
)
if result.exit_code == 0 and result.stdout.strip():
file_paths = [
f.strip()
for f in result.stdout.strip().split('\n')
if f.strip() and 'README.md' not in f
]
# Load skills from the found files
for file_path in file_paths:
content = await _read_file_from_workspace(
workspace, file_path, working_dir
)
if content:
# Calculate relative path for skill name
rel_path = file_path.replace(f'{skill_dir}/', '')
try:
# Use simple string path to avoid Path filesystem operations
skill = Skill.load(
path=rel_path, skill_dir=None, file_content=content
)
skills.append(skill)
_logger.debug(f'Loaded repo skill: {skill.name}')
except Exception as e:
_logger.warning(
f'Failed to create skill from {rel_path}: {str(e)}'
)
except Exception as e:
_logger.debug(f'Failed to find skill files in {skill_dir}: {str(e)}')
return skills
def _merge_repo_skills_with_precedence(
special_skills: list[Skill],
skills_dir_skills: list[Skill],
microagents_dir_skills: list[Skill],
) -> list[Skill]:
"""Merge repository skills with precedence order.
Precedence (highest to lowest):
1. Special files (repo root)
2. .openhands/skills/ directory
3. .openhands/microagents/ directory (backward compatibility)
Args:
special_skills: Skills from special files in repo root
skills_dir_skills: Skills from .openhands/skills/ directory
microagents_dir_skills: Skills from .openhands/microagents/ directory
Returns:
Deduplicated list of skills with proper precedence
"""
# Use a dict to deduplicate by name, with earlier sources taking precedence
skills_by_name = {}
for skill in special_skills + skills_dir_skills + microagents_dir_skills:
# Only add if not already present (earlier sources win)
if skill.name not in skills_by_name:
skills_by_name[skill.name] = skill
return list(skills_by_name.values())
async def load_repo_skills(
workspace: AsyncRemoteWorkspace,
selected_repository: str | None,
working_dir: str,
) -> list[Skill]:
"""Load repository-level skills from the workspace.
Loads skills from:
1. Special files in repo root: .cursorrules, agents.md, agent.md
2. .md files in .openhands/skills/ directory (preferred)
3. .md files in .openhands/microagents/ directory (for backward compatibility)
Args:
workspace: AsyncRemoteWorkspace to execute commands in the sandbox
selected_repository: Repository name (e.g., 'owner/repo') or None
working_dir: Working directory path
Returns:
List of Skill objects loaded from repository.
Returns empty list on errors.
"""
try:
# Determine repository root directory
repo_root = _determine_repo_root(working_dir, selected_repository)
_logger.info(f'Loading repo skills from {repo_root}')
# Load special files from repo root
special_skills = await _load_special_files(workspace, repo_root, working_dir)
# Load .md files from .openhands/skills/ directory (preferred)
skills_dir = f'{repo_root}/.openhands/skills'
skills_dir_skills = await _find_and_load_skill_md_files(
workspace, skills_dir, working_dir
)
# Load .md files from .openhands/microagents/ directory (backward compatibility)
microagents_dir = f'{repo_root}/.openhands/microagents'
microagents_dir_skills = await _find_and_load_skill_md_files(
workspace, microagents_dir, working_dir
)
# Merge all loaded skills with proper precedence
all_skills = _merge_repo_skills_with_precedence(
special_skills, skills_dir_skills, microagents_dir_skills
)
_logger.info(
f'Loaded {len(all_skills)} repo skills: {[s.name for s in all_skills]}'
)
return all_skills
except Exception as e:
_logger.warning(f'Failed to load repo skills: {str(e)}')
return []
def _validate_repository_for_org_skills(selected_repository: str) -> bool:
"""Validate that the repository path has sufficient parts for org skills.
Args:
selected_repository: Repository name (e.g., 'owner/repo')
Returns:
True if repository is valid for org skills loading, False otherwise
"""
repo_parts = selected_repository.split('/')
if len(repo_parts) < 2:
_logger.warning(
f'Repository path has insufficient parts ({len(repo_parts)} < 2), skipping org-level skills'
)
return False
return True
async def _get_org_repository_url(
org_openhands_repo: str, user_context: UserContext
) -> str | None:
@@ -210,193 +481,224 @@ async def _get_org_repository_url(
return None
async def build_org_config(
selected_repository: str | None,
user_context: UserContext,
) -> OrgConfig | None:
"""Build organization config for agent-server API request.
async def _clone_org_repository(
workspace: AsyncRemoteWorkspace,
remote_url: str,
org_repo_dir: str,
working_dir: str,
org_openhands_repo: str,
) -> bool:
"""Clone organization repository to temporary directory.
Args:
selected_repository: Repository name (e.g., 'owner/repo') or None
user_context: UserContext to access authentication and provider info
workspace: AsyncRemoteWorkspace to execute commands
remote_url: Authenticated Git URL
org_repo_dir: Temporary directory path for cloning
working_dir: Working directory for command execution
org_openhands_repo: Organization repository path (for logging)
Returns:
org_config dict if org repository exists and is accessible, None otherwise
True if clone successful, False otherwise
"""
_logger.debug(f'Creating temporary directory for org repo: {org_repo_dir}')
# Clone the repo (shallow clone for efficiency)
clone_cmd = f'GIT_TERMINAL_PROMPT=0 git clone --depth 1 {remote_url} {org_repo_dir}'
_logger.info('Executing clone command for org-level repo')
result = await workspace.execute_command(clone_cmd, working_dir, timeout=120.0)
if result.exit_code != 0:
_logger.info(
f'No org-level skills found at {org_openhands_repo} (exit_code: {result.exit_code})'
)
_logger.debug(f'Clone command output: {result.stderr}')
return False
_logger.info(f'Successfully cloned org-level skills from {org_openhands_repo}')
return True
async def _load_skills_from_org_directories(
workspace: AsyncRemoteWorkspace, org_repo_dir: str, working_dir: str
) -> tuple[list[Skill], list[Skill]]:
"""Load skills from both skills/ and microagents/ directories in org repo.
Args:
workspace: AsyncRemoteWorkspace to execute commands
org_repo_dir: Path to cloned organization repository
working_dir: Working directory for command execution
Returns:
Tuple of (skills_dir_skills, microagents_dir_skills)
"""
skills_dir = f'{org_repo_dir}/skills'
skills_dir_skills = await _find_and_load_skill_md_files(
workspace, skills_dir, working_dir
)
microagents_dir = f'{org_repo_dir}/microagents'
microagents_dir_skills = await _find_and_load_skill_md_files(
workspace, microagents_dir, working_dir
)
return skills_dir_skills, microagents_dir_skills
def _merge_org_skills_with_precedence(
skills_dir_skills: list[Skill], microagents_dir_skills: list[Skill]
) -> list[Skill]:
"""Merge skills from skills/ and microagents/ with proper precedence.
Precedence: skills/ > microagents/ (skills/ overrides microagents/ for same name)
Args:
skills_dir_skills: Skills loaded from skills/ directory
microagents_dir_skills: Skills loaded from microagents/ directory
Returns:
Merged list of skills with proper precedence applied
"""
skills_by_name = {}
for skill in microagents_dir_skills + skills_dir_skills:
# Later sources (skills/) override earlier ones (microagents/)
if skill.name not in skills_by_name:
skills_by_name[skill.name] = skill
else:
_logger.debug(
f'Overriding org skill "{skill.name}" from microagents/ with skills/'
)
skills_by_name[skill.name] = skill
return list(skills_by_name.values())
async def _cleanup_org_repository(
workspace: AsyncRemoteWorkspace, org_repo_dir: str, working_dir: str
) -> None:
"""Clean up cloned organization repository directory.
Args:
workspace: AsyncRemoteWorkspace to execute commands
org_repo_dir: Path to cloned organization repository
working_dir: Working directory for command execution
"""
cleanup_cmd = f'rm -rf {org_repo_dir}'
await workspace.execute_command(cleanup_cmd, working_dir, timeout=10.0)
async def load_org_skills(
workspace: AsyncRemoteWorkspace,
selected_repository: str | None,
working_dir: str,
user_context: UserContext,
) -> list[Skill]:
"""Load organization-level skills from the organization repository.
For example, if the repository is github.com/acme-co/api, this will check if
github.com/acme-co/.openhands exists. If it does, it will clone it and load
the skills from both the ./skills/ and ./microagents/ folders.
For GitLab repositories, it will use openhands-config instead of .openhands
since GitLab doesn't support repository names starting with non-alphanumeric
characters.
For Azure DevOps repositories, it will use org/openhands-config/openhands-config
format to match Azure DevOps's three-part repository structure (org/project/repo).
Args:
workspace: AsyncRemoteWorkspace to execute commands in the sandbox
selected_repository: Repository name (e.g., 'owner/repo') or None
working_dir: Working directory path
user_context: UserContext to access provider handler and authentication
Returns:
List of Skill objects loaded from organization repository.
Returns empty list if no repository selected or on errors.
"""
if not selected_repository:
return None
repo_parts = selected_repository.split('/')
if len(repo_parts) < 2:
_logger.warning(
f'Repository path has insufficient parts ({len(repo_parts)} < 2), '
f'skipping org-level skills'
)
return None
return []
try:
_logger.debug(
f'Starting org-level skill loading for repository: {selected_repository}'
)
# Validate repository path
if not _validate_repository_for_org_skills(selected_repository):
return []
# Determine organization repository path
org_openhands_repo, org_name = await _determine_org_repo_path(
selected_repository, user_context
)
org_repo_url = await _get_org_repository_url(org_openhands_repo, user_context)
if not org_repo_url:
return None
_logger.info(f'Checking for org-level skills at {org_openhands_repo}')
provider = await _get_provider_type(selected_repository, user_context)
# Get authenticated URL for org repository
remote_url = await _get_org_repository_url(org_openhands_repo, user_context)
if not remote_url:
return []
return OrgConfig(
repository=selected_repository,
provider=provider,
org_repo_url=org_repo_url,
org_name=org_name,
# Clone the organization repository
org_repo_dir = f'{working_dir}/_org_openhands_{org_name}'
clone_success = await _clone_org_repository(
workspace, remote_url, org_repo_dir, working_dir, org_openhands_repo
)
if not clone_success:
return []
# Load skills from both skills/ and microagents/ directories
(
skills_dir_skills,
microagents_dir_skills,
) = await _load_skills_from_org_directories(
workspace, org_repo_dir, working_dir
)
except Exception as e:
_logger.debug(f'Failed to build org config: {str(e)}')
return None
# Merge skills with proper precedence
loaded_skills = _merge_org_skills_with_precedence(
skills_dir_skills, microagents_dir_skills
)
def build_sandbox_config(sandbox: SandboxInfo) -> SandboxConfig | None:
"""Build sandbox config for agent-server API request.
Args:
sandbox: SandboxInfo containing exposed URLs
Returns:
sandbox_config dict if there are exposed URLs, None otherwise
"""
if not sandbox.exposed_urls:
return None
exposed_urls = [
ExposedUrlConfig(name=url.name, url=url.url, port=url.port)
for url in sandbox.exposed_urls
]
return SandboxConfig(exposed_urls=exposed_urls)
async def load_skills_from_agent_server(
agent_server_url: str,
session_api_key: str | None,
project_dir: str,
org_config: OrgConfig | None = None,
sandbox_config: SandboxConfig | None = None,
load_public: bool = True,
load_user: bool = True,
load_project: bool = True,
load_org: bool = True,
) -> list[Skill]:
"""Load all skills from the agent-server.
This function makes a single API call to the agent-server's /api/skills
endpoint to load and merge skills from all configured sources.
Args:
agent_server_url: URL of the agent server (e.g., 'http://localhost:8000')
session_api_key: Session API key for authentication (optional)
project_dir: Workspace directory path for project skills
org_config: Organization skills configuration (optional)
sandbox_config: Sandbox skills configuration (optional)
load_public: Whether to load public skills (default: True)
load_user: Whether to load user skills (default: True)
load_project: Whether to load project skills (default: True)
load_org: Whether to load organization skills (default: True)
Returns:
List of Skill objects merged from all sources.
Returns empty list on error.
"""
try:
# Build request payload
payload = {
'load_public': load_public,
'load_user': load_user,
'load_project': load_project,
'load_org': load_org,
'project_dir': project_dir,
'org_config': org_config.model_dump() if org_config else None,
'sandbox_config': sandbox_config.model_dump() if sandbox_config else None,
}
# Build headers
headers = {'Content-Type': 'application/json'}
if session_api_key:
headers['X-Session-API-Key'] = session_api_key
# Make API request
async with httpx.AsyncClient() as client:
response = await client.post(
f'{agent_server_url}/api/skills',
json=payload,
headers=headers,
timeout=60.0,
)
response.raise_for_status()
data = response.json()
# Convert response to Skill objects
skills: list[Skill] = []
for skill_data_dict in data.get('skills', []):
try:
skill_info = SkillInfo.model_validate(skill_data_dict)
skill = _convert_skill_info_to_skill(skill_info)
skills.append(skill)
except Exception as e:
skill_name = (
skill_data_dict.get('name', 'unknown')
if isinstance(skill_data_dict, dict)
else 'unknown'
)
_logger.warning(f'Failed to convert skill {skill_name}: {e}')
sources = data.get('sources', {})
_logger.info(
f'Loaded {len(skills)} skills from agent-server: '
f'sources={sources}, names={[s.name for s in skills]}'
f'Loaded {len(loaded_skills)} skills from org-level repository {org_openhands_repo}: {[s.name for s in loaded_skills]}'
)
return skills
# Clean up the org repo directory
await _cleanup_org_repository(workspace, org_repo_dir, working_dir)
except httpx.HTTPStatusError as e:
_logger.warning(
f'Agent-server returned error status {e.response.status_code}: '
f'{e.response.text}'
)
return []
except httpx.RequestError as e:
_logger.warning(f'Failed to connect to agent-server: {e}')
return loaded_skills
except AuthenticationError as e:
_logger.debug(f'org-level skill directory not found: {str(e)}')
return []
except Exception as e:
_logger.warning(f'Failed to load skills from agent-server: {e}')
_logger.warning(f'Failed to load org-level skills: {str(e)}')
return []
def _convert_skill_info_to_skill(skill_info: SkillInfo) -> Skill:
"""Convert skill info from API response to Skill object.
def merge_skills(skill_lists: list[list[Skill]]) -> list[Skill]:
"""Merge multiple skill lists, avoiding duplicates by name.
Later lists take precedence over earlier lists for duplicate names.
Args:
skill_info: SkillInfo model from API response
skill_lists: List of skill lists to merge
Returns:
Skill object
Deduplicated list of skills with later lists overriding earlier ones
"""
trigger = None
skills_by_name = {}
if skill_info.triggers:
# Determine trigger type based on content
if any(t.startswith('/') for t in skill_info.triggers):
trigger = TaskTrigger(triggers=skill_info.triggers)
else:
trigger = KeywordTrigger(keywords=skill_info.triggers)
for skill_list in skill_lists:
for skill in skill_list:
if skill.name in skills_by_name:
_logger.debug(
f'Overriding skill "{skill.name}" from earlier source with later source'
)
skills_by_name[skill.name] = skill
return Skill(
name=skill_info.name,
content=skill_info.content,
trigger=trigger,
source=skill_info.source,
description=skill_info.description,
is_agentskills_format=skill_info.is_agentskills_format,
)
result = list(skills_by_name.values())
_logger.debug(f'Merged skills: {[s.name for s in result]}')
return result
@@ -541,8 +541,7 @@ class SQLAppConversationInfoService(AppConversationInfoService):
def _fix_timezone(self, value: datetime) -> datetime:
"""Sqlite does not stpre timezones - and since we can't update the existing models
we assume UTC if the timezone is missing.
"""
we assume UTC if the timezone is missing."""
if not value.tzinfo:
value = value.replace(tzinfo=UTC)
return value
@@ -68,8 +68,7 @@ class StoredAppConversationStartTask(Base): # type: ignore
class SQLAppConversationStartTaskService(AppConversationStartTaskService):
"""SQL implementation of AppConversationStartTaskService focused on db operations.
This allows storing and retrieving conversation start tasks from the database.
"""
This allows storing and retrieving conversation start tasks from the database."""
session: AsyncSession
user_id: str | None = None
+14 -1
View File
@@ -243,7 +243,20 @@ def config_from_env() -> AppServerConfig:
config.sandbox_spec = DockerSandboxSpecServiceInjector()
if config.app_conversation_info is None:
config.app_conversation_info = SQLAppConversationInfoServiceInjector()
# Use enterprise injector if running in SAAS mode
if 'saas' in (os.getenv('OPENHANDS_CONFIG_CLS') or '').lower():
try:
# Import enterprise injector dynamically
from enterprise.server.utils.saas_app_conversation_info_injector import (
SaasAppConversationInfoServiceInjector,
)
config.app_conversation_info = SaasAppConversationInfoServiceInjector()
except ImportError:
# Fallback to OSS injector if enterprise module is not available
config.app_conversation_info = SQLAppConversationInfoServiceInjector()
else:
config.app_conversation_info = SQLAppConversationInfoServiceInjector()
if config.app_conversation_start_task is None:
config.app_conversation_start_task = (
@@ -13,7 +13,7 @@ from openhands.sdk.utils.models import DiscriminatedUnionMixin
# The version of the agent server to use for deployments.
# Typically this will be the same as the values from the pyproject.toml
AGENT_SERVER_IMAGE = 'ghcr.io/openhands/agent-server:c775ff6-python'
AGENT_SERVER_IMAGE = 'ghcr.io/openhands/agent-server:10fff69-python'
class SandboxSpecService(ABC):
+1 -16
View File
@@ -136,7 +136,7 @@ class LLM(RetryMixin, DebugMixin):
if self.config.model.startswith('openhands/'):
model_name = self.config.model.removeprefix('openhands/')
self.config.model = f'litellm_proxy/{model_name}'
self.config.base_url = _get_openhands_llm_base_url()
self.config.base_url = 'https://llm-proxy.app.all-hands.dev/'
logger.debug(
f'Rewrote openhands/{model_name} to {self.config.model} with base URL {self.config.base_url}'
)
@@ -851,18 +851,3 @@ class LLM(RetryMixin, DebugMixin):
# let pydantic handle the serialization
return [message.model_dump() for message in messages]
def _get_openhands_llm_base_url():
# Get the API url if specified
lite_llm_api_url = os.getenv('LITE_LLM_API_URL')
if lite_llm_api_url:
return lite_llm_api_url
# Fallback to using web_host.
web_host = os.getenv('WEB_HOST')
if web_host and ('.staging.' in web_host or web_host.startswith('staging')):
return 'https://llm-proxy.staging.all-hands.dev/'
# Use the default
return 'https://llm-proxy.app.all-hands.dev/'
+1 -5
View File
@@ -1,7 +1,5 @@
from __future__ import annotations
from typing import Annotated
from pydantic import (
BaseModel,
ConfigDict,
@@ -33,9 +31,7 @@ class Settings(BaseModel):
user_version: int | None = None
remote_runtime_resource_factor: int | None = None
# Planned to be removed from settings
secrets_store: Annotated[Secrets, Field(frozen=True)] = Field(
default_factory=Secrets
)
secrets_store: Secrets = Field(default_factory=Secrets, frozen=True)
enable_default_condenser: bool = True
enable_sound_notifications: bool = False
enable_proactive_conversation_starters: bool = True
Generated
+10 -10
View File
@@ -7731,14 +7731,14 @@ llama = ["llama-index (>=0.12.29,<0.13.0)", "llama-index-core (>=0.12.29,<0.13.0
[[package]]
name = "openhands-agent-server"
version = "1.10.0"
version = "1.8.2"
description = "OpenHands Agent Server - REST/WebSocket interface for OpenHands AI Agent"
optional = false
python-versions = ">=3.12"
groups = ["main"]
files = [
{file = "openhands_agent_server-1.10.0-py3-none-any.whl", hash = "sha256:2e21076fff5e7cf9d03a3b011e2c90a6a3a46d2da3f18db9f7553ac413229c22"},
{file = "openhands_agent_server-1.10.0.tar.gz", hash = "sha256:2062da2496a98a6c23201d086f124e02329d6c6d9d1b47be55921c084a29f55a"},
{file = "openhands_agent_server-1.8.2-py3-none-any.whl", hash = "sha256:e9abb2e0fe970715537d0e0fc1aea3dd64bb9e8b531f70cb72b3d4e486aaa46a"},
{file = "openhands_agent_server-1.8.2.tar.gz", hash = "sha256:43db2371ee84b100ac921396338dee74359fceeb5c9400c90530bcc5730144c3"},
]
[package.dependencies]
@@ -7755,14 +7755,14 @@ wsproto = ">=1.2.0"
[[package]]
name = "openhands-sdk"
version = "1.10.0"
version = "1.8.2"
description = "OpenHands SDK - Core functionality for building AI agents"
optional = false
python-versions = ">=3.12"
groups = ["main"]
files = [
{file = "openhands_sdk-1.10.0-py3-none-any.whl", hash = "sha256:5c8875f2a07d7fabe3449914639572bef9003821207cb06aa237a239e964eed5"},
{file = "openhands_sdk-1.10.0.tar.gz", hash = "sha256:93371b1af4532266ad2d225b9d7d3d711c745df31888efe643970673f62bdef9"},
{file = "openhands_sdk-1.8.2-py3-none-any.whl", hash = "sha256:b4fad9581865ce222a3e6722384e4df56113db01bd34c2d2d408dfd9695365c0"},
{file = "openhands_sdk-1.8.2.tar.gz", hash = "sha256:5bfb17c8b9515210d121249deb1f3d0dc407c3737edc55b5e73330b4571d61e3"},
]
[package.dependencies]
@@ -7783,14 +7783,14 @@ boto3 = ["boto3 (>=1.35.0)"]
[[package]]
name = "openhands-tools"
version = "1.10.0"
version = "1.8.2"
description = "OpenHands Tools - Runtime tools for AI agents"
optional = false
python-versions = ">=3.12"
groups = ["main"]
files = [
{file = "openhands_tools-1.10.0-py3-none-any.whl", hash = "sha256:1d5d2d1e34cc4ceb02c0ff1f008b06883ad48a8e7236ab8dd61ece64fbf8e2ed"},
{file = "openhands_tools-1.10.0.tar.gz", hash = "sha256:7ed38cb13545ec2c4a35c26ece725d5b35788d30597db8b1904619c043ec1194"},
{file = "openhands_tools-1.8.2-py3-none-any.whl", hash = "sha256:283f0c1fdd316914559cd16ade792383715478a8f5a73f7166daffc34bf9e5af"},
{file = "openhands_tools-1.8.2.tar.gz", hash = "sha256:eae416e3867f7cb595129a33a4b9237886c4b8a075d2bc7618da55963f2747d5"},
]
[package.dependencies]
@@ -17367,4 +17367,4 @@ third-party-runtimes = ["daytona", "e2b-code-interpreter", "modal", "runloop-api
[metadata]
lock-version = "2.1"
python-versions = "^3.12,<3.14"
content-hash = "f67478db2385eb258369313ac831b26582d744294c0996a35e786c3d7ced5db1"
content-hash = "530cf5f60f1a38d69fea854eb3682a64a192f9f97beaa0dfc9dcedf239de3a58"
+9 -6
View File
@@ -54,9 +54,9 @@ dependencies = [
"numpy",
"openai==2.8",
"openhands-aci==0.3.2",
"openhands-agent-server==1.10",
"openhands-sdk==1.10",
"openhands-tools==1.10",
"openhands-agent-server==1.8.2",
"openhands-sdk==1.8.2",
"openhands-tools==1.8.2",
"opentelemetry-api>=1.33.1",
"opentelemetry-exporter-otlp-proto-grpc>=1.33.1",
"pathspec>=0.12.1",
@@ -280,9 +280,12 @@ e2b-code-interpreter = { version = "^2.0.0", optional = true }
pybase62 = "^1.0.0"
# V1 dependencies
openhands-sdk = "1.10"
openhands-agent-server = "1.10"
openhands-tools = "1.10"
#openhands-agent-server = { git = "https://github.com/OpenHands/agent-sdk.git", subdirectory = "openhands-agent-server", rev = "15f565b8ac38876e40dc05c08e2b04ccaae4a66d" }
#openhands-sdk = { git = "https://github.com/OpenHands/agent-sdk.git", subdirectory = "openhands-sdk", rev = "15f565b8ac38876e40dc05c08e2b04ccaae4a66d" }
#openhands-tools = { git = "https://github.com/OpenHands/agent-sdk.git", subdirectory = "openhands-tools", rev = "15f565b8ac38876e40dc05c08e2b04ccaae4a66d" }
openhands-sdk = "1.8.2"
openhands-agent-server = "1.8.2"
openhands-tools = "1.8.2"
python-jose = { version = ">=3.3", extras = [ "cryptography" ] }
sqlalchemy = { extras = [ "asyncio" ], version = "^2.0.40" }
pg8000 = "^1.31.5"
@@ -17,7 +17,6 @@ from openhands.app_server.app_conversation.app_conversation_service_base import
)
from openhands.app_server.sandbox.sandbox_models import SandboxInfo
from openhands.app_server.user.user_context import UserContext
from openhands.sdk.context.skills import Skill
class MockUserInfo:
@@ -921,251 +920,347 @@ async def test_configure_git_user_settings_special_characters_in_name(mock_works
# =============================================================================
# Tests for load_and_merge_all_skills (updated to use agent-server)
# Tests for load_and_merge_all_skills with org skills
# =============================================================================
class TestMergeSkills:
"""Test _merge_skills method."""
class TestLoadAndMergeAllSkillsWithOrgSkills:
"""Test load_and_merge_all_skills includes organization skills."""
def test_merges_skills_with_no_duplicates(self):
"""Test merging skill lists with no duplicate names."""
@pytest.mark.asyncio
@patch(
'openhands.app_server.app_conversation.app_conversation_service_base.load_sandbox_skills'
)
@patch(
'openhands.app_server.app_conversation.app_conversation_service_base.load_global_skills'
)
@patch(
'openhands.app_server.app_conversation.app_conversation_service_base.load_user_skills'
)
@patch(
'openhands.app_server.app_conversation.app_conversation_service_base.load_org_skills'
)
@patch(
'openhands.app_server.app_conversation.app_conversation_service_base.load_repo_skills'
)
async def test_load_and_merge_includes_org_skills(
self,
mock_load_repo,
mock_load_org,
mock_load_user,
mock_load_global,
mock_load_sandbox,
):
"""Test that load_and_merge_all_skills loads and merges org skills."""
# Arrange
mock_user_context = Mock(spec=UserContext)
with patch.object(AppConversationServiceBase, '__abstractmethods__', set()):
with patch.object(
AppConversationServiceBase,
'__abstractmethods__',
set(),
):
service = AppConversationServiceBase(
init_git_in_empty_workspace=True, user_context=mock_user_context
init_git_in_empty_workspace=True,
user_context=mock_user_context,
)
skill1 = Mock(spec=Skill)
skill1.name = 'skill1'
skill2 = Mock(spec=Skill)
skill2.name = 'skill2'
skill3 = Mock(spec=Skill)
skill3.name = 'skill3'
sandbox = Mock(spec=SandboxInfo)
sandbox.exposed_urls = []
remote_workspace = AsyncMock()
skill_lists = [[skill1], [skill2], [skill3]]
# Create distinct mock skills for each source
sandbox_skill = Mock()
sandbox_skill.name = 'sandbox_skill'
global_skill = Mock()
global_skill.name = 'global_skill'
user_skill = Mock()
user_skill.name = 'user_skill'
org_skill = Mock()
org_skill.name = 'org_skill'
repo_skill = Mock()
repo_skill.name = 'repo_skill'
mock_load_sandbox.return_value = [sandbox_skill]
mock_load_global.return_value = [global_skill]
mock_load_user.return_value = [user_skill]
mock_load_org.return_value = [org_skill]
mock_load_repo.return_value = [repo_skill]
# Act
result = service._merge_skills(skill_lists)
result = await service.load_and_merge_all_skills(
sandbox, remote_workspace, 'owner/repo', '/workspace'
)
# Assert
assert len(result) == 3
assert len(result) == 5
names = {s.name for s in result}
assert names == {'skill1', 'skill2', 'skill3'}
def test_merges_skills_with_duplicates_later_wins(self):
"""Test that later skill lists override earlier ones for duplicate names."""
# Arrange
mock_user_context = Mock(spec=UserContext)
with patch.object(AppConversationServiceBase, '__abstractmethods__', set()):
service = AppConversationServiceBase(
init_git_in_empty_workspace=True, user_context=mock_user_context
assert names == {
'sandbox_skill',
'global_skill',
'user_skill',
'org_skill',
'repo_skill',
}
mock_load_org.assert_called_once_with(
remote_workspace, 'owner/repo', '/workspace', mock_user_context
)
skill1_v1 = Mock(spec=Skill)
skill1_v1.name = 'skill1'
skill1_v1.version = 'v1'
@pytest.mark.asyncio
@patch(
'openhands.app_server.app_conversation.app_conversation_service_base.load_sandbox_skills'
)
@patch(
'openhands.app_server.app_conversation.app_conversation_service_base.load_global_skills'
)
@patch(
'openhands.app_server.app_conversation.app_conversation_service_base.load_user_skills'
)
@patch(
'openhands.app_server.app_conversation.app_conversation_service_base.load_org_skills'
)
@patch(
'openhands.app_server.app_conversation.app_conversation_service_base.load_repo_skills'
)
async def test_load_and_merge_org_skills_precedence(
self,
mock_load_repo,
mock_load_org,
mock_load_user,
mock_load_global,
mock_load_sandbox,
):
"""Test that org skills have correct precedence (higher than user, lower than repo)."""
# Arrange
mock_user_context = Mock(spec=UserContext)
with patch.object(
AppConversationServiceBase,
'__abstractmethods__',
set(),
):
service = AppConversationServiceBase(
init_git_in_empty_workspace=True,
user_context=mock_user_context,
)
skill1_v2 = Mock(spec=Skill)
skill1_v2.name = 'skill1'
skill1_v2.version = 'v2'
sandbox = Mock(spec=SandboxInfo)
sandbox.exposed_urls = []
remote_workspace = AsyncMock()
skill2 = Mock(spec=Skill)
skill2.name = 'skill2'
# Create skills with same name but different sources
user_skill = Mock()
user_skill.name = 'common_skill'
user_skill.source = 'user'
skill_lists = [[skill1_v1], [skill1_v2, skill2]]
org_skill = Mock()
org_skill.name = 'common_skill'
org_skill.source = 'org'
repo_skill = Mock()
repo_skill.name = 'common_skill'
repo_skill.source = 'repo'
mock_load_sandbox.return_value = []
mock_load_global.return_value = []
mock_load_user.return_value = [user_skill]
mock_load_org.return_value = [org_skill]
mock_load_repo.return_value = [repo_skill]
# Act
result = service._merge_skills(skill_lists)
result = await service.load_and_merge_all_skills(
sandbox, remote_workspace, 'owner/repo', '/workspace'
)
# Assert
# Should have only one skill with repo source (highest precedence)
assert len(result) == 1
assert result[0].source == 'repo'
@pytest.mark.asyncio
@patch(
'openhands.app_server.app_conversation.app_conversation_service_base.load_sandbox_skills'
)
@patch(
'openhands.app_server.app_conversation.app_conversation_service_base.load_global_skills'
)
@patch(
'openhands.app_server.app_conversation.app_conversation_service_base.load_user_skills'
)
@patch(
'openhands.app_server.app_conversation.app_conversation_service_base.load_org_skills'
)
@patch(
'openhands.app_server.app_conversation.app_conversation_service_base.load_repo_skills'
)
async def test_load_and_merge_org_skills_override_user_skills(
self,
mock_load_repo,
mock_load_org,
mock_load_user,
mock_load_global,
mock_load_sandbox,
):
"""Test that org skills override user skills for same name."""
# Arrange
mock_user_context = Mock(spec=UserContext)
with patch.object(
AppConversationServiceBase,
'__abstractmethods__',
set(),
):
service = AppConversationServiceBase(
init_git_in_empty_workspace=True,
user_context=mock_user_context,
)
sandbox = Mock(spec=SandboxInfo)
sandbox.exposed_urls = []
remote_workspace = AsyncMock()
# Create skills with same name
user_skill = Mock()
user_skill.name = 'shared_skill'
user_skill.priority = 'low'
org_skill = Mock()
org_skill.name = 'shared_skill'
org_skill.priority = 'high'
mock_load_sandbox.return_value = []
mock_load_global.return_value = []
mock_load_user.return_value = [user_skill]
mock_load_org.return_value = [org_skill]
mock_load_repo.return_value = []
# Act
result = await service.load_and_merge_all_skills(
sandbox, remote_workspace, 'owner/repo', '/workspace'
)
# Assert
assert len(result) == 1
assert result[0].priority == 'high' # Org skill should win
@pytest.mark.asyncio
@patch(
'openhands.app_server.app_conversation.app_conversation_service_base.load_sandbox_skills'
)
@patch(
'openhands.app_server.app_conversation.app_conversation_service_base.load_global_skills'
)
@patch(
'openhands.app_server.app_conversation.app_conversation_service_base.load_user_skills'
)
@patch(
'openhands.app_server.app_conversation.app_conversation_service_base.load_org_skills'
)
@patch(
'openhands.app_server.app_conversation.app_conversation_service_base.load_repo_skills'
)
async def test_load_and_merge_handles_org_skills_failure(
self,
mock_load_repo,
mock_load_org,
mock_load_user,
mock_load_global,
mock_load_sandbox,
):
"""Test that failure to load org skills doesn't break the overall process."""
# Arrange
mock_user_context = Mock(spec=UserContext)
with patch.object(
AppConversationServiceBase,
'__abstractmethods__',
set(),
):
service = AppConversationServiceBase(
init_git_in_empty_workspace=True,
user_context=mock_user_context,
)
sandbox = Mock(spec=SandboxInfo)
sandbox.exposed_urls = []
remote_workspace = AsyncMock()
global_skill = Mock()
global_skill.name = 'global_skill'
repo_skill = Mock()
repo_skill.name = 'repo_skill'
mock_load_sandbox.return_value = []
mock_load_global.return_value = [global_skill]
mock_load_user.return_value = []
mock_load_org.return_value = [] # Org skills failed/empty
mock_load_repo.return_value = [repo_skill]
# Act
result = await service.load_and_merge_all_skills(
sandbox, remote_workspace, 'owner/repo', '/workspace'
)
# Assert
# Should still have skills from other sources
assert len(result) == 2
skill1_result = next(s for s in result if s.name == 'skill1')
assert skill1_result.version == 'v2'
class TestLoadAndMergeAllSkills:
"""Test load_and_merge_all_skills method (updated to use agent-server)."""
names = {s.name for s in result}
assert names == {'global_skill', 'repo_skill'}
@pytest.mark.asyncio
@patch(
'openhands.app_server.app_conversation.app_conversation_service_base.load_skills_from_agent_server'
'openhands.app_server.app_conversation.app_conversation_service_base.load_sandbox_skills'
)
@patch(
'openhands.app_server.app_conversation.app_conversation_service_base.build_org_config'
'openhands.app_server.app_conversation.app_conversation_service_base.load_global_skills'
)
@patch(
'openhands.app_server.app_conversation.app_conversation_service_base.build_sandbox_config'
'openhands.app_server.app_conversation.app_conversation_service_base.load_user_skills'
)
async def test_loads_skills_successfully(
@patch(
'openhands.app_server.app_conversation.app_conversation_service_base.load_org_skills'
)
@patch(
'openhands.app_server.app_conversation.app_conversation_service_base.load_repo_skills'
)
async def test_load_and_merge_no_selected_repository(
self,
mock_build_sandbox_config,
mock_build_org_config,
mock_load_skills,
mock_load_repo,
mock_load_org,
mock_load_user,
mock_load_global,
mock_load_sandbox,
):
"""Test successfully loading skills from agent-server."""
"""Test skill loading when no repository is selected."""
# Arrange
mock_user_context = Mock(spec=UserContext)
with patch.object(AppConversationServiceBase, '__abstractmethods__', set()):
with patch.object(
AppConversationServiceBase,
'__abstractmethods__',
set(),
):
service = AppConversationServiceBase(
init_git_in_empty_workspace=True, user_context=mock_user_context
init_git_in_empty_workspace=True,
user_context=mock_user_context,
)
mock_workspace = AsyncMock()
mock_workspace.working_dir = '/workspace'
from openhands.app_server.sandbox.sandbox_models import ExposedUrl
sandbox = Mock(spec=SandboxInfo)
exposed_url = ExposedUrl(
name='AGENT_SERVER', url='http://localhost:8000', port=8000
)
sandbox.exposed_urls = [exposed_url]
sandbox.session_api_key = 'test-api-key'
sandbox.exposed_urls = []
remote_workspace = AsyncMock()
skill1 = Mock(spec=Skill)
skill1.name = 'skill1'
skill2 = Mock(spec=Skill)
skill2.name = 'skill2'
global_skill = Mock()
global_skill.name = 'global_skill'
mock_load_skills.return_value = [skill1, skill2]
mock_build_org_config.return_value = {'repository': 'owner/repo'}
mock_build_sandbox_config.return_value = {'exposed_urls': []}
mock_load_sandbox.return_value = []
mock_load_global.return_value = [global_skill]
mock_load_user.return_value = []
mock_load_org.return_value = []
mock_load_repo.return_value = []
# Act
result = await service.load_and_merge_all_skills(
sandbox, 'owner/repo', '/workspace', 'http://localhost:8000'
sandbox, remote_workspace, None, '/workspace'
)
# Assert
assert len(result) == 2
assert result[0].name == 'skill1'
assert result[1].name == 'skill2'
mock_load_skills.assert_called_once()
call_kwargs = mock_load_skills.call_args[1]
assert call_kwargs['agent_server_url'] == 'http://localhost:8000'
assert call_kwargs['session_api_key'] == 'test-api-key'
assert call_kwargs['project_dir'] == '/workspace/repo'
@pytest.mark.asyncio
@patch(
'openhands.app_server.app_conversation.app_conversation_service_base.load_skills_from_agent_server'
)
async def test_returns_empty_list_when_no_agent_server_url(self, mock_load_skills):
"""Test returns empty list when agent-server URL is not available."""
# Arrange
mock_user_context = Mock(spec=UserContext)
with patch.object(AppConversationServiceBase, '__abstractmethods__', set()):
service = AppConversationServiceBase(
init_git_in_empty_workspace=True, user_context=mock_user_context
assert len(result) == 1
# Org skills should be called even with None repository
mock_load_org.assert_called_once_with(
remote_workspace, None, '/workspace', mock_user_context
)
AsyncMock()
from openhands.app_server.sandbox.sandbox_models import ExposedUrl
sandbox = Mock(spec=SandboxInfo)
exposed_url = ExposedUrl(
name='VSCODE', url='http://localhost:8080', port=8080
)
sandbox.exposed_urls = [exposed_url]
# Act - pass empty string to simulate no agent server URL
# This should still call load_skills_from_agent_server but it will fail
result = await service.load_and_merge_all_skills(
sandbox, 'owner/repo', '/workspace', ''
)
# Assert - should return empty list when agent_server_url is empty
assert result == []
@pytest.mark.asyncio
@patch(
'openhands.app_server.app_conversation.app_conversation_service_base.load_skills_from_agent_server'
)
@patch(
'openhands.app_server.app_conversation.app_conversation_service_base.build_org_config'
)
@patch(
'openhands.app_server.app_conversation.app_conversation_service_base.build_sandbox_config'
)
async def test_uses_working_dir_when_no_repository(
self,
mock_build_sandbox_config,
mock_build_org_config,
mock_load_skills,
):
"""Test uses working_dir as project_dir when no repository is selected."""
# Arrange
mock_user_context = Mock(spec=UserContext)
with patch.object(AppConversationServiceBase, '__abstractmethods__', set()):
service = AppConversationServiceBase(
init_git_in_empty_workspace=True, user_context=mock_user_context
)
AsyncMock()
from openhands.app_server.sandbox.sandbox_models import ExposedUrl
sandbox = Mock(spec=SandboxInfo)
exposed_url = ExposedUrl(
name='AGENT_SERVER', url='http://localhost:8000', port=8000
)
sandbox.exposed_urls = [exposed_url]
sandbox.session_api_key = 'test-key'
mock_load_skills.return_value = []
mock_build_org_config.return_value = None
mock_build_sandbox_config.return_value = None
# Act
await service.load_and_merge_all_skills(
sandbox, None, '/workspace', 'http://localhost:8000'
)
# Assert
call_kwargs = mock_load_skills.call_args[1]
assert call_kwargs['project_dir'] == '/workspace'
@pytest.mark.asyncio
@patch(
'openhands.app_server.app_conversation.app_conversation_service_base.load_skills_from_agent_server'
)
@patch(
'openhands.app_server.app_conversation.app_conversation_service_base.build_org_config'
)
@patch(
'openhands.app_server.app_conversation.app_conversation_service_base.build_sandbox_config'
)
async def test_handles_exception_gracefully(
self,
mock_build_sandbox_config,
mock_build_org_config,
mock_load_skills,
):
"""Test handles exceptions during skill loading."""
# Arrange
mock_user_context = Mock(spec=UserContext)
with patch.object(AppConversationServiceBase, '__abstractmethods__', set()):
service = AppConversationServiceBase(
init_git_in_empty_workspace=True, user_context=mock_user_context
)
AsyncMock()
from openhands.app_server.sandbox.sandbox_models import ExposedUrl
sandbox = Mock(spec=SandboxInfo)
exposed_url = ExposedUrl(
name='AGENT_SERVER', url='http://localhost:8000', port=8000
)
sandbox.exposed_urls = [exposed_url]
sandbox.session_api_key = 'test-key'
mock_load_skills.side_effect = Exception('Network error')
# Act
result = await service.load_and_merge_all_skills(
sandbox, 'owner/repo', '/workspace', 'http://localhost:8000'
)
# Assert
assert result == []
@@ -1165,50 +1165,6 @@ class TestLiveStatusAppConversationService:
)
self.mock_event_service.search_events.assert_called_once()
@pytest.mark.asyncio
async def test_export_conversation_calls_search_events_with_correct_parameter_name(
self,
):
"""Test that export_conversation calls search_events with 'conversation_id' parameter, not 'conversation_id__eq'.
This test verifies the fix for a bug where page_iterator was called with
conversation_id__eq instead of conversation_id, causing a TypeError since
the search_events method expects conversation_id as its parameter name.
"""
# Arrange
conversation_id = uuid4()
# Mock conversation info
mock_conversation_info = Mock(spec=AppConversationInfo)
mock_conversation_info.id = conversation_id
mock_conversation_info.model_dump_json = Mock(return_value='{}')
self.mock_app_conversation_info_service.get_app_conversation_info = AsyncMock(
return_value=mock_conversation_info
)
# Mock empty event page to simplify test
mock_event_page = Mock()
mock_event_page.items = []
mock_event_page.next_page_id = None
self.mock_event_service.search_events = AsyncMock(return_value=mock_event_page)
# Act
await self.service.export_conversation(conversation_id)
# Assert - Verify search_events was called with 'conversation_id', not 'conversation_id__eq'
self.mock_event_service.search_events.assert_called()
call_kwargs = self.mock_event_service.search_events.call_args[1]
assert 'conversation_id' in call_kwargs, (
"search_events should be called with 'conversation_id' parameter"
)
assert 'conversation_id__eq' not in call_kwargs, (
"search_events should NOT be called with 'conversation_id__eq' parameter"
)
assert call_kwargs['conversation_id'] == conversation_id
@pytest.mark.asyncio
async def test_export_conversation_large_pagination(self):
"""Test download with multiple pages of events."""
@@ -1332,7 +1288,7 @@ class TestLiveStatusAppConversationService:
task.sandbox_id = self.mock_sandbox.id
yield task
async def mock_run_setup_scripts(task, sandbox, workspace, agent_server_url):
async def mock_run_setup_scripts(task, sandbox, workspace):
yield task
self.service._wait_for_sandbox_start = mock_wait_for_sandbox
@@ -1786,855 +1742,3 @@ class TestLiveStatusAppConversationService:
stdio_server = mcp_servers['stdio-server']
assert stdio_server['command'] == 'npx'
assert stdio_server['env'] == {'TOKEN': 'value'}
class TestPluginHandling:
"""Test cases for plugin-related functionality in LiveStatusAppConversationService."""
def setup_method(self):
"""Set up test fixtures."""
# Create mock dependencies
self.mock_user_context = Mock(spec=UserContext)
self.mock_user_auth = Mock()
self.mock_user_context.user_auth = self.mock_user_auth
self.mock_jwt_service = Mock()
self.mock_sandbox_service = Mock()
self.mock_sandbox_spec_service = Mock()
self.mock_app_conversation_info_service = Mock()
self.mock_app_conversation_start_task_service = Mock()
self.mock_event_callback_service = Mock()
self.mock_event_service = Mock()
self.mock_httpx_client = Mock()
# Create service instance
self.service = LiveStatusAppConversationService(
init_git_in_empty_workspace=True,
user_context=self.mock_user_context,
app_conversation_info_service=self.mock_app_conversation_info_service,
app_conversation_start_task_service=self.mock_app_conversation_start_task_service,
event_callback_service=self.mock_event_callback_service,
event_service=self.mock_event_service,
sandbox_service=self.mock_sandbox_service,
sandbox_spec_service=self.mock_sandbox_spec_service,
jwt_service=self.mock_jwt_service,
sandbox_startup_timeout=30,
sandbox_startup_poll_frequency=1,
httpx_client=self.mock_httpx_client,
web_url='https://test.example.com',
openhands_provider_base_url='https://provider.example.com',
access_token_hard_timeout=None,
app_mode='test',
)
# Mock user info
self.mock_user = Mock()
self.mock_user.id = 'test_user_123'
self.mock_user.llm_model = 'gpt-4'
self.mock_user.llm_base_url = 'https://api.openai.com/v1'
self.mock_user.llm_api_key = 'test_api_key'
self.mock_user.confirmation_mode = False
self.mock_user.search_api_key = None
self.mock_user.condenser_max_size = None
self.mock_user.mcp_config = None
self.mock_user.security_analyzer = None
# Mock sandbox
self.mock_sandbox = Mock(spec=SandboxInfo)
self.mock_sandbox.id = uuid4()
self.mock_sandbox.status = SandboxStatus.RUNNING
def test_construct_initial_message_with_plugin_params_no_plugins(self):
"""Test _construct_initial_message_with_plugin_params with no plugins returns original message."""
from openhands.agent_server.models import SendMessageRequest, TextContent
# Test with None initial message and None plugins
result = self.service._construct_initial_message_with_plugin_params(None, None)
assert result is None
# Test with None initial message and empty plugins list
result = self.service._construct_initial_message_with_plugin_params(None, [])
assert result is None
# Test with initial message but None plugins
initial_msg = SendMessageRequest(content=[TextContent(text='Hello world')])
result = self.service._construct_initial_message_with_plugin_params(
initial_msg, None
)
assert result is initial_msg
def test_construct_initial_message_with_plugin_params_no_params(self):
"""Test _construct_initial_message_with_plugin_params with plugins but no parameters."""
from openhands.agent_server.models import SendMessageRequest, TextContent
from openhands.app_server.app_conversation.app_conversation_models import (
PluginSpec,
)
# Plugin with no parameters
plugins = [PluginSpec(source='github:owner/repo')]
# Test with None initial message
result = self.service._construct_initial_message_with_plugin_params(
None, plugins
)
assert result is None
# Test with initial message
initial_msg = SendMessageRequest(content=[TextContent(text='Hello world')])
result = self.service._construct_initial_message_with_plugin_params(
initial_msg, plugins
)
assert result is initial_msg
def test_construct_initial_message_with_plugin_params_creates_new_message(self):
"""Test _construct_initial_message_with_plugin_params creates message when no initial message."""
from openhands.agent_server.models import TextContent
from openhands.app_server.app_conversation.app_conversation_models import (
PluginSpec,
)
plugins = [
PluginSpec(
source='github:owner/repo',
parameters={'api_key': 'test123', 'debug': True},
)
]
result = self.service._construct_initial_message_with_plugin_params(
None, plugins
)
assert result is not None
assert len(result.content) == 1
assert isinstance(result.content[0], TextContent)
assert 'Plugin Configuration Parameters:' in result.content[0].text
assert '- api_key: test123' in result.content[0].text
assert '- debug: True' in result.content[0].text
assert result.run is True
def test_construct_initial_message_with_plugin_params_appends_to_message(self):
"""Test _construct_initial_message_with_plugin_params appends to existing message."""
from openhands.agent_server.models import SendMessageRequest, TextContent
from openhands.app_server.app_conversation.app_conversation_models import (
PluginSpec,
)
initial_msg = SendMessageRequest(
content=[TextContent(text='Please analyze this codebase')],
run=False,
)
plugins = [
PluginSpec(
source='github:owner/repo',
ref='v1.0.0',
parameters={'target_dir': '/src', 'verbose': True},
)
]
result = self.service._construct_initial_message_with_plugin_params(
initial_msg, plugins
)
assert result is not None
assert len(result.content) == 1
text = result.content[0].text
assert text.startswith('Please analyze this codebase')
assert 'Plugin Configuration Parameters:' in text
assert '- target_dir: /src' in text
assert '- verbose: True' in text
assert result.run is False
def test_construct_initial_message_with_plugin_params_preserves_role(self):
"""Test _construct_initial_message_with_plugin_params preserves message role."""
from openhands.agent_server.models import SendMessageRequest, TextContent
from openhands.app_server.app_conversation.app_conversation_models import (
PluginSpec,
)
initial_msg = SendMessageRequest(
role='system',
content=[TextContent(text='System message')],
)
plugins = [PluginSpec(source='github:owner/repo', parameters={'key': 'value'})]
result = self.service._construct_initial_message_with_plugin_params(
initial_msg, plugins
)
assert result is not None
assert result.role == 'system'
def test_construct_initial_message_with_plugin_params_empty_content(self):
"""Test _construct_initial_message_with_plugin_params handles empty content list."""
from openhands.agent_server.models import SendMessageRequest, TextContent
from openhands.app_server.app_conversation.app_conversation_models import (
PluginSpec,
)
initial_msg = SendMessageRequest(content=[])
plugins = [PluginSpec(source='github:owner/repo', parameters={'key': 'value'})]
result = self.service._construct_initial_message_with_plugin_params(
initial_msg, plugins
)
assert result is not None
assert len(result.content) == 1
assert isinstance(result.content[0], TextContent)
assert 'Plugin Configuration Parameters:' in result.content[0].text
def test_construct_initial_message_with_multiple_plugins(self):
"""Test _construct_initial_message_with_plugin_params handles multiple plugins."""
from openhands.agent_server.models import TextContent
from openhands.app_server.app_conversation.app_conversation_models import (
PluginSpec,
)
plugins = [
PluginSpec(
source='github:owner/plugin1',
parameters={'key1': 'value1'},
),
PluginSpec(
source='github:owner/plugin2',
parameters={'key2': 'value2'},
),
]
result = self.service._construct_initial_message_with_plugin_params(
None, plugins
)
assert result is not None
assert len(result.content) == 1
assert isinstance(result.content[0], TextContent)
text = result.content[0].text
assert 'Plugin Configuration Parameters:' in text
# Multiple plugins should show grouped by plugin name
assert 'plugin1' in text
assert 'plugin2' in text
assert 'key1: value1' in text
assert 'key2: value2' in text
@pytest.mark.asyncio
@patch(
'openhands.app_server.app_conversation.live_status_app_conversation_service.ExperimentManagerImpl'
)
async def test_finalize_conversation_request_with_plugins(
self, mock_experiment_manager
):
"""Test _finalize_conversation_request passes plugins list to StartConversationRequest."""
from openhands.app_server.app_conversation.app_conversation_models import (
PluginSpec,
)
# Arrange
mock_agent = Mock(spec=Agent)
mock_llm = Mock(spec=LLM)
mock_llm.model = 'gpt-4'
mock_llm.usage_id = 'agent'
mock_updated_agent = Mock(spec=Agent)
mock_updated_agent.llm = mock_llm
mock_updated_agent.condenser = None
mock_experiment_manager.run_agent_variant_tests__v1.return_value = (
mock_updated_agent
)
workspace = LocalWorkspace(working_dir='/test')
secrets = {'test': StaticSecret(value='secret')}
plugins = [
PluginSpec(
source='github:owner/my-plugin',
ref='v1.0.0',
parameters={'api_key': 'test123'},
)
]
# Act
result = await self.service._finalize_conversation_request(
mock_agent,
None,
self.mock_user,
workspace,
None,
secrets,
self.mock_sandbox,
None,
None,
'/test/dir',
plugins=plugins,
)
# Assert
assert isinstance(result, StartConversationRequest)
assert result.plugins is not None
assert len(result.plugins) == 1
assert result.plugins[0].source == 'github:owner/my-plugin'
assert result.plugins[0].ref == 'v1.0.0'
# Also verify initial message contains plugin params
assert result.initial_message is not None
assert (
'Plugin Configuration Parameters:' in result.initial_message.content[0].text
)
assert '- api_key: test123' in result.initial_message.content[0].text
@pytest.mark.asyncio
@patch(
'openhands.app_server.app_conversation.live_status_app_conversation_service.ExperimentManagerImpl'
)
async def test_finalize_conversation_request_without_plugins(
self, mock_experiment_manager
):
"""Test _finalize_conversation_request without plugins sets plugins to None."""
# Arrange
mock_agent = Mock(spec=Agent)
mock_llm = Mock(spec=LLM)
mock_llm.model = 'gpt-4'
mock_llm.usage_id = 'agent'
mock_updated_agent = Mock(spec=Agent)
mock_updated_agent.llm = mock_llm
mock_updated_agent.condenser = None
mock_experiment_manager.run_agent_variant_tests__v1.return_value = (
mock_updated_agent
)
workspace = LocalWorkspace(working_dir='/test')
secrets = {}
# Act
result = await self.service._finalize_conversation_request(
mock_agent,
None,
self.mock_user,
workspace,
None,
secrets,
self.mock_sandbox,
None,
None,
'/test/dir',
plugins=None,
)
# Assert
assert isinstance(result, StartConversationRequest)
assert result.plugins is None
@pytest.mark.asyncio
@patch(
'openhands.app_server.app_conversation.live_status_app_conversation_service.ExperimentManagerImpl'
)
async def test_finalize_conversation_request_plugin_without_ref(
self, mock_experiment_manager
):
"""Test _finalize_conversation_request with plugin that has no ref."""
from openhands.app_server.app_conversation.app_conversation_models import (
PluginSpec,
)
# Arrange
mock_agent = Mock(spec=Agent)
mock_llm = Mock(spec=LLM)
mock_llm.model = 'gpt-4'
mock_llm.usage_id = 'agent'
mock_updated_agent = Mock(spec=Agent)
mock_updated_agent.llm = mock_llm
mock_updated_agent.condenser = None
mock_experiment_manager.run_agent_variant_tests__v1.return_value = (
mock_updated_agent
)
workspace = LocalWorkspace(working_dir='/test')
secrets = {}
# Plugin without ref or parameters
plugins = [PluginSpec(source='github:owner/my-plugin')]
# Act
result = await self.service._finalize_conversation_request(
mock_agent,
None,
self.mock_user,
workspace,
None,
secrets,
self.mock_sandbox,
None,
None,
'/test/dir',
plugins=plugins,
)
# Assert
assert isinstance(result, StartConversationRequest)
assert result.plugins is not None
assert len(result.plugins) == 1
assert result.plugins[0].source == 'github:owner/my-plugin'
assert result.plugins[0].ref is None
# No parameters, so initial message should be None
assert result.initial_message is None
@pytest.mark.asyncio
@patch(
'openhands.app_server.app_conversation.live_status_app_conversation_service.ExperimentManagerImpl'
)
async def test_finalize_conversation_request_plugin_with_repo_path(
self, mock_experiment_manager
):
"""Test _finalize_conversation_request passes repo_path to PluginSource."""
from openhands.app_server.app_conversation.app_conversation_models import (
PluginSpec,
)
# Arrange
mock_agent = Mock(spec=Agent)
mock_llm = Mock(spec=LLM)
mock_llm.model = 'gpt-4'
mock_llm.usage_id = 'agent'
mock_updated_agent = Mock(spec=Agent)
mock_updated_agent.llm = mock_llm
mock_updated_agent.condenser = None
mock_experiment_manager.run_agent_variant_tests__v1.return_value = (
mock_updated_agent
)
workspace = LocalWorkspace(working_dir='/test')
secrets = {}
# Plugin with repo_path (for marketplace repos containing multiple plugins)
plugins = [
PluginSpec(
source='github:owner/marketplace-repo',
ref='main',
repo_path='plugins/city-weather',
)
]
# Act
result = await self.service._finalize_conversation_request(
mock_agent,
None,
self.mock_user,
workspace,
None,
secrets,
self.mock_sandbox,
None,
None,
'/test/dir',
plugins=plugins,
)
# Assert
assert isinstance(result, StartConversationRequest)
assert result.plugins is not None
assert len(result.plugins) == 1
assert result.plugins[0].source == 'github:owner/marketplace-repo'
assert result.plugins[0].ref == 'main'
assert result.plugins[0].repo_path == 'plugins/city-weather'
@pytest.mark.asyncio
@patch(
'openhands.app_server.app_conversation.live_status_app_conversation_service.ExperimentManagerImpl'
)
async def test_finalize_conversation_request_multiple_plugins(
self, mock_experiment_manager
):
"""Test _finalize_conversation_request with multiple plugins."""
from openhands.app_server.app_conversation.app_conversation_models import (
PluginSpec,
)
# Arrange
mock_agent = Mock(spec=Agent)
mock_llm = Mock(spec=LLM)
mock_llm.model = 'gpt-4'
mock_llm.usage_id = 'agent'
mock_updated_agent = Mock(spec=Agent)
mock_updated_agent.llm = mock_llm
mock_updated_agent.condenser = None
mock_experiment_manager.run_agent_variant_tests__v1.return_value = (
mock_updated_agent
)
workspace = LocalWorkspace(working_dir='/test')
secrets = {}
# Multiple plugins
plugins = [
PluginSpec(source='github:owner/security-plugin', ref='v2.0.0'),
PluginSpec(
source='github:owner/monorepo',
repo_path='plugins/logging',
),
PluginSpec(source='/local/path/to/plugin'),
]
# Act
result = await self.service._finalize_conversation_request(
mock_agent,
None,
self.mock_user,
workspace,
None,
secrets,
self.mock_sandbox,
None,
None,
'/test/dir',
plugins=plugins,
)
# Assert
assert isinstance(result, StartConversationRequest)
assert result.plugins is not None
assert len(result.plugins) == 3
assert result.plugins[0].source == 'github:owner/security-plugin'
assert result.plugins[0].ref == 'v2.0.0'
assert result.plugins[1].source == 'github:owner/monorepo'
assert result.plugins[1].repo_path == 'plugins/logging'
assert result.plugins[2].source == '/local/path/to/plugin'
@pytest.mark.asyncio
async def test_build_start_conversation_request_for_user_with_plugins(self):
"""Test _build_start_conversation_request_for_user passes plugins to finalize method."""
from openhands.app_server.app_conversation.app_conversation_models import (
PluginSpec,
)
# Arrange
self.mock_user_context.get_user_info.return_value = self.mock_user
self.mock_user_context.get_secrets.return_value = {}
self.mock_user_context.get_provider_tokens = AsyncMock(return_value=None)
self.mock_user_context.get_mcp_api_key.return_value = None
plugins = [
PluginSpec(
source='https://github.com/org/plugin.git',
ref='main',
parameters={'config_file': 'custom.yaml'},
)
]
# Mock _finalize_conversation_request to capture the call
mock_finalize = AsyncMock(return_value=Mock(spec=StartConversationRequest))
self.service._finalize_conversation_request = mock_finalize
# Act
await self.service._build_start_conversation_request_for_user(
self.mock_sandbox,
None,
None,
None,
'/workspace',
plugins=plugins,
)
# Assert
mock_finalize.assert_called_once()
call_kwargs = mock_finalize.call_args.kwargs
assert call_kwargs['plugins'] == plugins
@pytest.mark.asyncio
async def test_build_start_conversation_request_for_user_without_plugins(self):
"""Test _build_start_conversation_request_for_user works without plugins."""
# Arrange
self.mock_user_context.get_user_info.return_value = self.mock_user
self.mock_user_context.get_secrets.return_value = {}
self.mock_user_context.get_provider_tokens = AsyncMock(return_value=None)
self.mock_user_context.get_mcp_api_key.return_value = None
# Mock _finalize_conversation_request
mock_finalize = AsyncMock(return_value=Mock(spec=StartConversationRequest))
self.service._finalize_conversation_request = mock_finalize
# Act
await self.service._build_start_conversation_request_for_user(
self.mock_sandbox,
None,
None,
None,
'/workspace',
)
# Assert
mock_finalize.assert_called_once()
call_kwargs = mock_finalize.call_args.kwargs
assert call_kwargs.get('plugins') is None
class TestPluginSpecModel:
"""Test cases for the PluginSpec model."""
def test_plugin_spec_with_all_fields(self):
"""Test PluginSpec with all fields provided."""
from openhands.app_server.app_conversation.app_conversation_models import (
PluginSpec,
)
plugin = PluginSpec(
source='github:owner/repo',
ref='v1.0.0',
repo_path='plugins/my-plugin',
parameters={'key1': 'value1', 'key2': 123, 'key3': True},
)
assert plugin.source == 'github:owner/repo'
assert plugin.ref == 'v1.0.0'
assert plugin.repo_path == 'plugins/my-plugin'
assert plugin.parameters == {'key1': 'value1', 'key2': 123, 'key3': True}
def test_plugin_spec_with_only_source(self):
"""Test PluginSpec with only source provided."""
from openhands.app_server.app_conversation.app_conversation_models import (
PluginSpec,
)
plugin = PluginSpec(source='https://github.com/owner/repo.git')
assert plugin.source == 'https://github.com/owner/repo.git'
assert plugin.ref is None
assert plugin.repo_path is None
assert plugin.parameters is None
def test_plugin_spec_serialization(self):
"""Test PluginSpec serialization to JSON."""
from openhands.app_server.app_conversation.app_conversation_models import (
PluginSpec,
)
plugin = PluginSpec(
source='github:owner/repo',
ref='main',
repo_path='plugins/my-plugin',
parameters={'debug': True},
)
json_data = plugin.model_dump()
assert json_data == {
'source': 'github:owner/repo',
'ref': 'main',
'repo_path': 'plugins/my-plugin',
'parameters': {'debug': True},
}
def test_plugin_spec_deserialization(self):
"""Test PluginSpec deserialization from dict."""
from openhands.app_server.app_conversation.app_conversation_models import (
PluginSpec,
)
data = {
'source': 'github:owner/repo',
'ref': 'v2.0.0',
'repo_path': 'plugins/weather',
'parameters': {'timeout': 30},
}
plugin = PluginSpec.model_validate(data)
assert plugin.source == 'github:owner/repo'
assert plugin.ref == 'v2.0.0'
assert plugin.repo_path == 'plugins/weather'
assert plugin.parameters == {'timeout': 30}
def test_plugin_spec_display_name_github_format(self):
"""Test display_name extracts repo name from github:owner/repo format."""
from openhands.app_server.app_conversation.app_conversation_models import (
PluginSpec,
)
plugin = PluginSpec(source='github:owner/my-plugin')
assert plugin.display_name == 'my-plugin'
def test_plugin_spec_display_name_git_url(self):
"""Test display_name extracts repo name from git URL."""
from openhands.app_server.app_conversation.app_conversation_models import (
PluginSpec,
)
plugin = PluginSpec(source='https://github.com/owner/repo.git')
assert plugin.display_name == 'repo.git'
def test_plugin_spec_display_name_local_path(self):
"""Test display_name extracts directory name from local path."""
from openhands.app_server.app_conversation.app_conversation_models import (
PluginSpec,
)
plugin = PluginSpec(source='/local/path/to/plugin')
assert plugin.display_name == 'plugin'
def test_plugin_spec_display_name_no_slash(self):
"""Test display_name returns source as-is when no slash present."""
from openhands.app_server.app_conversation.app_conversation_models import (
PluginSpec,
)
plugin = PluginSpec(source='local-plugin')
assert plugin.display_name == 'local-plugin'
def test_plugin_spec_format_params_as_text(self):
"""Test format_params_as_text formats parameters as text."""
from openhands.app_server.app_conversation.app_conversation_models import (
PluginSpec,
)
plugin = PluginSpec(
source='github:owner/repo',
parameters={'key1': 'value1', 'key2': 123},
)
result = plugin.format_params_as_text()
assert result == '- key1: value1\n- key2: 123'
def test_plugin_spec_format_params_as_text_with_indent(self):
"""Test format_params_as_text with custom indent."""
from openhands.app_server.app_conversation.app_conversation_models import (
PluginSpec,
)
plugin = PluginSpec(
source='github:owner/repo',
parameters={'debug': True},
)
result = plugin.format_params_as_text(indent=' ')
assert result == ' - debug: True'
def test_plugin_spec_format_params_as_text_no_params(self):
"""Test format_params_as_text returns None when no parameters."""
from openhands.app_server.app_conversation.app_conversation_models import (
PluginSpec,
)
plugin = PluginSpec(source='github:owner/repo')
assert plugin.format_params_as_text() is None
def test_plugin_spec_inherits_repo_path_validation(self):
"""Test PluginSpec inherits validation from SDK's PluginSource."""
import pytest
from openhands.app_server.app_conversation.app_conversation_models import (
PluginSpec,
)
# Should reject absolute paths
with pytest.raises(ValueError, match='must be relative'):
PluginSpec(source='github:owner/repo', repo_path='/absolute/path')
# Should reject parent traversal
with pytest.raises(ValueError, match="cannot contain '..'"):
PluginSpec(source='github:owner/repo', repo_path='../parent/path')
class TestAppConversationStartRequestWithPlugins:
"""Test cases for AppConversationStartRequest with plugins field."""
def test_start_request_with_plugins(self):
"""Test AppConversationStartRequest with plugins field."""
from openhands.app_server.app_conversation.app_conversation_models import (
AppConversationStartRequest,
PluginSpec,
)
plugins = [
PluginSpec(
source='github:owner/my-plugin',
ref='v1.0.0',
parameters={'api_key': 'test'},
)
]
request = AppConversationStartRequest(
title='Test conversation',
plugins=plugins,
)
assert request.plugins is not None
assert len(request.plugins) == 1
assert request.plugins[0].source == 'github:owner/my-plugin'
assert request.plugins[0].ref == 'v1.0.0'
assert request.plugins[0].parameters == {'api_key': 'test'}
def test_start_request_without_plugins(self):
"""Test AppConversationStartRequest without plugins field (backwards compatible)."""
from openhands.app_server.app_conversation.app_conversation_models import (
AppConversationStartRequest,
)
request = AppConversationStartRequest(
title='Test conversation',
)
assert request.plugins is None
def test_start_request_serialization_with_plugins(self):
"""Test AppConversationStartRequest serialization includes plugins."""
from openhands.app_server.app_conversation.app_conversation_models import (
AppConversationStartRequest,
PluginSpec,
)
plugins = [PluginSpec(source='github:owner/repo')]
request = AppConversationStartRequest(plugins=plugins)
json_data = request.model_dump()
assert 'plugins' in json_data
assert len(json_data['plugins']) == 1
assert json_data['plugins'][0]['source'] == 'github:owner/repo'
def test_start_request_deserialization_with_plugins(self):
"""Test AppConversationStartRequest deserialization from JSON with plugins."""
from openhands.app_server.app_conversation.app_conversation_models import (
AppConversationStartRequest,
)
data = {
'title': 'Test',
'plugins': [
{
'source': 'github:owner/plugin',
'ref': 'main',
'parameters': {'key': 'value'},
},
],
}
request = AppConversationStartRequest.model_validate(data)
assert request.plugins is not None
assert len(request.plugins) == 1
assert request.plugins[0].source == 'github:owner/plugin'
assert request.plugins[0].ref == 'main'
assert request.plugins[0].parameters == {'key': 'value'}
def test_start_request_with_multiple_plugins(self):
"""Test AppConversationStartRequest with multiple plugins."""
from openhands.app_server.app_conversation.app_conversation_models import (
AppConversationStartRequest,
PluginSpec,
)
plugins = [
PluginSpec(source='github:owner/plugin1', ref='v1.0.0'),
PluginSpec(source='github:owner/plugin2', repo_path='plugins/sub'),
PluginSpec(source='/local/path'),
]
request = AppConversationStartRequest(
title='Test conversation',
plugins=plugins,
)
assert request.plugins is not None
assert len(request.plugins) == 3
assert request.plugins[0].source == 'github:owner/plugin1'
assert request.plugins[1].repo_path == 'plugins/sub'
assert request.plugins[2].source == '/local/path'
File diff suppressed because it is too large Load Diff