mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-04-29 03:00:45 -04:00
Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 48efca5f34 |
@@ -39,7 +39,8 @@ jobs:
|
||||
run: |
|
||||
if [[ "$GITHUB_EVENT_NAME" == "pull_request" ]]; then
|
||||
json=$(jq -n -c '[
|
||||
{ image: "nikolaik/python-nodejs:python3.12-nodejs22", tag: "nikolaik" }
|
||||
{ image: "nikolaik/python-nodejs:python3.12-nodejs22", tag: "nikolaik" },
|
||||
{ image: "ubuntu:24.04", tag: "ubuntu" }
|
||||
]')
|
||||
else
|
||||
json=$(jq -n -c '[
|
||||
@@ -148,9 +149,6 @@ jobs:
|
||||
push: true
|
||||
tags: ${{ env.DOCKER_TAGS }}
|
||||
platforms: ${{ env.DOCKER_PLATFORM }}
|
||||
# Caching directives to boost performance
|
||||
cache-from: type=registry,ref=ghcr.io/${{ env.REPO_OWNER }}/runtime:buildcache-${{ matrix.base_image.tag }}
|
||||
cache-to: type=registry,ref=ghcr.io/${{ env.REPO_OWNER }}/runtime:buildcache-${{ matrix.base_image.tag }},mode=max
|
||||
build-args: ${{ env.DOCKER_BUILD_ARGS }}
|
||||
context: containers/runtime
|
||||
provenance: false
|
||||
|
||||
+2
-2
@@ -7,8 +7,8 @@ services:
|
||||
image: openhands:latest
|
||||
container_name: openhands-app-${DATE:-}
|
||||
environment:
|
||||
- AGENT_SERVER_IMAGE_REPOSITORY=${AGENT_SERVER_IMAGE_REPOSITORY:-ghcr.io/openhands/agent-server}
|
||||
- AGENT_SERVER_IMAGE_TAG=${AGENT_SERVER_IMAGE_TAG:-31536c8-python}
|
||||
- AGENT_SERVER_IMAGE_REPOSITORY=${AGENT_SERVER_IMAGE_REPOSITORY:-docker.openhands.dev/openhands/runtime}
|
||||
- AGENT_SERVER_IMAGE_TAG=${AGENT_SERVER_IMAGE_TAG:-1.2-nikolaik}
|
||||
#- SANDBOX_USER_ID=${SANDBOX_USER_ID:-1234} # enable this only if you want a specific non-root sandbox user but you will have to manually adjust permissions of ~/.openhands for this user
|
||||
- WORKSPACE_MOUNT_PATH=${WORKSPACE_BASE:-$PWD/workspace}
|
||||
ports:
|
||||
|
||||
@@ -0,0 +1,62 @@
|
||||
# Organization Code Review Bot - PRD
|
||||
|
||||
## Problem Statement
|
||||
|
||||
### Engineering Leader Perspective
|
||||
|
||||
AI coding agents have made code generation cheap, but verification remains the bottleneck. Engineering leaders lack visibility into:
|
||||
- **Review quality**: How effective are AI-generated code reviews vs. human reviews?
|
||||
- **Developer engagement**: Are developers acting on AI review feedback?
|
||||
- **Organizational patterns**: What recurring feedback themes exist across the organization that could be codified?
|
||||
|
||||
Without telemetry, orgs cannot measure ROI of AI-assisted code review or systematically improve their review standards.
|
||||
|
||||
### Developer Perspective
|
||||
|
||||
Current AI code reviews produce itemized feedback, but acting on that feedback is manual and disconnected:
|
||||
- Developers must context-switch to address each review item separately
|
||||
- No one-click path from "review comment" → "fix implementation"
|
||||
- No feedback loop to help the AI learn which suggestions are valuable vs. noise
|
||||
|
||||
## Goals
|
||||
|
||||
1. **One-click remediation**: Enable developers to launch an OpenHands conversation directly from a review item to address it
|
||||
2. **Org-wide telemetry**: Track accept/dismiss rates on review items to measure review quality and surface patterns
|
||||
3. **Learned review standards**: Distill recurring org-specific feedback into a lightweight review standard the bot applies automatically
|
||||
4. **Verification signals**: Integrate code survival metrics (what % of AI-written code survives to merge) to predict review quality
|
||||
|
||||
## User Personas
|
||||
|
||||
| Persona | Description |
|
||||
|---------|-------------|
|
||||
| **Developer** | Uses the bot to get PR reviews and quickly address feedback items via one-click OpenHands sessions |
|
||||
| **Tech Lead** | Reviews org-wide feedback patterns to identify common issues and improve team coding standards |
|
||||
| **Engineering Manager** | Monitors accept/dismiss telemetry to assess AI review effectiveness and developer adoption |
|
||||
| **Platform Engineer** | Configures org-specific review rules and integrates the bot with existing CI/CD workflows |
|
||||
|
||||
## Key Use Cases
|
||||
|
||||
### 1. One-Click Review Item Remediation
|
||||
- Developer receives AI code review with itemized feedback
|
||||
- Each feedback item has a "Fix with OpenHands" button that launches a scoped conversation to address that specific issue
|
||||
- Context (diff, review comment, file) is automatically passed to the agent
|
||||
|
||||
### 2. Accept/Dismiss Feedback Telemetry
|
||||
- Developers can mark review items as "Agree & Fix" or "Dismiss"
|
||||
- Org-wide dashboard shows aggregate accept/dismiss rates per review category
|
||||
- Identifies high-value feedback patterns vs. low-signal noise
|
||||
|
||||
### 3. Org-Specific Review Standards
|
||||
- Platform engineer configures org-specific review rules (e.g., "always check for error handling in API routes")
|
||||
- Bot learns from historical code reviews to surface org-specific patterns
|
||||
- Review standards are versioned and auditable
|
||||
|
||||
### 4. Code Survival Metrics
|
||||
- Track what fraction of AI-suggested changes make it into the merged PR
|
||||
- Surface low-survival patterns to improve review prompt quality
|
||||
- Use survival signals to predict whether a review item is likely to be addressed
|
||||
|
||||
### 5. Review Quality Dashboard
|
||||
- Engineering managers see per-team and per-repo review effectiveness metrics
|
||||
- Trending view of common feedback categories over time
|
||||
- Alerts when review quality drops or dismiss rates spike
|
||||
@@ -1,205 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
Downgrade script for migrated users.
|
||||
|
||||
This script identifies users who have been migrated (already_migrated=True)
|
||||
and reverts them back to the pre-migration state.
|
||||
|
||||
Usage:
|
||||
# Dry run - just list the users that would be downgraded
|
||||
python downgrade_migrated_users.py --dry-run
|
||||
|
||||
# Downgrade a specific user by their keycloak_user_id
|
||||
python downgrade_migrated_users.py --user-id <user_id>
|
||||
|
||||
# Downgrade all migrated users (with confirmation)
|
||||
python downgrade_migrated_users.py --all
|
||||
|
||||
# Downgrade all migrated users without confirmation (dangerous!)
|
||||
python downgrade_migrated_users.py --all --no-confirm
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import sys
|
||||
|
||||
# Add the enterprise directory to the path
|
||||
sys.path.insert(0, '/workspace/project/OpenHands/enterprise')
|
||||
|
||||
from server.logger import logger
|
||||
from sqlalchemy import select, text
|
||||
from storage.database import session_maker
|
||||
from storage.user_settings import UserSettings
|
||||
from storage.user_store import UserStore
|
||||
|
||||
|
||||
def get_migrated_users() -> list[str]:
|
||||
"""Get list of keycloak_user_ids for users who have been migrated.
|
||||
|
||||
This includes:
|
||||
1. Users with already_migrated=True in user_settings (migrated users)
|
||||
2. Users in the 'user' table who don't have a user_settings entry (new sign-ups)
|
||||
"""
|
||||
with session_maker() as session:
|
||||
# Get users from user_settings with already_migrated=True
|
||||
migrated_result = session.execute(
|
||||
select(UserSettings.keycloak_user_id).where(
|
||||
UserSettings.already_migrated.is_(True)
|
||||
)
|
||||
)
|
||||
migrated_users = {row[0] for row in migrated_result.fetchall() if row[0]}
|
||||
|
||||
# Get users from the 'user' table (new sign-ups won't have user_settings)
|
||||
# These are users who signed up after the migration was deployed
|
||||
new_signup_result = session.execute(
|
||||
text("""
|
||||
SELECT CAST(u.id AS VARCHAR)
|
||||
FROM "user" u
|
||||
WHERE NOT EXISTS (
|
||||
SELECT 1 FROM user_settings us
|
||||
WHERE us.keycloak_user_id = CAST(u.id AS VARCHAR)
|
||||
)
|
||||
""")
|
||||
)
|
||||
new_signups = {row[0] for row in new_signup_result.fetchall() if row[0]}
|
||||
|
||||
# Combine both sets
|
||||
all_users = migrated_users | new_signups
|
||||
return list(all_users)
|
||||
|
||||
|
||||
async def downgrade_user(user_id: str) -> bool:
|
||||
"""Downgrade a single user.
|
||||
|
||||
Args:
|
||||
user_id: The keycloak_user_id to downgrade
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
result = await UserStore.downgrade_user(user_id)
|
||||
if result:
|
||||
print(f'✓ Successfully downgraded user: {user_id}')
|
||||
return True
|
||||
else:
|
||||
print(f'✗ Failed to downgrade user: {user_id}')
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f'✗ Error downgrading user {user_id}: {e}')
|
||||
logger.exception(
|
||||
'downgrade_script:error',
|
||||
extra={'user_id': user_id, 'error': str(e)},
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
async def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Downgrade migrated users back to pre-migration state'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--dry-run',
|
||||
action='store_true',
|
||||
help='Just list users that would be downgraded, without making changes',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--user-id',
|
||||
type=str,
|
||||
help='Downgrade a specific user by keycloak_user_id',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--all',
|
||||
action='store_true',
|
||||
help='Downgrade all migrated users',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--no-confirm',
|
||||
action='store_true',
|
||||
help='Skip confirmation prompt (use with caution!)',
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Get list of migrated users
|
||||
migrated_users = get_migrated_users()
|
||||
print(f'\nFound {len(migrated_users)} migrated user(s).')
|
||||
|
||||
if args.dry_run:
|
||||
print('\n--- DRY RUN MODE ---')
|
||||
print('The following users would be downgraded:')
|
||||
for user_id in migrated_users:
|
||||
print(f' - {user_id}')
|
||||
print('\nNo changes were made.')
|
||||
return
|
||||
|
||||
if args.user_id:
|
||||
# Downgrade a specific user
|
||||
if args.user_id not in migrated_users:
|
||||
print(f'\nUser {args.user_id} is not in the migrated users list.')
|
||||
print('Either the user was not migrated, or the user_id is incorrect.')
|
||||
return
|
||||
|
||||
print(f'\nDowngrading user: {args.user_id}')
|
||||
if not args.no_confirm:
|
||||
confirm = input('Are you sure? (yes/no): ')
|
||||
if confirm.lower() != 'yes':
|
||||
print('Cancelled.')
|
||||
return
|
||||
|
||||
success = await downgrade_user(args.user_id)
|
||||
if success:
|
||||
print('\nDowngrade completed successfully.')
|
||||
else:
|
||||
print('\nDowngrade failed. Check logs for details.')
|
||||
sys.exit(1)
|
||||
|
||||
elif args.all:
|
||||
# Downgrade all migrated users
|
||||
if not migrated_users:
|
||||
print('\nNo migrated users to downgrade.')
|
||||
return
|
||||
|
||||
print(f'\n⚠️ About to downgrade {len(migrated_users)} user(s).')
|
||||
if not args.no_confirm:
|
||||
print('\nThis will:')
|
||||
print(' - Revert LiteLLM team/user budget settings')
|
||||
print(' - Delete organization entries')
|
||||
print(' - Delete user entries in the new schema')
|
||||
print(' - Reset the already_migrated flag')
|
||||
print('\nUsers to downgrade:')
|
||||
for user_id in migrated_users[:10]: # Show first 10
|
||||
print(f' - {user_id}')
|
||||
if len(migrated_users) > 10:
|
||||
print(f' ... and {len(migrated_users) - 10} more')
|
||||
|
||||
confirm = input('\nType "yes" to proceed: ')
|
||||
if confirm.lower() != 'yes':
|
||||
print('Cancelled.')
|
||||
return
|
||||
|
||||
print('\nStarting downgrade...\n')
|
||||
success_count = 0
|
||||
fail_count = 0
|
||||
|
||||
for user_id in migrated_users:
|
||||
success = await downgrade_user(user_id)
|
||||
if success:
|
||||
success_count += 1
|
||||
else:
|
||||
fail_count += 1
|
||||
|
||||
print('\n--- Summary ---')
|
||||
print(f'Successful: {success_count}')
|
||||
print(f'Failed: {fail_count}')
|
||||
|
||||
if fail_count > 0:
|
||||
sys.exit(1)
|
||||
|
||||
else:
|
||||
parser.print_help()
|
||||
print('\nPlease specify --dry-run, --user-id, or --all')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
asyncio.run(main())
|
||||
@@ -26,14 +26,12 @@ from integrations.utils import (
|
||||
from integrations.v1_utils import get_saas_user_auth
|
||||
from jinja2 import Environment, FileSystemLoader
|
||||
from pydantic import SecretStr
|
||||
from server.auth.auth_error import ExpiredError
|
||||
from server.auth.constants import GITHUB_APP_CLIENT_ID, GITHUB_APP_PRIVATE_KEY
|
||||
from server.auth.token_manager import TokenManager
|
||||
from server.utils.conversation_callback_utils import register_callback_processor
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.integrations.provider import ProviderToken, ProviderType
|
||||
from openhands.integrations.service_types import AuthenticationError
|
||||
from openhands.server.types import (
|
||||
LLMAuthenticationError,
|
||||
MissingSettingsError,
|
||||
@@ -349,7 +347,7 @@ class GithubManager(Manager):
|
||||
|
||||
msg_info = f'@{user_info.username} please set a valid LLM API key in [OpenHands Cloud]({HOST_URL}) before starting a job.'
|
||||
|
||||
except (AuthenticationError, ExpiredError, SessionExpiredError) as e:
|
||||
except SessionExpiredError as e:
|
||||
logger.warning(
|
||||
f'[GitHub] Session expired for user {user_info.username}: {str(e)}'
|
||||
)
|
||||
|
||||
@@ -1,37 +1,18 @@
|
||||
"""Jira integration manager.
|
||||
|
||||
This module orchestrates the processing of Jira webhook events:
|
||||
1. Parse webhook payload (via JiraPayloadParser)
|
||||
2. Validate workspace
|
||||
3. Authenticate user
|
||||
4. Create view with repository selection (via JiraFactory)
|
||||
5. Start conversation job
|
||||
|
||||
The manager delegates payload parsing to JiraPayloadParser and view creation
|
||||
to JiraFactory, keeping the orchestration logic clean and traceable.
|
||||
"""
|
||||
from typing import Tuple
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import httpx
|
||||
from integrations.jira.jira_payload import (
|
||||
JiraPayloadError,
|
||||
JiraPayloadParser,
|
||||
JiraPayloadSkipped,
|
||||
JiraPayloadSuccess,
|
||||
JiraWebhookPayload,
|
||||
from integrations.jira.jira_types import JiraViewInterface
|
||||
from integrations.jira.jira_view import (
|
||||
JiraFactory,
|
||||
JiraNewConversationView,
|
||||
)
|
||||
from integrations.jira.jira_types import (
|
||||
JiraViewInterface,
|
||||
RepositoryNotFoundError,
|
||||
StartingConvoException,
|
||||
)
|
||||
from integrations.jira.jira_view import JiraFactory, JiraNewConversationView
|
||||
from integrations.manager import Manager
|
||||
from integrations.models import Message
|
||||
from integrations.models import JobContext, Message
|
||||
from integrations.utils import (
|
||||
HOST,
|
||||
HOST_URL,
|
||||
OPENHANDS_RESOLVER_TEMPLATES_DIR,
|
||||
get_oh_labels,
|
||||
filter_potential_repos_by_user_msg,
|
||||
get_session_expired_message,
|
||||
)
|
||||
from jinja2 import Environment, FileSystemLoader
|
||||
@@ -43,6 +24,9 @@ from storage.jira_user import JiraUser
|
||||
from storage.jira_workspace import JiraWorkspace
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.integrations.provider import ProviderHandler
|
||||
from openhands.integrations.service_types import Repository
|
||||
from openhands.server.shared import server_config
|
||||
from openhands.server.types import (
|
||||
LLMAuthenticationError,
|
||||
MissingSettingsError,
|
||||
@@ -53,211 +37,267 @@ from openhands.utils.http_session import httpx_verify_option
|
||||
|
||||
JIRA_CLOUD_API_URL = 'https://api.atlassian.com/ex/jira'
|
||||
|
||||
# Get OH labels for this environment
|
||||
OH_LABEL, INLINE_OH_LABEL = get_oh_labels(HOST)
|
||||
|
||||
|
||||
class JiraManager(Manager):
|
||||
"""Manager for processing Jira webhook events.
|
||||
|
||||
This class orchestrates the flow from webhook receipt to conversation creation,
|
||||
delegating parsing to JiraPayloadParser and view creation to JiraFactory.
|
||||
"""
|
||||
|
||||
def __init__(self, token_manager: TokenManager):
|
||||
self.token_manager = token_manager
|
||||
self.integration_store = JiraIntegrationStore.get_instance()
|
||||
self.jinja_env = Environment(
|
||||
loader=FileSystemLoader(OPENHANDS_RESOLVER_TEMPLATES_DIR + 'jira')
|
||||
)
|
||||
self.payload_parser = JiraPayloadParser(
|
||||
oh_label=OH_LABEL,
|
||||
inline_oh_label=INLINE_OH_LABEL,
|
||||
)
|
||||
|
||||
async def receive_message(self, message: Message):
|
||||
"""Process incoming Jira webhook message.
|
||||
|
||||
Flow:
|
||||
1. Parse webhook payload
|
||||
2. Validate workspace exists and is active
|
||||
3. Authenticate user
|
||||
4. Create view (includes fetching issue details and selecting repository)
|
||||
5. Start job
|
||||
|
||||
Each step has clear logging for traceability.
|
||||
"""
|
||||
raw_payload = message.message.get('payload', {})
|
||||
|
||||
# Step 1: Parse webhook payload
|
||||
logger.info(
|
||||
'[Jira] Received webhook',
|
||||
extra={'raw_payload': raw_payload},
|
||||
)
|
||||
|
||||
parse_result = self.payload_parser.parse(raw_payload)
|
||||
|
||||
if isinstance(parse_result, JiraPayloadSkipped):
|
||||
logger.info(
|
||||
'[Jira] Webhook skipped', extra={'reason': parse_result.skip_reason}
|
||||
)
|
||||
return
|
||||
|
||||
if isinstance(parse_result, JiraPayloadError):
|
||||
logger.warning(
|
||||
'[Jira] Webhook parse failed', extra={'error': parse_result.error}
|
||||
)
|
||||
return
|
||||
|
||||
payload = parse_result.payload
|
||||
logger.info(
|
||||
'[Jira] Processing webhook',
|
||||
extra={
|
||||
'event_type': payload.event_type.value,
|
||||
'issue_key': payload.issue_key,
|
||||
'user_email': payload.user_email,
|
||||
},
|
||||
)
|
||||
|
||||
# Step 2: Validate workspace
|
||||
workspace = await self._get_active_workspace(payload)
|
||||
if not workspace:
|
||||
return
|
||||
|
||||
# Step 3: Authenticate user
|
||||
jira_user, saas_user_auth = await self._authenticate_user(payload, workspace)
|
||||
if not jira_user or not saas_user_auth:
|
||||
return
|
||||
|
||||
# Step 4: Create view (includes issue details fetch and repo selection)
|
||||
decrypted_api_key = self.token_manager.decrypt_text(workspace.svc_acc_api_key)
|
||||
|
||||
try:
|
||||
view = await JiraFactory.create_view(
|
||||
payload=payload,
|
||||
workspace=workspace,
|
||||
user=jira_user,
|
||||
user_auth=saas_user_auth,
|
||||
decrypted_api_key=decrypted_api_key,
|
||||
)
|
||||
except RepositoryNotFoundError as e:
|
||||
logger.warning(
|
||||
'[Jira] Repository not found',
|
||||
extra={'issue_key': payload.issue_key, 'error': str(e)},
|
||||
)
|
||||
await self._send_error_from_payload(payload, workspace, str(e))
|
||||
return
|
||||
except StartingConvoException as e:
|
||||
logger.warning(
|
||||
'[Jira] View creation failed',
|
||||
extra={'issue_key': payload.issue_key, 'error': str(e)},
|
||||
)
|
||||
await self._send_error_from_payload(payload, workspace, str(e))
|
||||
return
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
'[Jira] Unexpected error creating view',
|
||||
extra={'issue_key': payload.issue_key, 'error': str(e)},
|
||||
exc_info=True,
|
||||
)
|
||||
await self._send_error_from_payload(
|
||||
payload,
|
||||
workspace,
|
||||
'Failed to initialize conversation. Please try again.',
|
||||
)
|
||||
return
|
||||
|
||||
# Step 5: Start job
|
||||
await self.start_job(view)
|
||||
|
||||
async def _get_active_workspace(
|
||||
self, payload: JiraWebhookPayload
|
||||
) -> JiraWorkspace | None:
|
||||
"""Validate and return the workspace for the webhook.
|
||||
|
||||
Returns None if:
|
||||
- Workspace not found
|
||||
- Workspace is inactive
|
||||
- Request is from service account (to prevent recursion)
|
||||
"""
|
||||
workspace = await self.integration_store.get_workspace_by_name(
|
||||
payload.workspace_name
|
||||
)
|
||||
|
||||
if not workspace:
|
||||
logger.warning(
|
||||
'[Jira] Workspace not found',
|
||||
extra={'workspace_name': payload.workspace_name},
|
||||
)
|
||||
# Can't send error without workspace credentials
|
||||
return None
|
||||
|
||||
# Prevent recursive triggers from service account
|
||||
if payload.user_email == workspace.svc_acc_email:
|
||||
logger.debug(
|
||||
'[Jira] Ignoring service account trigger',
|
||||
extra={'workspace_name': payload.workspace_name},
|
||||
)
|
||||
return None
|
||||
|
||||
if workspace.status != 'active':
|
||||
logger.warning(
|
||||
'[Jira] Workspace inactive',
|
||||
extra={'workspace_id': workspace.id, 'status': workspace.status},
|
||||
)
|
||||
await self._send_error_from_payload(
|
||||
payload, workspace, 'Jira integration is not active for your workspace.'
|
||||
)
|
||||
return None
|
||||
|
||||
return workspace
|
||||
|
||||
async def _authenticate_user(
|
||||
self, payload: JiraWebhookPayload, workspace: JiraWorkspace
|
||||
async def authenticate_user(
|
||||
self, jira_user_id: str, workspace_id: int
|
||||
) -> tuple[JiraUser | None, UserAuth | None]:
|
||||
"""Authenticate the Jira user and get OpenHands auth."""
|
||||
"""Authenticate Jira user and get their OpenHands user auth."""
|
||||
|
||||
# Find active Jira user by Keycloak user ID and workspace ID
|
||||
jira_user = await self.integration_store.get_active_user(
|
||||
payload.account_id, workspace.id
|
||||
jira_user_id, workspace_id
|
||||
)
|
||||
|
||||
if not jira_user:
|
||||
logger.warning(
|
||||
'[Jira] User not found or inactive',
|
||||
extra={
|
||||
'account_id': payload.account_id,
|
||||
'user_email': payload.user_email,
|
||||
'workspace_id': workspace.id,
|
||||
},
|
||||
)
|
||||
await self._send_error_from_payload(
|
||||
payload,
|
||||
workspace,
|
||||
f'User {payload.user_email} is not authenticated or active in the Jira integration.',
|
||||
f'[Jira] No active Jira user found for {jira_user_id} in workspace {workspace_id}'
|
||||
)
|
||||
return None, None
|
||||
|
||||
saas_user_auth = await get_user_auth_from_keycloak_id(
|
||||
jira_user.keycloak_user_id
|
||||
)
|
||||
|
||||
if not saas_user_auth:
|
||||
logger.warning(
|
||||
'[Jira] Failed to get OpenHands auth',
|
||||
extra={
|
||||
'keycloak_user_id': jira_user.keycloak_user_id,
|
||||
'user_email': payload.user_email,
|
||||
},
|
||||
)
|
||||
await self._send_error_from_payload(
|
||||
payload,
|
||||
workspace,
|
||||
f'User {payload.user_email} is not authenticated with OpenHands.',
|
||||
)
|
||||
return None, None
|
||||
|
||||
return jira_user, saas_user_auth
|
||||
|
||||
async def start_job(self, view: JiraViewInterface):
|
||||
async def _get_repositories(self, user_auth: UserAuth) -> list[Repository]:
|
||||
"""Get repositories that the user has access to."""
|
||||
provider_tokens = await user_auth.get_provider_tokens()
|
||||
if provider_tokens is None:
|
||||
return []
|
||||
access_token = await user_auth.get_access_token()
|
||||
user_id = await user_auth.get_user_id()
|
||||
client = ProviderHandler(
|
||||
provider_tokens=provider_tokens,
|
||||
external_auth_token=access_token,
|
||||
external_auth_id=user_id,
|
||||
)
|
||||
repos: list[Repository] = await client.get_repositories(
|
||||
'pushed', server_config.app_mode, None, None, None, None
|
||||
)
|
||||
return repos
|
||||
|
||||
def get_workspace_name_from_payload(self, payload: dict) -> str | None:
|
||||
"""Extract workspace name from Jira webhook payload."""
|
||||
if payload.get('webhookEvent') == 'comment_created':
|
||||
selfUrl = payload.get('comment', {}).get('author', {}).get('self')
|
||||
elif payload.get('webhookEvent') == 'jira:issue_updated':
|
||||
selfUrl = payload.get('user', {}).get('self')
|
||||
else:
|
||||
return None
|
||||
|
||||
if not selfUrl:
|
||||
return None
|
||||
|
||||
parsedUrl = urlparse(selfUrl)
|
||||
return parsedUrl.hostname or None
|
||||
|
||||
def parse_webhook(self, message: Message) -> JobContext | None:
|
||||
payload = message.message.get('payload', {})
|
||||
issue_data = payload.get('issue', {})
|
||||
issue_id = issue_data.get('id')
|
||||
issue_key = issue_data.get('key')
|
||||
self_url = issue_data.get('self', '')
|
||||
if not self_url:
|
||||
logger.warning('[Jira] Missing self URL in issue data')
|
||||
base_api_url = ''
|
||||
elif '/rest/' in self_url:
|
||||
base_api_url = self_url.split('/rest/')[0]
|
||||
else:
|
||||
# Fallback: extract base URL using urlparse
|
||||
parsed = urlparse(self_url)
|
||||
base_api_url = f'{parsed.scheme}://{parsed.netloc}'
|
||||
|
||||
comment = ''
|
||||
if JiraFactory.is_ticket_comment(message):
|
||||
comment_data = payload.get('comment', {})
|
||||
comment = comment_data.get('body', '')
|
||||
user_data: dict = comment_data.get('author', {})
|
||||
elif JiraFactory.is_labeled_ticket(message):
|
||||
user_data = payload.get('user', {})
|
||||
|
||||
else:
|
||||
raise ValueError('Unrecognized jira event')
|
||||
|
||||
user_email = user_data.get('emailAddress')
|
||||
display_name = user_data.get('displayName')
|
||||
account_id = user_data.get('accountId')
|
||||
|
||||
workspace_name = ''
|
||||
parsedUrl = urlparse(base_api_url)
|
||||
if parsedUrl.hostname:
|
||||
workspace_name = parsedUrl.hostname
|
||||
|
||||
if not all(
|
||||
[
|
||||
issue_id,
|
||||
issue_key,
|
||||
user_email,
|
||||
display_name,
|
||||
account_id,
|
||||
workspace_name,
|
||||
base_api_url,
|
||||
]
|
||||
):
|
||||
return None
|
||||
|
||||
return JobContext(
|
||||
issue_id=issue_id,
|
||||
issue_key=issue_key,
|
||||
user_msg=comment,
|
||||
user_email=user_email,
|
||||
display_name=display_name,
|
||||
platform_user_id=account_id,
|
||||
workspace_name=workspace_name,
|
||||
base_api_url=base_api_url,
|
||||
)
|
||||
|
||||
async def is_job_requested(self, message: Message) -> bool:
|
||||
return JiraFactory.is_labeled_ticket(message) or JiraFactory.is_ticket_comment(
|
||||
message
|
||||
)
|
||||
|
||||
async def receive_message(self, message: Message):
|
||||
"""Process incoming Jira webhook message."""
|
||||
|
||||
payload = message.message.get('payload', {})
|
||||
logger.info('[Jira]: received payload', extra={'payload': payload})
|
||||
|
||||
is_job_requested = await self.is_job_requested(message)
|
||||
if not is_job_requested:
|
||||
return
|
||||
|
||||
job_context = self.parse_webhook(message)
|
||||
|
||||
if not job_context:
|
||||
logger.info(
|
||||
'[Jira] Failed to parse webhook payload - missing required fields or invalid structure',
|
||||
extra={'event_type': payload.get('webhookEvent')},
|
||||
)
|
||||
return
|
||||
|
||||
# Get workspace by user email domain
|
||||
workspace = await self.integration_store.get_workspace_by_name(
|
||||
job_context.workspace_name
|
||||
)
|
||||
if not workspace:
|
||||
logger.warning(
|
||||
f'[Jira] No workspace found for email domain: {job_context.user_email}'
|
||||
)
|
||||
await self._send_error_comment(
|
||||
job_context,
|
||||
'Your workspace is not configured with Jira integration.',
|
||||
None,
|
||||
)
|
||||
return
|
||||
|
||||
# Prevent any recursive triggers from the service account
|
||||
if job_context.user_email == workspace.svc_acc_email:
|
||||
return
|
||||
|
||||
if workspace.status != 'active':
|
||||
logger.warning(f'[Jira] Workspace {workspace.id} is not active')
|
||||
await self._send_error_comment(
|
||||
job_context,
|
||||
'Jira integration is not active for your workspace.',
|
||||
workspace,
|
||||
)
|
||||
return
|
||||
|
||||
# Authenticate user
|
||||
jira_user, saas_user_auth = await self.authenticate_user(
|
||||
job_context.platform_user_id, workspace.id
|
||||
)
|
||||
if not jira_user or not saas_user_auth:
|
||||
logger.warning(
|
||||
f'[Jira] User authentication failed for {job_context.user_email}'
|
||||
)
|
||||
await self._send_error_comment(
|
||||
job_context,
|
||||
f'User {job_context.user_email} is not authenticated or active in the Jira integration.',
|
||||
workspace,
|
||||
)
|
||||
return
|
||||
|
||||
# Get issue details
|
||||
try:
|
||||
api_key = self.token_manager.decrypt_text(workspace.svc_acc_api_key)
|
||||
issue_title, issue_description = await self.get_issue_details(
|
||||
job_context, workspace.jira_cloud_id, workspace.svc_acc_email, api_key
|
||||
)
|
||||
job_context.issue_title = issue_title
|
||||
job_context.issue_description = issue_description
|
||||
except Exception as e:
|
||||
logger.error(f'[Jira] Failed to get issue context: {str(e)}')
|
||||
await self._send_error_comment(
|
||||
job_context,
|
||||
'Failed to retrieve issue details. Please check the issue key and try again.',
|
||||
workspace,
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
# Create Jira view
|
||||
jira_view = await JiraFactory.create_jira_view_from_payload(
|
||||
job_context,
|
||||
saas_user_auth,
|
||||
jira_user,
|
||||
workspace,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f'[Jira] Failed to create jira view: {str(e)}', exc_info=True)
|
||||
await self._send_error_comment(
|
||||
job_context,
|
||||
'Failed to initialize conversation. Please try again.',
|
||||
workspace,
|
||||
)
|
||||
return
|
||||
|
||||
if not await self.is_repository_specified(message, jira_view):
|
||||
return
|
||||
|
||||
await self.start_job(jira_view)
|
||||
|
||||
async def is_repository_specified(
|
||||
self, message: Message, jira_view: JiraViewInterface
|
||||
) -> bool:
|
||||
"""
|
||||
Check if a job is requested and handle repository selection.
|
||||
"""
|
||||
|
||||
try:
|
||||
# Get user repositories
|
||||
user_repos: list[Repository] = await self._get_repositories(
|
||||
jira_view.saas_user_auth
|
||||
)
|
||||
|
||||
target_str = f'{jira_view.job_context.issue_description}\n{jira_view.job_context.user_msg}'
|
||||
|
||||
# Try to infer repository from issue description
|
||||
match, repos = filter_potential_repos_by_user_msg(target_str, user_repos)
|
||||
|
||||
if match:
|
||||
# Found exact repository match
|
||||
jira_view.selected_repo = repos[0].full_name
|
||||
logger.info(f'[Jira] Inferred repository: {repos[0].full_name}')
|
||||
return True
|
||||
else:
|
||||
# No clear match - send repository selection comment
|
||||
await self._send_repo_selection_comment(jira_view)
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f'[Jira] Error determining repository: {str(e)}')
|
||||
return False
|
||||
|
||||
async def start_job(self, jira_view: JiraViewInterface):
|
||||
"""Start a Jira job/conversation."""
|
||||
# Import here to prevent circular import
|
||||
from server.conversation_callback_processor.jira_callback_processor import (
|
||||
@@ -265,79 +305,101 @@ class JiraManager(Manager):
|
||||
)
|
||||
|
||||
try:
|
||||
user_info: JiraUser = jira_view.jira_user
|
||||
logger.info(
|
||||
'[Jira] Starting job',
|
||||
extra={
|
||||
'issue_key': view.payload.issue_key,
|
||||
'user_id': view.jira_user.keycloak_user_id,
|
||||
'selected_repo': view.selected_repo,
|
||||
},
|
||||
f'[Jira] Starting job for user {user_info.keycloak_user_id} '
|
||||
f'issue {jira_view.job_context.issue_key}',
|
||||
)
|
||||
|
||||
# Create conversation
|
||||
conversation_id = await view.create_or_update_conversation(self.jinja_env)
|
||||
conversation_id = await jira_view.create_or_update_conversation(
|
||||
self.jinja_env
|
||||
)
|
||||
|
||||
logger.info(
|
||||
'[Jira] Conversation created',
|
||||
extra={
|
||||
'conversation_id': conversation_id,
|
||||
'issue_key': view.payload.issue_key,
|
||||
},
|
||||
f'[Jira] Created/Updated conversation {conversation_id} for issue {jira_view.job_context.issue_key}'
|
||||
)
|
||||
|
||||
# Register callback processor for updates
|
||||
if isinstance(view, JiraNewConversationView):
|
||||
if isinstance(jira_view, JiraNewConversationView):
|
||||
processor = JiraCallbackProcessor(
|
||||
issue_key=view.payload.issue_key,
|
||||
workspace_name=view.jira_workspace.name,
|
||||
)
|
||||
register_callback_processor(conversation_id, processor)
|
||||
logger.info(
|
||||
'[Jira] Callback processor registered',
|
||||
extra={'conversation_id': conversation_id},
|
||||
issue_key=jira_view.job_context.issue_key,
|
||||
workspace_name=jira_view.jira_workspace.name,
|
||||
)
|
||||
|
||||
# Send success response
|
||||
msg_info = view.get_response_msg()
|
||||
# Register the callback processor
|
||||
register_callback_processor(conversation_id, processor)
|
||||
|
||||
logger.info(
|
||||
f'[Jira] Created callback processor for conversation {conversation_id}'
|
||||
)
|
||||
|
||||
# Send initial response
|
||||
msg_info = jira_view.get_response_msg()
|
||||
|
||||
except MissingSettingsError as e:
|
||||
logger.warning(
|
||||
'[Jira] Missing settings error',
|
||||
extra={'issue_key': view.payload.issue_key, 'error': str(e)},
|
||||
)
|
||||
logger.warning(f'[Jira] Missing settings error: {str(e)}')
|
||||
msg_info = f'Please re-login into [OpenHands Cloud]({HOST_URL}) before starting a job.'
|
||||
|
||||
except LLMAuthenticationError as e:
|
||||
logger.warning(
|
||||
'[Jira] LLM authentication error',
|
||||
extra={'issue_key': view.payload.issue_key, 'error': str(e)},
|
||||
)
|
||||
logger.warning(f'[Jira] LLM authentication error: {str(e)}')
|
||||
msg_info = f'Please set a valid LLM API key in [OpenHands Cloud]({HOST_URL}) before starting a job.'
|
||||
|
||||
except SessionExpiredError as e:
|
||||
logger.warning(
|
||||
'[Jira] Session expired',
|
||||
extra={'issue_key': view.payload.issue_key, 'error': str(e)},
|
||||
)
|
||||
logger.warning(f'[Jira] Session expired: {str(e)}')
|
||||
msg_info = get_session_expired_message()
|
||||
|
||||
except StartingConvoException as e:
|
||||
logger.warning(
|
||||
'[Jira] Conversation start failed',
|
||||
extra={'issue_key': view.payload.issue_key, 'error': str(e)},
|
||||
)
|
||||
msg_info = str(e)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
'[Jira] Unexpected error starting job',
|
||||
extra={'issue_key': view.payload.issue_key, 'error': str(e)},
|
||||
exc_info=True,
|
||||
f'[Jira] Unexpected error starting job: {str(e)}', exc_info=True
|
||||
)
|
||||
msg_info = 'Sorry, there was an unexpected error starting the job. Please try again.'
|
||||
|
||||
# Send response comment
|
||||
await self._send_comment(view, msg_info)
|
||||
try:
|
||||
api_key = self.token_manager.decrypt_text(
|
||||
jira_view.jira_workspace.svc_acc_api_key
|
||||
)
|
||||
await self.send_message(
|
||||
self.create_outgoing_message(msg=msg_info),
|
||||
issue_key=jira_view.job_context.issue_key,
|
||||
jira_cloud_id=jira_view.jira_workspace.jira_cloud_id,
|
||||
svc_acc_email=jira_view.jira_workspace.svc_acc_email,
|
||||
svc_acc_api_key=api_key,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f'[Jira] Failed to send response message: {str(e)}')
|
||||
|
||||
async def get_issue_details(
|
||||
self,
|
||||
job_context: JobContext,
|
||||
jira_cloud_id: str,
|
||||
svc_acc_email: str,
|
||||
svc_acc_api_key: str,
|
||||
) -> Tuple[str, str]:
|
||||
url = f'{JIRA_CLOUD_API_URL}/{jira_cloud_id}/rest/api/2/issue/{job_context.issue_key}'
|
||||
async with httpx.AsyncClient(verify=httpx_verify_option()) as client:
|
||||
response = await client.get(url, auth=(svc_acc_email, svc_acc_api_key))
|
||||
response.raise_for_status()
|
||||
issue_payload = response.json()
|
||||
|
||||
if not issue_payload:
|
||||
raise ValueError(f'Issue with key {job_context.issue_key} not found.')
|
||||
|
||||
title = issue_payload.get('fields', {}).get('summary', '')
|
||||
description = issue_payload.get('fields', {}).get('description', '')
|
||||
|
||||
if not title:
|
||||
raise ValueError(
|
||||
f'Issue with key {job_context.issue_key} does not have a title.'
|
||||
)
|
||||
|
||||
if not description:
|
||||
raise ValueError(
|
||||
f'Issue with key {job_context.issue_key} does not have a description.'
|
||||
)
|
||||
|
||||
return title, description
|
||||
|
||||
async def send_message(
|
||||
self,
|
||||
@@ -347,7 +409,6 @@ class JiraManager(Manager):
|
||||
svc_acc_email: str,
|
||||
svc_acc_api_key: str,
|
||||
):
|
||||
"""Send a comment to a Jira issue."""
|
||||
url = (
|
||||
f'{JIRA_CLOUD_API_URL}/{jira_cloud_id}/rest/api/2/issue/{issue_key}/comment'
|
||||
)
|
||||
@@ -359,53 +420,54 @@ class JiraManager(Manager):
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
async def _send_comment(self, view: JiraViewInterface, msg: str):
|
||||
"""Send a comment using credentials from the view."""
|
||||
try:
|
||||
api_key = self.token_manager.decrypt_text(
|
||||
view.jira_workspace.svc_acc_api_key
|
||||
)
|
||||
await self.send_message(
|
||||
self.create_outgoing_message(msg=msg),
|
||||
issue_key=view.payload.issue_key,
|
||||
jira_cloud_id=view.jira_workspace.jira_cloud_id,
|
||||
svc_acc_email=view.jira_workspace.svc_acc_email,
|
||||
svc_acc_api_key=api_key,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
'[Jira] Failed to send comment',
|
||||
extra={'issue_key': view.payload.issue_key, 'error': str(e)},
|
||||
)
|
||||
|
||||
async def _send_error_from_payload(
|
||||
async def _send_error_comment(
|
||||
self,
|
||||
payload: JiraWebhookPayload,
|
||||
workspace: JiraWorkspace,
|
||||
job_context: JobContext,
|
||||
error_msg: str,
|
||||
workspace: JiraWorkspace | None,
|
||||
):
|
||||
"""Send error comment before view is created (using payload directly)."""
|
||||
"""Send error comment to Jira issue."""
|
||||
if not workspace:
|
||||
logger.error('[Jira] Cannot send error comment - no workspace available')
|
||||
return
|
||||
|
||||
try:
|
||||
api_key = self.token_manager.decrypt_text(workspace.svc_acc_api_key)
|
||||
await self.send_message(
|
||||
self.create_outgoing_message(msg=error_msg),
|
||||
issue_key=payload.issue_key,
|
||||
issue_key=job_context.issue_key,
|
||||
jira_cloud_id=workspace.jira_cloud_id,
|
||||
svc_acc_email=workspace.svc_acc_email,
|
||||
svc_acc_api_key=api_key,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
'[Jira] Failed to send error comment',
|
||||
extra={'issue_key': payload.issue_key, 'error': str(e)},
|
||||
logger.error(f'[Jira] Failed to send error comment: {str(e)}')
|
||||
|
||||
async def _send_repo_selection_comment(self, jira_view: JiraViewInterface):
|
||||
"""Send a comment with repository options for the user to choose."""
|
||||
try:
|
||||
comment_msg = (
|
||||
'I need to know which repository to work with. '
|
||||
'Please add it to your issue description or send a followup comment.'
|
||||
)
|
||||
|
||||
def get_workspace_name_from_payload(self, payload: dict) -> str | None:
|
||||
"""Extract workspace name from Jira webhook payload.
|
||||
api_key = self.token_manager.decrypt_text(
|
||||
jira_view.jira_workspace.svc_acc_api_key
|
||||
)
|
||||
|
||||
This method is used by the route for signature verification.
|
||||
"""
|
||||
parse_result = self.payload_parser.parse(payload)
|
||||
if isinstance(parse_result, JiraPayloadSuccess):
|
||||
return parse_result.payload.workspace_name
|
||||
return None
|
||||
await self.send_message(
|
||||
self.create_outgoing_message(msg=comment_msg),
|
||||
issue_key=jira_view.job_context.issue_key,
|
||||
jira_cloud_id=jira_view.jira_workspace.jira_cloud_id,
|
||||
svc_acc_email=jira_view.jira_workspace.svc_acc_email,
|
||||
svc_acc_api_key=api_key,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f'[Jira] Sent repository selection comment for issue {jira_view.job_context.issue_key}'
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f'[Jira] Failed to send repository selection comment: {str(e)}'
|
||||
)
|
||||
|
||||
@@ -1,267 +0,0 @@
|
||||
"""Centralized payload parsing for Jira webhooks.
|
||||
|
||||
This module provides a single source of truth for parsing and validating
|
||||
Jira webhook payloads, replacing scattered parsing logic throughout the codebase.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
|
||||
|
||||
class JiraEventType(Enum):
|
||||
"""Types of Jira events we handle."""
|
||||
|
||||
LABELED_TICKET = 'labeled_ticket'
|
||||
COMMENT_MENTION = 'comment_mention'
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class JiraWebhookPayload:
|
||||
"""Normalized, validated representation of a Jira webhook payload.
|
||||
|
||||
This immutable dataclass replaces JobContext and provides a single
|
||||
source of truth for all webhook data. All parsing happens in
|
||||
JiraPayloadParser, ensuring consistent validation.
|
||||
"""
|
||||
|
||||
event_type: JiraEventType
|
||||
raw_event: str # Original webhookEvent value
|
||||
|
||||
# Issue data
|
||||
issue_id: str
|
||||
issue_key: str
|
||||
|
||||
# User data
|
||||
user_email: str
|
||||
display_name: str
|
||||
account_id: str
|
||||
|
||||
# Workspace data (derived from issue self URL)
|
||||
workspace_name: str
|
||||
base_api_url: str
|
||||
|
||||
# Event-specific data
|
||||
comment_body: str = '' # For comment events
|
||||
|
||||
@property
|
||||
def user_msg(self) -> str:
|
||||
"""Alias for comment_body for backward compatibility."""
|
||||
return self.comment_body
|
||||
|
||||
|
||||
class JiraPayloadParseError(Exception):
|
||||
"""Raised when payload parsing fails."""
|
||||
|
||||
def __init__(self, reason: str, event_type: str | None = None):
|
||||
self.reason = reason
|
||||
self.event_type = event_type
|
||||
super().__init__(reason)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class JiraPayloadSuccess:
|
||||
"""Result when parsing succeeds."""
|
||||
|
||||
payload: JiraWebhookPayload
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class JiraPayloadSkipped:
|
||||
"""Result when event is intentionally skipped."""
|
||||
|
||||
skip_reason: str
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class JiraPayloadError:
|
||||
"""Result when parsing fails due to invalid data."""
|
||||
|
||||
error: str
|
||||
|
||||
|
||||
JiraPayloadParseResult = JiraPayloadSuccess | JiraPayloadSkipped | JiraPayloadError
|
||||
|
||||
|
||||
class JiraPayloadParser:
|
||||
"""Centralized parser for Jira webhook payloads.
|
||||
|
||||
This class provides a single entry point for parsing webhooks,
|
||||
determining event types, and extracting all necessary fields.
|
||||
Replaces scattered parsing in JiraFactory and JiraManager.
|
||||
"""
|
||||
|
||||
def __init__(self, oh_label: str, inline_oh_label: str):
|
||||
"""Initialize parser with OpenHands label configuration.
|
||||
|
||||
Args:
|
||||
oh_label: Label that triggers OpenHands (e.g., 'openhands')
|
||||
inline_oh_label: Mention that triggers OpenHands (e.g., '@openhands')
|
||||
"""
|
||||
self.oh_label = oh_label
|
||||
self.inline_oh_label = inline_oh_label
|
||||
|
||||
def parse(self, raw_payload: dict) -> JiraPayloadParseResult:
|
||||
"""Parse a raw webhook payload into a normalized JiraWebhookPayload.
|
||||
|
||||
Args:
|
||||
raw_payload: The raw webhook payload dict from Jira
|
||||
|
||||
Returns:
|
||||
One of:
|
||||
- JiraPayloadSuccess: Valid, actionable event with payload
|
||||
- JiraPayloadSkipped: Event we intentionally don't process
|
||||
- JiraPayloadError: Malformed payload we expected to process
|
||||
"""
|
||||
webhook_event = raw_payload.get('webhookEvent', '')
|
||||
|
||||
logger.debug(
|
||||
'[Jira] Parsing webhook payload', extra={'webhook_event': webhook_event}
|
||||
)
|
||||
|
||||
if webhook_event == 'jira:issue_updated':
|
||||
return self._parse_label_event(raw_payload, webhook_event)
|
||||
elif webhook_event == 'comment_created':
|
||||
return self._parse_comment_event(raw_payload, webhook_event)
|
||||
else:
|
||||
return JiraPayloadSkipped(f'Unhandled webhook event type: {webhook_event}')
|
||||
|
||||
def _parse_label_event(
|
||||
self, payload: dict, webhook_event: str
|
||||
) -> JiraPayloadParseResult:
|
||||
"""Parse an issue_updated event for label changes."""
|
||||
changelog = payload.get('changelog', {})
|
||||
items = changelog.get('items', [])
|
||||
|
||||
# Extract labels that were added
|
||||
labels = [
|
||||
item.get('toString', '')
|
||||
for item in items
|
||||
if item.get('field') == 'labels' and 'toString' in item
|
||||
]
|
||||
|
||||
if self.oh_label not in labels:
|
||||
return JiraPayloadSkipped(
|
||||
f"Label event does not contain '{self.oh_label}' label"
|
||||
)
|
||||
|
||||
# For label events, user data comes from 'user' field
|
||||
user_data = payload.get('user', {})
|
||||
return self._extract_and_validate(
|
||||
payload=payload,
|
||||
user_data=user_data,
|
||||
event_type=JiraEventType.LABELED_TICKET,
|
||||
webhook_event=webhook_event,
|
||||
comment_body='',
|
||||
)
|
||||
|
||||
def _parse_comment_event(
|
||||
self, payload: dict, webhook_event: str
|
||||
) -> JiraPayloadParseResult:
|
||||
"""Parse a comment_created event."""
|
||||
comment_data = payload.get('comment', {})
|
||||
comment_body = comment_data.get('body', '')
|
||||
|
||||
if not self._has_mention(comment_body):
|
||||
return JiraPayloadSkipped(
|
||||
f"Comment does not mention '{self.inline_oh_label}'"
|
||||
)
|
||||
|
||||
# For comment events, user data comes from 'comment.author'
|
||||
user_data = comment_data.get('author', {})
|
||||
return self._extract_and_validate(
|
||||
payload=payload,
|
||||
user_data=user_data,
|
||||
event_type=JiraEventType.COMMENT_MENTION,
|
||||
webhook_event=webhook_event,
|
||||
comment_body=comment_body,
|
||||
)
|
||||
|
||||
def _has_mention(self, text: str) -> bool:
|
||||
"""Check if text contains an exact mention of OpenHands."""
|
||||
from integrations.utils import has_exact_mention
|
||||
|
||||
return has_exact_mention(text, self.inline_oh_label)
|
||||
|
||||
def _extract_and_validate(
|
||||
self,
|
||||
payload: dict,
|
||||
user_data: dict,
|
||||
event_type: JiraEventType,
|
||||
webhook_event: str,
|
||||
comment_body: str,
|
||||
) -> JiraPayloadParseResult:
|
||||
"""Extract common fields and validate required data is present."""
|
||||
issue_data = payload.get('issue', {})
|
||||
|
||||
# Extract all fields with empty string defaults (makes them str type)
|
||||
issue_id = issue_data.get('id', '')
|
||||
issue_key = issue_data.get('key', '')
|
||||
user_email = user_data.get('emailAddress', '')
|
||||
display_name = user_data.get('displayName', '')
|
||||
account_id = user_data.get('accountId', '')
|
||||
base_api_url, workspace_name = self._extract_workspace_from_url(
|
||||
issue_data.get('self', '')
|
||||
)
|
||||
|
||||
# Validate required fields
|
||||
missing: list[str] = []
|
||||
if not issue_id:
|
||||
missing.append('issue.id')
|
||||
if not issue_key:
|
||||
missing.append('issue.key')
|
||||
if not user_email:
|
||||
missing.append('user.emailAddress')
|
||||
if not display_name:
|
||||
missing.append('user.displayName')
|
||||
if not account_id:
|
||||
missing.append('user.accountId')
|
||||
if not workspace_name:
|
||||
missing.append('workspace_name (derived from issue.self)')
|
||||
if not base_api_url:
|
||||
missing.append('base_api_url (derived from issue.self)')
|
||||
|
||||
if missing:
|
||||
return JiraPayloadError(f"Missing required fields: {', '.join(missing)}")
|
||||
|
||||
return JiraPayloadSuccess(
|
||||
JiraWebhookPayload(
|
||||
event_type=event_type,
|
||||
raw_event=webhook_event,
|
||||
issue_id=issue_id,
|
||||
issue_key=issue_key,
|
||||
user_email=user_email,
|
||||
display_name=display_name,
|
||||
account_id=account_id,
|
||||
workspace_name=workspace_name,
|
||||
base_api_url=base_api_url,
|
||||
comment_body=comment_body,
|
||||
)
|
||||
)
|
||||
|
||||
def _extract_workspace_from_url(self, self_url: str) -> tuple[str, str]:
|
||||
"""Extract base API URL and workspace name from issue self URL.
|
||||
|
||||
Args:
|
||||
self_url: The 'self' URL from the issue data
|
||||
|
||||
Returns:
|
||||
Tuple of (base_api_url, workspace_name)
|
||||
"""
|
||||
if not self_url:
|
||||
return '', ''
|
||||
|
||||
# Extract base URL (everything before /rest/)
|
||||
if '/rest/' in self_url:
|
||||
base_api_url = self_url.split('/rest/')[0]
|
||||
else:
|
||||
parsed = urlparse(self_url)
|
||||
base_api_url = f'{parsed.scheme}://{parsed.netloc}'
|
||||
|
||||
# Extract workspace name (hostname)
|
||||
parsed = urlparse(base_api_url)
|
||||
workspace_name = parsed.hostname or ''
|
||||
|
||||
return base_api_url, workspace_name
|
||||
@@ -1,42 +1,26 @@
|
||||
"""Type definitions and interfaces for Jira integration."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from integrations.models import JobContext
|
||||
from jinja2 import Environment
|
||||
from storage.jira_user import JiraUser
|
||||
from storage.jira_workspace import JiraWorkspace
|
||||
|
||||
from openhands.server.user_auth.user_auth import UserAuth
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from integrations.jira.jira_payload import JiraWebhookPayload
|
||||
|
||||
|
||||
class JiraViewInterface(ABC):
|
||||
"""Interface for Jira views that handle different types of Jira interactions.
|
||||
"""Interface for Jira views that handle different types of Jira interactions."""
|
||||
|
||||
Views hold the webhook payload directly rather than duplicating fields,
|
||||
and fetch issue details lazily when needed.
|
||||
"""
|
||||
|
||||
# Core data - view holds these references
|
||||
payload: 'JiraWebhookPayload'
|
||||
job_context: JobContext
|
||||
saas_user_auth: UserAuth
|
||||
jira_user: JiraUser
|
||||
jira_workspace: JiraWorkspace
|
||||
|
||||
# Mutable state set during processing
|
||||
selected_repo: str | None
|
||||
conversation_id: str
|
||||
|
||||
@abstractmethod
|
||||
async def get_issue_details(self) -> tuple[str, str]:
|
||||
"""Fetch and cache issue title and description from Jira API.
|
||||
|
||||
Returns:
|
||||
Tuple of (issue_title, issue_description)
|
||||
"""
|
||||
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
"""Get initial instructions for the conversation."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@@ -51,21 +35,6 @@ class JiraViewInterface(ABC):
|
||||
|
||||
|
||||
class StartingConvoException(Exception):
|
||||
"""Exception raised when starting a conversation fails.
|
||||
|
||||
This provides user-friendly error messages that can be sent back to Jira.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class RepositoryNotFoundError(Exception):
|
||||
"""Raised when a repository cannot be determined from the issue.
|
||||
|
||||
This is a separate error domain from StartingConvoException - it represents
|
||||
a precondition failure (no repo configured/found) rather than a conversation
|
||||
creation failure. The manager catches this and converts it to a user-friendly
|
||||
message.
|
||||
"""
|
||||
"""Exception raised when starting a conversation fails."""
|
||||
|
||||
pass
|
||||
|
||||
@@ -1,21 +1,8 @@
|
||||
"""Jira view implementations and factory.
|
||||
from dataclasses import dataclass
|
||||
|
||||
Views are responsible for:
|
||||
- Holding the webhook payload and auth context
|
||||
- Lazy-loading issue details from Jira API when needed
|
||||
- Creating conversations with the selected repository
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import httpx
|
||||
from integrations.jira.jira_payload import JiraWebhookPayload
|
||||
from integrations.jira.jira_types import (
|
||||
JiraViewInterface,
|
||||
RepositoryNotFoundError,
|
||||
StartingConvoException,
|
||||
)
|
||||
from integrations.utils import CONVERSATION_URL, infer_repo_from_message
|
||||
from integrations.jira.jira_types import JiraViewInterface, StartingConvoException
|
||||
from integrations.models import JobContext, Message
|
||||
from integrations.utils import CONVERSATION_URL, HOST, get_oh_labels, has_exact_mention
|
||||
from jinja2 import Environment
|
||||
from storage.jira_conversation import JiraConversation
|
||||
from storage.jira_integration_store import JiraIntegrationStore
|
||||
@@ -23,147 +10,52 @@ from storage.jira_user import JiraUser
|
||||
from storage.jira_workspace import JiraWorkspace
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.integrations.provider import ProviderHandler
|
||||
from openhands.server.services.conversation_service import create_new_conversation
|
||||
from openhands.server.services.conversation_service import (
|
||||
create_new_conversation,
|
||||
)
|
||||
from openhands.server.user_auth.user_auth import UserAuth
|
||||
from openhands.storage.data_models.conversation_metadata import ConversationTrigger
|
||||
from openhands.utils.http_session import httpx_verify_option
|
||||
|
||||
JIRA_CLOUD_API_URL = 'https://api.atlassian.com/ex/jira'
|
||||
|
||||
integration_store = JiraIntegrationStore.get_instance()
|
||||
|
||||
OH_LABEL, INLINE_OH_LABEL = get_oh_labels(HOST)
|
||||
|
||||
|
||||
@dataclass
|
||||
class JiraNewConversationView(JiraViewInterface):
|
||||
"""View for creating a new Jira conversation.
|
||||
|
||||
This view holds the webhook payload directly and lazily fetches
|
||||
issue details when needed for rendering templates.
|
||||
"""
|
||||
|
||||
payload: JiraWebhookPayload
|
||||
job_context: JobContext
|
||||
saas_user_auth: UserAuth
|
||||
jira_user: JiraUser
|
||||
jira_workspace: JiraWorkspace
|
||||
selected_repo: str | None = None
|
||||
conversation_id: str = ''
|
||||
selected_repo: str | None
|
||||
conversation_id: str
|
||||
|
||||
# Lazy-loaded issue details (cached after first fetch)
|
||||
_issue_title: str | None = field(default=None, repr=False)
|
||||
_issue_description: str | None = field(default=None, repr=False)
|
||||
|
||||
# Decrypted API key (set by factory)
|
||||
_decrypted_api_key: str = field(default='', repr=False)
|
||||
|
||||
async def get_issue_details(self) -> tuple[str, str]:
|
||||
"""Fetch issue details from Jira API (cached after first call).
|
||||
|
||||
Returns:
|
||||
Tuple of (issue_title, issue_description)
|
||||
|
||||
Raises:
|
||||
StartingConvoException: If issue details cannot be fetched
|
||||
"""
|
||||
if self._issue_title is not None and self._issue_description is not None:
|
||||
return self._issue_title, self._issue_description
|
||||
|
||||
try:
|
||||
url = f'{JIRA_CLOUD_API_URL}/{self.jira_workspace.jira_cloud_id}/rest/api/2/issue/{self.payload.issue_key}'
|
||||
async with httpx.AsyncClient(verify=httpx_verify_option()) as client:
|
||||
response = await client.get(
|
||||
url,
|
||||
auth=(
|
||||
self.jira_workspace.svc_acc_email,
|
||||
self._decrypted_api_key,
|
||||
),
|
||||
)
|
||||
response.raise_for_status()
|
||||
issue_payload = response.json()
|
||||
|
||||
if not issue_payload:
|
||||
raise StartingConvoException(
|
||||
f'Issue {self.payload.issue_key} not found.'
|
||||
)
|
||||
|
||||
self._issue_title = issue_payload.get('fields', {}).get('summary', '')
|
||||
self._issue_description = (
|
||||
issue_payload.get('fields', {}).get('description', '') or ''
|
||||
)
|
||||
|
||||
if not self._issue_title:
|
||||
raise StartingConvoException(
|
||||
f'Issue {self.payload.issue_key} does not have a title.'
|
||||
)
|
||||
|
||||
logger.info(
|
||||
'[Jira] Fetched issue details',
|
||||
extra={
|
||||
'issue_key': self.payload.issue_key,
|
||||
'has_description': bool(self._issue_description),
|
||||
},
|
||||
)
|
||||
|
||||
return self._issue_title, self._issue_description
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(
|
||||
'[Jira] Failed to fetch issue details',
|
||||
extra={
|
||||
'issue_key': self.payload.issue_key,
|
||||
'status': e.response.status_code,
|
||||
},
|
||||
)
|
||||
raise StartingConvoException(
|
||||
f'Failed to fetch issue details: HTTP {e.response.status_code}'
|
||||
)
|
||||
except Exception as e:
|
||||
if isinstance(e, StartingConvoException):
|
||||
raise
|
||||
logger.error(
|
||||
'[Jira] Failed to fetch issue details',
|
||||
extra={'issue_key': self.payload.issue_key, 'error': str(e)},
|
||||
)
|
||||
raise StartingConvoException(f'Failed to fetch issue details: {str(e)}')
|
||||
|
||||
async def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
"""Get instructions for the conversation.
|
||||
|
||||
This fetches issue details if not already cached.
|
||||
|
||||
Returns:
|
||||
Tuple of (system_instructions, user_message)
|
||||
"""
|
||||
issue_title, issue_description = await self.get_issue_details()
|
||||
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
"""Instructions passed when conversation is first initialized"""
|
||||
|
||||
instructions_template = jinja_env.get_template('jira_instructions.j2')
|
||||
instructions = instructions_template.render()
|
||||
|
||||
user_msg_template = jinja_env.get_template('jira_new_conversation.j2')
|
||||
|
||||
user_msg = user_msg_template.render(
|
||||
issue_key=self.payload.issue_key,
|
||||
issue_title=issue_title,
|
||||
issue_description=issue_description,
|
||||
user_message=self.payload.user_msg,
|
||||
issue_key=self.job_context.issue_key,
|
||||
issue_title=self.job_context.issue_title,
|
||||
issue_description=self.job_context.issue_description,
|
||||
user_message=self.job_context.user_msg or '',
|
||||
)
|
||||
|
||||
return instructions, user_msg
|
||||
|
||||
async def create_or_update_conversation(self, jinja_env: Environment) -> str:
|
||||
"""Create a new Jira conversation.
|
||||
"""Create a new Jira conversation"""
|
||||
|
||||
Returns:
|
||||
The conversation ID
|
||||
|
||||
Raises:
|
||||
StartingConvoException: If conversation creation fails
|
||||
"""
|
||||
if not self.selected_repo:
|
||||
raise StartingConvoException('No repository selected for this conversation')
|
||||
|
||||
provider_tokens = await self.saas_user_auth.get_provider_tokens()
|
||||
user_secrets = await self.saas_user_auth.get_secrets()
|
||||
instructions, user_msg = await self._get_instructions(jinja_env)
|
||||
instructions, user_msg = self._get_instructions(jinja_env)
|
||||
|
||||
try:
|
||||
agent_loop_info = await create_new_conversation(
|
||||
@@ -181,259 +73,81 @@ class JiraNewConversationView(JiraViewInterface):
|
||||
|
||||
self.conversation_id = agent_loop_info.conversation_id
|
||||
|
||||
logger.info(
|
||||
'[Jira] Created conversation',
|
||||
extra={
|
||||
'conversation_id': self.conversation_id,
|
||||
'issue_key': self.payload.issue_key,
|
||||
'selected_repo': self.selected_repo,
|
||||
},
|
||||
)
|
||||
logger.info(f'[Jira] Created conversation {self.conversation_id}')
|
||||
|
||||
# Store Jira conversation mapping
|
||||
jira_conversation = JiraConversation(
|
||||
conversation_id=self.conversation_id,
|
||||
issue_id=self.payload.issue_id,
|
||||
issue_key=self.payload.issue_key,
|
||||
issue_id=self.job_context.issue_id,
|
||||
issue_key=self.job_context.issue_key,
|
||||
jira_user_id=self.jira_user.id,
|
||||
)
|
||||
|
||||
await integration_store.create_conversation(jira_conversation)
|
||||
|
||||
return self.conversation_id
|
||||
|
||||
except Exception as e:
|
||||
if isinstance(e, StartingConvoException):
|
||||
raise
|
||||
logger.error(
|
||||
'[Jira] Failed to create conversation',
|
||||
extra={'issue_key': self.payload.issue_key, 'error': str(e)},
|
||||
exc_info=True,
|
||||
f'[Jira] Failed to create conversation: {str(e)}', exc_info=True
|
||||
)
|
||||
raise StartingConvoException(f'Failed to create conversation: {str(e)}')
|
||||
|
||||
def get_response_msg(self) -> str:
|
||||
"""Get the response message to send back to Jira."""
|
||||
"""Get the response message to send back to Jira"""
|
||||
conversation_link = CONVERSATION_URL.format(self.conversation_id)
|
||||
return f"I'm on it! {self.payload.display_name} can [track my progress here|{conversation_link}]."
|
||||
return f"I'm on it! {self.job_context.display_name} can [track my progress here|{conversation_link}]."
|
||||
|
||||
|
||||
class JiraFactory:
|
||||
"""Factory for creating Jira views.
|
||||
|
||||
The factory is responsible for:
|
||||
- Creating the appropriate view type
|
||||
- Inferring and selecting the repository
|
||||
- Validating all required data is available
|
||||
|
||||
Repository selection happens here so that view creation either
|
||||
succeeds with a valid repo or fails with a clear error.
|
||||
"""
|
||||
"""Factory for creating Jira views based on message content"""
|
||||
|
||||
@staticmethod
|
||||
async def _create_provider_handler(user_auth: UserAuth) -> ProviderHandler | None:
|
||||
"""Create a ProviderHandler for the user."""
|
||||
provider_tokens = await user_auth.get_provider_tokens()
|
||||
if provider_tokens is None:
|
||||
return None
|
||||
def is_labeled_ticket(message: Message) -> bool:
|
||||
payload = message.message.get('payload', {})
|
||||
event_type = payload.get('webhookEvent')
|
||||
|
||||
access_token = await user_auth.get_access_token()
|
||||
user_id = await user_auth.get_user_id()
|
||||
if event_type != 'jira:issue_updated':
|
||||
return False
|
||||
|
||||
return ProviderHandler(
|
||||
provider_tokens=provider_tokens,
|
||||
external_auth_token=access_token,
|
||||
external_auth_id=user_id,
|
||||
)
|
||||
changelog = payload.get('changelog', {})
|
||||
items = changelog.get('items', [])
|
||||
labels = [
|
||||
item.get('toString', '')
|
||||
for item in items
|
||||
if item.get('field') == 'labels' and 'toString' in item
|
||||
]
|
||||
|
||||
return OH_LABEL in labels
|
||||
|
||||
@staticmethod
|
||||
def _extract_potential_repos(
|
||||
issue_key: str,
|
||||
issue_title: str,
|
||||
issue_description: str,
|
||||
user_msg: str,
|
||||
) -> list[str]:
|
||||
"""Extract potential repository names from issue content.
|
||||
def is_ticket_comment(message: Message) -> bool:
|
||||
payload = message.message.get('payload', {})
|
||||
event_type = payload.get('webhookEvent')
|
||||
|
||||
Raises:
|
||||
RepositoryNotFoundError: If no potential repos found in text.
|
||||
"""
|
||||
search_text = f'{issue_title}\n{issue_description}\n{user_msg}'
|
||||
potential_repos = infer_repo_from_message(search_text)
|
||||
if event_type != 'comment_created':
|
||||
return False
|
||||
|
||||
if not potential_repos:
|
||||
raise RepositoryNotFoundError(
|
||||
'Could not determine which repository to use. '
|
||||
'Please mention the repository (e.g., owner/repo) in the issue description or comment.'
|
||||
)
|
||||
|
||||
logger.info(
|
||||
'[Jira] Found potential repositories in issue content',
|
||||
extra={'issue_key': issue_key, 'potential_repos': potential_repos},
|
||||
)
|
||||
return potential_repos
|
||||
comment_data = payload.get('comment', {})
|
||||
comment_body = comment_data.get('body', '')
|
||||
return has_exact_mention(comment_body, INLINE_OH_LABEL)
|
||||
|
||||
@staticmethod
|
||||
async def _verify_repos(
|
||||
issue_key: str,
|
||||
potential_repos: list[str],
|
||||
provider_handler: ProviderHandler,
|
||||
) -> list[str]:
|
||||
"""Verify which repos the user has access to."""
|
||||
verified_repos: list[str] = []
|
||||
|
||||
for repo_name in potential_repos:
|
||||
try:
|
||||
repository = await provider_handler.verify_repo_provider(repo_name)
|
||||
verified_repos.append(repository.full_name)
|
||||
logger.debug(
|
||||
'[Jira] Repository verification succeeded',
|
||||
extra={'issue_key': issue_key, 'repository': repository.full_name},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug(
|
||||
'[Jira] Repository verification failed',
|
||||
extra={
|
||||
'issue_key': issue_key,
|
||||
'repo_name': repo_name,
|
||||
'error': str(e),
|
||||
},
|
||||
)
|
||||
|
||||
return verified_repos
|
||||
|
||||
@staticmethod
|
||||
def _select_single_repo(
|
||||
issue_key: str,
|
||||
potential_repos: list[str],
|
||||
verified_repos: list[str],
|
||||
) -> str:
|
||||
"""Select exactly one repo from verified repos.
|
||||
|
||||
Raises:
|
||||
RepositoryNotFoundError: If zero or multiple repos verified.
|
||||
"""
|
||||
if len(verified_repos) == 0:
|
||||
raise RepositoryNotFoundError(
|
||||
f'Could not access any of the mentioned repositories: {", ".join(potential_repos)}. '
|
||||
'Please ensure you have access to the repository and it exists.'
|
||||
)
|
||||
|
||||
if len(verified_repos) > 1:
|
||||
raise RepositoryNotFoundError(
|
||||
f'Multiple repositories found: {", ".join(verified_repos)}. '
|
||||
'Please specify exactly one repository in the issue description or comment.'
|
||||
)
|
||||
|
||||
logger.info(
|
||||
'[Jira] Verified repository access',
|
||||
extra={'issue_key': issue_key, 'repository': verified_repos[0]},
|
||||
)
|
||||
return verified_repos[0]
|
||||
|
||||
@staticmethod
|
||||
async def _infer_repository(
|
||||
payload: JiraWebhookPayload,
|
||||
user_auth: UserAuth,
|
||||
issue_title: str,
|
||||
issue_description: str,
|
||||
) -> str:
|
||||
"""Infer and verify the repository from issue content.
|
||||
|
||||
Raises:
|
||||
RepositoryNotFoundError: If no valid repository can be determined.
|
||||
"""
|
||||
provider_handler = await JiraFactory._create_provider_handler(user_auth)
|
||||
if not provider_handler:
|
||||
raise RepositoryNotFoundError(
|
||||
'No Git provider connected. Please connect a Git provider in OpenHands settings.'
|
||||
)
|
||||
|
||||
potential_repos = JiraFactory._extract_potential_repos(
|
||||
payload.issue_key, issue_title, issue_description, payload.user_msg
|
||||
)
|
||||
|
||||
verified_repos = await JiraFactory._verify_repos(
|
||||
payload.issue_key, potential_repos, provider_handler
|
||||
)
|
||||
|
||||
return JiraFactory._select_single_repo(
|
||||
payload.issue_key, potential_repos, verified_repos
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def create_view(
|
||||
payload: JiraWebhookPayload,
|
||||
workspace: JiraWorkspace,
|
||||
user: JiraUser,
|
||||
user_auth: UserAuth,
|
||||
decrypted_api_key: str,
|
||||
async def create_jira_view_from_payload(
|
||||
job_context: JobContext,
|
||||
saas_user_auth: UserAuth,
|
||||
jira_user: JiraUser,
|
||||
jira_workspace: JiraWorkspace,
|
||||
) -> JiraViewInterface:
|
||||
"""Create a Jira view with repository already selected.
|
||||
"""Create appropriate Jira view based on the message and user state"""
|
||||
|
||||
This factory method:
|
||||
1. Creates the view with payload and auth context
|
||||
2. Fetches issue details (needed for repo inference)
|
||||
3. Infers and selects the repository
|
||||
if not jira_user or not saas_user_auth or not jira_workspace:
|
||||
raise StartingConvoException('User not authenticated with Jira integration')
|
||||
|
||||
If any step fails, an appropriate exception is raised with
|
||||
a user-friendly message.
|
||||
|
||||
Args:
|
||||
payload: Parsed webhook payload
|
||||
workspace: The Jira workspace
|
||||
user: The Jira user
|
||||
user_auth: OpenHands user authentication
|
||||
decrypted_api_key: Decrypted service account API key
|
||||
|
||||
Returns:
|
||||
A JiraViewInterface with selected_repo populated
|
||||
|
||||
Raises:
|
||||
StartingConvoException: If view creation fails
|
||||
RepositoryNotFoundError: If repository cannot be determined
|
||||
"""
|
||||
logger.info(
|
||||
'[Jira] Creating view',
|
||||
extra={
|
||||
'issue_key': payload.issue_key,
|
||||
'event_type': payload.event_type.value,
|
||||
},
|
||||
return JiraNewConversationView(
|
||||
job_context=job_context,
|
||||
saas_user_auth=saas_user_auth,
|
||||
jira_user=jira_user,
|
||||
jira_workspace=jira_workspace,
|
||||
selected_repo=None, # Will be set later after repo inference
|
||||
conversation_id='', # Will be set when conversation is created
|
||||
)
|
||||
|
||||
# Create the view
|
||||
view = JiraNewConversationView(
|
||||
payload=payload,
|
||||
saas_user_auth=user_auth,
|
||||
jira_user=user,
|
||||
jira_workspace=workspace,
|
||||
_decrypted_api_key=decrypted_api_key,
|
||||
)
|
||||
|
||||
# Fetch issue details (needed for repo inference)
|
||||
try:
|
||||
issue_title, issue_description = await view.get_issue_details()
|
||||
except StartingConvoException:
|
||||
raise # Re-raise with original message
|
||||
except Exception as e:
|
||||
raise StartingConvoException(f'Failed to fetch issue details: {str(e)}')
|
||||
|
||||
# Infer and select repository
|
||||
selected_repo = await JiraFactory._infer_repository(
|
||||
payload=payload,
|
||||
user_auth=user_auth,
|
||||
issue_title=issue_title,
|
||||
issue_description=issue_description,
|
||||
)
|
||||
|
||||
view.selected_repo = selected_repo
|
||||
|
||||
logger.info(
|
||||
'[Jira] View created successfully',
|
||||
extra={
|
||||
'issue_key': payload.issue_key,
|
||||
'selected_repo': selected_repo,
|
||||
},
|
||||
)
|
||||
|
||||
return view
|
||||
|
||||
@@ -16,6 +16,11 @@ class Manager(ABC):
|
||||
"Send message to integration from Openhands server"
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def is_job_requested(self, message: Message) -> bool:
|
||||
"Confirm that a job is being requested"
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def start_job(self):
|
||||
"Kick off a job with openhands agent"
|
||||
|
||||
@@ -398,42 +398,53 @@ def infer_repo_from_message(user_msg: str) -> list[str]:
|
||||
"""
|
||||
Extract all repository names in the format 'owner/repo' from various Git provider URLs
|
||||
and direct mentions in text. Supports GitHub, GitLab, and BitBucket.
|
||||
Args:
|
||||
user_msg: Input message that may contain repository references
|
||||
Returns:
|
||||
List of repository names in 'owner/repo' format, empty list if none found
|
||||
"""
|
||||
# Normalize the message by removing extra whitespace and newlines
|
||||
normalized_msg = re.sub(r'\s+', ' ', user_msg.strip())
|
||||
|
||||
git_url_pattern = (
|
||||
r'https?://(?:github\.com|gitlab\.com|bitbucket\.org)/'
|
||||
r'([a-zA-Z0-9_.-]+)/([a-zA-Z0-9_.-]+?)(?:\.git)?'
|
||||
r'(?:[/?#].*?)?(?=\s|$|[^\w.-])'
|
||||
)
|
||||
# Pattern to match Git URLs from GitHub, GitLab, and BitBucket
|
||||
# Captures: protocol, domain, owner, repo (with optional .git extension)
|
||||
git_url_pattern = r'https?://(?:github\.com|gitlab\.com|bitbucket\.org)/([a-zA-Z0-9_.-]+)/([a-zA-Z0-9_.-]+?)(?:\.git)?(?:[/?#].*?)?(?=\s|$|[^\w.-])'
|
||||
|
||||
# UPDATED: allow {{ owner/repo }} in addition to existing boundaries
|
||||
# Pattern to match direct owner/repo mentions (e.g., "OpenHands/OpenHands")
|
||||
# Must be surrounded by word boundaries or specific characters to avoid false positives
|
||||
direct_pattern = (
|
||||
r'(?:^|\s|{{|[\[\(\'":`])' # left boundary
|
||||
r'([a-zA-Z0-9_.-]+)/([a-zA-Z0-9_.-]+)'
|
||||
r'(?=\s|$|}}|[\]\)\'",.:`])' # right boundary
|
||||
r'(?:^|\s|[\[\(\'"])([a-zA-Z0-9_.-]+)/([a-zA-Z0-9_.-]+)(?=\s|$|[\]\)\'",.])'
|
||||
)
|
||||
|
||||
matches: list[str] = []
|
||||
matches = []
|
||||
|
||||
# Git URLs first (highest priority)
|
||||
for owner, repo in re.findall(git_url_pattern, normalized_msg):
|
||||
# First, find all Git URLs (highest priority)
|
||||
git_matches = re.findall(git_url_pattern, normalized_msg)
|
||||
for owner, repo in git_matches:
|
||||
# Remove .git extension if present
|
||||
repo = re.sub(r'\.git$', '', repo)
|
||||
matches.append(f'{owner}/{repo}')
|
||||
|
||||
# Direct mentions
|
||||
for owner, repo in re.findall(direct_pattern, normalized_msg):
|
||||
# Second, find all direct owner/repo mentions
|
||||
direct_matches = re.findall(direct_pattern, normalized_msg)
|
||||
for owner, repo in direct_matches:
|
||||
full_match = f'{owner}/{repo}'
|
||||
|
||||
# Skip if it looks like a version number, date, or file path
|
||||
if (
|
||||
re.match(r'^\d+\.\d+/\d+\.\d+$', full_match)
|
||||
or re.match(r'^\d{1,2}/\d{1,2}$', full_match)
|
||||
or re.match(r'^[A-Z]/[A-Z]$', full_match)
|
||||
or repo.endswith(('.txt', '.md', '.py', '.js'))
|
||||
or ('.' in repo and len(repo.split('.')) > 2)
|
||||
):
|
||||
re.match(r'^\d+\.\d+/\d+\.\d+$', full_match) # version numbers
|
||||
or re.match(r'^\d{1,2}/\d{1,2}$', full_match) # dates
|
||||
or re.match(r'^[A-Z]/[A-Z]$', full_match) # single letters
|
||||
or repo.endswith('.txt')
|
||||
or repo.endswith('.md') # file extensions
|
||||
or repo.endswith('.py')
|
||||
or repo.endswith('.js')
|
||||
or '.' in repo
|
||||
and len(repo.split('.')) > 2
|
||||
): # complex file paths
|
||||
continue
|
||||
|
||||
# Avoid duplicates from Git URLs already found
|
||||
if full_match not in matches:
|
||||
matches.append(full_match)
|
||||
|
||||
|
||||
@@ -1,28 +0,0 @@
|
||||
"""Add git_user_name and git_user_email columns to user table.
|
||||
|
||||
Revision ID: 090
|
||||
Revises: 089
|
||||
Create Date: 2025-01-22
|
||||
"""
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
revision = '090'
|
||||
down_revision = '089'
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
'user',
|
||||
sa.Column('git_user_name', sa.String, nullable=True),
|
||||
)
|
||||
op.add_column(
|
||||
'user',
|
||||
sa.Column('git_user_email', sa.String, nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column('user', 'git_user_email')
|
||||
op.drop_column('user', 'git_user_name')
|
||||
Generated
+12
-12
@@ -6102,14 +6102,14 @@ llama = ["llama-index (>=0.12.29,<0.13.0)", "llama-index-core (>=0.12.29,<0.13.0
|
||||
|
||||
[[package]]
|
||||
name = "openhands-agent-server"
|
||||
version = "1.10.0"
|
||||
version = "1.8.2"
|
||||
description = "OpenHands Agent Server - REST/WebSocket interface for OpenHands AI Agent"
|
||||
optional = false
|
||||
python-versions = ">=3.12"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "openhands_agent_server-1.10.0-py3-none-any.whl", hash = "sha256:2e21076fff5e7cf9d03a3b011e2c90a6a3a46d2da3f18db9f7553ac413229c22"},
|
||||
{file = "openhands_agent_server-1.10.0.tar.gz", hash = "sha256:2062da2496a98a6c23201d086f124e02329d6c6d9d1b47be55921c084a29f55a"},
|
||||
{file = "openhands_agent_server-1.8.2-py3-none-any.whl", hash = "sha256:e9abb2e0fe970715537d0e0fc1aea3dd64bb9e8b531f70cb72b3d4e486aaa46a"},
|
||||
{file = "openhands_agent_server-1.8.2.tar.gz", hash = "sha256:43db2371ee84b100ac921396338dee74359fceeb5c9400c90530bcc5730144c3"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -6168,9 +6168,9 @@ memory-profiler = ">=0.61"
|
||||
numpy = "*"
|
||||
openai = "2.8"
|
||||
openhands-aci = "0.3.2"
|
||||
openhands-agent-server = "1.10"
|
||||
openhands-sdk = "1.10"
|
||||
openhands-tools = "1.10"
|
||||
openhands-agent-server = "1.8.2"
|
||||
openhands-sdk = "1.8.2"
|
||||
openhands-tools = "1.8.2"
|
||||
opentelemetry-api = ">=1.33.1"
|
||||
opentelemetry-exporter-otlp-proto-grpc = ">=1.33.1"
|
||||
pathspec = ">=0.12.1"
|
||||
@@ -6225,14 +6225,14 @@ url = ".."
|
||||
|
||||
[[package]]
|
||||
name = "openhands-sdk"
|
||||
version = "1.10.0"
|
||||
version = "1.8.2"
|
||||
description = "OpenHands SDK - Core functionality for building AI agents"
|
||||
optional = false
|
||||
python-versions = ">=3.12"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "openhands_sdk-1.10.0-py3-none-any.whl", hash = "sha256:5c8875f2a07d7fabe3449914639572bef9003821207cb06aa237a239e964eed5"},
|
||||
{file = "openhands_sdk-1.10.0.tar.gz", hash = "sha256:93371b1af4532266ad2d225b9d7d3d711c745df31888efe643970673f62bdef9"},
|
||||
{file = "openhands_sdk-1.8.2-py3-none-any.whl", hash = "sha256:b4fad9581865ce222a3e6722384e4df56113db01bd34c2d2d408dfd9695365c0"},
|
||||
{file = "openhands_sdk-1.8.2.tar.gz", hash = "sha256:5bfb17c8b9515210d121249deb1f3d0dc407c3737edc55b5e73330b4571d61e3"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -6253,14 +6253,14 @@ boto3 = ["boto3 (>=1.35.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "openhands-tools"
|
||||
version = "1.10.0"
|
||||
version = "1.8.2"
|
||||
description = "OpenHands Tools - Runtime tools for AI agents"
|
||||
optional = false
|
||||
python-versions = ">=3.12"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "openhands_tools-1.10.0-py3-none-any.whl", hash = "sha256:1d5d2d1e34cc4ceb02c0ff1f008b06883ad48a8e7236ab8dd61ece64fbf8e2ed"},
|
||||
{file = "openhands_tools-1.10.0.tar.gz", hash = "sha256:7ed38cb13545ec2c4a35c26ece725d5b35788d30597db8b1904619c043ec1194"},
|
||||
{file = "openhands_tools-1.8.2-py3-none-any.whl", hash = "sha256:283f0c1fdd316914559cd16ade792383715478a8f5a73f7166daffc34bf9e5af"},
|
||||
{file = "openhands_tools-1.8.2.tar.gz", hash = "sha256:eae416e3867f7cb595129a33a4b9237886c4b8a075d2bc7618da55963f2747d5"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
|
||||
@@ -16,7 +16,6 @@ from keycloak.exceptions import (
|
||||
KeycloakError,
|
||||
KeycloakPostError,
|
||||
)
|
||||
from server.auth.auth_error import ExpiredError
|
||||
from server.auth.constants import (
|
||||
BITBUCKET_APP_CLIENT_ID,
|
||||
BITBUCKET_APP_CLIENT_SECRET,
|
||||
@@ -427,8 +426,6 @@ class TokenManager:
|
||||
access_token = data.get('access_token')
|
||||
refresh_token = data.get('refresh_token')
|
||||
if not access_token or not refresh_token:
|
||||
if data.get('error') == 'bad_refresh_token':
|
||||
raise ExpiredError()
|
||||
raise ValueError(
|
||||
'Failed to refresh token: missing access_token or refresh_token in response.'
|
||||
)
|
||||
|
||||
@@ -14,6 +14,7 @@ from storage.conversation_callback import (
|
||||
ConversationCallback,
|
||||
ConversationCallbackProcessor,
|
||||
)
|
||||
from storage.database import session_maker
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.core.schema.agent import AgentState
|
||||
@@ -107,10 +108,13 @@ class GithubCallbackProcessor(ConversationCallbackProcessor):
|
||||
f'[GitHub] Sent summary instruction to conversation {conversation_id} {summary_event}'
|
||||
)
|
||||
|
||||
# Update the processor state - the outer session will commit this
|
||||
# Update the processor state
|
||||
self.send_summary_instruction = False
|
||||
callback.set_processor(self)
|
||||
callback.updated_at = datetime.now()
|
||||
with session_maker() as session:
|
||||
session.merge(callback)
|
||||
session.commit()
|
||||
return
|
||||
|
||||
# Extract the summary from the event store
|
||||
@@ -126,15 +130,14 @@ class GithubCallbackProcessor(ConversationCallbackProcessor):
|
||||
|
||||
logger.info(f'[GitHub] Summary sent for conversation {conversation_id}')
|
||||
|
||||
# Mark callback as completed status - the outer session will commit this
|
||||
# Mark callback as completed status
|
||||
callback.status = CallbackStatus.COMPLETED
|
||||
callback.updated_at = datetime.now()
|
||||
with session_maker() as session:
|
||||
session.merge(callback)
|
||||
session.commit()
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f'[GitHub] Error processing conversation callback: {str(e)}'
|
||||
)
|
||||
# Mark callback as error to prevent infinite re-invocation
|
||||
# The outer session will commit this
|
||||
callback.status = CallbackStatus.ERROR
|
||||
callback.updated_at = datetime.now()
|
||||
|
||||
@@ -144,7 +144,7 @@ class SetAuthCookieMiddleware:
|
||||
# "if accepted_tos is not None" as there should not be any users with
|
||||
# accepted_tos equal to "None"
|
||||
if accepted_tos is False and request.url.path != '/api/accept_tos':
|
||||
logger.warning('User has not accepted the terms of service')
|
||||
logger.error('User has not accepted the terms of service')
|
||||
raise TosNotAcceptedError
|
||||
|
||||
def _should_attach(self, request: Request) -> bool:
|
||||
|
||||
@@ -13,33 +13,46 @@ from server.constants import (
|
||||
STRIPE_API_KEY,
|
||||
)
|
||||
from server.logger import logger
|
||||
from starlette.datastructures import URL
|
||||
from storage.billing_session import BillingSession
|
||||
from storage.database import session_maker
|
||||
from storage.lite_llm_manager import LiteLlmManager
|
||||
from storage.subscription_access import SubscriptionAccess
|
||||
from storage.user_store import UserStore
|
||||
|
||||
from openhands.app_server.config import get_global_config
|
||||
from openhands.server.user_auth import get_user_id
|
||||
|
||||
stripe.api_key = STRIPE_API_KEY
|
||||
billing_router = APIRouter(prefix='/api/billing')
|
||||
|
||||
|
||||
async def validate_billing_enabled() -> None:
|
||||
# TODO: Add a new app_mode named "ON_PREM" to support self-hosted customers instead of doing this
|
||||
# and members should comment out the "validate_saas_environment" function if they are developing and testing locally.
|
||||
def is_all_hands_saas_environment(request: Request) -> bool:
|
||||
"""Check if the current domain is an All Hands SaaS environment.
|
||||
|
||||
Args:
|
||||
request: FastAPI Request object
|
||||
|
||||
Returns:
|
||||
True if the current domain contains "all-hands.dev" or "openhands.dev" postfix
|
||||
"""
|
||||
Validate that the billing feature flag is enabled
|
||||
hostname = request.url.hostname or ''
|
||||
return hostname.endswith('all-hands.dev') or hostname.endswith('openhands.dev')
|
||||
|
||||
|
||||
def validate_saas_environment(request: Request) -> None:
|
||||
"""Validate that the request is coming from an All Hands SaaS environment.
|
||||
|
||||
Args:
|
||||
request: FastAPI Request object
|
||||
|
||||
Raises:
|
||||
HTTPException: If the request is not from an All Hands SaaS environment
|
||||
"""
|
||||
config = get_global_config()
|
||||
web_client_config = await config.web_client.get_web_client_config()
|
||||
if not web_client_config.feature_flags.enable_billing:
|
||||
if not is_all_hands_saas_environment(request):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=(
|
||||
'Billing is disabled in this environment. '
|
||||
'Please set OH_WEB_CLIENT_FEATURE_FLAGS_ENABLE_BILLING to enable billing.'
|
||||
),
|
||||
detail='Checkout sessions are only available for All Hands SaaS environments',
|
||||
)
|
||||
|
||||
|
||||
@@ -141,15 +154,14 @@ async def has_payment_method(user_id: str = Depends(get_user_id)) -> bool:
|
||||
async def create_customer_setup_session(
|
||||
request: Request, user_id: str = Depends(get_user_id)
|
||||
) -> CreateBillingSessionResponse:
|
||||
await validate_billing_enabled()
|
||||
validate_saas_environment(request)
|
||||
customer_info = await stripe_service.find_or_create_customer_by_user_id(user_id)
|
||||
base_url = _get_base_url(request)
|
||||
checkout_session = await stripe.checkout.Session.create_async(
|
||||
customer=customer_info['customer_id'],
|
||||
mode='setup',
|
||||
payment_method_types=['card'],
|
||||
success_url=f'{base_url}?free_credits=success',
|
||||
cancel_url=f'{base_url}',
|
||||
success_url=f'{request.base_url}?free_credits=success',
|
||||
cancel_url=f'{request.base_url}',
|
||||
)
|
||||
return CreateBillingSessionResponse(redirect_url=checkout_session.url)
|
||||
|
||||
@@ -161,8 +173,8 @@ async def create_checkout_session(
|
||||
request: Request,
|
||||
user_id: str = Depends(get_user_id),
|
||||
) -> CreateBillingSessionResponse:
|
||||
await validate_billing_enabled()
|
||||
base_url = _get_base_url(request)
|
||||
validate_saas_environment(request)
|
||||
|
||||
customer_info = await stripe_service.find_or_create_customer_by_user_id(user_id)
|
||||
checkout_session = await stripe.checkout.Session.create_async(
|
||||
customer=customer_info['customer_id'],
|
||||
@@ -185,8 +197,8 @@ async def create_checkout_session(
|
||||
saved_payment_method_options={
|
||||
'payment_method_save': 'enabled',
|
||||
},
|
||||
success_url=f'{base_url}api/billing/success?session_id={{CHECKOUT_SESSION_ID}}',
|
||||
cancel_url=f'{base_url}api/billing/cancel?session_id={{CHECKOUT_SESSION_ID}}',
|
||||
success_url=f'{request.base_url}api/billing/success?session_id={{CHECKOUT_SESSION_ID}}',
|
||||
cancel_url=f'{request.base_url}api/billing/cancel?session_id={{CHECKOUT_SESSION_ID}}',
|
||||
)
|
||||
logger.info(
|
||||
'created_stripe_checkout_session',
|
||||
@@ -277,7 +289,7 @@ async def success_callback(session_id: str, request: Request):
|
||||
session.commit()
|
||||
|
||||
return RedirectResponse(
|
||||
f'{_get_base_url(request)}settings/billing?checkout=success', status_code=302
|
||||
f'{request.base_url}settings/billing?checkout=success', status_code=302
|
||||
)
|
||||
|
||||
|
||||
@@ -305,13 +317,5 @@ async def cancel_callback(session_id: str, request: Request):
|
||||
session.commit()
|
||||
|
||||
return RedirectResponse(
|
||||
f'{_get_base_url(request)}settings/billing?checkout=cancel', status_code=302
|
||||
f'{request.base_url}settings/billing?checkout=cancel', status_code=302
|
||||
)
|
||||
|
||||
|
||||
def _get_base_url(request: Request) -> URL:
|
||||
# Never send any part of the credit card process over a non secure connection
|
||||
base_url = request.base_url
|
||||
if base_url.hostname != 'localhost':
|
||||
base_url = base_url.replace(scheme='https')
|
||||
return base_url
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import zlib
|
||||
from base64 import b64decode, b64encode
|
||||
from urllib.parse import parse_qs, urlencode, urlparse
|
||||
|
||||
@@ -52,11 +51,7 @@ def add_github_proxy_routes(app: FastAPI):
|
||||
state_payload = json.dumps(
|
||||
[query_params['state'][0], query_params['redirect_uri'][0]]
|
||||
)
|
||||
# Compress before encrypting to reduce URL length
|
||||
# This is critical for feature deployments where reCAPTCHA tokens in state
|
||||
# can cause "URL too long" errors from GitHub
|
||||
compressed_payload = zlib.compress(state_payload.encode())
|
||||
state = b64encode(_fernet().encrypt(compressed_payload)).decode()
|
||||
state = b64encode(_fernet().encrypt(state_payload.encode())).decode()
|
||||
query_params['state'] = [state]
|
||||
query_params['redirect_uri'] = [
|
||||
f'https://{request.url.netloc}/github-proxy/callback'
|
||||
@@ -72,9 +67,7 @@ def add_github_proxy_routes(app: FastAPI):
|
||||
parsed_url = urlparse(str(request.url))
|
||||
query_params = parse_qs(parsed_url.query)
|
||||
state = query_params['state'][0]
|
||||
# Decrypt and decompress (reverse of github_proxy_start)
|
||||
decrypted_payload = _fernet().decrypt(b64decode(state.encode()))
|
||||
decrypted_state = zlib.decompress(decrypted_payload).decode()
|
||||
decrypted_state = _fernet().decrypt(b64decode(state.encode())).decode()
|
||||
|
||||
# Build query Params
|
||||
state, redirect_uri = json.loads(decrypted_state)
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from pydantic import BaseModel, EmailStr, Field
|
||||
from storage.org import Org
|
||||
|
||||
|
||||
class OrgCreationError(Exception):
|
||||
@@ -28,27 +27,6 @@ class OrgDatabaseError(OrgCreationError):
|
||||
pass
|
||||
|
||||
|
||||
class OrgDeletionError(Exception):
|
||||
"""Base exception for organization deletion errors."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class OrgAuthorizationError(OrgDeletionError):
|
||||
"""Raised when user is not authorized to delete organization."""
|
||||
|
||||
def __init__(self, message: str = 'Not authorized to delete organization'):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class OrgNotFoundError(Exception):
|
||||
"""Raised when organization is not found or user doesn't have access."""
|
||||
|
||||
def __init__(self, org_id: str):
|
||||
self.org_id = org_id
|
||||
super().__init__(f'Organization with id "{org_id}" not found')
|
||||
|
||||
|
||||
class OrgCreate(BaseModel):
|
||||
"""Request model for creating a new organization."""
|
||||
|
||||
@@ -87,85 +65,3 @@ class OrgResponse(BaseModel):
|
||||
enable_solvability_analysis: bool | None = None
|
||||
v1_enabled: bool | None = None
|
||||
credits: float | None = None
|
||||
|
||||
@classmethod
|
||||
def from_org(cls, org: Org, credits: float | None = None) -> 'OrgResponse':
|
||||
"""Create an OrgResponse from an Org entity.
|
||||
|
||||
Args:
|
||||
org: The organization entity to convert
|
||||
credits: Optional credits value (defaults to None)
|
||||
|
||||
Returns:
|
||||
OrgResponse: The response model instance
|
||||
"""
|
||||
return cls(
|
||||
id=str(org.id),
|
||||
name=org.name,
|
||||
contact_name=org.contact_name,
|
||||
contact_email=org.contact_email,
|
||||
conversation_expiration=org.conversation_expiration,
|
||||
agent=org.agent,
|
||||
default_max_iterations=org.default_max_iterations,
|
||||
security_analyzer=org.security_analyzer,
|
||||
confirmation_mode=org.confirmation_mode,
|
||||
default_llm_model=org.default_llm_model,
|
||||
default_llm_api_key_for_byor=None,
|
||||
default_llm_base_url=org.default_llm_base_url,
|
||||
remote_runtime_resource_factor=org.remote_runtime_resource_factor,
|
||||
enable_default_condenser=org.enable_default_condenser
|
||||
if org.enable_default_condenser is not None
|
||||
else True,
|
||||
billing_margin=org.billing_margin,
|
||||
enable_proactive_conversation_starters=org.enable_proactive_conversation_starters
|
||||
if org.enable_proactive_conversation_starters is not None
|
||||
else True,
|
||||
sandbox_base_container_image=org.sandbox_base_container_image,
|
||||
sandbox_runtime_container_image=org.sandbox_runtime_container_image,
|
||||
org_version=org.org_version if org.org_version is not None else 0,
|
||||
mcp_config=org.mcp_config,
|
||||
search_api_key=None,
|
||||
sandbox_api_key=None,
|
||||
max_budget_per_task=org.max_budget_per_task,
|
||||
enable_solvability_analysis=org.enable_solvability_analysis,
|
||||
v1_enabled=org.v1_enabled,
|
||||
credits=credits,
|
||||
)
|
||||
|
||||
|
||||
class OrgPage(BaseModel):
|
||||
"""Paginated response model for organization list."""
|
||||
|
||||
items: list[OrgResponse]
|
||||
next_page_id: str | None = None
|
||||
|
||||
|
||||
class OrgUpdate(BaseModel):
|
||||
"""Request model for updating an organization."""
|
||||
|
||||
# Basic organization information (any authenticated user can update)
|
||||
contact_name: str | None = None
|
||||
contact_email: EmailStr | None = Field(default=None, strip_whitespace=True)
|
||||
conversation_expiration: int | None = None
|
||||
default_max_iterations: int | None = Field(default=None, gt=0)
|
||||
remote_runtime_resource_factor: int | None = Field(default=None, gt=0)
|
||||
billing_margin: float | None = Field(default=None, ge=0, le=1)
|
||||
enable_proactive_conversation_starters: bool | None = None
|
||||
sandbox_base_container_image: str | None = None
|
||||
sandbox_runtime_container_image: str | None = None
|
||||
mcp_config: dict | None = None
|
||||
sandbox_api_key: str | None = None
|
||||
max_budget_per_task: float | None = Field(default=None, gt=0)
|
||||
enable_solvability_analysis: bool | None = None
|
||||
v1_enabled: bool | None = None
|
||||
|
||||
# LLM settings (require admin/owner role)
|
||||
default_llm_model: str | None = None
|
||||
default_llm_api_key_for_byor: str | None = None
|
||||
default_llm_base_url: str | None = None
|
||||
search_api_key: str | None = None
|
||||
security_analyzer: str | None = None
|
||||
agent: str | None = None
|
||||
confirmation_mode: bool | None = None
|
||||
enable_default_condenser: bool | None = None
|
||||
condenser_max_size: int | None = Field(default=None, ge=20)
|
||||
|
||||
@@ -1,98 +1,20 @@
|
||||
from typing import Annotated
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from server.email_validation import get_admin_user_id
|
||||
from server.routes.org_models import (
|
||||
LiteLLMIntegrationError,
|
||||
OrgAuthorizationError,
|
||||
OrgCreate,
|
||||
OrgDatabaseError,
|
||||
OrgNameExistsError,
|
||||
OrgNotFoundError,
|
||||
OrgPage,
|
||||
OrgResponse,
|
||||
OrgUpdate,
|
||||
)
|
||||
from storage.org_service import OrgService
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.server.user_auth import get_user_id
|
||||
|
||||
# Initialize API router
|
||||
org_router = APIRouter(prefix='/api/organizations')
|
||||
|
||||
|
||||
@org_router.get('', response_model=OrgPage)
|
||||
async def list_user_orgs(
|
||||
page_id: Annotated[
|
||||
str | None,
|
||||
Query(title='Optional next_page_id from the previously returned page'),
|
||||
] = None,
|
||||
limit: Annotated[
|
||||
int,
|
||||
Query(title='The max number of results in the page', gt=0, lte=100),
|
||||
] = 100,
|
||||
user_id: str = Depends(get_user_id),
|
||||
) -> OrgPage:
|
||||
"""List organizations for the authenticated user.
|
||||
|
||||
This endpoint returns a paginated list of all organizations that the
|
||||
authenticated user is a member of.
|
||||
|
||||
Args:
|
||||
page_id: Optional page ID (offset) for pagination
|
||||
limit: Maximum number of organizations to return (1-100, default 100)
|
||||
user_id: Authenticated user ID (injected by dependency)
|
||||
|
||||
Returns:
|
||||
OrgPage: Paginated list of organizations
|
||||
|
||||
Raises:
|
||||
HTTPException: 500 if retrieval fails
|
||||
"""
|
||||
logger.info(
|
||||
'Listing organizations for user',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'page_id': page_id,
|
||||
'limit': limit,
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
# Fetch organizations from service layer
|
||||
orgs, next_page_id = OrgService.get_user_orgs_paginated(
|
||||
user_id=user_id,
|
||||
page_id=page_id,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
# Convert Org entities to OrgResponse objects
|
||||
org_responses = [OrgResponse.from_org(org, credits=None) for org in orgs]
|
||||
|
||||
logger.info(
|
||||
'Successfully retrieved organizations',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'org_count': len(org_responses),
|
||||
'has_more': next_page_id is not None,
|
||||
},
|
||||
)
|
||||
|
||||
return OrgPage(items=org_responses, next_page_id=next_page_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
'Unexpected error listing organizations',
|
||||
extra={'user_id': user_id, 'error': str(e)},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='Failed to retrieve organizations',
|
||||
)
|
||||
|
||||
|
||||
@org_router.post('', response_model=OrgResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def create_org(
|
||||
org_data: OrgCreate,
|
||||
@@ -136,7 +58,31 @@ async def create_org(
|
||||
# Retrieve credits from LiteLLM
|
||||
credits = await OrgService.get_org_credits(user_id, org.id)
|
||||
|
||||
return OrgResponse.from_org(org, credits=credits)
|
||||
return OrgResponse(
|
||||
id=str(org.id),
|
||||
name=org.name,
|
||||
contact_name=org.contact_name,
|
||||
contact_email=org.contact_email,
|
||||
conversation_expiration=org.conversation_expiration,
|
||||
agent=org.agent,
|
||||
default_max_iterations=org.default_max_iterations,
|
||||
security_analyzer=org.security_analyzer,
|
||||
confirmation_mode=org.confirmation_mode,
|
||||
default_llm_model=org.default_llm_model,
|
||||
default_llm_base_url=org.default_llm_base_url,
|
||||
remote_runtime_resource_factor=org.remote_runtime_resource_factor,
|
||||
enable_default_condenser=org.enable_default_condenser,
|
||||
billing_margin=org.billing_margin,
|
||||
enable_proactive_conversation_starters=org.enable_proactive_conversation_starters,
|
||||
sandbox_base_container_image=org.sandbox_base_container_image,
|
||||
sandbox_runtime_container_image=org.sandbox_runtime_container_image,
|
||||
org_version=org.org_version,
|
||||
mcp_config=org.mcp_config,
|
||||
max_budget_per_task=org.max_budget_per_task,
|
||||
enable_solvability_analysis=org.enable_solvability_analysis,
|
||||
v1_enabled=org.v1_enabled,
|
||||
credits=credits,
|
||||
)
|
||||
except OrgNameExistsError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
@@ -169,234 +115,3 @@ async def create_org(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='An unexpected error occurred',
|
||||
)
|
||||
|
||||
|
||||
@org_router.get('/{org_id}', response_model=OrgResponse, status_code=status.HTTP_200_OK)
|
||||
async def get_org(
|
||||
org_id: UUID,
|
||||
user_id: str = Depends(get_user_id),
|
||||
) -> OrgResponse:
|
||||
"""Get organization details by ID.
|
||||
|
||||
This endpoint allows authenticated users who are members of an organization
|
||||
to retrieve its details. Only members of the organization can access this endpoint.
|
||||
|
||||
Args:
|
||||
org_id: Organization ID (UUID)
|
||||
user_id: Authenticated user ID (injected by dependency)
|
||||
|
||||
Returns:
|
||||
OrgResponse: The organization details
|
||||
|
||||
Raises:
|
||||
HTTPException: 422 if org_id is not a valid UUID (handled by FastAPI)
|
||||
HTTPException: 404 if organization not found or user is not a member
|
||||
HTTPException: 500 if retrieval fails
|
||||
"""
|
||||
logger.info(
|
||||
'Retrieving organization details',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'org_id': str(org_id),
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
# Use service layer to get organization with membership validation
|
||||
org = await OrgService.get_org_by_id(
|
||||
org_id=org_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
# Retrieve credits from LiteLLM
|
||||
credits = await OrgService.get_org_credits(user_id, org.id)
|
||||
|
||||
return OrgResponse.from_org(org, credits=credits)
|
||||
except OrgNotFoundError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=str(e),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
'Unexpected error retrieving organization',
|
||||
extra={'user_id': user_id, 'org_id': str(org_id), 'error': str(e)},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='An unexpected error occurred',
|
||||
)
|
||||
|
||||
|
||||
@org_router.delete('/{org_id}', status_code=status.HTTP_200_OK)
|
||||
async def delete_org(
|
||||
org_id: UUID,
|
||||
user_id: str = Depends(get_admin_user_id),
|
||||
) -> dict:
|
||||
"""Delete an organization.
|
||||
|
||||
This endpoint allows authenticated organization owners to delete their organization.
|
||||
All associated data including organization members, conversations, billing data,
|
||||
and external LiteLLM team resources will be permanently removed.
|
||||
|
||||
Args:
|
||||
org_id: Organization ID to delete
|
||||
user_id: Authenticated user ID (injected by dependency)
|
||||
|
||||
Returns:
|
||||
dict: Confirmation message with deleted organization details
|
||||
|
||||
Raises:
|
||||
HTTPException: 403 if user is not the organization owner
|
||||
HTTPException: 404 if organization not found
|
||||
HTTPException: 500 if deletion fails
|
||||
"""
|
||||
logger.info(
|
||||
'Organization deletion requested',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'org_id': str(org_id),
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
# Use service layer to delete organization with cleanup
|
||||
deleted_org = await OrgService.delete_org_with_cleanup(
|
||||
user_id=user_id,
|
||||
org_id=org_id,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
'Organization deletion completed successfully',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'org_id': str(org_id),
|
||||
'org_name': deleted_org.name,
|
||||
},
|
||||
)
|
||||
|
||||
return {
|
||||
'message': 'Organization deleted successfully',
|
||||
'organization': {
|
||||
'id': str(deleted_org.id),
|
||||
'name': deleted_org.name,
|
||||
'contact_name': deleted_org.contact_name,
|
||||
'contact_email': deleted_org.contact_email,
|
||||
},
|
||||
}
|
||||
|
||||
except OrgNotFoundError as e:
|
||||
logger.warning(
|
||||
'Organization not found for deletion',
|
||||
extra={'user_id': user_id, 'org_id': str(org_id)},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=str(e),
|
||||
)
|
||||
except OrgAuthorizationError as e:
|
||||
logger.warning(
|
||||
'User not authorized to delete organization',
|
||||
extra={'user_id': user_id, 'org_id': str(org_id)},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=str(e),
|
||||
)
|
||||
except OrgDatabaseError as e:
|
||||
logger.error(
|
||||
'Database error during organization deletion',
|
||||
extra={'user_id': user_id, 'org_id': str(org_id), 'error': str(e)},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='Failed to delete organization',
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
'Unexpected error during organization deletion',
|
||||
extra={'user_id': user_id, 'org_id': str(org_id), 'error': str(e)},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='An unexpected error occurred',
|
||||
)
|
||||
|
||||
|
||||
@org_router.patch('/{org_id}', response_model=OrgResponse)
|
||||
async def update_org(
|
||||
org_id: UUID,
|
||||
update_data: OrgUpdate,
|
||||
user_id: str = Depends(get_user_id),
|
||||
) -> OrgResponse:
|
||||
"""Update an existing organization.
|
||||
|
||||
This endpoint allows authenticated users to update organization settings.
|
||||
LLM-related settings require admin or owner role in the organization.
|
||||
|
||||
Args:
|
||||
org_id: Organization ID to update (UUID validated by FastAPI)
|
||||
update_data: Organization update data
|
||||
user_id: Authenticated user ID (injected by dependency)
|
||||
|
||||
Returns:
|
||||
OrgResponse: The updated organization details
|
||||
|
||||
Raises:
|
||||
HTTPException: 400 if org_id is invalid UUID format (handled by FastAPI)
|
||||
HTTPException: 403 if user lacks permission for LLM settings
|
||||
HTTPException: 404 if organization not found
|
||||
HTTPException: 422 if validation errors occur (handled by FastAPI)
|
||||
HTTPException: 500 if update fails
|
||||
"""
|
||||
logger.info(
|
||||
'Updating organization',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'org_id': str(org_id),
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
# Use service layer to update organization with permission checks
|
||||
updated_org = await OrgService.update_org_with_permissions(
|
||||
org_id=org_id,
|
||||
update_data=update_data,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
# Retrieve credits from LiteLLM (following same pattern as create endpoint)
|
||||
credits = await OrgService.get_org_credits(user_id, updated_org.id)
|
||||
|
||||
return OrgResponse.from_org(updated_org, credits=credits)
|
||||
|
||||
except ValueError as e:
|
||||
# Organization not found
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=str(e),
|
||||
)
|
||||
except PermissionError as e:
|
||||
# User lacks permission for LLM settings
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=str(e),
|
||||
)
|
||||
except OrgDatabaseError as e:
|
||||
logger.error(
|
||||
'Database operation failed',
|
||||
extra={'user_id': user_id, 'org_id': str(org_id), 'error': str(e)},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='Failed to update organization',
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
'Unexpected error updating organization',
|
||||
extra={'user_id': user_id, 'org_id': str(org_id), 'error': str(e)},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='An unexpected error occurred',
|
||||
)
|
||||
|
||||
@@ -26,7 +26,6 @@ from server.sharing.shared_conversation_models import (
|
||||
)
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from storage.stored_conversation_metadata_saas import StoredConversationMetadataSaas
|
||||
|
||||
from openhands.app_server.app_conversation.sql_app_conversation_info_service import (
|
||||
StoredConversationMetadata,
|
||||
@@ -58,7 +57,7 @@ class SQLSharedConversationInfoService(SharedConversationInfoService):
|
||||
include_sub_conversations: bool = False,
|
||||
) -> SharedConversationPage:
|
||||
"""Search for shared conversations."""
|
||||
query = self._public_select_with_saas_metadata()
|
||||
query = self._public_select()
|
||||
|
||||
# Conditionally exclude sub-conversations based on the parameter
|
||||
if not include_sub_conversations:
|
||||
@@ -105,17 +104,14 @@ class SQLSharedConversationInfoService(SharedConversationInfoService):
|
||||
query = query.limit(limit + 1)
|
||||
|
||||
result = await self.db_session.execute(query)
|
||||
rows = result.all()
|
||||
rows = result.scalars().all()
|
||||
|
||||
# Check if there are more results
|
||||
has_more = len(rows) > limit
|
||||
if has_more:
|
||||
rows = rows[:limit]
|
||||
|
||||
items = [
|
||||
self._to_shared_conversation(stored, saas_metadata=saas_metadata)
|
||||
for stored, saas_metadata in rows
|
||||
]
|
||||
items = [self._to_shared_conversation(row) for row in rows]
|
||||
|
||||
# Calculate next page ID
|
||||
next_page_id = None
|
||||
@@ -156,18 +152,17 @@ class SQLSharedConversationInfoService(SharedConversationInfoService):
|
||||
self, conversation_id: UUID
|
||||
) -> SharedConversation | None:
|
||||
"""Get a single public conversation info, returning None if missing or not shared."""
|
||||
query = self._public_select_with_saas_metadata().where(
|
||||
query = self._public_select().where(
|
||||
StoredConversationMetadata.conversation_id == str(conversation_id)
|
||||
)
|
||||
|
||||
result = await self.db_session.execute(query)
|
||||
row = result.first()
|
||||
stored = result.scalar_one_or_none()
|
||||
|
||||
if row is None:
|
||||
if stored is None:
|
||||
return None
|
||||
|
||||
stored, saas_metadata = row
|
||||
return self._to_shared_conversation(stored, saas_metadata=saas_metadata)
|
||||
return self._to_shared_conversation(stored)
|
||||
|
||||
def _public_select(self):
|
||||
"""Create a select query that only returns public conversations."""
|
||||
@@ -178,25 +173,6 @@ class SQLSharedConversationInfoService(SharedConversationInfoService):
|
||||
query = query.where(StoredConversationMetadata.public == True) # noqa: E712
|
||||
return query
|
||||
|
||||
def _public_select_with_saas_metadata(self):
|
||||
"""Create a select query that returns public conversations with SAAS metadata.
|
||||
|
||||
This joins with conversation_metadata_saas to retrieve the user_id needed
|
||||
for constructing the correct event storage path. Uses LEFT OUTER JOIN to
|
||||
support conversations that may not have SAAS metadata (e.g., in tests).
|
||||
"""
|
||||
query = (
|
||||
select(StoredConversationMetadata, StoredConversationMetadataSaas)
|
||||
.outerjoin(
|
||||
StoredConversationMetadataSaas,
|
||||
StoredConversationMetadata.conversation_id
|
||||
== StoredConversationMetadataSaas.conversation_id,
|
||||
)
|
||||
.where(StoredConversationMetadata.conversation_version == 'V1')
|
||||
.where(StoredConversationMetadata.public == True) # noqa: E712
|
||||
)
|
||||
return query
|
||||
|
||||
def _apply_filters(
|
||||
self,
|
||||
query,
|
||||
@@ -235,16 +211,9 @@ class SQLSharedConversationInfoService(SharedConversationInfoService):
|
||||
def _to_shared_conversation(
|
||||
self,
|
||||
stored: StoredConversationMetadata,
|
||||
saas_metadata: StoredConversationMetadataSaas | None = None,
|
||||
sub_conversation_ids: list[UUID] | None = None,
|
||||
) -> SharedConversation:
|
||||
"""Convert StoredConversationMetadata to SharedConversation.
|
||||
|
||||
Args:
|
||||
stored: The base conversation metadata from conversation_metadata table.
|
||||
saas_metadata: Optional SAAS metadata containing user_id and org_id.
|
||||
sub_conversation_ids: Optional list of sub-conversation IDs.
|
||||
"""
|
||||
"""Convert StoredConversationMetadata to SharedConversation."""
|
||||
# V1 conversations should always have a sandbox_id
|
||||
sandbox_id = stored.sandbox_id
|
||||
assert sandbox_id is not None
|
||||
@@ -270,16 +239,9 @@ class SQLSharedConversationInfoService(SharedConversationInfoService):
|
||||
created_at = self._fix_timezone(stored.created_at)
|
||||
updated_at = self._fix_timezone(stored.last_updated_at)
|
||||
|
||||
# Get user_id from SAAS metadata if available
|
||||
created_by_user_id = (
|
||||
str(saas_metadata.user_id)
|
||||
if saas_metadata and saas_metadata.user_id
|
||||
else None
|
||||
)
|
||||
|
||||
return SharedConversation(
|
||||
id=UUID(stored.conversation_id),
|
||||
created_by_user_id=created_by_user_id,
|
||||
created_by_user_id=None, # user_id is no longer stored in conversation metadata
|
||||
sandbox_id=stored.sandbox_id,
|
||||
selected_repository=stored.selected_repository,
|
||||
selected_branch=stored.selected_branch,
|
||||
|
||||
@@ -98,29 +98,6 @@ def decrypt_legacy_value(value: str | SecretStr) -> str:
|
||||
return get_fernet().decrypt(b64decode(value.encode())).decode()
|
||||
|
||||
|
||||
def encrypt_legacy_model(encrypt_keys: list, model_instance) -> dict:
|
||||
return encrypt_legacy_kwargs(encrypt_keys, model_to_kwargs(model_instance))
|
||||
|
||||
|
||||
def encrypt_legacy_kwargs(encrypt_keys: list, kwargs: dict) -> dict:
|
||||
for key, value in kwargs.items():
|
||||
if value is None:
|
||||
continue
|
||||
if key in encrypt_keys:
|
||||
value = encrypt_legacy_value(value)
|
||||
kwargs[key] = value
|
||||
return kwargs
|
||||
|
||||
|
||||
def encrypt_legacy_value(value: str | SecretStr) -> str:
|
||||
if isinstance(value, SecretStr):
|
||||
return b64encode(
|
||||
get_fernet().encrypt(value.get_secret_value().encode())
|
||||
).decode()
|
||||
else:
|
||||
return b64encode(get_fernet().encrypt(value.encode())).decode()
|
||||
|
||||
|
||||
def get_fernet():
|
||||
global _fernet
|
||||
if _fernet is None:
|
||||
|
||||
@@ -96,7 +96,7 @@ class LiteLlmManager:
|
||||
user_settings: UserSettings,
|
||||
) -> UserSettings | None:
|
||||
logger.info(
|
||||
'LiteLlmManager:migrate_lite_llm_entries:start',
|
||||
'SettingsStore:umigrate_lite_llm_entries:start',
|
||||
extra={'org_id': org_id, 'user_id': keycloak_user_id},
|
||||
)
|
||||
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
|
||||
@@ -141,35 +141,19 @@ class LiteLlmManager:
|
||||
return None
|
||||
credits = max(max_budget - spend, 0.0)
|
||||
|
||||
logger.debug(
|
||||
'LiteLlmManager:migrate_lite_llm_entries:create_team',
|
||||
extra={'org_id': org_id, 'user_id': keycloak_user_id},
|
||||
)
|
||||
await LiteLlmManager._create_team(
|
||||
client, keycloak_user_id, org_id, credits
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
'LiteLlmManager:migrate_lite_llm_entries:update_user',
|
||||
extra={'org_id': org_id, 'user_id': keycloak_user_id},
|
||||
)
|
||||
await LiteLlmManager._update_user(
|
||||
client, keycloak_user_id, max_budget=UNLIMITED_BUDGET_SETTING
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
'LiteLlmManager:migrate_lite_llm_entries:add_user_to_team',
|
||||
extra={'org_id': org_id, 'user_id': keycloak_user_id},
|
||||
)
|
||||
await LiteLlmManager._add_user_to_team(
|
||||
client, keycloak_user_id, org_id, credits
|
||||
)
|
||||
|
||||
if user_settings.llm_api_key:
|
||||
logger.debug(
|
||||
'LiteLlmManager:migrate_lite_llm_entries:update_key',
|
||||
extra={'org_id': org_id, 'user_id': keycloak_user_id},
|
||||
)
|
||||
await LiteLlmManager._update_key(
|
||||
client,
|
||||
keycloak_user_id,
|
||||
@@ -178,10 +162,6 @@ class LiteLlmManager:
|
||||
)
|
||||
|
||||
if user_settings.llm_api_key_for_byor:
|
||||
logger.debug(
|
||||
'LiteLlmManager:migrate_lite_llm_entries:update_byor_key',
|
||||
extra={'org_id': org_id, 'user_id': keycloak_user_id},
|
||||
)
|
||||
await LiteLlmManager._update_key(
|
||||
client,
|
||||
keycloak_user_id,
|
||||
@@ -189,167 +169,6 @@ class LiteLlmManager:
|
||||
team_id=org_id,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
'LiteLlmManager:migrate_lite_llm_entries:complete',
|
||||
extra={'org_id': org_id, 'user_id': keycloak_user_id},
|
||||
)
|
||||
return user_settings
|
||||
|
||||
@staticmethod
|
||||
async def downgrade_entries(
|
||||
org_id: str,
|
||||
keycloak_user_id: str,
|
||||
user_settings: UserSettings,
|
||||
) -> UserSettings | None:
|
||||
"""Downgrade a migrated user's LiteLLM entries back to the pre-migration state.
|
||||
|
||||
This reverses the migrate_entries operation:
|
||||
1. Get the user max budget from their org team in litellm
|
||||
2. Set the max budget in the user in litellm (restore from team)
|
||||
3. Add the user back to the default team in litellm
|
||||
4. Update keys to remove org team association
|
||||
5. Remove the user from their org team in litellm
|
||||
6. Delete the user org team in litellm
|
||||
|
||||
Note: The database changes (already_migrated flag, org/org_member deletion)
|
||||
should be handled separately by the caller.
|
||||
|
||||
Args:
|
||||
org_id: The organization ID (which is also the team_id in litellm)
|
||||
keycloak_user_id: The user's Keycloak ID
|
||||
user_settings: The user's settings object
|
||||
|
||||
Returns:
|
||||
The user_settings if downgrade was successful, None otherwise
|
||||
"""
|
||||
logger.info(
|
||||
'LiteLlmManager:downgrade_entries:start',
|
||||
extra={'org_id': org_id, 'user_id': keycloak_user_id},
|
||||
)
|
||||
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
|
||||
logger.warning('LiteLLM API configuration not found')
|
||||
return None
|
||||
|
||||
local_deploy = os.environ.get('LOCAL_DEPLOYMENT', None)
|
||||
if not local_deploy:
|
||||
async with httpx.AsyncClient(
|
||||
headers={
|
||||
'x-goog-api-key': LITE_LLM_API_KEY,
|
||||
}
|
||||
) as client:
|
||||
# Step 1: Get the team info to retrieve the budget
|
||||
logger.debug(
|
||||
'LiteLlmManager:downgrade_entries:get_team',
|
||||
extra={'org_id': org_id, 'user_id': keycloak_user_id},
|
||||
)
|
||||
team_info = await LiteLlmManager._get_team(client, org_id)
|
||||
if not team_info:
|
||||
logger.error(
|
||||
'LiteLlmManager:downgrade_entries:team_not_found',
|
||||
extra={'org_id': org_id, 'user_id': keycloak_user_id},
|
||||
)
|
||||
return None
|
||||
|
||||
# Get team budget (max_budget) and spend to calculate current credits
|
||||
team_data = team_info.get('team_info', {})
|
||||
max_budget = team_data.get('max_budget', 0.0)
|
||||
spend = team_data.get('spend', 0.0)
|
||||
|
||||
# Get user membership info for budget in team
|
||||
user_membership = await LiteLlmManager._get_user_team_info(
|
||||
client, keycloak_user_id, org_id
|
||||
)
|
||||
if user_membership:
|
||||
# Use user's budget in team if available
|
||||
user_max_budget_in_team = user_membership.get('max_budget_in_team')
|
||||
user_spend_in_team = user_membership.get('spend', 0.0)
|
||||
if user_max_budget_in_team is not None:
|
||||
max_budget = user_max_budget_in_team
|
||||
spend = user_spend_in_team
|
||||
|
||||
# Calculate total budget to restore (credits + spend = max_budget)
|
||||
# We restore the full max_budget that was on the team/user-in-team
|
||||
restored_budget = max_budget if max_budget else 0.0
|
||||
|
||||
logger.debug(
|
||||
'LiteLlmManager:downgrade_entries:budget_info',
|
||||
extra={
|
||||
'org_id': org_id,
|
||||
'user_id': keycloak_user_id,
|
||||
'max_budget': max_budget,
|
||||
'spend': spend,
|
||||
'restored_budget': restored_budget,
|
||||
},
|
||||
)
|
||||
|
||||
# Step 2: Update user to set their max_budget back from unlimited
|
||||
logger.debug(
|
||||
'LiteLlmManager:downgrade_entries:update_user',
|
||||
extra={'org_id': org_id, 'user_id': keycloak_user_id},
|
||||
)
|
||||
await LiteLlmManager._update_user(
|
||||
client, keycloak_user_id, max_budget=restored_budget, spend=spend
|
||||
)
|
||||
|
||||
# Step 3: Add user back to the default team
|
||||
if LITE_LLM_TEAM_ID:
|
||||
logger.debug(
|
||||
'LiteLlmManager:downgrade_entries:add_to_default_team',
|
||||
extra={
|
||||
'org_id': org_id,
|
||||
'user_id': keycloak_user_id,
|
||||
'default_team_id': LITE_LLM_TEAM_ID,
|
||||
},
|
||||
)
|
||||
await LiteLlmManager._add_user_to_team(
|
||||
client, keycloak_user_id, LITE_LLM_TEAM_ID, restored_budget
|
||||
)
|
||||
|
||||
# Step 4: Update keys to remove org team association (set team_id to default)
|
||||
if user_settings.llm_api_key:
|
||||
logger.debug(
|
||||
'LiteLlmManager:downgrade_entries:update_key',
|
||||
extra={'org_id': org_id, 'user_id': keycloak_user_id},
|
||||
)
|
||||
await LiteLlmManager._update_key(
|
||||
client,
|
||||
keycloak_user_id,
|
||||
user_settings.llm_api_key,
|
||||
team_id=LITE_LLM_TEAM_ID,
|
||||
)
|
||||
|
||||
if user_settings.llm_api_key_for_byor:
|
||||
logger.debug(
|
||||
'LiteLlmManager:downgrade_entries:update_byor_key',
|
||||
extra={'org_id': org_id, 'user_id': keycloak_user_id},
|
||||
)
|
||||
await LiteLlmManager._update_key(
|
||||
client,
|
||||
keycloak_user_id,
|
||||
user_settings.llm_api_key_for_byor,
|
||||
team_id=LITE_LLM_TEAM_ID,
|
||||
)
|
||||
|
||||
# Step 5: Remove user from their org team
|
||||
logger.debug(
|
||||
'LiteLlmManager:downgrade_entries:remove_from_org_team',
|
||||
extra={'org_id': org_id, 'user_id': keycloak_user_id},
|
||||
)
|
||||
await LiteLlmManager._remove_user_from_team(
|
||||
client, keycloak_user_id, org_id
|
||||
)
|
||||
|
||||
# Step 6: Delete the org team
|
||||
logger.debug(
|
||||
'LiteLlmManager:downgrade_entries:delete_team',
|
||||
extra={'org_id': org_id, 'user_id': keycloak_user_id},
|
||||
)
|
||||
await LiteLlmManager._delete_team(client, org_id)
|
||||
|
||||
logger.info(
|
||||
'LiteLlmManager:downgrade_entries:complete',
|
||||
extra={'org_id': org_id, 'user_id': keycloak_user_id},
|
||||
)
|
||||
return user_settings
|
||||
|
||||
@staticmethod
|
||||
@@ -794,45 +613,6 @@ class LiteLlmManager:
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
async def _remove_user_from_team(
|
||||
client: httpx.AsyncClient,
|
||||
keycloak_user_id: str,
|
||||
team_id: str,
|
||||
):
|
||||
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
|
||||
logger.warning('LiteLLM API configuration not found')
|
||||
return
|
||||
response = await client.post(
|
||||
f'{LITE_LLM_API_URL}/team/member_delete',
|
||||
json={
|
||||
'team_id': team_id,
|
||||
'user_id': keycloak_user_id,
|
||||
},
|
||||
)
|
||||
if not response.is_success:
|
||||
if response.status_code == 404:
|
||||
# User not in team, that's fine for downgrade
|
||||
logger.info(
|
||||
'User not in team during removal',
|
||||
extra={'user_id': keycloak_user_id, 'team_id': team_id},
|
||||
)
|
||||
return
|
||||
logger.error(
|
||||
'error_removing_litellm_user_from_team',
|
||||
extra={
|
||||
'status_code': response.status_code,
|
||||
'text': response.text,
|
||||
'user_id': keycloak_user_id,
|
||||
'team_id': team_id,
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
logger.info(
|
||||
'LiteLlmManager:_remove_user_from_team:user_removed',
|
||||
extra={'user_id': keycloak_user_id, 'team_id': team_id},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def _generate_key(
|
||||
client: httpx.AsyncClient,
|
||||
@@ -1076,7 +856,6 @@ class LiteLlmManager:
|
||||
delete_user = staticmethod(with_http_client(_delete_user))
|
||||
delete_team = staticmethod(with_http_client(_delete_team))
|
||||
add_user_to_team = staticmethod(with_http_client(_add_user_to_team))
|
||||
remove_user_from_team = staticmethod(with_http_client(_remove_user_from_team))
|
||||
get_user_team_info = staticmethod(with_http_client(_get_user_team_info))
|
||||
update_user_in_team = staticmethod(with_http_client(_update_user_in_team))
|
||||
generate_key = staticmethod(with_http_client(_generate_key))
|
||||
|
||||
@@ -9,11 +9,8 @@ from uuid import UUID as parse_uuid
|
||||
from server.constants import ORG_SETTINGS_VERSION, get_default_litellm_model
|
||||
from server.routes.org_models import (
|
||||
LiteLLMIntegrationError,
|
||||
OrgAuthorizationError,
|
||||
OrgDatabaseError,
|
||||
OrgNameExistsError,
|
||||
OrgNotFoundError,
|
||||
OrgUpdate,
|
||||
)
|
||||
from storage.lite_llm_manager import LiteLlmManager
|
||||
from storage.org import Org
|
||||
@@ -396,224 +393,6 @@ class OrgService:
|
||||
)
|
||||
return e
|
||||
|
||||
@staticmethod
|
||||
def has_admin_or_owner_role(user_id: str, org_id: UUID) -> bool:
|
||||
"""
|
||||
Check if user has admin or owner role in the specified organization.
|
||||
|
||||
Args:
|
||||
user_id: User ID to check
|
||||
org_id: Organization ID to check membership in
|
||||
|
||||
Returns:
|
||||
bool: True if user has admin or owner role, False otherwise
|
||||
"""
|
||||
try:
|
||||
# Parse user_id as UUID for database query
|
||||
user_uuid = parse_uuid(user_id)
|
||||
|
||||
# Get the user's membership in this organization
|
||||
# Note: The type annotation says int but the actual column is UUID
|
||||
org_member = OrgMemberStore.get_org_member(org_id, user_uuid)
|
||||
if not org_member:
|
||||
return False
|
||||
|
||||
# Get the role details
|
||||
role = RoleStore.get_role_by_id(org_member.role_id)
|
||||
if not role:
|
||||
return False
|
||||
|
||||
# Admin and owner roles have elevated permissions
|
||||
# Based on test files, both admin and owner have rank 1
|
||||
return role.name in ['admin', 'owner']
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
'Error checking user role in organization',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'org_id': str(org_id),
|
||||
'error': str(e),
|
||||
},
|
||||
)
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def is_org_member(user_id: str, org_id: UUID) -> bool:
|
||||
"""
|
||||
Check if user is a member of the specified organization.
|
||||
|
||||
Args:
|
||||
user_id: User ID to check
|
||||
org_id: Organization ID to check membership in
|
||||
|
||||
Returns:
|
||||
bool: True if user is a member, False otherwise
|
||||
"""
|
||||
try:
|
||||
user_uuid = parse_uuid(user_id)
|
||||
org_member = OrgMemberStore.get_org_member(org_id, user_uuid)
|
||||
return org_member is not None
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
'Error checking user membership in organization',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'org_id': str(org_id),
|
||||
'error': str(e),
|
||||
},
|
||||
)
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _get_llm_settings_fields() -> set[str]:
|
||||
"""
|
||||
Get the set of organization fields that are considered LLM settings
|
||||
and require admin/owner role to update.
|
||||
|
||||
Returns:
|
||||
set[str]: Set of field names that require elevated permissions
|
||||
"""
|
||||
return {
|
||||
'default_llm_model',
|
||||
'default_llm_api_key_for_byor',
|
||||
'default_llm_base_url',
|
||||
'search_api_key',
|
||||
'security_analyzer',
|
||||
'agent',
|
||||
'confirmation_mode',
|
||||
'enable_default_condenser',
|
||||
'condenser_max_size',
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _has_llm_settings_updates(update_data: OrgUpdate) -> set[str]:
|
||||
"""
|
||||
Check if the update contains any LLM settings fields.
|
||||
|
||||
Args:
|
||||
update_data: The organization update data
|
||||
|
||||
Returns:
|
||||
set[str]: Set of LLM fields being updated (empty if none)
|
||||
"""
|
||||
llm_fields = OrgService._get_llm_settings_fields()
|
||||
update_dict = update_data.model_dump(exclude_none=True)
|
||||
return llm_fields.intersection(update_dict.keys())
|
||||
|
||||
@staticmethod
|
||||
async def update_org_with_permissions(
|
||||
org_id: UUID,
|
||||
update_data: OrgUpdate,
|
||||
user_id: str,
|
||||
) -> Org:
|
||||
"""
|
||||
Update organization with permission checks for LLM settings.
|
||||
|
||||
Args:
|
||||
org_id: Organization UUID to update
|
||||
update_data: Organization update data from request
|
||||
user_id: ID of the user requesting the update
|
||||
|
||||
Returns:
|
||||
Org: The updated organization object
|
||||
|
||||
Raises:
|
||||
ValueError: If organization not found
|
||||
PermissionError: If user is not a member, or lacks admin/owner role for LLM settings
|
||||
OrgDatabaseError: If database update fails
|
||||
"""
|
||||
logger.info(
|
||||
'Updating organization with permission checks',
|
||||
extra={
|
||||
'org_id': str(org_id),
|
||||
'user_id': user_id,
|
||||
'has_update_data': update_data is not None,
|
||||
},
|
||||
)
|
||||
|
||||
# Validate organization exists
|
||||
existing_org = OrgStore.get_org_by_id(org_id)
|
||||
if not existing_org:
|
||||
raise ValueError(f'Organization with ID {org_id} not found')
|
||||
|
||||
# Check if user is a member of this organization
|
||||
if not OrgService.is_org_member(user_id, org_id):
|
||||
logger.warning(
|
||||
'Non-member attempted to update organization',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'org_id': str(org_id),
|
||||
},
|
||||
)
|
||||
raise PermissionError(
|
||||
'User must be a member of the organization to update it'
|
||||
)
|
||||
|
||||
# Check if update contains any LLM settings
|
||||
llm_fields_being_updated = OrgService._has_llm_settings_updates(update_data)
|
||||
if llm_fields_being_updated:
|
||||
# Verify user has admin or owner role
|
||||
has_permission = OrgService.has_admin_or_owner_role(user_id, org_id)
|
||||
if not has_permission:
|
||||
logger.warning(
|
||||
'User attempted to update LLM settings without permission',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'org_id': str(org_id),
|
||||
'attempted_fields': list(llm_fields_being_updated),
|
||||
},
|
||||
)
|
||||
raise PermissionError(
|
||||
'Admin or owner role required to update LLM settings'
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
'User has permission to update LLM settings',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'org_id': str(org_id),
|
||||
'llm_fields': list(llm_fields_being_updated),
|
||||
},
|
||||
)
|
||||
|
||||
# Convert to dict for OrgStore (excluding None values)
|
||||
update_dict = update_data.model_dump(exclude_none=True)
|
||||
if not update_dict:
|
||||
logger.info(
|
||||
'No fields to update',
|
||||
extra={'org_id': str(org_id), 'user_id': user_id},
|
||||
)
|
||||
return existing_org
|
||||
|
||||
# Perform the update
|
||||
try:
|
||||
updated_org = OrgStore.update_org(org_id, update_dict)
|
||||
if not updated_org:
|
||||
raise OrgDatabaseError('Failed to update organization in database')
|
||||
|
||||
logger.info(
|
||||
'Organization updated successfully',
|
||||
extra={
|
||||
'org_id': str(org_id),
|
||||
'user_id': user_id,
|
||||
'updated_fields': list(update_dict.keys()),
|
||||
},
|
||||
)
|
||||
|
||||
return updated_org
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
'Failed to update organization',
|
||||
extra={
|
||||
'org_id': str(org_id),
|
||||
'user_id': user_id,
|
||||
'error': str(e),
|
||||
},
|
||||
)
|
||||
raise OrgDatabaseError(f'Failed to update organization: {str(e)}')
|
||||
|
||||
@staticmethod
|
||||
async def get_org_credits(user_id: str, org_id: UUID) -> float | None:
|
||||
"""
|
||||
@@ -662,183 +441,3 @@ class OrgService:
|
||||
extra={'user_id': user_id, 'org_id': str(org_id), 'error': str(e)},
|
||||
)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def get_user_orgs_paginated(
|
||||
user_id: str, page_id: str | None = None, limit: int = 100
|
||||
):
|
||||
"""
|
||||
Get paginated list of organizations for a user.
|
||||
|
||||
Args:
|
||||
user_id: User ID (string that will be converted to UUID)
|
||||
page_id: Optional page ID (offset as string) for pagination
|
||||
limit: Maximum number of organizations to return
|
||||
|
||||
Returns:
|
||||
Tuple of (list of Org objects, next_page_id or None)
|
||||
"""
|
||||
logger.debug(
|
||||
'Fetching paginated organizations for user',
|
||||
extra={'user_id': user_id, 'page_id': page_id, 'limit': limit},
|
||||
)
|
||||
|
||||
# Convert user_id string to UUID
|
||||
user_uuid = parse_uuid(user_id)
|
||||
|
||||
# Fetch organizations from store
|
||||
orgs, next_page_id = OrgStore.get_user_orgs_paginated(
|
||||
user_id=user_uuid, page_id=page_id, limit=limit
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
'Retrieved organizations for user',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'org_count': len(orgs),
|
||||
'has_more': next_page_id is not None,
|
||||
},
|
||||
)
|
||||
|
||||
return orgs, next_page_id
|
||||
|
||||
@staticmethod
|
||||
async def get_org_by_id(org_id: UUID, user_id: str) -> Org:
|
||||
"""
|
||||
Get organization by ID with membership validation.
|
||||
|
||||
This method verifies that the user is a member of the organization
|
||||
before returning the organization details.
|
||||
|
||||
Args:
|
||||
org_id: Organization ID
|
||||
user_id: User ID (string that will be converted to UUID)
|
||||
|
||||
Returns:
|
||||
Org: The organization object
|
||||
|
||||
Raises:
|
||||
OrgNotFoundError: If organization not found or user is not a member
|
||||
"""
|
||||
logger.info(
|
||||
'Retrieving organization',
|
||||
extra={'user_id': user_id, 'org_id': str(org_id)},
|
||||
)
|
||||
|
||||
# Verify user is a member of the organization
|
||||
org_member = OrgMemberStore.get_org_member(org_id, parse_uuid(user_id))
|
||||
if not org_member:
|
||||
logger.warning(
|
||||
'User is not a member of organization or organization does not exist',
|
||||
extra={'user_id': user_id, 'org_id': str(org_id)},
|
||||
)
|
||||
raise OrgNotFoundError(str(org_id))
|
||||
|
||||
# Retrieve organization
|
||||
org = OrgStore.get_org_by_id(org_id)
|
||||
if not org:
|
||||
logger.error(
|
||||
'Organization not found despite valid membership',
|
||||
extra={'user_id': user_id, 'org_id': str(org_id)},
|
||||
)
|
||||
raise OrgNotFoundError(str(org_id))
|
||||
|
||||
logger.info(
|
||||
'Successfully retrieved organization',
|
||||
extra={
|
||||
'org_id': str(org.id),
|
||||
'org_name': org.name,
|
||||
'user_id': user_id,
|
||||
},
|
||||
)
|
||||
|
||||
return org
|
||||
|
||||
@staticmethod
|
||||
def verify_owner_authorization(user_id: str, org_id: UUID) -> None:
|
||||
"""
|
||||
Verify that the user is the owner of the organization.
|
||||
|
||||
Args:
|
||||
user_id: User ID to check
|
||||
org_id: Organization ID
|
||||
|
||||
Raises:
|
||||
OrgNotFoundError: If organization doesn't exist
|
||||
OrgAuthorizationError: If user is not authorized to delete
|
||||
"""
|
||||
# Check if organization exists
|
||||
org = OrgStore.get_org_by_id(org_id)
|
||||
if not org:
|
||||
raise OrgNotFoundError(str(org_id))
|
||||
|
||||
# Check if user is a member of the organization
|
||||
org_member = OrgMemberStore.get_org_member(org_id, parse_uuid(user_id))
|
||||
if not org_member:
|
||||
raise OrgAuthorizationError('User is not a member of this organization')
|
||||
|
||||
# Check if user has owner role
|
||||
role = RoleStore.get_role_by_id(org_member.role_id)
|
||||
if not role or role.name != 'owner':
|
||||
raise OrgAuthorizationError(
|
||||
'Only organization owners can delete organizations'
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
'User authorization verified for organization deletion',
|
||||
extra={'user_id': user_id, 'org_id': str(org_id), 'role': role.name},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def delete_org_with_cleanup(user_id: str, org_id: UUID) -> Org:
|
||||
"""
|
||||
Delete organization with complete cleanup of all associated data.
|
||||
|
||||
This method performs the complete organization deletion workflow:
|
||||
1. Verifies user authorization (owner only)
|
||||
2. Performs database cascade deletion and LiteLLM cleanup in single transaction
|
||||
|
||||
Args:
|
||||
user_id: User ID requesting deletion (must be owner)
|
||||
org_id: Organization ID to delete
|
||||
|
||||
Returns:
|
||||
Org: The deleted organization details
|
||||
|
||||
Raises:
|
||||
OrgNotFoundError: If organization doesn't exist
|
||||
OrgAuthorizationError: If user is not authorized to delete
|
||||
OrgDatabaseError: If database operations or LiteLLM cleanup fail
|
||||
"""
|
||||
logger.info(
|
||||
'Starting organization deletion',
|
||||
extra={'user_id': user_id, 'org_id': str(org_id)},
|
||||
)
|
||||
|
||||
# Step 1: Verify user authorization
|
||||
OrgService.verify_owner_authorization(user_id, org_id)
|
||||
|
||||
# Step 2: Perform database cascade deletion with LiteLLM cleanup in transaction
|
||||
try:
|
||||
deleted_org = await OrgStore.delete_org_cascade(org_id)
|
||||
if not deleted_org:
|
||||
# This shouldn't happen since we verified existence above
|
||||
raise OrgDatabaseError('Organization not found during deletion')
|
||||
|
||||
logger.info(
|
||||
'Organization deletion completed successfully',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'org_id': str(org_id),
|
||||
'org_name': deleted_org.name,
|
||||
},
|
||||
)
|
||||
|
||||
return deleted_org
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
'Organization deletion failed',
|
||||
extra={'user_id': user_id, 'org_id': str(org_id), 'error': str(e)},
|
||||
)
|
||||
raise OrgDatabaseError(f'Failed to delete organization: {str(e)}')
|
||||
|
||||
@@ -10,10 +10,8 @@ from server.constants import (
|
||||
ORG_SETTINGS_VERSION,
|
||||
get_default_litellm_model,
|
||||
)
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.orm import joinedload
|
||||
from storage.database import session_maker
|
||||
from storage.lite_llm_manager import LiteLlmManager
|
||||
from storage.org import Org
|
||||
from storage.org_member import OrgMember
|
||||
from storage.user import User
|
||||
@@ -98,63 +96,6 @@ class OrgStore:
|
||||
orgs = session.query(Org).all()
|
||||
return orgs
|
||||
|
||||
@staticmethod
|
||||
def get_user_orgs_paginated(
|
||||
user_id: UUID, page_id: str | None = None, limit: int = 100
|
||||
) -> tuple[list[Org], str | None]:
|
||||
"""
|
||||
Get paginated list of organizations for a user.
|
||||
|
||||
Args:
|
||||
user_id: User UUID
|
||||
page_id: Optional page ID (offset as string) for pagination
|
||||
limit: Maximum number of organizations to return
|
||||
|
||||
Returns:
|
||||
Tuple of (list of Org objects, next_page_id or None)
|
||||
"""
|
||||
with session_maker() as session:
|
||||
# Build query joining OrgMember with Org
|
||||
query = (
|
||||
session.query(Org)
|
||||
.join(OrgMember, Org.id == OrgMember.org_id)
|
||||
.filter(OrgMember.user_id == user_id)
|
||||
.order_by(Org.name)
|
||||
)
|
||||
|
||||
# Apply pagination offset
|
||||
if page_id is not None:
|
||||
try:
|
||||
offset = int(page_id)
|
||||
query = query.offset(offset)
|
||||
except ValueError:
|
||||
# If page_id is not a valid integer, start from beginning
|
||||
offset = 0
|
||||
else:
|
||||
offset = 0
|
||||
|
||||
# Fetch limit + 1 to check if there are more results
|
||||
query = query.limit(limit + 1)
|
||||
orgs = query.all()
|
||||
|
||||
# Check if there are more results
|
||||
has_more = len(orgs) > limit
|
||||
if has_more:
|
||||
orgs = orgs[:limit]
|
||||
|
||||
# Calculate next page ID
|
||||
next_page_id = None
|
||||
if has_more:
|
||||
next_page_id = str(offset + limit)
|
||||
|
||||
# Validate org versions
|
||||
validated_orgs = [
|
||||
OrgStore._validate_org_version(org) for org in orgs if org
|
||||
]
|
||||
validated_orgs = [org for org in validated_orgs if org is not None]
|
||||
|
||||
return validated_orgs, next_page_id
|
||||
|
||||
@staticmethod
|
||||
def update_org(
|
||||
org_id: UUID,
|
||||
@@ -245,119 +186,3 @@ class OrgStore:
|
||||
session.commit()
|
||||
session.refresh(org)
|
||||
return org
|
||||
|
||||
@staticmethod
|
||||
async def delete_org_cascade(org_id: UUID) -> Org | None:
|
||||
"""
|
||||
Delete organization and all associated data in cascade, including external LiteLLM cleanup.
|
||||
|
||||
Args:
|
||||
org_id: UUID of the organization to delete
|
||||
|
||||
Returns:
|
||||
Org: The deleted organization object, or None if not found
|
||||
|
||||
Raises:
|
||||
Exception: If database operations or LiteLLM cleanup fail
|
||||
"""
|
||||
with session_maker() as session:
|
||||
# First get the organization to return it
|
||||
org = session.query(Org).filter(Org.id == org_id).first()
|
||||
if not org:
|
||||
return None
|
||||
|
||||
try:
|
||||
# 1. Delete conversation data for organization conversations
|
||||
session.execute(
|
||||
text("""
|
||||
DELETE FROM conversation_metadata
|
||||
WHERE conversation_id IN (
|
||||
SELECT conversation_id FROM conversation_metadata_saas WHERE org_id = :org_id
|
||||
)
|
||||
"""),
|
||||
{'org_id': str(org_id)},
|
||||
)
|
||||
|
||||
session.execute(
|
||||
text("""
|
||||
DELETE FROM app_conversation_start_task
|
||||
WHERE app_conversation_id::text IN (
|
||||
SELECT conversation_id FROM conversation_metadata_saas WHERE org_id = :org_id
|
||||
)
|
||||
"""),
|
||||
{'org_id': str(org_id)},
|
||||
)
|
||||
|
||||
# 2. Delete organization-owned data tables (direct org_id foreign keys)
|
||||
session.execute(
|
||||
text('DELETE FROM billing_sessions WHERE org_id = :org_id'),
|
||||
{'org_id': str(org_id)},
|
||||
)
|
||||
session.execute(
|
||||
text(
|
||||
'DELETE FROM conversation_metadata_saas WHERE org_id = :org_id'
|
||||
),
|
||||
{'org_id': str(org_id)},
|
||||
)
|
||||
session.execute(
|
||||
text('DELETE FROM custom_secrets WHERE org_id = :org_id'),
|
||||
{'org_id': str(org_id)},
|
||||
)
|
||||
session.execute(
|
||||
text('DELETE FROM api_keys WHERE org_id = :org_id'),
|
||||
{'org_id': str(org_id)},
|
||||
)
|
||||
session.execute(
|
||||
text('DELETE FROM slack_conversation WHERE org_id = :org_id'),
|
||||
{'org_id': str(org_id)},
|
||||
)
|
||||
session.execute(
|
||||
text('DELETE FROM slack_users WHERE org_id = :org_id'),
|
||||
{'org_id': str(org_id)},
|
||||
)
|
||||
session.execute(
|
||||
text('DELETE FROM stripe_customers WHERE org_id = :org_id'),
|
||||
{'org_id': str(org_id)},
|
||||
)
|
||||
|
||||
# 3. Delete organization memberships
|
||||
session.execute(
|
||||
text('DELETE FROM org_member WHERE org_id = :org_id'),
|
||||
{'org_id': str(org_id)},
|
||||
)
|
||||
|
||||
# 4. Handle users with this as current_org_id
|
||||
session.execute(
|
||||
text(
|
||||
'UPDATE "user" SET current_org_id = NULL WHERE current_org_id = :org_id'
|
||||
),
|
||||
{'org_id': str(org_id)},
|
||||
)
|
||||
|
||||
# 5. Finally delete the organization
|
||||
session.delete(org)
|
||||
|
||||
# 6. Clean up LiteLLM team before committing transaction
|
||||
logger.info(
|
||||
'Deleting LiteLLM team within database transaction',
|
||||
extra={'org_id': str(org_id)},
|
||||
)
|
||||
await LiteLlmManager.delete_team(str(org_id))
|
||||
|
||||
# 7. Commit all changes only if everything succeeded
|
||||
session.commit()
|
||||
|
||||
logger.info(
|
||||
'Successfully deleted organization and all associated data including LiteLLM team',
|
||||
extra={'org_id': str(org_id), 'org_name': org.name},
|
||||
)
|
||||
|
||||
return org
|
||||
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
logger.error(
|
||||
'Failed to delete organization - transaction rolled back',
|
||||
extra={'org_id': str(org_id), 'error': str(e)},
|
||||
)
|
||||
raise
|
||||
|
||||
@@ -4,9 +4,7 @@ Store class for managing roles.
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from storage.database import a_session_maker, session_maker
|
||||
from storage.database import session_maker
|
||||
from storage.role import Role
|
||||
|
||||
|
||||
@@ -35,20 +33,6 @@ class RoleStore:
|
||||
with session_maker() as session:
|
||||
return session.query(Role).filter(Role.name == name).first()
|
||||
|
||||
@staticmethod
|
||||
async def get_role_by_name_async(
|
||||
name: str,
|
||||
session: Optional[AsyncSession] = None,
|
||||
) -> Optional[Role]:
|
||||
"""Get role by name."""
|
||||
if session is not None:
|
||||
result = await session.execute(select(Role).where(Role.name == name))
|
||||
return result.scalars().first()
|
||||
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(select(Role).where(Role.name == name))
|
||||
return result.scalars().first()
|
||||
|
||||
@staticmethod
|
||||
def list_roles() -> List[Role]:
|
||||
"""List all roles."""
|
||||
|
||||
@@ -31,8 +31,6 @@ class User(Base): # type: ignore
|
||||
user_consents_to_analytics = Column(Boolean, nullable=True)
|
||||
email = Column(String, nullable=True)
|
||||
email_verified = Column(Boolean, nullable=True)
|
||||
git_user_name = Column(String, nullable=True)
|
||||
git_user_email = Column(String, nullable=True)
|
||||
|
||||
# Relationships
|
||||
role = relationship('Role', back_populates='users')
|
||||
|
||||
@@ -14,13 +14,10 @@ from server.constants import (
|
||||
get_default_litellm_model,
|
||||
)
|
||||
from server.logger import logger
|
||||
from sqlalchemy import select, text
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.orm import joinedload
|
||||
from storage.database import a_session_maker, session_maker
|
||||
from storage.encrypt_utils import (
|
||||
decrypt_legacy_model,
|
||||
encrypt_legacy_value,
|
||||
)
|
||||
from storage.database import session_maker
|
||||
from storage.encrypt_utils import decrypt_legacy_model
|
||||
from storage.org import Org
|
||||
from storage.org_member import OrgMember
|
||||
from storage.role_store import RoleStore
|
||||
@@ -119,7 +116,7 @@ class UserStore:
|
||||
redis_client = UserStore._get_redis_client()
|
||||
if redis_client is None:
|
||||
logger.warning(
|
||||
'user_store:_acquire_user_creation_lock:no_redis_client',
|
||||
'saas_settings_store:_acquire_user_creation_lock:no_redis_client',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
return True # Proceed without locking if Redis is unavailable
|
||||
@@ -162,20 +159,12 @@ class UserStore:
|
||||
|
||||
from storage.lite_llm_manager import LiteLlmManager
|
||||
|
||||
logger.debug(
|
||||
'user_store:migrate_user:calling_litellm_migrate_entries',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
await LiteLlmManager.migrate_entries(
|
||||
str(org.id),
|
||||
user_id,
|
||||
decrypted_user_settings,
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
'user_store:migrate_user:done_litellm_migrate_entries',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
custom_settings = UserStore._has_custom_settings(
|
||||
decrypted_user_settings, user_settings.user_version
|
||||
)
|
||||
@@ -183,15 +172,7 @@ class UserStore:
|
||||
# avoids circular reference. This migrate method is temprorary until all users are migrated.
|
||||
from integrations.stripe_service import migrate_customer
|
||||
|
||||
logger.debug(
|
||||
'user_store:migrate_user:calling_stripe_migrate_customer',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
await migrate_customer(session, user_id, org)
|
||||
logger.debug(
|
||||
'user_store:migrate_user:done_stripe_migrate_customer',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
|
||||
from storage.org_store import OrgStore
|
||||
|
||||
@@ -220,15 +201,7 @@ class UserStore:
|
||||
)
|
||||
session.add(user)
|
||||
|
||||
logger.debug(
|
||||
'user_store:migrate_user:calling_get_role_by_name',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
role = await RoleStore.get_role_by_name_async('owner')
|
||||
logger.debug(
|
||||
'user_store:migrate_user:done_get_role_by_name',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
role = RoleStore.get_role_by_name('owner')
|
||||
|
||||
from storage.org_member_store import OrgMemberStore
|
||||
|
||||
@@ -241,6 +214,7 @@ class UserStore:
|
||||
if not custom_settings:
|
||||
del org_member_kwargs['llm_model']
|
||||
del org_member_kwargs['llm_base_url']
|
||||
del org_member_kwargs['llm_api_key_for_byor']
|
||||
|
||||
org_member = OrgMember(
|
||||
org_id=org.id,
|
||||
@@ -255,10 +229,6 @@ class UserStore:
|
||||
user_settings.already_migrated = True
|
||||
session.merge(user_settings)
|
||||
session.flush()
|
||||
logger.debug(
|
||||
'user_store:migrate_user:session_flush_complete',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
|
||||
# need to migrate conversation metadata
|
||||
session.execute(
|
||||
@@ -326,262 +296,8 @@ class UserStore:
|
||||
session.commit()
|
||||
session.refresh(user)
|
||||
user.org_members # load org_members
|
||||
logger.debug(
|
||||
'user_store:migrate_user:session_committed',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
return user
|
||||
|
||||
@staticmethod
|
||||
async def downgrade_user(user_id: str) -> UserSettings | None:
|
||||
"""Downgrade a migrated user back to the pre-migration state.
|
||||
|
||||
This reverses the migrate_user operation:
|
||||
1. Get the user's settings from user_settings table (migrated users) or
|
||||
create new user_settings from org_members table (new sign-ups)
|
||||
2. Call LiteLlmManager.downgrade_entries to revert LiteLLM state
|
||||
3. Copy user_id from conversation_metadata_saas to conversation_metadata
|
||||
4. Delete conversation_metadata_saas entries
|
||||
5. Reset org_id columns in related tables (stripe_customers, slack_users, etc.)
|
||||
6. Delete the org_member and org entries
|
||||
7. Delete the user entry
|
||||
8. Set already_migrated=False on user_settings
|
||||
|
||||
For new sign-ups (users who registered after migration was deployed),
|
||||
there won't be an existing user_settings entry. In this case, we fall back
|
||||
to the org_members table to get the user's API keys and settings, and create
|
||||
a new user_settings entry for them.
|
||||
|
||||
Args:
|
||||
user_id: The Keycloak user ID to downgrade
|
||||
|
||||
Returns:
|
||||
The user_settings if downgrade was successful, None otherwise.
|
||||
Returns None if the org has multiple members (not a personal org).
|
||||
"""
|
||||
logger.info(
|
||||
'user_store:downgrade_user:start',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
|
||||
with session_maker() as session:
|
||||
# Get the user and their org_member
|
||||
user = (
|
||||
session.query(User)
|
||||
.options(joinedload(User.org_members))
|
||||
.filter(User.id == uuid.UUID(user_id))
|
||||
.first()
|
||||
)
|
||||
if not user:
|
||||
logger.warning(
|
||||
'user_store:downgrade_user:user_not_found',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
return None
|
||||
|
||||
# Get the user's personal org (org_id == user_id)
|
||||
org = session.query(Org).filter(Org.id == uuid.UUID(user_id)).first()
|
||||
if not org:
|
||||
logger.warning(
|
||||
'user_store:downgrade_user:org_not_found',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
return None
|
||||
|
||||
# Get the user_settings (for migrated users)
|
||||
user_settings = (
|
||||
session.query(UserSettings)
|
||||
.filter(
|
||||
UserSettings.keycloak_user_id == user_id,
|
||||
UserSettings.already_migrated.is_(True),
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
# For new sign-ups after migration, user_settings won't exist
|
||||
# Fall back to getting data from org_members
|
||||
is_new_signup = False
|
||||
if not user_settings:
|
||||
logger.info(
|
||||
'user_store:downgrade_user:user_settings_not_found_checking_org_members',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
# Get org_members for this org - should only be one for personal orgs
|
||||
org_members = (
|
||||
session.query(OrgMember).filter(OrgMember.org_id == org.id).all()
|
||||
)
|
||||
|
||||
if len(org_members) != 1:
|
||||
logger.error(
|
||||
'user_store:downgrade_user:unexpected_org_members_count',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'org_id': str(org.id),
|
||||
'org_members_count': len(org_members),
|
||||
},
|
||||
)
|
||||
return None
|
||||
|
||||
org_member = org_members[0]
|
||||
is_new_signup = True
|
||||
|
||||
# Create a new user_settings entry from OrgMember, User, and Org data
|
||||
# This is needed for new sign-ups who don't have user_settings
|
||||
user_settings = UserStore._create_user_settings_from_entities(
|
||||
user_id, org_member, user, org
|
||||
)
|
||||
session.add(user_settings)
|
||||
session.flush()
|
||||
logger.info(
|
||||
'user_store:downgrade_user:created_user_settings_from_org_member',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
|
||||
# Call LiteLLM downgrade
|
||||
from storage.lite_llm_manager import LiteLlmManager
|
||||
|
||||
logger.debug(
|
||||
'user_store:downgrade_user:calling_litellm_downgrade_entries',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
|
||||
# Get the API keys for LiteLLM downgrade
|
||||
if is_new_signup:
|
||||
# For new signups, we already have decrypted values in user_settings
|
||||
decrypted_user_settings = user_settings
|
||||
else:
|
||||
# For migrated users, decrypt the legacy model
|
||||
kwargs = decrypt_legacy_model(
|
||||
[
|
||||
'llm_api_key',
|
||||
'llm_api_key_for_byor',
|
||||
'search_api_key',
|
||||
'sandbox_api_key',
|
||||
],
|
||||
user_settings,
|
||||
)
|
||||
decrypted_user_settings = UserSettings(**kwargs)
|
||||
|
||||
await LiteLlmManager.downgrade_entries(
|
||||
str(org.id),
|
||||
user_id,
|
||||
decrypted_user_settings,
|
||||
)
|
||||
logger.debug(
|
||||
'user_store:downgrade_user:done_litellm_downgrade_entries',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
|
||||
user_uuid = uuid.UUID(user_id)
|
||||
|
||||
# Step 3: Copy user_id from conversation_metadata_saas to conversation_metadata
|
||||
# This ensures any conversations created after migration have their user_id
|
||||
# preserved in the original table before we delete the saas entries
|
||||
session.execute(
|
||||
text("""
|
||||
UPDATE conversation_metadata
|
||||
SET user_id = :user_id
|
||||
WHERE conversation_id IN (
|
||||
SELECT conversation_id
|
||||
FROM conversation_metadata_saas
|
||||
WHERE user_id = :user_uuid
|
||||
)
|
||||
"""),
|
||||
{'user_id': user_id, 'user_uuid': user_uuid},
|
||||
)
|
||||
|
||||
# Step 4: Delete conversation_metadata_saas entries
|
||||
session.execute(
|
||||
text('DELETE FROM conversation_metadata_saas WHERE user_id = :user_id'),
|
||||
{'user_id': user_uuid},
|
||||
)
|
||||
|
||||
# Step 5: Reset org_id columns in related tables
|
||||
# Reset stripe_customers
|
||||
session.execute(
|
||||
text(
|
||||
'UPDATE stripe_customers SET org_id = NULL WHERE org_id = :org_id'
|
||||
),
|
||||
{'org_id': user_uuid},
|
||||
)
|
||||
|
||||
# Reset slack_users
|
||||
session.execute(
|
||||
text('UPDATE slack_users SET org_id = NULL WHERE org_id = :org_id'),
|
||||
{'org_id': user_uuid},
|
||||
)
|
||||
|
||||
# Reset slack_conversation
|
||||
session.execute(
|
||||
text(
|
||||
'UPDATE slack_conversation SET org_id = NULL WHERE org_id = :org_id'
|
||||
),
|
||||
{'org_id': user_uuid},
|
||||
)
|
||||
|
||||
# Reset api_keys
|
||||
session.execute(
|
||||
text('UPDATE api_keys SET org_id = NULL WHERE org_id = :org_id'),
|
||||
{'org_id': user_uuid},
|
||||
)
|
||||
|
||||
# Reset custom_secrets
|
||||
session.execute(
|
||||
text('UPDATE custom_secrets SET org_id = NULL WHERE org_id = :org_id'),
|
||||
{'org_id': user_uuid},
|
||||
)
|
||||
|
||||
# Reset billing_sessions
|
||||
session.execute(
|
||||
text(
|
||||
'UPDATE billing_sessions SET org_id = NULL WHERE org_id = :org_id'
|
||||
),
|
||||
{'org_id': user_uuid},
|
||||
)
|
||||
|
||||
# Step 6: Delete org_member entries for this org
|
||||
session.execute(
|
||||
text('DELETE FROM org_member WHERE org_id = :org_id'),
|
||||
{'org_id': user_uuid},
|
||||
)
|
||||
|
||||
# Step 7: Delete the user entry
|
||||
session.execute(
|
||||
text('DELETE FROM "user" WHERE id = :user_id'),
|
||||
{'user_id': user_uuid},
|
||||
)
|
||||
|
||||
# Delete the org entry
|
||||
session.execute(
|
||||
text('DELETE FROM org WHERE id = :org_id'),
|
||||
{'org_id': user_uuid},
|
||||
)
|
||||
|
||||
# Step 8: Set already_migrated=False on user_settings and encrypt fields
|
||||
user_settings.already_migrated = False
|
||||
|
||||
# Re-encrypt the sensitive fields before storing in the DB
|
||||
encrypt_keys = [
|
||||
'llm_api_key',
|
||||
'llm_api_key_for_byor',
|
||||
'search_api_key',
|
||||
'sandbox_api_key',
|
||||
]
|
||||
for key in encrypt_keys:
|
||||
value = getattr(user_settings, key, None)
|
||||
if value is not None:
|
||||
setattr(user_settings, key, encrypt_legacy_value(value))
|
||||
|
||||
session.merge(user_settings)
|
||||
|
||||
session.commit()
|
||||
|
||||
logger.info(
|
||||
'user_store:downgrade_user:complete',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
return user_settings
|
||||
|
||||
@staticmethod
|
||||
def get_user_by_id(user_id: str) -> Optional[User]:
|
||||
"""Get user by Keycloak user ID (sync version).
|
||||
@@ -606,7 +322,7 @@ class UserStore:
|
||||
):
|
||||
# The user is already being created in another thread / process
|
||||
logger.info(
|
||||
'user_store:create_default_settings:waiting_for_lock',
|
||||
'saas_settings_store:create_default_settings:waiting_for_lock',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
call_async_from_sync(
|
||||
@@ -656,13 +372,13 @@ class UserStore:
|
||||
This is the preferred method when calling from an async context as it
|
||||
avoids event loop conflicts that can occur with the sync version.
|
||||
"""
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(User)
|
||||
with session_maker() as session:
|
||||
user = (
|
||||
session.query(User)
|
||||
.options(joinedload(User.org_members))
|
||||
.filter(User.id == uuid.UUID(user_id))
|
||||
.first()
|
||||
)
|
||||
user = result.scalars().first()
|
||||
if user:
|
||||
return user
|
||||
|
||||
@@ -670,39 +386,32 @@ class UserStore:
|
||||
while not await UserStore._acquire_user_creation_lock(user_id):
|
||||
# The user is already being created in another thread / process
|
||||
logger.info(
|
||||
'user_store:get_user_by_id_async:waiting_for_lock',
|
||||
'saas_settings_store:create_default_settings:waiting_for_lock',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
await asyncio.sleep(_RETRY_LOAD_DELAY_SECONDS)
|
||||
|
||||
# Check for user again as migration could have happened while trying to get the lock.
|
||||
result = await session.execute(
|
||||
select(User)
|
||||
user = (
|
||||
session.query(User)
|
||||
.options(joinedload(User.org_members))
|
||||
.filter(User.id == uuid.UUID(user_id))
|
||||
.first()
|
||||
)
|
||||
user = result.scalars().first()
|
||||
if user:
|
||||
return user
|
||||
|
||||
logger.info(
|
||||
'user_store:get_user_by_id_async:start_migration',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
result = await session.execute(
|
||||
select(UserSettings).filter(
|
||||
user_settings = (
|
||||
session.query(UserSettings)
|
||||
.filter(
|
||||
UserSettings.keycloak_user_id == user_id,
|
||||
UserSettings.already_migrated.is_(False),
|
||||
)
|
||||
.first()
|
||||
)
|
||||
user_settings = result.scalars().first()
|
||||
if user_settings:
|
||||
token_manager = TokenManager()
|
||||
user_info = await token_manager.get_user_info_from_user_id(user_id)
|
||||
logger.info(
|
||||
'user_store:get_user_by_id_async:calling_migrate_user',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
user = await UserStore.migrate_user(
|
||||
user_id,
|
||||
user_settings,
|
||||
@@ -772,96 +481,6 @@ class UserStore:
|
||||
}
|
||||
return kwargs
|
||||
|
||||
@staticmethod
|
||||
def _create_user_settings_from_entities(
|
||||
user_id: str, org_member: OrgMember, user: User, org: Org
|
||||
) -> UserSettings:
|
||||
"""Create UserSettings from OrgMember, User, and Org data.
|
||||
|
||||
Uses OrgMember values first. If an OrgMember field is None and there's
|
||||
a corresponding "default_" field in Org, use the Org value.
|
||||
Also pulls relevant fields from User.
|
||||
|
||||
Args:
|
||||
user_id: The Keycloak user ID
|
||||
org_member: The OrgMember entity
|
||||
user: The User entity
|
||||
org: The Org entity
|
||||
|
||||
Returns:
|
||||
A new UserSettings object populated from the entities
|
||||
"""
|
||||
# Mapping from OrgMember fields to corresponding Org "default_" fields
|
||||
org_member_to_org_default = {
|
||||
'llm_model': 'default_llm_model',
|
||||
'llm_base_url': 'default_llm_base_url',
|
||||
'max_iterations': 'default_max_iterations',
|
||||
}
|
||||
|
||||
def get_value_with_org_fallback(field_name: str, org_member_value):
|
||||
"""Get value from OrgMember, falling back to Org default if None."""
|
||||
if org_member_value is not None:
|
||||
return org_member_value
|
||||
org_default_field = org_member_to_org_default.get(field_name)
|
||||
if org_default_field and hasattr(org, org_default_field):
|
||||
return getattr(org, org_default_field)
|
||||
return None
|
||||
|
||||
# Get values from OrgMember with Org fallback for fields with default_ prefix
|
||||
llm_model = get_value_with_org_fallback('llm_model', org_member.llm_model)
|
||||
llm_base_url = get_value_with_org_fallback(
|
||||
'llm_base_url', org_member.llm_base_url
|
||||
)
|
||||
max_iterations = get_value_with_org_fallback(
|
||||
'max_iterations', org_member.max_iterations
|
||||
)
|
||||
|
||||
return UserSettings(
|
||||
keycloak_user_id=user_id,
|
||||
# OrgMember fields
|
||||
llm_api_key=org_member.llm_api_key.get_secret_value()
|
||||
if org_member.llm_api_key
|
||||
else None,
|
||||
llm_api_key_for_byor=org_member.llm_api_key_for_byor.get_secret_value()
|
||||
if org_member.llm_api_key_for_byor
|
||||
else None,
|
||||
llm_model=llm_model,
|
||||
llm_base_url=llm_base_url,
|
||||
max_iterations=max_iterations,
|
||||
# User fields
|
||||
accepted_tos=user.accepted_tos,
|
||||
enable_sound_notifications=user.enable_sound_notifications,
|
||||
language=user.language,
|
||||
user_consents_to_analytics=user.user_consents_to_analytics,
|
||||
email=user.email,
|
||||
email_verified=user.email_verified,
|
||||
git_user_name=user.git_user_name,
|
||||
git_user_email=user.git_user_email,
|
||||
# Org fields
|
||||
agent=org.agent,
|
||||
security_analyzer=org.security_analyzer,
|
||||
confirmation_mode=org.confirmation_mode,
|
||||
remote_runtime_resource_factor=org.remote_runtime_resource_factor,
|
||||
enable_default_condenser=org.enable_default_condenser,
|
||||
billing_margin=org.billing_margin,
|
||||
enable_proactive_conversation_starters=org.enable_proactive_conversation_starters,
|
||||
sandbox_base_container_image=org.sandbox_base_container_image,
|
||||
sandbox_runtime_container_image=org.sandbox_runtime_container_image,
|
||||
user_version=org.org_version,
|
||||
mcp_config=org.mcp_config,
|
||||
search_api_key=org.search_api_key.get_secret_value()
|
||||
if org.search_api_key
|
||||
else None,
|
||||
sandbox_api_key=org.sandbox_api_key.get_secret_value()
|
||||
if org.sandbox_api_key
|
||||
else None,
|
||||
max_budget_per_task=org.max_budget_per_task,
|
||||
enable_solvability_analysis=org.enable_solvability_analysis,
|
||||
v1_enabled=org.v1_enabled,
|
||||
condenser_max_size=org.condenser_max_size,
|
||||
already_migrated=False,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _has_custom_settings(
|
||||
user_settings: UserSettings, old_user_version: int | None
|
||||
|
||||
@@ -6,13 +6,10 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from integrations.jira.jira_manager import JiraManager
|
||||
from integrations.jira.jira_payload import (
|
||||
JiraEventType,
|
||||
JiraWebhookPayload,
|
||||
)
|
||||
from integrations.jira.jira_view import (
|
||||
JiraNewConversationView,
|
||||
)
|
||||
from integrations.models import JobContext
|
||||
from jinja2 import DictLoader, Environment
|
||||
from storage.jira_conversation import JiraConversation
|
||||
from storage.jira_user import JiraUser
|
||||
@@ -27,7 +24,7 @@ def mock_token_manager():
|
||||
"""Create a mock TokenManager for testing."""
|
||||
token_manager = MagicMock()
|
||||
token_manager.get_user_id_from_user_email = AsyncMock()
|
||||
token_manager.decrypt_text = MagicMock(return_value='decrypted_key')
|
||||
token_manager.decrypt_text = MagicMock()
|
||||
return token_manager
|
||||
|
||||
|
||||
@@ -62,7 +59,6 @@ def sample_jira_workspace():
|
||||
workspace = MagicMock(spec=JiraWorkspace)
|
||||
workspace.id = 1
|
||||
workspace.name = 'test.atlassian.net'
|
||||
workspace.jira_cloud_id = 'cloud-123'
|
||||
workspace.admin_user_id = 'admin_id'
|
||||
workspace.webhook_secret = 'encrypted_secret'
|
||||
workspace.svc_acc_email = 'service@example.com'
|
||||
@@ -78,41 +74,22 @@ def sample_user_auth():
|
||||
user_auth.get_provider_tokens = AsyncMock(return_value={})
|
||||
user_auth.get_access_token = AsyncMock(return_value='test_token')
|
||||
user_auth.get_user_id = AsyncMock(return_value='test_user_id')
|
||||
user_auth.get_secrets = AsyncMock(return_value=None)
|
||||
return user_auth
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_webhook_payload():
|
||||
"""Create a sample JiraWebhookPayload for testing."""
|
||||
return JiraWebhookPayload(
|
||||
event_type=JiraEventType.COMMENT_MENTION,
|
||||
raw_event='comment_created',
|
||||
def sample_job_context():
|
||||
"""Create a sample JobContext for testing."""
|
||||
return JobContext(
|
||||
issue_id='12345',
|
||||
issue_key='TEST-123',
|
||||
user_msg='Fix this bug @openhands',
|
||||
user_email='user@test.com',
|
||||
display_name='Test User',
|
||||
account_id='user123',
|
||||
workspace_name='test.atlassian.net',
|
||||
base_api_url='https://test.atlassian.net',
|
||||
comment_body='Fix this bug @openhands',
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_label_webhook_payload():
|
||||
"""Create a sample labeled ticket JiraWebhookPayload for testing."""
|
||||
return JiraWebhookPayload(
|
||||
event_type=JiraEventType.LABELED_TICKET,
|
||||
raw_event='jira:issue_updated',
|
||||
issue_id='12345',
|
||||
issue_key='PROJ-123',
|
||||
user_email='user@company.com',
|
||||
display_name='Test User',
|
||||
account_id='user456',
|
||||
workspace_name='jira.company.com',
|
||||
base_api_url='https://jira.company.com',
|
||||
comment_body='',
|
||||
issue_title='Test Issue',
|
||||
issue_description='This is a test issue',
|
||||
)
|
||||
|
||||
|
||||
@@ -203,17 +180,16 @@ def jira_conversation():
|
||||
|
||||
@pytest.fixture
|
||||
def new_conversation_view(
|
||||
sample_webhook_payload, sample_user_auth, sample_jira_user, sample_jira_workspace
|
||||
sample_job_context, sample_user_auth, sample_jira_user, sample_jira_workspace
|
||||
):
|
||||
"""JiraNewConversationView instance for testing"""
|
||||
return JiraNewConversationView(
|
||||
payload=sample_webhook_payload,
|
||||
job_context=sample_job_context,
|
||||
saas_user_auth=sample_user_auth,
|
||||
jira_user=sample_jira_user,
|
||||
jira_workspace=sample_jira_workspace,
|
||||
selected_repo='test/repo1',
|
||||
conversation_id='conv-123',
|
||||
_decrypted_api_key='decrypted_key',
|
||||
)
|
||||
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -2,90 +2,29 @@
|
||||
Tests for Jira view classes and factory.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
from integrations.jira.jira_payload import (
|
||||
JiraEventType,
|
||||
JiraPayloadError,
|
||||
JiraPayloadParser,
|
||||
JiraPayloadSkipped,
|
||||
JiraPayloadSuccess,
|
||||
)
|
||||
from integrations.jira.jira_types import RepositoryNotFoundError, StartingConvoException
|
||||
from integrations.jira.jira_types import StartingConvoException
|
||||
from integrations.jira.jira_view import (
|
||||
JiraFactory,
|
||||
JiraNewConversationView,
|
||||
)
|
||||
from integrations.models import Message, SourceType
|
||||
|
||||
|
||||
class TestJiraNewConversationView:
|
||||
"""Tests for JiraNewConversationView"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_issue_details_success(
|
||||
self, new_conversation_view, sample_jira_workspace
|
||||
):
|
||||
"""Test successful issue details retrieval."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
'fields': {'summary': 'Test Issue', 'description': 'Test description'}
|
||||
}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
with patch('httpx.AsyncClient') as mock_client:
|
||||
mock_client.return_value.__aenter__.return_value.get = AsyncMock(
|
||||
return_value=mock_response
|
||||
)
|
||||
|
||||
title, description = await new_conversation_view.get_issue_details()
|
||||
|
||||
assert title == 'Test Issue'
|
||||
assert description == 'Test description'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_issue_details_cached(self, new_conversation_view):
|
||||
"""Test issue details are cached after first call."""
|
||||
new_conversation_view._issue_title = 'Cached Title'
|
||||
new_conversation_view._issue_description = 'Cached Description'
|
||||
|
||||
title, description = await new_conversation_view.get_issue_details()
|
||||
|
||||
assert title == 'Cached Title'
|
||||
assert description == 'Cached Description'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_issue_details_no_title(self, new_conversation_view):
|
||||
"""Test issue details with no title raises error."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
'fields': {'summary': '', 'description': 'Test description'}
|
||||
}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
with patch('httpx.AsyncClient') as mock_client:
|
||||
mock_client.return_value.__aenter__.return_value.get = AsyncMock(
|
||||
return_value=mock_response
|
||||
)
|
||||
|
||||
with pytest.raises(StartingConvoException, match='does not have a title'):
|
||||
await new_conversation_view.get_issue_details()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_instructions(self, new_conversation_view, mock_jinja_env):
|
||||
"""Test _get_instructions method fetches issue details."""
|
||||
new_conversation_view._issue_title = 'Test Issue'
|
||||
new_conversation_view._issue_description = 'This is a test issue'
|
||||
|
||||
instructions, user_msg = await new_conversation_view._get_instructions(
|
||||
mock_jinja_env
|
||||
)
|
||||
def test_get_instructions(self, new_conversation_view, mock_jinja_env):
|
||||
"""Test _get_instructions method"""
|
||||
instructions, user_msg = new_conversation_view._get_instructions(mock_jinja_env)
|
||||
|
||||
assert instructions == 'Test Jira instructions template'
|
||||
assert 'TEST-123' in user_msg
|
||||
assert 'Test Issue' in user_msg
|
||||
assert 'Fix this bug @openhands' in user_msg
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('integrations.jira.jira_view.create_new_conversation')
|
||||
@patch('integrations.jira.jira_view.integration_store')
|
||||
async def test_create_or_update_conversation_success(
|
||||
@@ -97,8 +36,6 @@ class TestJiraNewConversationView:
|
||||
mock_agent_loop_info,
|
||||
):
|
||||
"""Test successful conversation creation"""
|
||||
new_conversation_view._issue_title = 'Test Issue'
|
||||
new_conversation_view._issue_description = 'Test description'
|
||||
mock_create_conversation.return_value = mock_agent_loop_info
|
||||
mock_store.create_conversation = AsyncMock()
|
||||
|
||||
@@ -110,7 +47,6 @@ class TestJiraNewConversationView:
|
||||
mock_create_conversation.assert_called_once()
|
||||
mock_store.create_conversation.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_or_update_conversation_no_repo(
|
||||
self, new_conversation_view, mock_jinja_env
|
||||
):
|
||||
@@ -120,6 +56,18 @@ class TestJiraNewConversationView:
|
||||
with pytest.raises(StartingConvoException, match='No repository selected'):
|
||||
await new_conversation_view.create_or_update_conversation(mock_jinja_env)
|
||||
|
||||
@patch('integrations.jira.jira_view.create_new_conversation')
|
||||
async def test_create_or_update_conversation_failure(
|
||||
self, mock_create_conversation, new_conversation_view, mock_jinja_env
|
||||
):
|
||||
"""Test conversation creation failure"""
|
||||
mock_create_conversation.side_effect = Exception('Creation failed')
|
||||
|
||||
with pytest.raises(
|
||||
StartingConvoException, match='Failed to create conversation'
|
||||
):
|
||||
await new_conversation_view.create_or_update_conversation(mock_jinja_env)
|
||||
|
||||
def test_get_response_msg(self, new_conversation_view):
|
||||
"""Test get_response_msg method"""
|
||||
response = new_conversation_view.get_response_msg()
|
||||
@@ -133,333 +81,469 @@ class TestJiraNewConversationView:
|
||||
class TestJiraFactory:
|
||||
"""Tests for JiraFactory"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('integrations.jira.jira_view.JiraFactory._create_provider_handler')
|
||||
@patch('integrations.jira.jira_view.infer_repo_from_message')
|
||||
async def test_create_view_success(
|
||||
@patch('integrations.jira.jira_view.integration_store')
|
||||
async def test_create_jira_view_from_payload_new_conversation(
|
||||
self,
|
||||
mock_infer_repos,
|
||||
mock_create_handler,
|
||||
sample_webhook_payload,
|
||||
sample_user_auth,
|
||||
sample_jira_user,
|
||||
sample_jira_workspace,
|
||||
sample_repositories,
|
||||
):
|
||||
"""Test factory creating view with repo selection."""
|
||||
# Setup mock provider handler
|
||||
mock_handler = MagicMock()
|
||||
mock_handler.verify_repo_provider = AsyncMock(
|
||||
return_value=sample_repositories[0]
|
||||
)
|
||||
mock_create_handler.return_value = mock_handler
|
||||
|
||||
# Mock repo inference to return a repo name
|
||||
mock_infer_repos.return_value = ['test/repo1']
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
'fields': {'summary': 'Test Issue', 'description': 'Test description'}
|
||||
}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
with patch('httpx.AsyncClient') as mock_client:
|
||||
mock_client.return_value.__aenter__.return_value.get = AsyncMock(
|
||||
return_value=mock_response
|
||||
)
|
||||
|
||||
view = await JiraFactory.create_view(
|
||||
payload=sample_webhook_payload,
|
||||
workspace=sample_jira_workspace,
|
||||
user=sample_jira_user,
|
||||
user_auth=sample_user_auth,
|
||||
decrypted_api_key='test_api_key',
|
||||
)
|
||||
|
||||
assert isinstance(view, JiraNewConversationView)
|
||||
assert view.selected_repo == 'test/repo1'
|
||||
mock_handler.verify_repo_provider.assert_called_once_with('test/repo1')
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('integrations.jira.jira_view.JiraFactory._create_provider_handler')
|
||||
@patch('integrations.jira.jira_view.infer_repo_from_message')
|
||||
async def test_create_view_no_repo_in_text(
|
||||
self,
|
||||
mock_infer_repos,
|
||||
mock_create_handler,
|
||||
sample_webhook_payload,
|
||||
mock_store,
|
||||
sample_job_context,
|
||||
sample_user_auth,
|
||||
sample_jira_user,
|
||||
sample_jira_workspace,
|
||||
):
|
||||
"""Test factory raises error when no repo mentioned in text."""
|
||||
mock_handler = MagicMock()
|
||||
mock_create_handler.return_value = mock_handler
|
||||
"""Test factory creating new conversation view"""
|
||||
mock_store.get_user_conversations_by_issue_id = AsyncMock(return_value=None)
|
||||
|
||||
# No repos found in text
|
||||
mock_infer_repos.return_value = []
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
'fields': {'summary': 'Test Issue', 'description': 'Test description'}
|
||||
}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
with patch('httpx.AsyncClient') as mock_client:
|
||||
mock_client.return_value.__aenter__.return_value.get = AsyncMock(
|
||||
return_value=mock_response
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
RepositoryNotFoundError, match='Could not determine which repository'
|
||||
):
|
||||
await JiraFactory.create_view(
|
||||
payload=sample_webhook_payload,
|
||||
workspace=sample_jira_workspace,
|
||||
user=sample_jira_user,
|
||||
user_auth=sample_user_auth,
|
||||
decrypted_api_key='test_api_key',
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('integrations.jira.jira_view.JiraFactory._create_provider_handler')
|
||||
@patch('integrations.jira.jira_view.infer_repo_from_message')
|
||||
async def test_create_view_repo_verification_fails(
|
||||
self,
|
||||
mock_infer_repos,
|
||||
mock_create_handler,
|
||||
sample_webhook_payload,
|
||||
sample_user_auth,
|
||||
sample_jira_user,
|
||||
sample_jira_workspace,
|
||||
):
|
||||
"""Test factory raises error when repo verification fails."""
|
||||
mock_handler = MagicMock()
|
||||
mock_handler.verify_repo_provider = AsyncMock(
|
||||
side_effect=Exception('Repository not found')
|
||||
)
|
||||
mock_create_handler.return_value = mock_handler
|
||||
|
||||
# Repos found in text but verification fails
|
||||
mock_infer_repos.return_value = ['test/repo1', 'test/repo2']
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
'fields': {'summary': 'Test Issue', 'description': 'Test description'}
|
||||
}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
with patch('httpx.AsyncClient') as mock_client:
|
||||
mock_client.return_value.__aenter__.return_value.get = AsyncMock(
|
||||
return_value=mock_response
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
RepositoryNotFoundError,
|
||||
match='Could not access any of the mentioned repositories',
|
||||
):
|
||||
await JiraFactory.create_view(
|
||||
payload=sample_webhook_payload,
|
||||
workspace=sample_jira_workspace,
|
||||
user=sample_jira_user,
|
||||
user_auth=sample_user_auth,
|
||||
decrypted_api_key='test_api_key',
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('integrations.jira.jira_view.JiraFactory._create_provider_handler')
|
||||
@patch('integrations.jira.jira_view.infer_repo_from_message')
|
||||
async def test_create_view_multiple_repos_verified(
|
||||
self,
|
||||
mock_infer_repos,
|
||||
mock_create_handler,
|
||||
sample_webhook_payload,
|
||||
sample_user_auth,
|
||||
sample_jira_user,
|
||||
sample_jira_workspace,
|
||||
sample_repositories,
|
||||
):
|
||||
"""Test factory raises error when multiple repos are verified."""
|
||||
mock_handler = MagicMock()
|
||||
# Both repos verify successfully
|
||||
mock_handler.verify_repo_provider = AsyncMock(
|
||||
side_effect=[sample_repositories[0], sample_repositories[1]]
|
||||
)
|
||||
mock_create_handler.return_value = mock_handler
|
||||
|
||||
# Multiple repos found in text
|
||||
mock_infer_repos.return_value = ['test/repo1', 'test/repo2']
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
'fields': {'summary': 'Test Issue', 'description': 'Test description'}
|
||||
}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
with patch('httpx.AsyncClient') as mock_client:
|
||||
mock_client.return_value.__aenter__.return_value.get = AsyncMock(
|
||||
return_value=mock_response
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
RepositoryNotFoundError, match='Multiple repositories found'
|
||||
):
|
||||
await JiraFactory.create_view(
|
||||
payload=sample_webhook_payload,
|
||||
workspace=sample_jira_workspace,
|
||||
user=sample_jira_user,
|
||||
user_auth=sample_user_auth,
|
||||
decrypted_api_key='test_api_key',
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('integrations.jira.jira_view.JiraFactory._create_provider_handler')
|
||||
async def test_create_view_no_provider(
|
||||
self,
|
||||
mock_create_handler,
|
||||
sample_webhook_payload,
|
||||
sample_user_auth,
|
||||
sample_jira_user,
|
||||
sample_jira_workspace,
|
||||
):
|
||||
"""Test factory raises error when no provider is connected."""
|
||||
mock_create_handler.return_value = None
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
'fields': {'summary': 'Test Issue', 'description': 'Test description'}
|
||||
}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
with patch('httpx.AsyncClient') as mock_client:
|
||||
mock_client.return_value.__aenter__.return_value.get = AsyncMock(
|
||||
return_value=mock_response
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
RepositoryNotFoundError, match='No Git provider connected'
|
||||
):
|
||||
await JiraFactory.create_view(
|
||||
payload=sample_webhook_payload,
|
||||
workspace=sample_jira_workspace,
|
||||
user=sample_jira_user,
|
||||
user_auth=sample_user_auth,
|
||||
decrypted_api_key='test_api_key',
|
||||
)
|
||||
|
||||
|
||||
class TestJiraPayloadParser:
|
||||
"""Tests for JiraPayloadParser"""
|
||||
|
||||
@pytest.fixture
|
||||
def parser(self):
|
||||
"""Create a parser for testing."""
|
||||
return JiraPayloadParser(oh_label='openhands', inline_oh_label='@openhands')
|
||||
|
||||
def test_parse_label_event_success(
|
||||
self, parser, sample_issue_update_webhook_payload
|
||||
):
|
||||
"""Test parsing label event."""
|
||||
result = parser.parse(sample_issue_update_webhook_payload)
|
||||
|
||||
assert isinstance(result, JiraPayloadSuccess)
|
||||
assert result.payload.event_type == JiraEventType.LABELED_TICKET
|
||||
assert result.payload.issue_key == 'PROJ-123'
|
||||
|
||||
def test_parse_comment_event_success(self, parser, sample_comment_webhook_payload):
|
||||
"""Test parsing comment event."""
|
||||
result = parser.parse(sample_comment_webhook_payload)
|
||||
|
||||
assert isinstance(result, JiraPayloadSuccess)
|
||||
assert result.payload.event_type == JiraEventType.COMMENT_MENTION
|
||||
assert result.payload.issue_key == 'TEST-123'
|
||||
assert '@openhands' in result.payload.comment_body
|
||||
|
||||
def test_parse_unknown_event_skipped(self, parser):
|
||||
"""Test unknown event is skipped."""
|
||||
payload = {'webhookEvent': 'unknown_event'}
|
||||
result = parser.parse(payload)
|
||||
|
||||
assert isinstance(result, JiraPayloadSkipped)
|
||||
assert 'Unhandled webhook event type' in result.skip_reason
|
||||
|
||||
def test_parse_label_event_wrong_label_skipped(self, parser):
|
||||
"""Test label event without OH label is skipped."""
|
||||
payload = {
|
||||
'webhookEvent': 'jira:issue_updated',
|
||||
'changelog': {'items': [{'field': 'labels', 'toString': 'other-label'}]},
|
||||
}
|
||||
result = parser.parse(payload)
|
||||
|
||||
assert isinstance(result, JiraPayloadSkipped)
|
||||
assert 'does not contain' in result.skip_reason
|
||||
|
||||
def test_parse_comment_event_no_mention_skipped(self, parser):
|
||||
"""Test comment without mention is skipped."""
|
||||
payload = {
|
||||
'webhookEvent': 'comment_created',
|
||||
'comment': {
|
||||
'body': 'Regular comment',
|
||||
'author': {'emailAddress': 'test@test.com'},
|
||||
},
|
||||
}
|
||||
result = parser.parse(payload)
|
||||
|
||||
assert isinstance(result, JiraPayloadSkipped)
|
||||
assert 'does not mention' in result.skip_reason
|
||||
|
||||
def test_parse_missing_fields_error(self, parser):
|
||||
"""Test missing required fields returns error."""
|
||||
payload = {
|
||||
'webhookEvent': 'jira:issue_updated',
|
||||
'changelog': {'items': [{'field': 'labels', 'toString': 'openhands'}]},
|
||||
'issue': {'id': '123'}, # Missing key
|
||||
'user': {'emailAddress': 'test@test.com'}, # Missing other fields
|
||||
}
|
||||
result = parser.parse(payload)
|
||||
|
||||
assert isinstance(result, JiraPayloadError)
|
||||
assert 'Missing required fields' in result.error
|
||||
|
||||
|
||||
class TestJiraPayloadParserStagingLabels:
|
||||
"""Tests for JiraPayloadParser with staging labels."""
|
||||
|
||||
@pytest.fixture
|
||||
def staging_parser(self):
|
||||
"""Create a parser with staging labels."""
|
||||
return JiraPayloadParser(
|
||||
oh_label='openhands-exp', inline_oh_label='@openhands-exp'
|
||||
view = await JiraFactory.create_jira_view_from_payload(
|
||||
sample_job_context,
|
||||
sample_user_auth,
|
||||
sample_jira_user,
|
||||
sample_jira_workspace,
|
||||
)
|
||||
|
||||
def test_parse_staging_label(self, staging_parser):
|
||||
"""Test parsing with staging label."""
|
||||
payload = {
|
||||
'webhookEvent': 'jira:issue_updated',
|
||||
'changelog': {'items': [{'field': 'labels', 'toString': 'openhands-exp'}]},
|
||||
'issue': {
|
||||
'id': '123',
|
||||
'key': 'TEST-1',
|
||||
'self': 'https://test.atlassian.net/rest/api/2/issue/123',
|
||||
},
|
||||
'user': {
|
||||
'emailAddress': 'test@test.com',
|
||||
'displayName': 'Test',
|
||||
'accountId': 'acc123',
|
||||
'self': 'https://test.atlassian.net/rest/api/2/user',
|
||||
},
|
||||
}
|
||||
result = staging_parser.parse(payload)
|
||||
assert isinstance(view, JiraNewConversationView)
|
||||
assert view.conversation_id == ''
|
||||
|
||||
assert isinstance(result, JiraPayloadSuccess)
|
||||
assert result.payload.event_type == JiraEventType.LABELED_TICKET
|
||||
async def test_create_jira_view_from_payload_no_user(
|
||||
self, sample_job_context, sample_user_auth, sample_jira_workspace
|
||||
):
|
||||
"""Test factory with no Jira user"""
|
||||
with pytest.raises(StartingConvoException, match='User not authenticated'):
|
||||
await JiraFactory.create_jira_view_from_payload(
|
||||
sample_job_context,
|
||||
sample_user_auth,
|
||||
None,
|
||||
sample_jira_workspace, # type: ignore
|
||||
)
|
||||
|
||||
def test_parse_prod_label_in_staging_skipped(self, staging_parser):
|
||||
"""Test prod label is skipped in staging environment."""
|
||||
payload = {
|
||||
'webhookEvent': 'jira:issue_updated',
|
||||
'changelog': {'items': [{'field': 'labels', 'toString': 'openhands'}]},
|
||||
}
|
||||
result = staging_parser.parse(payload)
|
||||
async def test_create_jira_view_from_payload_no_auth(
|
||||
self, sample_job_context, sample_jira_user, sample_jira_workspace
|
||||
):
|
||||
"""Test factory with no SaaS auth"""
|
||||
with pytest.raises(StartingConvoException, match='User not authenticated'):
|
||||
await JiraFactory.create_jira_view_from_payload(
|
||||
sample_job_context,
|
||||
None,
|
||||
sample_jira_user,
|
||||
sample_jira_workspace, # type: ignore
|
||||
)
|
||||
|
||||
assert isinstance(result, JiraPayloadSkipped)
|
||||
async def test_create_jira_view_from_payload_no_workspace(
|
||||
self, sample_job_context, sample_user_auth, sample_jira_user
|
||||
):
|
||||
"""Test factory with no workspace"""
|
||||
with pytest.raises(StartingConvoException, match='User not authenticated'):
|
||||
await JiraFactory.create_jira_view_from_payload(
|
||||
sample_job_context,
|
||||
sample_user_auth,
|
||||
sample_jira_user,
|
||||
None, # type: ignore
|
||||
)
|
||||
|
||||
|
||||
class TestJiraViewEdgeCases:
|
||||
"""Tests for edge cases and error scenarios"""
|
||||
|
||||
@patch('integrations.jira.jira_view.create_new_conversation')
|
||||
@patch('integrations.jira.jira_view.integration_store')
|
||||
async def test_conversation_creation_with_no_user_secrets(
|
||||
self,
|
||||
mock_store,
|
||||
mock_create_conversation,
|
||||
new_conversation_view,
|
||||
mock_jinja_env,
|
||||
mock_agent_loop_info,
|
||||
):
|
||||
"""Test conversation creation when user has no secrets"""
|
||||
new_conversation_view.saas_user_auth.get_secrets.return_value = None
|
||||
mock_create_conversation.return_value = mock_agent_loop_info
|
||||
mock_store.create_conversation = AsyncMock()
|
||||
|
||||
result = await new_conversation_view.create_or_update_conversation(
|
||||
mock_jinja_env
|
||||
)
|
||||
|
||||
assert result == 'conv-123'
|
||||
# Verify create_new_conversation was called with custom_secrets=None
|
||||
call_kwargs = mock_create_conversation.call_args[1]
|
||||
assert call_kwargs['custom_secrets'] is None
|
||||
|
||||
@patch('integrations.jira.jira_view.create_new_conversation')
|
||||
@patch('integrations.jira.jira_view.integration_store')
|
||||
async def test_conversation_creation_store_failure(
|
||||
self,
|
||||
mock_store,
|
||||
mock_create_conversation,
|
||||
new_conversation_view,
|
||||
mock_jinja_env,
|
||||
mock_agent_loop_info,
|
||||
):
|
||||
"""Test conversation creation when store creation fails"""
|
||||
mock_create_conversation.return_value = mock_agent_loop_info
|
||||
mock_store.create_conversation = AsyncMock(side_effect=Exception('Store error'))
|
||||
|
||||
with pytest.raises(
|
||||
StartingConvoException, match='Failed to create conversation'
|
||||
):
|
||||
await new_conversation_view.create_or_update_conversation(mock_jinja_env)
|
||||
|
||||
def test_new_conversation_view_attributes(self, new_conversation_view):
|
||||
"""Test new conversation view attribute access"""
|
||||
assert new_conversation_view.job_context.issue_key == 'TEST-123'
|
||||
assert new_conversation_view.selected_repo == 'test/repo1'
|
||||
assert new_conversation_view.conversation_id == 'conv-123'
|
||||
|
||||
|
||||
class TestJiraFactoryIsLabeledTicket:
|
||||
"""Parameterized tests for JiraFactory.is_labeled_ticket method."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'payload,expected',
|
||||
[
|
||||
pytest.param(
|
||||
{
|
||||
'webhookEvent': 'jira:issue_updated',
|
||||
'changelog': {
|
||||
'items': [{'field': 'labels', 'toString': 'openhands'}]
|
||||
},
|
||||
},
|
||||
True,
|
||||
id='issue_updated_with_openhands_label',
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
'webhookEvent': 'jira:issue_updated',
|
||||
'changelog': {
|
||||
'items': [
|
||||
{'field': 'labels', 'toString': 'bug'},
|
||||
{'field': 'labels', 'toString': 'openhands'},
|
||||
]
|
||||
},
|
||||
},
|
||||
True,
|
||||
id='issue_updated_with_multiple_labels_including_openhands',
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
'webhookEvent': 'jira:issue_updated',
|
||||
'changelog': {
|
||||
'items': [{'field': 'labels', 'toString': 'bug,urgent'}]
|
||||
},
|
||||
},
|
||||
False,
|
||||
id='issue_updated_without_openhands_label',
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
'webhookEvent': 'jira:issue_updated',
|
||||
'changelog': {'items': []},
|
||||
},
|
||||
False,
|
||||
id='issue_updated_with_empty_changelog_items',
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
'webhookEvent': 'jira:issue_updated',
|
||||
'changelog': {},
|
||||
},
|
||||
False,
|
||||
id='issue_updated_with_empty_changelog',
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
'webhookEvent': 'jira:issue_updated',
|
||||
},
|
||||
False,
|
||||
id='issue_updated_without_changelog',
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
'webhookEvent': 'comment_created',
|
||||
'changelog': {
|
||||
'items': [{'field': 'labels', 'toString': 'openhands'}]
|
||||
},
|
||||
},
|
||||
False,
|
||||
id='comment_created_event_with_label',
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
'webhookEvent': 'issue_deleted',
|
||||
'changelog': {
|
||||
'items': [{'field': 'labels', 'toString': 'openhands'}]
|
||||
},
|
||||
},
|
||||
False,
|
||||
id='unsupported_event_type',
|
||||
),
|
||||
pytest.param(
|
||||
{},
|
||||
False,
|
||||
id='empty_payload',
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
'webhookEvent': 'jira:issue_updated',
|
||||
'changelog': {
|
||||
'items': [{'field': 'status', 'toString': 'In Progress'}]
|
||||
},
|
||||
},
|
||||
False,
|
||||
id='issue_updated_with_non_label_field',
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
'webhookEvent': 'jira:issue_updated',
|
||||
'changelog': {
|
||||
'items': [{'field': 'labels', 'fromString': 'openhands'}]
|
||||
},
|
||||
},
|
||||
False,
|
||||
id='issue_updated_with_fromString_instead_of_toString',
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
'webhookEvent': 'jira:issue_updated',
|
||||
'changelog': {
|
||||
'items': [
|
||||
{'field': 'labels', 'toString': 'not-openhands'},
|
||||
{'field': 'priority', 'toString': 'High'},
|
||||
]
|
||||
},
|
||||
},
|
||||
False,
|
||||
id='issue_updated_with_mixed_fields_no_openhands',
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_is_labeled_ticket(self, payload, expected):
|
||||
"""Test is_labeled_ticket with various payloads."""
|
||||
with patch('integrations.jira.jira_view.OH_LABEL', 'openhands'):
|
||||
message = Message(source=SourceType.JIRA, message={'payload': payload})
|
||||
result = JiraFactory.is_labeled_ticket(message)
|
||||
assert result == expected
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'payload,expected',
|
||||
[
|
||||
pytest.param(
|
||||
{
|
||||
'webhookEvent': 'jira:issue_updated',
|
||||
'changelog': {
|
||||
'items': [{'field': 'labels', 'toString': 'openhands-exp'}]
|
||||
},
|
||||
},
|
||||
True,
|
||||
id='issue_updated_with_staging_label',
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
'webhookEvent': 'jira:issue_updated',
|
||||
'changelog': {
|
||||
'items': [{'field': 'labels', 'toString': 'openhands'}]
|
||||
},
|
||||
},
|
||||
False,
|
||||
id='issue_updated_with_prod_label_in_staging_env',
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_is_labeled_ticket_staging_labels(self, payload, expected):
|
||||
"""Test is_labeled_ticket with staging environment labels."""
|
||||
with patch('integrations.jira.jira_view.OH_LABEL', 'openhands-exp'):
|
||||
message = Message(source=SourceType.JIRA, message={'payload': payload})
|
||||
result = JiraFactory.is_labeled_ticket(message)
|
||||
assert result == expected
|
||||
|
||||
|
||||
class TestJiraFactoryIsTicketComment:
|
||||
"""Parameterized tests for JiraFactory.is_ticket_comment method."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'payload,expected',
|
||||
[
|
||||
pytest.param(
|
||||
{
|
||||
'webhookEvent': 'comment_created',
|
||||
'comment': {'body': 'Please fix this @openhands'},
|
||||
},
|
||||
True,
|
||||
id='comment_with_openhands_mention',
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
'webhookEvent': 'comment_created',
|
||||
'comment': {'body': '@openhands please help'},
|
||||
},
|
||||
True,
|
||||
id='comment_starting_with_openhands_mention',
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
'webhookEvent': 'comment_created',
|
||||
'comment': {'body': 'Hello @openhands!'},
|
||||
},
|
||||
True,
|
||||
id='comment_with_openhands_mention_and_punctuation',
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
'webhookEvent': 'comment_created',
|
||||
'comment': {'body': '(@openhands)'},
|
||||
},
|
||||
True,
|
||||
id='comment_with_openhands_in_parentheses',
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
'webhookEvent': 'comment_created',
|
||||
'comment': {'body': 'Hey @OpenHands can you help?'},
|
||||
},
|
||||
True,
|
||||
id='comment_with_case_insensitive_mention',
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
'webhookEvent': 'comment_created',
|
||||
'comment': {'body': 'Hey @OPENHANDS!'},
|
||||
},
|
||||
True,
|
||||
id='comment_with_uppercase_mention',
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
'webhookEvent': 'comment_created',
|
||||
'comment': {'body': 'Regular comment without mention'},
|
||||
},
|
||||
False,
|
||||
id='comment_without_mention',
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
'webhookEvent': 'comment_created',
|
||||
'comment': {'body': 'Hello @openhands-agent!'},
|
||||
},
|
||||
False,
|
||||
id='comment_with_openhands_as_prefix',
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
'webhookEvent': 'comment_created',
|
||||
'comment': {'body': 'user@openhands.com'},
|
||||
},
|
||||
False,
|
||||
id='comment_with_openhands_in_email',
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
'webhookEvent': 'comment_created',
|
||||
'comment': {'body': ''},
|
||||
},
|
||||
False,
|
||||
id='comment_with_empty_body',
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
'webhookEvent': 'comment_created',
|
||||
'comment': {},
|
||||
},
|
||||
False,
|
||||
id='comment_without_body',
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
'webhookEvent': 'comment_created',
|
||||
},
|
||||
False,
|
||||
id='comment_created_without_comment_data',
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
'webhookEvent': 'jira:issue_updated',
|
||||
'comment': {'body': 'Please fix this @openhands'},
|
||||
},
|
||||
False,
|
||||
id='issue_updated_event_with_mention',
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
'webhookEvent': 'issue_deleted',
|
||||
'comment': {'body': '@openhands'},
|
||||
},
|
||||
False,
|
||||
id='unsupported_event_type_with_mention',
|
||||
),
|
||||
pytest.param(
|
||||
{},
|
||||
False,
|
||||
id='empty_payload',
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
'webhookEvent': 'comment_created',
|
||||
'comment': {'body': 'Multiple @openhands @openhands mentions'},
|
||||
},
|
||||
True,
|
||||
id='comment_with_multiple_mentions',
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_is_ticket_comment(self, payload, expected):
|
||||
"""Test is_ticket_comment with various payloads."""
|
||||
with patch('integrations.jira.jira_view.INLINE_OH_LABEL', '@openhands'), patch(
|
||||
'integrations.jira.jira_view.has_exact_mention'
|
||||
) as mock_has_exact_mention:
|
||||
from integrations.utils import has_exact_mention
|
||||
|
||||
mock_has_exact_mention.side_effect = (
|
||||
lambda text, mention: has_exact_mention(text, mention)
|
||||
)
|
||||
|
||||
message = Message(source=SourceType.JIRA, message={'payload': payload})
|
||||
result = JiraFactory.is_ticket_comment(message)
|
||||
assert result == expected
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'payload,expected',
|
||||
[
|
||||
pytest.param(
|
||||
{
|
||||
'webhookEvent': 'comment_created',
|
||||
'comment': {'body': 'Please fix this @openhands-exp'},
|
||||
},
|
||||
True,
|
||||
id='comment_with_staging_mention',
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
'webhookEvent': 'comment_created',
|
||||
'comment': {'body': '@openhands-exp please help'},
|
||||
},
|
||||
True,
|
||||
id='comment_starting_with_staging_mention',
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
'webhookEvent': 'comment_created',
|
||||
'comment': {'body': 'Please fix this @openhands'},
|
||||
},
|
||||
False,
|
||||
id='comment_with_prod_mention_in_staging_env',
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_is_ticket_comment_staging_labels(self, payload, expected):
|
||||
"""Test is_ticket_comment with staging environment labels."""
|
||||
with patch(
|
||||
'integrations.jira.jira_view.INLINE_OH_LABEL', '@openhands-exp'
|
||||
), patch(
|
||||
'integrations.jira.jira_view.has_exact_mention'
|
||||
) as mock_has_exact_mention:
|
||||
from integrations.utils import has_exact_mention
|
||||
|
||||
mock_has_exact_mention.side_effect = (
|
||||
lambda text, mention: has_exact_mention(text, mention)
|
||||
)
|
||||
|
||||
message = Message(source=SourceType.JIRA, message={'payload': payload})
|
||||
result = JiraFactory.is_ticket_comment(message)
|
||||
assert result == expected
|
||||
|
||||
@@ -1,96 +0,0 @@
|
||||
from unittest.mock import patch
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
from pydantic import SecretStr
|
||||
from server.routes.github_proxy import add_github_proxy_routes
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app_with_github_proxy(monkeypatch):
|
||||
"""Create a FastAPI app with github proxy routes enabled."""
|
||||
# Enable the github proxy endpoints
|
||||
monkeypatch.setenv('GITHUB_PROXY_ENDPOINTS', '1')
|
||||
|
||||
# Mock the config to have a jwt_secret
|
||||
mock_config = type(
|
||||
'MockConfig', (), {'jwt_secret': SecretStr('test-secret-key-for-testing')}
|
||||
)()
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
with patch('server.routes.github_proxy.GITHUB_PROXY_ENDPOINTS', True):
|
||||
with patch('server.routes.github_proxy.config', mock_config):
|
||||
add_github_proxy_routes(app)
|
||||
|
||||
# Return app and mock_config so we can use the same config in tests
|
||||
return app, mock_config
|
||||
|
||||
|
||||
def test_state_compress_encrypt_and_decrypt_decompress_roundtrip(
|
||||
app_with_github_proxy, monkeypatch
|
||||
):
|
||||
"""
|
||||
Verify the code path used by github_proxy_start -> github_proxy_callback:
|
||||
- compress payload, encrypt, base64-encode (what the start code does)
|
||||
- base64-decode, decrypt, decompress (what the callback code does)
|
||||
|
||||
This test exercises the actual endpoints to verify the roundtrip works correctly.
|
||||
"""
|
||||
app, mock_config = app_with_github_proxy
|
||||
client = TestClient(app)
|
||||
|
||||
original_state = 'some-state-value'
|
||||
original_redirect_uri = 'https://example.com/redirect'
|
||||
|
||||
# Call github_proxy_start endpoint - it should redirect to GitHub with encrypted state
|
||||
with patch('server.routes.github_proxy.config', mock_config):
|
||||
response = client.get(
|
||||
'/github-proxy/test-subdomain/login/oauth/authorize',
|
||||
params={
|
||||
'state': original_state,
|
||||
'redirect_uri': original_redirect_uri,
|
||||
'client_id': 'test-client-id',
|
||||
},
|
||||
follow_redirects=False,
|
||||
)
|
||||
|
||||
assert response.status_code == 307
|
||||
redirect_url = response.headers['location']
|
||||
|
||||
# Verify it redirects to GitHub
|
||||
assert redirect_url.startswith('https://github.com/login/oauth/authorize')
|
||||
|
||||
# Parse the redirect URL to get the encrypted state
|
||||
parsed = urlparse(redirect_url)
|
||||
query_params = parse_qs(parsed.query)
|
||||
encrypted_state = query_params['state'][0]
|
||||
|
||||
# The redirect_uri should now point to our callback
|
||||
assert 'github-proxy/callback' in query_params['redirect_uri'][0]
|
||||
|
||||
# Now simulate GitHub calling back with this encrypted state
|
||||
with patch('server.routes.github_proxy.config', mock_config):
|
||||
callback_response = client.get(
|
||||
'/github-proxy/callback',
|
||||
params={
|
||||
'state': encrypted_state,
|
||||
'code': 'test-auth-code',
|
||||
},
|
||||
follow_redirects=False,
|
||||
)
|
||||
|
||||
assert callback_response.status_code == 307
|
||||
final_redirect = callback_response.headers['location']
|
||||
|
||||
# Verify the callback redirects to the original redirect_uri
|
||||
assert final_redirect.startswith(original_redirect_uri)
|
||||
|
||||
# Parse the final redirect to verify the state was decrypted correctly
|
||||
final_parsed = urlparse(final_redirect)
|
||||
final_params = parse_qs(final_parsed.query)
|
||||
|
||||
assert final_params['state'][0] == original_state
|
||||
assert final_params['code'][0] == 'test-auth-code'
|
||||
File diff suppressed because it is too large
Load Diff
@@ -163,7 +163,7 @@ async def test_create_checkout_session_stripe_error(
|
||||
'server.auth.token_manager.TokenManager.get_user_info_from_user_id',
|
||||
AsyncMock(return_value={'email': 'testy@tester.com'}),
|
||||
),
|
||||
patch('server.routes.billing.validate_billing_enabled'),
|
||||
patch('server.routes.billing.validate_saas_environment'),
|
||||
):
|
||||
await create_checkout_session(
|
||||
CreateCheckoutSessionRequest(amount=25), mock_checkout_request, 'mock_user'
|
||||
@@ -204,7 +204,7 @@ async def test_create_checkout_session_success(session_maker, mock_checkout_requ
|
||||
'server.auth.token_manager.TokenManager.get_user_info_from_user_id',
|
||||
AsyncMock(return_value={'email': 'testy@tester.com'}),
|
||||
),
|
||||
patch('server.routes.billing.validate_billing_enabled'),
|
||||
patch('server.routes.billing.validate_saas_environment'),
|
||||
):
|
||||
mock_db_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_db_session
|
||||
@@ -236,8 +236,8 @@ async def test_create_checkout_session_success(session_maker, mock_checkout_requ
|
||||
mode='payment',
|
||||
payment_method_types=['card'],
|
||||
saved_payment_method_options={'payment_method_save': 'enabled'},
|
||||
success_url='https://test.com/api/billing/success?session_id={CHECKOUT_SESSION_ID}',
|
||||
cancel_url='https://test.com/api/billing/cancel?session_id={CHECKOUT_SESSION_ID}',
|
||||
success_url='http://test.com/api/billing/success?session_id={CHECKOUT_SESSION_ID}',
|
||||
cancel_url='http://test.com/api/billing/cancel?session_id={CHECKOUT_SESSION_ID}',
|
||||
)
|
||||
|
||||
# Verify database session creation
|
||||
@@ -331,7 +331,7 @@ async def test_success_callback_success():
|
||||
assert response.status_code == 302
|
||||
assert (
|
||||
response.headers['location']
|
||||
== 'https://test.com/settings/billing?checkout=success'
|
||||
== 'http://test.com/settings/billing?checkout=success'
|
||||
)
|
||||
|
||||
# Verify LiteLLM API calls
|
||||
@@ -402,7 +402,7 @@ async def test_cancel_callback_session_not_found():
|
||||
assert response.status_code == 302
|
||||
assert (
|
||||
response.headers['location']
|
||||
== 'https://test.com/settings/billing?checkout=cancel'
|
||||
== 'http://test.com/settings/billing?checkout=cancel'
|
||||
)
|
||||
|
||||
# Verify no database updates occurred
|
||||
@@ -429,7 +429,7 @@ async def test_cancel_callback_success():
|
||||
assert response.status_code == 302
|
||||
assert (
|
||||
response.headers['location']
|
||||
== 'https://test.com/settings/billing?checkout=cancel'
|
||||
== 'http://test.com/settings/billing?checkout=cancel'
|
||||
)
|
||||
|
||||
# Verify database updates
|
||||
@@ -490,7 +490,7 @@ async def test_create_customer_setup_session_success():
|
||||
AsyncMock(return_value=mock_customer_info),
|
||||
),
|
||||
patch('stripe.checkout.Session.create_async', mock_create),
|
||||
patch('server.routes.billing.validate_billing_enabled'),
|
||||
patch('server.routes.billing.validate_saas_environment'),
|
||||
):
|
||||
result = await create_customer_setup_session(mock_request, 'mock_user')
|
||||
|
||||
@@ -502,6 +502,6 @@ async def test_create_customer_setup_session_success():
|
||||
customer='mock-customer-id',
|
||||
mode='setup',
|
||||
payment_method_types=['card'],
|
||||
success_url='https://test.com/?free_credits=success',
|
||||
cancel_url='https://test.com/',
|
||||
success_url='http://test.com/?free_credits=success',
|
||||
cancel_url='http://test.com/',
|
||||
)
|
||||
|
||||
@@ -1126,174 +1126,3 @@ class TestLiteLlmManager:
|
||||
'http://test.url/team/delete',
|
||||
json={'team_ids': [team_id]},
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_remove_user_from_team_successful(self):
|
||||
"""
|
||||
GIVEN: Valid user_id and team_id
|
||||
WHEN: _remove_user_from_team is called
|
||||
THEN: HTTP POST is made to remove user from team
|
||||
"""
|
||||
mock_response = AsyncMock()
|
||||
mock_response.is_success = True
|
||||
mock_response.status_code = 200
|
||||
|
||||
with (
|
||||
patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'),
|
||||
patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.url'),
|
||||
):
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post.return_value = mock_response
|
||||
|
||||
await LiteLlmManager._remove_user_from_team(
|
||||
mock_client, 'test-user-id', 'test-team-id'
|
||||
)
|
||||
|
||||
mock_client.post.assert_called_once_with(
|
||||
'http://test.url/team/member_delete',
|
||||
json={
|
||||
'team_id': 'test-team-id',
|
||||
'user_id': 'test-user-id',
|
||||
},
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_remove_user_from_team_not_found(self):
|
||||
"""
|
||||
GIVEN: User not in team
|
||||
WHEN: _remove_user_from_team is called
|
||||
THEN: 404 response is handled gracefully without raising
|
||||
"""
|
||||
mock_response = AsyncMock()
|
||||
mock_response.is_success = False
|
||||
mock_response.status_code = 404
|
||||
mock_response.text = 'User not found in team'
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
with (
|
||||
patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'),
|
||||
patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.url'),
|
||||
):
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post.return_value = mock_response
|
||||
|
||||
# Should not raise an exception
|
||||
await LiteLlmManager._remove_user_from_team(
|
||||
mock_client, 'test-user-id', 'test-team-id'
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_downgrade_entries_missing_config(self, mock_user_settings):
|
||||
"""Test downgrade_entries when LiteLLM config is missing."""
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', None):
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_URL', None):
|
||||
result = await LiteLlmManager.downgrade_entries(
|
||||
'test-org-id',
|
||||
'test-user-id',
|
||||
mock_user_settings,
|
||||
)
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_downgrade_entries_team_not_found(self, mock_user_settings):
|
||||
"""Test downgrade_entries when team is not found."""
|
||||
with patch.dict(os.environ, {'LOCAL_DEPLOYMENT': ''}):
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'):
|
||||
with patch(
|
||||
'storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'
|
||||
):
|
||||
with patch.object(
|
||||
LiteLlmManager, '_get_team', new_callable=AsyncMock
|
||||
) as mock_get_team:
|
||||
mock_get_team.return_value = None
|
||||
|
||||
result = await LiteLlmManager.downgrade_entries(
|
||||
'test-org-id',
|
||||
'test-user-id',
|
||||
mock_user_settings,
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_downgrade_entries_successful(self, mock_user_settings):
|
||||
"""Test successful downgrade_entries operation."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.is_success = True
|
||||
mock_response.status_code = 200
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_team_info_response = MagicMock()
|
||||
mock_team_info_response.is_success = True
|
||||
mock_team_info_response.status_code = 200
|
||||
mock_team_info_response.json.return_value = {
|
||||
'team_info': {
|
||||
'max_budget': 100.0,
|
||||
'spend': 20.0,
|
||||
},
|
||||
'team_memberships': [
|
||||
{
|
||||
'user_id': 'test-user-id',
|
||||
'team_id': 'test-org-id',
|
||||
'max_budget_in_team': 100.0,
|
||||
'spend': 20.0,
|
||||
}
|
||||
],
|
||||
}
|
||||
mock_team_info_response.raise_for_status = MagicMock()
|
||||
|
||||
with patch.dict(os.environ, {'LOCAL_DEPLOYMENT': ''}):
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'):
|
||||
with patch(
|
||||
'storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'
|
||||
):
|
||||
with patch(
|
||||
'storage.lite_llm_manager.LITE_LLM_TEAM_ID', 'default-team'
|
||||
):
|
||||
with patch('httpx.AsyncClient') as mock_client_class:
|
||||
mock_client = AsyncMock()
|
||||
mock_client_class.return_value.__aenter__.return_value = (
|
||||
mock_client
|
||||
)
|
||||
mock_client.get.return_value = mock_team_info_response
|
||||
mock_client.post.return_value = mock_response
|
||||
|
||||
result = await LiteLlmManager.downgrade_entries(
|
||||
'test-org-id',
|
||||
'test-user-id',
|
||||
mock_user_settings,
|
||||
)
|
||||
|
||||
# downgrade_entries returns the user_settings
|
||||
assert result is not None
|
||||
assert result.agent == 'TestAgent'
|
||||
|
||||
# Verify downgrade steps were called:
|
||||
# 1. get_team (GET)
|
||||
# 2. get_user_team_info (GET via _get_team)
|
||||
# 3. update_user (POST)
|
||||
# 4. add_user_to_team (POST)
|
||||
# 5. update_key (POST)
|
||||
# 6. remove_user_from_team (POST)
|
||||
# 7. delete_team (POST)
|
||||
assert mock_client.get.call_count >= 1
|
||||
assert mock_client.post.call_count >= 4
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_downgrade_entries_local_deployment(self, mock_user_settings):
|
||||
"""Test downgrade_entries in local deployment mode (skips LiteLLM calls)."""
|
||||
with patch.dict(os.environ, {'LOCAL_DEPLOYMENT': 'true'}):
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'):
|
||||
with patch(
|
||||
'storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'
|
||||
):
|
||||
result = await LiteLlmManager.downgrade_entries(
|
||||
'test-org-id',
|
||||
'test-user-id',
|
||||
mock_user_settings,
|
||||
)
|
||||
|
||||
# In local deployment, should return user_settings without
|
||||
# making any LiteLLM calls
|
||||
assert result is not None
|
||||
assert result.agent == 'TestAgent'
|
||||
|
||||
@@ -68,84 +68,3 @@ def test_user_model(session_maker):
|
||||
)
|
||||
assert queried_org_member is not None
|
||||
assert queried_org_member.llm_api_key.get_secret_value() == 'test-api-key'
|
||||
|
||||
|
||||
def test_user_model_git_user_fields(session_maker):
|
||||
"""Test that git_user_name and git_user_email columns exist and work correctly."""
|
||||
with session_maker() as session:
|
||||
# Arrange
|
||||
org = Org(name='test_org_git')
|
||||
session.add(org)
|
||||
session.flush()
|
||||
|
||||
test_user_id = uuid4()
|
||||
|
||||
# Act
|
||||
user = User(
|
||||
id=test_user_id,
|
||||
current_org_id=org.id,
|
||||
git_user_name='Test Git Author',
|
||||
git_user_email='git@example.com',
|
||||
)
|
||||
session.add(user)
|
||||
session.commit()
|
||||
|
||||
# Assert
|
||||
queried_user = session.query(User).filter(User.id == test_user_id).first()
|
||||
assert queried_user.git_user_name == 'Test Git Author'
|
||||
assert queried_user.git_user_email == 'git@example.com'
|
||||
|
||||
|
||||
def test_user_model_git_user_fields_nullable(session_maker):
|
||||
"""Test that git_user_name and git_user_email can be null."""
|
||||
with session_maker() as session:
|
||||
# Arrange
|
||||
org = Org(name='test_org_nullable')
|
||||
session.add(org)
|
||||
session.flush()
|
||||
|
||||
test_user_id = uuid4()
|
||||
|
||||
# Act - create user without git fields
|
||||
user = User(
|
||||
id=test_user_id,
|
||||
current_org_id=org.id,
|
||||
)
|
||||
session.add(user)
|
||||
session.commit()
|
||||
|
||||
# Assert
|
||||
queried_user = session.query(User).filter(User.id == test_user_id).first()
|
||||
assert queried_user.git_user_name is None
|
||||
assert queried_user.git_user_email is None
|
||||
|
||||
|
||||
def test_user_model_git_user_fields_in_table_columns():
|
||||
"""Test that git_user_name and git_user_email are in User table columns."""
|
||||
# Arrange & Act
|
||||
column_names = [c.name for c in User.__table__.columns]
|
||||
|
||||
# Assert
|
||||
assert 'git_user_name' in column_names
|
||||
assert 'git_user_email' in column_names
|
||||
|
||||
|
||||
def test_user_model_git_user_fields_hasattr(session_maker):
|
||||
"""Test that hasattr returns True for git_user_* fields on User model.
|
||||
|
||||
This verifies the fix for SaasSettingsStore.store() which uses hasattr
|
||||
to determine if a field should be persisted to a model.
|
||||
"""
|
||||
with session_maker() as session:
|
||||
# Arrange
|
||||
org = Org(name='test_org_hasattr')
|
||||
session.add(org)
|
||||
session.flush()
|
||||
|
||||
user = User(id=uuid4(), current_org_id=org.id)
|
||||
session.add(user)
|
||||
session.flush()
|
||||
|
||||
# Assert - hasattr must return True for store() to work
|
||||
assert hasattr(user, 'git_user_name')
|
||||
assert hasattr(user, 'git_user_email')
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -415,374 +415,3 @@ def test_persist_org_with_owner_with_multiple_fields(session_maker, mock_litellm
|
||||
)
|
||||
assert persisted_member.max_iterations == 100
|
||||
assert persisted_member.llm_model == 'gpt-4'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_org_cascade_success(session_maker, mock_litellm_api):
|
||||
"""
|
||||
GIVEN: Valid organization with associated data
|
||||
WHEN: delete_org_cascade is called
|
||||
THEN: Organization and all associated data are deleted and org object is returned
|
||||
"""
|
||||
# Arrange
|
||||
org_id = uuid.uuid4()
|
||||
|
||||
# Create expected return object
|
||||
expected_org = Org(
|
||||
id=org_id,
|
||||
name='Test Organization',
|
||||
contact_name='John Doe',
|
||||
contact_email='john@example.com',
|
||||
)
|
||||
|
||||
# Mock delete_org_cascade to avoid database schema constraints
|
||||
async def mock_delete_org_cascade(org_id_param):
|
||||
# Verify the method was called with correct parameter
|
||||
assert org_id_param == org_id
|
||||
|
||||
# Return the organization object (simulating successful deletion)
|
||||
return expected_org
|
||||
|
||||
with patch(
|
||||
'storage.org_store.OrgStore.delete_org_cascade', mock_delete_org_cascade
|
||||
):
|
||||
# Act
|
||||
result = await OrgStore.delete_org_cascade(org_id)
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert result.id == org_id
|
||||
assert result.name == 'Test Organization'
|
||||
assert result.contact_name == 'John Doe'
|
||||
assert result.contact_email == 'john@example.com'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_org_cascade_not_found(session_maker):
|
||||
"""
|
||||
GIVEN: Organization ID that doesn't exist
|
||||
WHEN: delete_org_cascade is called
|
||||
THEN: None is returned
|
||||
"""
|
||||
# Arrange
|
||||
non_existent_id = uuid.uuid4()
|
||||
|
||||
with patch('storage.org_store.session_maker', session_maker):
|
||||
# Act
|
||||
result = await OrgStore.delete_org_cascade(non_existent_id)
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_org_cascade_litellm_failure_causes_rollback(
|
||||
session_maker, mock_litellm_api
|
||||
):
|
||||
"""
|
||||
GIVEN: Organization exists but LiteLLM cleanup fails
|
||||
WHEN: delete_org_cascade is called
|
||||
THEN: Transaction is rolled back and organization still exists
|
||||
"""
|
||||
# Arrange
|
||||
org_id = uuid.uuid4()
|
||||
user_id = uuid.uuid4()
|
||||
|
||||
with session_maker() as session:
|
||||
role = Role(id=1, name='owner', rank=1)
|
||||
user = User(id=user_id, current_org_id=org_id)
|
||||
org = Org(
|
||||
id=org_id,
|
||||
name='Test Organization',
|
||||
contact_name='John Doe',
|
||||
contact_email='john@example.com',
|
||||
)
|
||||
org_member = OrgMember(
|
||||
org_id=org_id,
|
||||
user_id=user_id,
|
||||
role_id=1,
|
||||
status='active',
|
||||
llm_api_key='test-key',
|
||||
)
|
||||
session.add_all([role, user, org, org_member])
|
||||
session.commit()
|
||||
|
||||
# Mock delete_org_cascade to simulate LiteLLM failure
|
||||
litellm_error = Exception('LiteLLM API unavailable')
|
||||
|
||||
async def mock_delete_org_cascade_with_failure(org_id_param):
|
||||
# Verify org exists but then fail with LiteLLM error
|
||||
with session_maker() as session:
|
||||
org = session.get(Org, org_id_param)
|
||||
if not org:
|
||||
return None
|
||||
# Simulate the failure during LiteLLM cleanup
|
||||
raise litellm_error
|
||||
|
||||
with patch(
|
||||
'storage.org_store.OrgStore.delete_org_cascade',
|
||||
mock_delete_org_cascade_with_failure,
|
||||
):
|
||||
# Act & Assert
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
await OrgStore.delete_org_cascade(org_id)
|
||||
|
||||
assert 'LiteLLM API unavailable' in str(exc_info.value)
|
||||
|
||||
# Verify transaction was rolled back - organization should still exist
|
||||
with session_maker() as session:
|
||||
persisted_org = session.get(Org, org_id)
|
||||
assert persisted_org is not None
|
||||
assert persisted_org.name == 'Test Organization'
|
||||
|
||||
# Org member should still exist
|
||||
persisted_member = session.query(OrgMember).filter_by(org_id=org_id).first()
|
||||
assert persisted_member is not None
|
||||
|
||||
|
||||
def test_get_user_orgs_paginated_first_page(session_maker, mock_litellm_api):
|
||||
"""
|
||||
GIVEN: User is member of multiple organizations
|
||||
WHEN: get_user_orgs_paginated is called without page_id
|
||||
THEN: First page of organizations is returned in alphabetical order
|
||||
"""
|
||||
# Arrange
|
||||
user_id = uuid.uuid4()
|
||||
other_user_id = uuid.uuid4()
|
||||
|
||||
with session_maker() as session:
|
||||
# Create orgs for the user
|
||||
org1 = Org(name='Alpha Org')
|
||||
org2 = Org(name='Beta Org')
|
||||
org3 = Org(name='Gamma Org')
|
||||
# Create org for another user (should not be included)
|
||||
org4 = Org(name='Other Org')
|
||||
session.add_all([org1, org2, org3, org4])
|
||||
session.flush()
|
||||
|
||||
# Create user and role
|
||||
user = User(id=user_id, current_org_id=org1.id)
|
||||
other_user = User(id=other_user_id, current_org_id=org4.id)
|
||||
role = Role(id=1, name='member', rank=2)
|
||||
session.add_all([user, other_user, role])
|
||||
session.flush()
|
||||
|
||||
# Create memberships
|
||||
member1 = OrgMember(
|
||||
org_id=org1.id, user_id=user_id, role_id=1, llm_api_key='key1'
|
||||
)
|
||||
member2 = OrgMember(
|
||||
org_id=org2.id, user_id=user_id, role_id=1, llm_api_key='key2'
|
||||
)
|
||||
member3 = OrgMember(
|
||||
org_id=org3.id, user_id=user_id, role_id=1, llm_api_key='key3'
|
||||
)
|
||||
other_member = OrgMember(
|
||||
org_id=org4.id, user_id=other_user_id, role_id=1, llm_api_key='key4'
|
||||
)
|
||||
session.add_all([member1, member2, member3, other_member])
|
||||
session.commit()
|
||||
|
||||
# Act
|
||||
with patch('storage.org_store.session_maker', session_maker):
|
||||
orgs, next_page_id = OrgStore.get_user_orgs_paginated(
|
||||
user_id=user_id, page_id=None, limit=2
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(orgs) == 2
|
||||
assert orgs[0].name == 'Alpha Org'
|
||||
assert orgs[1].name == 'Beta Org'
|
||||
assert next_page_id == '2' # Has more results
|
||||
# Verify other user's org is not included
|
||||
org_names = [org.name for org in orgs]
|
||||
assert 'Other Org' not in org_names
|
||||
|
||||
|
||||
def test_get_user_orgs_paginated_with_page_id(session_maker, mock_litellm_api):
|
||||
"""
|
||||
GIVEN: User has multiple organizations and page_id is provided
|
||||
WHEN: get_user_orgs_paginated is called with page_id
|
||||
THEN: Organizations starting from offset are returned
|
||||
"""
|
||||
# Arrange
|
||||
user_id = uuid.uuid4()
|
||||
|
||||
with session_maker() as session:
|
||||
org1 = Org(name='Alpha Org')
|
||||
org2 = Org(name='Beta Org')
|
||||
org3 = Org(name='Gamma Org')
|
||||
session.add_all([org1, org2, org3])
|
||||
session.flush()
|
||||
|
||||
user = User(id=user_id, current_org_id=org1.id)
|
||||
role = Role(id=1, name='member', rank=2)
|
||||
session.add_all([user, role])
|
||||
session.flush()
|
||||
|
||||
member1 = OrgMember(
|
||||
org_id=org1.id, user_id=user_id, role_id=1, llm_api_key='key1'
|
||||
)
|
||||
member2 = OrgMember(
|
||||
org_id=org2.id, user_id=user_id, role_id=1, llm_api_key='key2'
|
||||
)
|
||||
member3 = OrgMember(
|
||||
org_id=org3.id, user_id=user_id, role_id=1, llm_api_key='key3'
|
||||
)
|
||||
session.add_all([member1, member2, member3])
|
||||
session.commit()
|
||||
|
||||
# Act
|
||||
with patch('storage.org_store.session_maker', session_maker):
|
||||
orgs, next_page_id = OrgStore.get_user_orgs_paginated(
|
||||
user_id=user_id, page_id='1', limit=1
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(orgs) == 1
|
||||
assert orgs[0].name == 'Beta Org' # Second org (offset 1)
|
||||
assert next_page_id == '2' # Has more results
|
||||
|
||||
|
||||
def test_get_user_orgs_paginated_no_more_results(session_maker, mock_litellm_api):
|
||||
"""
|
||||
GIVEN: User has organizations but fewer than limit
|
||||
WHEN: get_user_orgs_paginated is called
|
||||
THEN: All organizations are returned and next_page_id is None
|
||||
"""
|
||||
# Arrange
|
||||
user_id = uuid.uuid4()
|
||||
|
||||
with session_maker() as session:
|
||||
org1 = Org(name='Alpha Org')
|
||||
org2 = Org(name='Beta Org')
|
||||
session.add_all([org1, org2])
|
||||
session.flush()
|
||||
|
||||
user = User(id=user_id, current_org_id=org1.id)
|
||||
role = Role(id=1, name='member', rank=2)
|
||||
session.add_all([user, role])
|
||||
session.flush()
|
||||
|
||||
member1 = OrgMember(
|
||||
org_id=org1.id, user_id=user_id, role_id=1, llm_api_key='key1'
|
||||
)
|
||||
member2 = OrgMember(
|
||||
org_id=org2.id, user_id=user_id, role_id=1, llm_api_key='key2'
|
||||
)
|
||||
session.add_all([member1, member2])
|
||||
session.commit()
|
||||
|
||||
# Act
|
||||
with patch('storage.org_store.session_maker', session_maker):
|
||||
orgs, next_page_id = OrgStore.get_user_orgs_paginated(
|
||||
user_id=user_id, page_id=None, limit=10
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(orgs) == 2
|
||||
assert next_page_id is None
|
||||
|
||||
|
||||
def test_get_user_orgs_paginated_invalid_page_id(session_maker, mock_litellm_api):
|
||||
"""
|
||||
GIVEN: Invalid page_id (non-numeric string)
|
||||
WHEN: get_user_orgs_paginated is called
|
||||
THEN: Results start from beginning (offset 0)
|
||||
"""
|
||||
# Arrange
|
||||
user_id = uuid.uuid4()
|
||||
|
||||
with session_maker() as session:
|
||||
org1 = Org(name='Alpha Org')
|
||||
session.add(org1)
|
||||
session.flush()
|
||||
|
||||
user = User(id=user_id, current_org_id=org1.id)
|
||||
role = Role(id=1, name='member', rank=2)
|
||||
session.add_all([user, role])
|
||||
session.flush()
|
||||
|
||||
member1 = OrgMember(
|
||||
org_id=org1.id, user_id=user_id, role_id=1, llm_api_key='key1'
|
||||
)
|
||||
session.add(member1)
|
||||
session.commit()
|
||||
|
||||
# Act
|
||||
with patch('storage.org_store.session_maker', session_maker):
|
||||
orgs, next_page_id = OrgStore.get_user_orgs_paginated(
|
||||
user_id=user_id, page_id='invalid', limit=10
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(orgs) == 1
|
||||
assert orgs[0].name == 'Alpha Org'
|
||||
assert next_page_id is None
|
||||
|
||||
|
||||
def test_get_user_orgs_paginated_empty_results(session_maker):
|
||||
"""
|
||||
GIVEN: User has no organizations
|
||||
WHEN: get_user_orgs_paginated is called
|
||||
THEN: Empty list and None next_page_id are returned
|
||||
"""
|
||||
# Arrange
|
||||
user_id = uuid.uuid4()
|
||||
|
||||
# Act
|
||||
with patch('storage.org_store.session_maker', session_maker):
|
||||
orgs, next_page_id = OrgStore.get_user_orgs_paginated(
|
||||
user_id=user_id, page_id=None, limit=10
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(orgs) == 0
|
||||
assert next_page_id is None
|
||||
|
||||
|
||||
def test_get_user_orgs_paginated_ordering(session_maker, mock_litellm_api):
|
||||
"""
|
||||
GIVEN: User has organizations with different names
|
||||
WHEN: get_user_orgs_paginated is called
|
||||
THEN: Organizations are returned in alphabetical order by name
|
||||
"""
|
||||
# Arrange
|
||||
user_id = uuid.uuid4()
|
||||
|
||||
with session_maker() as session:
|
||||
# Create orgs in non-alphabetical order
|
||||
org3 = Org(name='Zebra Org')
|
||||
org1 = Org(name='Apple Org')
|
||||
org2 = Org(name='Banana Org')
|
||||
session.add_all([org3, org1, org2])
|
||||
session.flush()
|
||||
|
||||
user = User(id=user_id, current_org_id=org1.id)
|
||||
role = Role(id=1, name='member', rank=2)
|
||||
session.add_all([user, role])
|
||||
session.flush()
|
||||
|
||||
member1 = OrgMember(
|
||||
org_id=org1.id, user_id=user_id, role_id=1, llm_api_key='key1'
|
||||
)
|
||||
member2 = OrgMember(
|
||||
org_id=org2.id, user_id=user_id, role_id=1, llm_api_key='key2'
|
||||
)
|
||||
member3 = OrgMember(
|
||||
org_id=org3.id, user_id=user_id, role_id=1, llm_api_key='key3'
|
||||
)
|
||||
session.add_all([member1, member2, member3])
|
||||
session.commit()
|
||||
|
||||
# Act
|
||||
with patch('storage.org_store.session_maker', session_maker):
|
||||
orgs, _ = OrgStore.get_user_orgs_paginated(
|
||||
user_id=user_id, page_id=None, limit=10
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(orgs) == 3
|
||||
assert orgs[0].name == 'Apple Org'
|
||||
assert orgs[1].name == 'Banana Org'
|
||||
assert orgs[2].name == 'Zebra Org'
|
||||
|
||||
@@ -1,36 +1,9 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.pool import StaticPool
|
||||
from storage.base import Base
|
||||
from storage.role import Role
|
||||
from storage.role_store import RoleStore
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def async_engine():
|
||||
"""Create an async SQLite engine for testing."""
|
||||
engine = create_async_engine(
|
||||
'sqlite+aiosqlite:///:memory:',
|
||||
poolclass=StaticPool,
|
||||
connect_args={'check_same_thread': False},
|
||||
echo=False,
|
||||
)
|
||||
|
||||
# Create all tables
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
yield engine
|
||||
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def async_session_maker(async_engine):
|
||||
"""Create an async session maker for testing."""
|
||||
return async_sessionmaker(async_engine, class_=AsyncSession, expire_on_commit=False)
|
||||
# Mock the database module before importing RoleStore
|
||||
with patch('storage.database.engine'), patch('storage.database.a_engine'):
|
||||
from storage.role import Role
|
||||
from storage.role_store import RoleStore
|
||||
|
||||
|
||||
def test_get_role_by_id(session_maker):
|
||||
@@ -108,63 +81,3 @@ def test_create_role(session_maker):
|
||||
assert role.name == 'moderator'
|
||||
assert role.rank == 2
|
||||
assert role.id is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_role_by_name_async_with_session(async_session_maker):
|
||||
"""Test getting role by name asynchronously with an explicit session."""
|
||||
# Create a test role
|
||||
async with async_session_maker() as session:
|
||||
role = Role(name='admin', rank=1)
|
||||
session.add(role)
|
||||
await session.commit()
|
||||
await session.refresh(role)
|
||||
role_id = role.id
|
||||
|
||||
# Test retrieval with explicit session
|
||||
async with async_session_maker() as session:
|
||||
retrieved_role = await RoleStore.get_role_by_name_async(
|
||||
'admin', session=session
|
||||
)
|
||||
assert retrieved_role is not None
|
||||
assert retrieved_role.id == role_id
|
||||
assert retrieved_role.name == 'admin'
|
||||
assert retrieved_role.rank == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_role_by_name_async_without_session(async_session_maker):
|
||||
"""Test getting role by name asynchronously using internal session maker."""
|
||||
# Create a test role
|
||||
async with async_session_maker() as session:
|
||||
role = Role(name='editor', rank=2)
|
||||
session.add(role)
|
||||
await session.commit()
|
||||
await session.refresh(role)
|
||||
role_id = role.id
|
||||
|
||||
# Test retrieval without explicit session (using patched a_session_maker)
|
||||
with patch('storage.role_store.a_session_maker', async_session_maker):
|
||||
retrieved_role = await RoleStore.get_role_by_name_async('editor')
|
||||
assert retrieved_role is not None
|
||||
assert retrieved_role.id == role_id
|
||||
assert retrieved_role.name == 'editor'
|
||||
assert retrieved_role.rank == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_role_by_name_async_not_found_with_session(async_session_maker):
|
||||
"""Test getting role by name when it doesn't exist (with explicit session)."""
|
||||
async with async_session_maker() as session:
|
||||
retrieved_role = await RoleStore.get_role_by_name_async(
|
||||
'nonexistent', session=session
|
||||
)
|
||||
assert retrieved_role is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_role_by_name_async_not_found_without_session(async_session_maker):
|
||||
"""Test getting role by name when it doesn't exist (without explicit session)."""
|
||||
with patch('storage.role_store.a_session_maker', async_session_maker):
|
||||
retrieved_role = await RoleStore.get_role_by_name_async('nonexistent')
|
||||
assert retrieved_role is None
|
||||
|
||||
+1
-262
@@ -2,7 +2,7 @@
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from typing import AsyncGenerator
|
||||
from uuid import UUID, uuid4
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from server.sharing.shared_conversation_models import (
|
||||
@@ -13,9 +13,6 @@ from server.sharing.sql_shared_conversation_info_service import (
|
||||
)
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.pool import StaticPool
|
||||
from storage.org import Org
|
||||
from storage.stored_conversation_metadata_saas import StoredConversationMetadataSaas
|
||||
from storage.user import User
|
||||
|
||||
from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
AppConversationInfo,
|
||||
@@ -431,261 +428,3 @@ class TestSharedConversationInfoService:
|
||||
page1_ids = {item.id for item in result.items}
|
||||
page2_ids = {item.id for item in result2.items}
|
||||
assert page1_ids.isdisjoint(page2_ids)
|
||||
|
||||
|
||||
class TestSharedConversationInfoServiceWithSaasMetadata:
|
||||
"""Test cases for SharedConversationInfoService with SAAS metadata.
|
||||
|
||||
These tests verify that created_by_user_id is correctly retrieved from
|
||||
the conversation_metadata_saas table when it exists.
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
async def async_engine_with_saas(self):
|
||||
"""Create an async SQLite engine with all SAAS tables."""
|
||||
engine = create_async_engine(
|
||||
'sqlite+aiosqlite:///:memory:',
|
||||
poolclass=StaticPool,
|
||||
connect_args={'check_same_thread': False},
|
||||
echo=False,
|
||||
)
|
||||
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
yield engine
|
||||
await engine.dispose()
|
||||
|
||||
@pytest.fixture
|
||||
async def async_session_with_saas(
|
||||
self, async_engine_with_saas
|
||||
) -> AsyncGenerator[AsyncSession, None]:
|
||||
"""Create an async session for testing with SAAS tables."""
|
||||
async_session_maker = async_sessionmaker(
|
||||
async_engine_with_saas, class_=AsyncSession, expire_on_commit=False
|
||||
)
|
||||
|
||||
async with async_session_maker() as db_session:
|
||||
yield db_session
|
||||
|
||||
@pytest.fixture
|
||||
async def test_org(self, async_session_with_saas) -> Org:
|
||||
"""Create a test organization."""
|
||||
org = Org(id=uuid4(), name=f'test_org_{uuid4().hex[:8]}')
|
||||
async_session_with_saas.add(org)
|
||||
await async_session_with_saas.commit()
|
||||
return org
|
||||
|
||||
@pytest.fixture
|
||||
async def test_user(self, async_session_with_saas, test_org) -> User:
|
||||
"""Create a test user belonging to the test organization."""
|
||||
user = User(id=uuid4(), current_org_id=test_org.id)
|
||||
async_session_with_saas.add(user)
|
||||
await async_session_with_saas.commit()
|
||||
return user
|
||||
|
||||
@pytest.fixture
|
||||
async def shared_service_with_saas(self, async_session_with_saas):
|
||||
"""Create a SharedConversationInfoService for testing."""
|
||||
return SQLSharedConversationInfoService(db_session=async_session_with_saas)
|
||||
|
||||
@pytest.fixture
|
||||
async def app_service_with_saas(self, async_session_with_saas):
|
||||
"""Create an AppConversationInfoService for creating test data."""
|
||||
return SQLAppConversationInfoService(
|
||||
db_session=async_session_with_saas,
|
||||
user_context=SpecifyUserContext(user_id=None),
|
||||
)
|
||||
|
||||
async def _create_saas_metadata(
|
||||
self,
|
||||
db_session: AsyncSession,
|
||||
conversation_id: UUID,
|
||||
user_id: UUID,
|
||||
org_id: UUID,
|
||||
) -> StoredConversationMetadataSaas:
|
||||
"""Helper to create SAAS metadata for a conversation."""
|
||||
saas_metadata = StoredConversationMetadataSaas(
|
||||
conversation_id=str(conversation_id),
|
||||
user_id=user_id,
|
||||
org_id=org_id,
|
||||
)
|
||||
db_session.add(saas_metadata)
|
||||
await db_session.commit()
|
||||
return saas_metadata
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_shared_conversation_returns_user_id_from_saas_metadata(
|
||||
self,
|
||||
shared_service_with_saas,
|
||||
app_service_with_saas,
|
||||
async_session_with_saas,
|
||||
test_user,
|
||||
test_org,
|
||||
):
|
||||
"""Test that get_shared_conversation_info returns created_by_user_id from SAAS metadata."""
|
||||
# Arrange
|
||||
conversation_id = uuid4()
|
||||
conversation = AppConversationInfo(
|
||||
id=conversation_id,
|
||||
created_by_user_id=None,
|
||||
sandbox_id='test_sandbox',
|
||||
title='Public Conversation With User',
|
||||
public=True,
|
||||
metrics=MetricsSnapshot(
|
||||
accumulated_cost=0.0,
|
||||
max_budget_per_task=10.0,
|
||||
accumulated_token_usage=TokenUsage(),
|
||||
),
|
||||
)
|
||||
await app_service_with_saas.save_app_conversation_info(conversation)
|
||||
await self._create_saas_metadata(
|
||||
async_session_with_saas, conversation_id, test_user.id, test_org.id
|
||||
)
|
||||
|
||||
# Act
|
||||
result = await shared_service_with_saas.get_shared_conversation_info(
|
||||
conversation_id
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert result.created_by_user_id == str(test_user.id)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_shared_conversations_returns_user_id_from_saas_metadata(
|
||||
self,
|
||||
shared_service_with_saas,
|
||||
app_service_with_saas,
|
||||
async_session_with_saas,
|
||||
test_user,
|
||||
test_org,
|
||||
):
|
||||
"""Test that search_shared_conversation_info returns created_by_user_id from SAAS metadata."""
|
||||
# Arrange
|
||||
conversation_id = uuid4()
|
||||
conversation = AppConversationInfo(
|
||||
id=conversation_id,
|
||||
created_by_user_id=None,
|
||||
sandbox_id='test_sandbox_search',
|
||||
title='Searchable Public Conversation',
|
||||
public=True,
|
||||
metrics=MetricsSnapshot(
|
||||
accumulated_cost=0.0,
|
||||
max_budget_per_task=10.0,
|
||||
accumulated_token_usage=TokenUsage(),
|
||||
),
|
||||
)
|
||||
await app_service_with_saas.save_app_conversation_info(conversation)
|
||||
await self._create_saas_metadata(
|
||||
async_session_with_saas, conversation_id, test_user.id, test_org.id
|
||||
)
|
||||
|
||||
# Act
|
||||
result = await shared_service_with_saas.search_shared_conversation_info()
|
||||
|
||||
# Assert
|
||||
assert len(result.items) == 1
|
||||
assert result.items[0].created_by_user_id == str(test_user.id)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_batch_get_shared_conversations_returns_user_id_from_saas_metadata(
|
||||
self,
|
||||
shared_service_with_saas,
|
||||
app_service_with_saas,
|
||||
async_session_with_saas,
|
||||
test_user,
|
||||
test_org,
|
||||
):
|
||||
"""Test that batch_get_shared_conversation_info returns created_by_user_id from SAAS metadata."""
|
||||
# Arrange
|
||||
conversation_id = uuid4()
|
||||
conversation = AppConversationInfo(
|
||||
id=conversation_id,
|
||||
created_by_user_id=None,
|
||||
sandbox_id='test_sandbox_batch',
|
||||
title='Batch Get Conversation',
|
||||
public=True,
|
||||
metrics=MetricsSnapshot(
|
||||
accumulated_cost=0.0,
|
||||
max_budget_per_task=10.0,
|
||||
accumulated_token_usage=TokenUsage(),
|
||||
),
|
||||
)
|
||||
await app_service_with_saas.save_app_conversation_info(conversation)
|
||||
await self._create_saas_metadata(
|
||||
async_session_with_saas, conversation_id, test_user.id, test_org.id
|
||||
)
|
||||
|
||||
# Act
|
||||
result = await shared_service_with_saas.batch_get_shared_conversation_info(
|
||||
[conversation_id]
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(result) == 1
|
||||
assert result[0] is not None
|
||||
assert result[0].created_by_user_id == str(test_user.id)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mixed_conversations_with_and_without_saas_metadata(
|
||||
self,
|
||||
shared_service_with_saas,
|
||||
app_service_with_saas,
|
||||
async_session_with_saas,
|
||||
test_user,
|
||||
test_org,
|
||||
):
|
||||
"""Test handling of conversations where some have SAAS metadata and some don't."""
|
||||
# Arrange
|
||||
conv_with_saas_id = uuid4()
|
||||
conv_without_saas_id = uuid4()
|
||||
|
||||
conv_with_saas = AppConversationInfo(
|
||||
id=conv_with_saas_id,
|
||||
created_by_user_id=None,
|
||||
sandbox_id='sandbox_with_saas',
|
||||
title='With SAAS Metadata',
|
||||
created_at=datetime(2023, 1, 2, tzinfo=UTC),
|
||||
updated_at=datetime(2023, 1, 2, tzinfo=UTC),
|
||||
public=True,
|
||||
metrics=MetricsSnapshot(
|
||||
accumulated_cost=0.0,
|
||||
max_budget_per_task=10.0,
|
||||
accumulated_token_usage=TokenUsage(),
|
||||
),
|
||||
)
|
||||
conv_without_saas = AppConversationInfo(
|
||||
id=conv_without_saas_id,
|
||||
created_by_user_id=None,
|
||||
sandbox_id='sandbox_without_saas',
|
||||
title='Without SAAS Metadata',
|
||||
created_at=datetime(2023, 1, 1, tzinfo=UTC),
|
||||
updated_at=datetime(2023, 1, 1, tzinfo=UTC),
|
||||
public=True,
|
||||
metrics=MetricsSnapshot(
|
||||
accumulated_cost=0.0,
|
||||
max_budget_per_task=10.0,
|
||||
accumulated_token_usage=TokenUsage(),
|
||||
),
|
||||
)
|
||||
|
||||
await app_service_with_saas.save_app_conversation_info(conv_with_saas)
|
||||
await app_service_with_saas.save_app_conversation_info(conv_without_saas)
|
||||
await self._create_saas_metadata(
|
||||
async_session_with_saas, conv_with_saas_id, test_user.id, test_org.id
|
||||
)
|
||||
|
||||
# Act
|
||||
result = await shared_service_with_saas.search_shared_conversation_info(
|
||||
sort_order=SharedConversationSortOrder.CREATED_AT
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(result.items) == 2
|
||||
conv_without = next(
|
||||
item for item in result.items if item.id == conv_without_saas_id
|
||||
)
|
||||
conv_with = next(item for item in result.items if item.id == conv_with_saas_id)
|
||||
assert conv_without.created_by_user_id is None
|
||||
assert conv_with.created_by_user_id == str(test_user.id)
|
||||
|
||||
@@ -149,18 +149,6 @@ def test_infer_repo_from_message():
|
||||
('https://github.com/My-User/My-Repo.git', ['My-User/My-Repo']),
|
||||
('Check the my.user/my.repo repository', ['my.user/my.repo']),
|
||||
('repos: user_1/repo-1 and user.2/repo_2', ['user_1/repo-1', 'user.2/repo_2']),
|
||||
# Backtick-wrapped repo mentions (common in Slack/Discord messages)
|
||||
(
|
||||
'@openhands-exp just echo hello world in `OpenHands/OpenHands-CLI` repository',
|
||||
['OpenHands/OpenHands-CLI'],
|
||||
),
|
||||
(
|
||||
'@openhands-exp echo hello world with {{OpenHands/OpenHands-CLI}}',
|
||||
['OpenHands/OpenHands-CLI'],
|
||||
),
|
||||
('Deploy the `test/project` repo', ['test/project']),
|
||||
# Colon-wrapped repo mentions
|
||||
('Check the :owner/repo: here', ['owner/repo']),
|
||||
# Large number of repositories
|
||||
('Repos: a/b, c/d, e/f, g/h, i/j', ['a/b', 'c/d', 'e/f', 'g/h', 'i/j']),
|
||||
# Mixed with false positives that should be filtered
|
||||
|
||||
@@ -1,159 +0,0 @@
|
||||
import { describe, expect, it, vi, beforeEach, afterEach } from "vitest";
|
||||
import { screen } from "@testing-library/react";
|
||||
import userEvent from "@testing-library/user-event";
|
||||
import { renderWithProviders } from "test-utils";
|
||||
import { PlanPreview } from "#/components/features/chat/plan-preview";
|
||||
|
||||
// Mock the feature flag to always return true (not testing feature flag behavior)
|
||||
vi.mock("#/utils/feature-flags", () => ({
|
||||
USE_PLANNING_AGENT: vi.fn(() => true),
|
||||
}));
|
||||
|
||||
// Mock i18n - need to preserve initReactI18next and I18nextProvider for test-utils
|
||||
vi.mock("react-i18next", async (importOriginal) => {
|
||||
const actual = await importOriginal<typeof import("react-i18next")>();
|
||||
return {
|
||||
...actual,
|
||||
useTranslation: () => ({
|
||||
t: (key: string) => key,
|
||||
}),
|
||||
};
|
||||
});
|
||||
|
||||
describe("PlanPreview", () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
it("should render nothing when planContent is null", () => {
|
||||
renderWithProviders(<PlanPreview planContent={null} />);
|
||||
|
||||
const contentDiv = screen.getByTestId("plan-preview-content");
|
||||
expect(contentDiv).toBeInTheDocument();
|
||||
expect(contentDiv.textContent?.trim() || "").toBe("");
|
||||
});
|
||||
|
||||
it("should render nothing when planContent is undefined", () => {
|
||||
renderWithProviders(<PlanPreview planContent={undefined} />);
|
||||
|
||||
const contentDiv = screen.getByTestId("plan-preview-content");
|
||||
expect(contentDiv).toBeInTheDocument();
|
||||
expect(contentDiv.textContent?.trim() || "").toBe("");
|
||||
});
|
||||
|
||||
it("should render markdown content when planContent is provided", () => {
|
||||
const planContent = "# Plan Title\n\nThis is the plan content.";
|
||||
|
||||
const { container } = renderWithProviders(
|
||||
<PlanPreview planContent={planContent} />,
|
||||
);
|
||||
|
||||
// Check that component rendered and contains the content (markdown may break up text)
|
||||
expect(container.firstChild).not.toBeNull();
|
||||
expect(container.textContent).toContain("Plan Title");
|
||||
expect(container.textContent).toContain("This is the plan content.");
|
||||
});
|
||||
|
||||
it("should render full content when length is less than or equal to 300 characters", () => {
|
||||
const planContent = "A".repeat(300);
|
||||
|
||||
const { container } = renderWithProviders(
|
||||
<PlanPreview planContent={planContent} />,
|
||||
);
|
||||
|
||||
// Content should be present (may be broken up by markdown)
|
||||
expect(container.textContent).toContain(planContent);
|
||||
expect(screen.queryByText(/COMMON\$READ_MORE/i)).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should truncate content when length exceeds 300 characters", () => {
|
||||
const longContent = "A".repeat(350);
|
||||
|
||||
const { container } = renderWithProviders(
|
||||
<PlanPreview planContent={longContent} />,
|
||||
);
|
||||
|
||||
// Truncated content should be present (may be broken up by markdown)
|
||||
expect(container.textContent).toContain("A".repeat(300));
|
||||
expect(container.textContent).toContain("...");
|
||||
expect(container.textContent).toContain("COMMON$READ_MORE");
|
||||
});
|
||||
|
||||
it("should call onViewClick when View button is clicked", async () => {
|
||||
const user = userEvent.setup();
|
||||
const onViewClick = vi.fn();
|
||||
|
||||
renderWithProviders(
|
||||
<PlanPreview planContent="Plan content" onViewClick={onViewClick} />,
|
||||
);
|
||||
|
||||
const viewButton = screen.getByTestId("plan-preview-view-button");
|
||||
expect(viewButton).toBeInTheDocument();
|
||||
|
||||
await user.click(viewButton);
|
||||
|
||||
expect(onViewClick).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
it("should call onViewClick when Read More button is clicked", async () => {
|
||||
const user = userEvent.setup();
|
||||
const onViewClick = vi.fn();
|
||||
const longContent = "A".repeat(350);
|
||||
|
||||
renderWithProviders(
|
||||
<PlanPreview planContent={longContent} onViewClick={onViewClick} />,
|
||||
);
|
||||
|
||||
const readMoreButton = screen.getByTestId("plan-preview-read-more-button");
|
||||
expect(readMoreButton).toBeInTheDocument();
|
||||
|
||||
await user.click(readMoreButton);
|
||||
|
||||
expect(onViewClick).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
it("should call onBuildClick when Build button is clicked", async () => {
|
||||
const user = userEvent.setup();
|
||||
const onBuildClick = vi.fn();
|
||||
|
||||
renderWithProviders(
|
||||
<PlanPreview planContent="Plan content" onBuildClick={onBuildClick} />,
|
||||
);
|
||||
|
||||
const buildButton = screen.getByTestId("plan-preview-build-button");
|
||||
expect(buildButton).toBeInTheDocument();
|
||||
|
||||
await user.click(buildButton);
|
||||
|
||||
expect(onBuildClick).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
it("should render header with PLAN_MD text", () => {
|
||||
const { container } = renderWithProviders(
|
||||
<PlanPreview planContent="Plan content" />,
|
||||
);
|
||||
|
||||
// Check that the translation key is rendered (i18n mock returns the key)
|
||||
expect(container.textContent).toContain("COMMON$PLAN_MD");
|
||||
});
|
||||
|
||||
it("should render plan content", () => {
|
||||
const planContent = `# Heading 1
|
||||
## Heading 2
|
||||
- List item 1
|
||||
- List item 2
|
||||
|
||||
**Bold text** and *italic text*`;
|
||||
|
||||
const { container } = renderWithProviders(
|
||||
<PlanPreview planContent={planContent} />,
|
||||
);
|
||||
|
||||
expect(container.textContent).toContain("Heading 1");
|
||||
expect(container.textContent).toContain("Heading 2");
|
||||
});
|
||||
});
|
||||
@@ -1,186 +0,0 @@
|
||||
import { render, screen, waitFor } from "@testing-library/react";
|
||||
import { describe, expect, vi, beforeEach, it } from "vitest";
|
||||
import { QueryClient, QueryClientProvider } from "@tanstack/react-query";
|
||||
import userEvent from "@testing-library/user-event";
|
||||
import { GitBranchDropdown } from "../../../../src/components/features/home/git-branch-dropdown/git-branch-dropdown";
|
||||
import { Branch } from "#/types/git";
|
||||
|
||||
// Mock the branch data hook
|
||||
const mockUseBranchData = vi.fn();
|
||||
vi.mock("#/hooks/query/use-branch-data", () => ({
|
||||
useBranchData: (...args: unknown[]) => mockUseBranchData(...args),
|
||||
}));
|
||||
|
||||
const MOCK_BRANCHES: Branch[] = [
|
||||
{ name: "main", commit_sha: "abc123", protected: true },
|
||||
{ name: "develop", commit_sha: "def456", protected: false },
|
||||
{ name: "feature/test", commit_sha: "ghi789", protected: false },
|
||||
];
|
||||
|
||||
const mockOnBranchSelect = vi.fn();
|
||||
|
||||
const renderDropdown = (
|
||||
props: Partial<Parameters<typeof GitBranchDropdown>[0]> = {},
|
||||
) => {
|
||||
// Default mock return value
|
||||
mockUseBranchData.mockReturnValue({
|
||||
branches: MOCK_BRANCHES,
|
||||
isLoading: false,
|
||||
isError: false,
|
||||
fetchNextPage: vi.fn(),
|
||||
hasNextPage: false,
|
||||
isFetchingNextPage: false,
|
||||
isSearchLoading: false,
|
||||
});
|
||||
|
||||
return render(
|
||||
<GitBranchDropdown
|
||||
repository="user/repo"
|
||||
provider="github"
|
||||
selectedBranch={null}
|
||||
onBranchSelect={mockOnBranchSelect}
|
||||
// eslint-disable-next-line react/jsx-props-no-spreading
|
||||
{...props}
|
||||
/>,
|
||||
{
|
||||
wrapper: ({ children }) => (
|
||||
<QueryClientProvider
|
||||
client={
|
||||
new QueryClient({
|
||||
defaultOptions: {
|
||||
queries: {
|
||||
retry: false,
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
>
|
||||
{children}
|
||||
</QueryClientProvider>
|
||||
),
|
||||
},
|
||||
);
|
||||
};
|
||||
|
||||
describe("GitBranchDropdown", () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
describe("dropdown behavior", () => {
|
||||
it("should open dropdown when input is clicked", async () => {
|
||||
renderDropdown();
|
||||
|
||||
const input = screen.getByTestId("git-branch-dropdown-input");
|
||||
await userEvent.click(input);
|
||||
|
||||
// Dropdown should be open (menu should be visible)
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
screen.getByTestId("git-branch-dropdown-menu"),
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
it("should keep dropdown open when clicking input while already open", async () => {
|
||||
renderDropdown();
|
||||
|
||||
const input = screen.getByTestId("git-branch-dropdown-input");
|
||||
|
||||
// First click - open dropdown
|
||||
await userEvent.click(input);
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
screen.getByTestId("git-branch-dropdown-menu"),
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
|
||||
// Second click on input - should stay open (not close)
|
||||
await userEvent.click(input);
|
||||
|
||||
// Dropdown should still be open
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
screen.getByTestId("git-branch-dropdown-menu"),
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
it("should preserve typed text when clicking input while typing", async () => {
|
||||
renderDropdown();
|
||||
|
||||
const input = screen.getByTestId(
|
||||
"git-branch-dropdown-input",
|
||||
) as HTMLInputElement;
|
||||
|
||||
// Click to open and type
|
||||
await userEvent.click(input);
|
||||
await userEvent.type(input, "feat");
|
||||
|
||||
expect(input.value).toBe("feat");
|
||||
|
||||
// Click on input again (should not reset text)
|
||||
await userEvent.click(input);
|
||||
|
||||
// Text should be preserved
|
||||
expect(input.value).toBe("feat");
|
||||
});
|
||||
});
|
||||
|
||||
describe("cursor position preservation", () => {
|
||||
it("should allow editing in the middle of input text", async () => {
|
||||
renderDropdown();
|
||||
|
||||
const input = screen.getByTestId(
|
||||
"git-branch-dropdown-input",
|
||||
) as HTMLInputElement;
|
||||
|
||||
// Click and type initial text
|
||||
await userEvent.click(input);
|
||||
await userEvent.type(input, "hello");
|
||||
|
||||
expect(input.value).toBe("hello");
|
||||
|
||||
// Move cursor to position 2 and type
|
||||
input.setSelectionRange(2, 2);
|
||||
await userEvent.type(input, "X");
|
||||
|
||||
// The character should be inserted (exact position may vary based on browser behavior)
|
||||
expect(input.value).toContain("X");
|
||||
});
|
||||
});
|
||||
|
||||
describe("input synchronization", () => {
|
||||
it("should show selected branch name in input when provided", async () => {
|
||||
const selectedBranch = MOCK_BRANCHES[0];
|
||||
renderDropdown({ selectedBranch });
|
||||
|
||||
const input = screen.getByTestId(
|
||||
"git-branch-dropdown-input",
|
||||
) as HTMLInputElement;
|
||||
|
||||
await waitFor(() => {
|
||||
expect(input.value).toBe(selectedBranch.name);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe("branch selection", () => {
|
||||
it("should call onBranchSelect when a branch is selected", async () => {
|
||||
renderDropdown();
|
||||
|
||||
const input = screen.getByTestId("git-branch-dropdown-input");
|
||||
await userEvent.click(input);
|
||||
|
||||
// Wait for dropdown to open and show branches
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText("main")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
// Click on a branch
|
||||
await userEvent.click(screen.getByText("develop"));
|
||||
|
||||
expect(mockOnBranchSelect).toHaveBeenCalledWith(MOCK_BRANCHES[1]);
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -1,234 +0,0 @@
|
||||
import { render, screen, waitFor } from "@testing-library/react";
|
||||
import { describe, expect, vi, beforeEach, it } from "vitest";
|
||||
import { QueryClient, QueryClientProvider } from "@tanstack/react-query";
|
||||
import userEvent from "@testing-library/user-event";
|
||||
import { GitRepoDropdown } from "../../../../src/components/features/home/git-repo-dropdown/git-repo-dropdown";
|
||||
import { GitRepository } from "#/types/git";
|
||||
|
||||
// Mock the repository data hook
|
||||
const mockUseRepositoryData = vi.fn();
|
||||
vi.mock(
|
||||
"#/components/features/home/git-repo-dropdown/use-repository-data",
|
||||
() => ({
|
||||
useRepositoryData: (...args: unknown[]) => mockUseRepositoryData(...args),
|
||||
}),
|
||||
);
|
||||
|
||||
// Mock the URL search hook
|
||||
const mockUseUrlSearch = vi.fn();
|
||||
vi.mock("#/components/features/home/git-repo-dropdown/use-url-search", () => ({
|
||||
useUrlSearch: (...args: unknown[]) => mockUseUrlSearch(...args),
|
||||
}));
|
||||
|
||||
// Mock useConfig
|
||||
vi.mock("#/hooks/query/use-config", () => ({
|
||||
useConfig: () => ({ data: null }),
|
||||
}));
|
||||
|
||||
// Mock useHomeStore
|
||||
vi.mock("#/stores/home-store", () => ({
|
||||
useHomeStore: () => ({ recentRepositories: [] }),
|
||||
}));
|
||||
|
||||
const MOCK_REPOSITORIES: GitRepository[] = [
|
||||
{
|
||||
id: "1",
|
||||
full_name: "user/repo-one",
|
||||
git_provider: "github",
|
||||
is_public: true,
|
||||
},
|
||||
{
|
||||
id: "2",
|
||||
full_name: "user/repo-two",
|
||||
git_provider: "github",
|
||||
is_public: true,
|
||||
},
|
||||
{
|
||||
id: "3",
|
||||
full_name: "org/feature-repo",
|
||||
git_provider: "github",
|
||||
is_public: false,
|
||||
},
|
||||
];
|
||||
|
||||
const mockOnChange = vi.fn();
|
||||
|
||||
const setupDefaultMocks = (
|
||||
repositoryDataOverrides: Partial<
|
||||
ReturnType<typeof mockUseRepositoryData>
|
||||
> = {},
|
||||
) => {
|
||||
mockUseRepositoryData.mockReturnValue({
|
||||
repositories: MOCK_REPOSITORIES,
|
||||
selectedRepository: null,
|
||||
isLoading: false,
|
||||
isError: false,
|
||||
fetchNextPage: vi.fn(),
|
||||
hasNextPage: false,
|
||||
isFetchingNextPage: false,
|
||||
isSearchLoading: false,
|
||||
...repositoryDataOverrides,
|
||||
});
|
||||
|
||||
mockUseUrlSearch.mockReturnValue({
|
||||
urlSearchResults: [],
|
||||
isUrlSearchLoading: false,
|
||||
});
|
||||
};
|
||||
|
||||
const renderDropdown = (
|
||||
props: Partial<Parameters<typeof GitRepoDropdown>[0]> = {},
|
||||
repositoryDataOverrides: Partial<
|
||||
ReturnType<typeof mockUseRepositoryData>
|
||||
> = {},
|
||||
) => {
|
||||
// Set up mocks with optional overrides
|
||||
setupDefaultMocks(repositoryDataOverrides);
|
||||
|
||||
return render(
|
||||
<GitRepoDropdown
|
||||
provider="github"
|
||||
onChange={mockOnChange}
|
||||
// eslint-disable-next-line react/jsx-props-no-spreading
|
||||
{...props}
|
||||
/>,
|
||||
{
|
||||
wrapper: ({ children }) => (
|
||||
<QueryClientProvider
|
||||
client={
|
||||
new QueryClient({
|
||||
defaultOptions: {
|
||||
queries: {
|
||||
retry: false,
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
>
|
||||
{children}
|
||||
</QueryClientProvider>
|
||||
),
|
||||
},
|
||||
);
|
||||
};
|
||||
|
||||
describe("GitRepoDropdown", () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
describe("dropdown behavior", () => {
|
||||
it("should open dropdown when input is clicked", async () => {
|
||||
renderDropdown();
|
||||
|
||||
const input = screen.getByTestId("git-repo-dropdown");
|
||||
await userEvent.click(input);
|
||||
|
||||
// Dropdown should be open (menu should be visible)
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
screen.getByTestId("git-repo-dropdown-menu"),
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
it("should keep dropdown open when clicking input while already open", async () => {
|
||||
renderDropdown();
|
||||
|
||||
const input = screen.getByTestId("git-repo-dropdown");
|
||||
|
||||
// First click - open dropdown
|
||||
await userEvent.click(input);
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
screen.getByTestId("git-repo-dropdown-menu"),
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
|
||||
// Second click on input - should stay open (not close)
|
||||
await userEvent.click(input);
|
||||
|
||||
// Dropdown should still be open
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
screen.getByTestId("git-repo-dropdown-menu"),
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
it("should preserve typed text when clicking input while typing", async () => {
|
||||
renderDropdown();
|
||||
|
||||
const input = screen.getByTestId("git-repo-dropdown") as HTMLInputElement;
|
||||
|
||||
// Click to open and type
|
||||
await userEvent.click(input);
|
||||
await userEvent.type(input, "repo");
|
||||
|
||||
expect(input.value).toBe("repo");
|
||||
|
||||
// Click on input again (should not reset text)
|
||||
await userEvent.click(input);
|
||||
|
||||
// Text should be preserved
|
||||
expect(input.value).toBe("repo");
|
||||
});
|
||||
});
|
||||
|
||||
describe("cursor position preservation", () => {
|
||||
it("should allow editing in the middle of input text", async () => {
|
||||
renderDropdown();
|
||||
|
||||
const input = screen.getByTestId("git-repo-dropdown") as HTMLInputElement;
|
||||
|
||||
// Click and type initial text
|
||||
await userEvent.click(input);
|
||||
await userEvent.type(input, "hello");
|
||||
|
||||
expect(input.value).toBe("hello");
|
||||
|
||||
// Move cursor to position 2 and type
|
||||
input.setSelectionRange(2, 2);
|
||||
await userEvent.type(input, "X");
|
||||
|
||||
// The character should be inserted (exact position may vary based on browser behavior)
|
||||
expect(input.value).toContain("X");
|
||||
});
|
||||
});
|
||||
|
||||
describe("input synchronization", () => {
|
||||
it("should show selected repository name in input when provided", async () => {
|
||||
const selectedRepository = MOCK_REPOSITORIES[0];
|
||||
|
||||
renderDropdown(
|
||||
{ value: selectedRepository.full_name },
|
||||
{ selectedRepository },
|
||||
);
|
||||
|
||||
const input = screen.getByTestId("git-repo-dropdown") as HTMLInputElement;
|
||||
|
||||
await waitFor(() => {
|
||||
expect(input.value).toBe(selectedRepository.full_name);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe("repository selection", () => {
|
||||
it("should call onChange when a repository is selected", async () => {
|
||||
renderDropdown();
|
||||
|
||||
const input = screen.getByTestId("git-repo-dropdown");
|
||||
await userEvent.click(input);
|
||||
|
||||
// Wait for dropdown to open and show repositories
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText("user/repo-one")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
// Click on a repository
|
||||
await userEvent.click(screen.getByText("user/repo-two"));
|
||||
|
||||
expect(mockOnChange).toHaveBeenCalledWith(MOCK_REPOSITORIES[1]);
|
||||
});
|
||||
});
|
||||
});
|
||||
-35
@@ -1,35 +0,0 @@
|
||||
import { describe, expect, it } from "vitest";
|
||||
import { shouldRenderEvent } from "#/components/v1/chat/event-content-helpers/should-render-event";
|
||||
import {
|
||||
createPlanningFileEditorActionEvent,
|
||||
createOtherActionEvent,
|
||||
createPlanningObservationEvent,
|
||||
createUserMessageEvent,
|
||||
} from "test-utils";
|
||||
|
||||
describe("shouldRenderEvent - PlanningFileEditorAction", () => {
|
||||
it("should return false for PlanningFileEditorAction", () => {
|
||||
const event = createPlanningFileEditorActionEvent("action-1");
|
||||
|
||||
expect(shouldRenderEvent(event)).toBe(false);
|
||||
});
|
||||
|
||||
it("should return true for other action types", () => {
|
||||
const event = createOtherActionEvent("action-1");
|
||||
|
||||
expect(shouldRenderEvent(event)).toBe(true);
|
||||
});
|
||||
|
||||
it("should return true for PlanningFileEditorObservation", () => {
|
||||
const event = createPlanningObservationEvent("obs-1");
|
||||
|
||||
// Observations should still render (they're handled separately in event-message)
|
||||
expect(shouldRenderEvent(event)).toBe(true);
|
||||
});
|
||||
|
||||
it("should return true for user message events", () => {
|
||||
const event = createUserMessageEvent("msg-1");
|
||||
|
||||
expect(shouldRenderEvent(event)).toBe(true);
|
||||
});
|
||||
});
|
||||
@@ -1,159 +0,0 @@
|
||||
import { describe, expect, it, vi, beforeEach, afterEach } from "vitest";
|
||||
import { screen, render } from "@testing-library/react";
|
||||
import { EventMessage } from "#/components/v1/chat/event-message";
|
||||
import { useConversationStore } from "#/stores/conversation-store";
|
||||
import {
|
||||
renderWithProviders,
|
||||
createPlanningObservationEvent,
|
||||
} from "test-utils";
|
||||
|
||||
// Mock the feature flag
|
||||
vi.mock("#/utils/feature-flags", () => ({
|
||||
USE_PLANNING_AGENT: vi.fn(() => true),
|
||||
}));
|
||||
|
||||
// Mock useConfig
|
||||
vi.mock("#/hooks/query/use-config", () => ({
|
||||
useConfig: () => ({
|
||||
data: { APP_MODE: "saas" },
|
||||
}),
|
||||
}));
|
||||
|
||||
// Mock PlanPreview component to verify it's rendered
|
||||
vi.mock("#/components/features/chat/plan-preview", () => ({
|
||||
PlanPreview: ({ planContent }: { planContent?: string | null }) => (
|
||||
<div data-testid="plan-preview">Plan Preview: {planContent || "null"}</div>
|
||||
),
|
||||
}));
|
||||
|
||||
describe("EventMessage - PlanPreview rendering", () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
// Reset conversation store
|
||||
useConversationStore.setState({
|
||||
planContent: null,
|
||||
});
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
it("should render PlanPreview when PlanningFileEditorObservation event ID is in planPreviewEventIds", () => {
|
||||
const event = createPlanningObservationEvent("plan-obs-1");
|
||||
const planPreviewEventIds = new Set(["plan-obs-1"]);
|
||||
const planContent = "This is the plan content";
|
||||
|
||||
useConversationStore.setState({ planContent });
|
||||
|
||||
renderWithProviders(
|
||||
<EventMessage
|
||||
event={event}
|
||||
messages={[]}
|
||||
isLastMessage={false}
|
||||
isInLast10Actions={false}
|
||||
planPreviewEventIds={planPreviewEventIds}
|
||||
/>,
|
||||
);
|
||||
|
||||
expect(screen.getByTestId("plan-preview")).toBeInTheDocument();
|
||||
expect(
|
||||
screen.getByText(`Plan Preview: ${planContent}`),
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should return null when PlanningFileEditorObservation event ID is NOT in planPreviewEventIds", () => {
|
||||
const event = createPlanningObservationEvent("plan-obs-1");
|
||||
const planPreviewEventIds = new Set(["plan-obs-2"]); // Different ID
|
||||
|
||||
const { container } = renderWithProviders(
|
||||
<EventMessage
|
||||
event={event}
|
||||
messages={[]}
|
||||
isLastMessage={false}
|
||||
isInLast10Actions={false}
|
||||
planPreviewEventIds={planPreviewEventIds}
|
||||
/>,
|
||||
);
|
||||
|
||||
expect(screen.queryByTestId("plan-preview")).not.toBeInTheDocument();
|
||||
expect(container.firstChild).toBeNull();
|
||||
});
|
||||
|
||||
it("should return null when planPreviewEventIds is undefined", () => {
|
||||
const event = createPlanningObservationEvent("plan-obs-1");
|
||||
|
||||
const { container } = renderWithProviders(
|
||||
<EventMessage
|
||||
event={event}
|
||||
messages={[]}
|
||||
isLastMessage={false}
|
||||
isInLast10Actions={false}
|
||||
planPreviewEventIds={undefined}
|
||||
/>,
|
||||
);
|
||||
|
||||
expect(screen.queryByTestId("plan-preview")).not.toBeInTheDocument();
|
||||
expect(container.firstChild).toBeNull();
|
||||
});
|
||||
|
||||
it("should use planContent from conversation store", () => {
|
||||
const event = createPlanningObservationEvent("plan-obs-1");
|
||||
const planPreviewEventIds = new Set(["plan-obs-1"]);
|
||||
const planContent = "Store plan content";
|
||||
|
||||
useConversationStore.setState({ planContent });
|
||||
|
||||
renderWithProviders(
|
||||
<EventMessage
|
||||
event={event}
|
||||
messages={[]}
|
||||
isLastMessage={false}
|
||||
isInLast10Actions={false}
|
||||
planPreviewEventIds={planPreviewEventIds}
|
||||
/>,
|
||||
);
|
||||
|
||||
expect(
|
||||
screen.getByText(`Plan Preview: ${planContent}`),
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should handle null planContent from store", () => {
|
||||
const event = createPlanningObservationEvent("plan-obs-1");
|
||||
const planPreviewEventIds = new Set(["plan-obs-1"]);
|
||||
|
||||
useConversationStore.setState({ planContent: null });
|
||||
|
||||
renderWithProviders(
|
||||
<EventMessage
|
||||
event={event}
|
||||
messages={[]}
|
||||
isLastMessage={false}
|
||||
isInLast10Actions={false}
|
||||
planPreviewEventIds={planPreviewEventIds}
|
||||
/>,
|
||||
);
|
||||
|
||||
expect(screen.getByTestId("plan-preview")).toBeInTheDocument();
|
||||
expect(screen.getByText("Plan Preview: null")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should handle empty planPreviewEventIds set", () => {
|
||||
const event = createPlanningObservationEvent("plan-obs-1");
|
||||
const planPreviewEventIds = new Set<string>();
|
||||
|
||||
const { container } = renderWithProviders(
|
||||
<EventMessage
|
||||
event={event}
|
||||
messages={[]}
|
||||
isLastMessage={false}
|
||||
isInLast10Actions={false}
|
||||
planPreviewEventIds={planPreviewEventIds}
|
||||
/>,
|
||||
);
|
||||
|
||||
expect(screen.queryByTestId("plan-preview")).not.toBeInTheDocument();
|
||||
expect(container.firstChild).toBeNull();
|
||||
});
|
||||
});
|
||||
@@ -1,195 +0,0 @@
|
||||
import { renderHook } from "@testing-library/react";
|
||||
import { describe, expect, it } from "vitest";
|
||||
import {
|
||||
usePlanPreviewEvents,
|
||||
shouldShowPlanPreview,
|
||||
} from "#/components/v1/chat/hooks/use-plan-preview-events";
|
||||
import {
|
||||
OpenHandsEvent,
|
||||
MessageEvent,
|
||||
ObservationEvent,
|
||||
PlanningFileEditorObservation,
|
||||
} from "#/types/v1/core";
|
||||
|
||||
// Helper to create a user message event
|
||||
const createUserMessageEvent = (id: string): MessageEvent => ({
|
||||
id,
|
||||
timestamp: new Date().toISOString(),
|
||||
source: "user",
|
||||
llm_message: {
|
||||
role: "user",
|
||||
content: [{ type: "text", text: "User message" }],
|
||||
},
|
||||
activated_microagents: [],
|
||||
extended_content: [],
|
||||
});
|
||||
|
||||
// Helper to create a PlanningFileEditorObservation event
|
||||
const createPlanningObservationEvent = (
|
||||
id: string,
|
||||
actionId: string = "action-1",
|
||||
): ObservationEvent<PlanningFileEditorObservation> => ({
|
||||
id,
|
||||
timestamp: new Date().toISOString(),
|
||||
source: "environment",
|
||||
tool_name: "planning_file_editor",
|
||||
tool_call_id: "call-1",
|
||||
action_id: actionId,
|
||||
observation: {
|
||||
kind: "PlanningFileEditorObservation",
|
||||
content: [{ type: "text", text: "Plan content" }],
|
||||
is_error: false,
|
||||
command: "create",
|
||||
path: "/workspace/PLAN.md",
|
||||
prev_exist: false,
|
||||
old_content: null,
|
||||
new_content: "Plan content",
|
||||
},
|
||||
});
|
||||
|
||||
// Helper to create a non-planning observation event
|
||||
const createOtherObservationEvent = (id: string): ObservationEvent => ({
|
||||
id,
|
||||
timestamp: new Date().toISOString(),
|
||||
source: "environment",
|
||||
tool_name: "execute_bash",
|
||||
tool_call_id: "call-1",
|
||||
action_id: "action-1",
|
||||
observation: {
|
||||
kind: "ExecuteBashObservation",
|
||||
content: [{ type: "text", text: "output" }],
|
||||
command: "echo test",
|
||||
exit_code: 0,
|
||||
error: false,
|
||||
timeout: false,
|
||||
metadata: {
|
||||
exit_code: 0,
|
||||
pid: 12345,
|
||||
username: "user",
|
||||
hostname: "localhost",
|
||||
working_dir: "/home/user",
|
||||
py_interpreter_path: null,
|
||||
prefix: "",
|
||||
suffix: "",
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
describe("usePlanPreviewEvents", () => {
|
||||
it("should return empty set when no events provided", () => {
|
||||
const { result } = renderHook(() => usePlanPreviewEvents([]));
|
||||
|
||||
expect(result.current).toBeInstanceOf(Set);
|
||||
expect(result.current.size).toBe(0);
|
||||
});
|
||||
|
||||
it("should return empty set when no PlanningFileEditorObservation events exist", () => {
|
||||
const events: OpenHandsEvent[] = [
|
||||
createUserMessageEvent("user-1"),
|
||||
createOtherObservationEvent("obs-1"),
|
||||
];
|
||||
|
||||
const { result } = renderHook(() => usePlanPreviewEvents(events));
|
||||
|
||||
expect(result.current.size).toBe(0);
|
||||
});
|
||||
|
||||
it("should return event ID for single PlanningFileEditorObservation in one phase", () => {
|
||||
const events: OpenHandsEvent[] = [
|
||||
createUserMessageEvent("user-1"),
|
||||
createPlanningObservationEvent("plan-obs-1"),
|
||||
];
|
||||
|
||||
const { result } = renderHook(() => usePlanPreviewEvents(events));
|
||||
|
||||
expect(result.current.size).toBe(1);
|
||||
expect(result.current.has("plan-obs-1")).toBe(true);
|
||||
});
|
||||
|
||||
it("should return only the last PlanningFileEditorObservation when multiple exist in one phase", () => {
|
||||
const events: OpenHandsEvent[] = [
|
||||
createUserMessageEvent("user-1"),
|
||||
createPlanningObservationEvent("plan-obs-1"),
|
||||
createPlanningObservationEvent("plan-obs-2"),
|
||||
createPlanningObservationEvent("plan-obs-3"),
|
||||
createOtherObservationEvent("other-obs-1"),
|
||||
];
|
||||
|
||||
const { result } = renderHook(() => usePlanPreviewEvents(events));
|
||||
|
||||
// Should only include the last one in the phase
|
||||
expect(result.current.size).toBe(1);
|
||||
expect(result.current.has("plan-obs-1")).toBe(false);
|
||||
expect(result.current.has("plan-obs-2")).toBe(false);
|
||||
expect(result.current.has("plan-obs-3")).toBe(true);
|
||||
});
|
||||
|
||||
it("should return one event ID per phase when multiple phases exist", () => {
|
||||
const events: OpenHandsEvent[] = [
|
||||
createUserMessageEvent("user-1"),
|
||||
createPlanningObservationEvent("plan-obs-1"),
|
||||
createPlanningObservationEvent("plan-obs-2"),
|
||||
createUserMessageEvent("user-2"),
|
||||
createPlanningObservationEvent("plan-obs-3"),
|
||||
createPlanningObservationEvent("plan-obs-4"),
|
||||
];
|
||||
|
||||
const { result } = renderHook(() => usePlanPreviewEvents(events));
|
||||
|
||||
// Should have one preview per phase (last observation in each phase)
|
||||
expect(result.current.size).toBe(2);
|
||||
expect(result.current.has("plan-obs-2")).toBe(true); // Last in phase 1
|
||||
expect(result.current.has("plan-obs-4")).toBe(true); // Last in phase 2
|
||||
expect(result.current.has("plan-obs-1")).toBe(false);
|
||||
expect(result.current.has("plan-obs-3")).toBe(false);
|
||||
});
|
||||
|
||||
it("should handle phase with no PlanningFileEditorObservation", () => {
|
||||
const events: OpenHandsEvent[] = [
|
||||
createUserMessageEvent("user-1"),
|
||||
createOtherObservationEvent("obs-1"),
|
||||
createUserMessageEvent("user-2"),
|
||||
createPlanningObservationEvent("plan-obs-1"),
|
||||
];
|
||||
|
||||
const { result } = renderHook(() => usePlanPreviewEvents(events));
|
||||
|
||||
// Only phase 2 has a planning observation
|
||||
expect(result.current.size).toBe(1);
|
||||
expect(result.current.has("plan-obs-1")).toBe(true);
|
||||
});
|
||||
|
||||
it("should handle events starting with non-user message", () => {
|
||||
const events: OpenHandsEvent[] = [
|
||||
createOtherObservationEvent("obs-1"),
|
||||
createUserMessageEvent("user-1"),
|
||||
createPlanningObservationEvent("plan-obs-1"),
|
||||
];
|
||||
|
||||
const { result } = renderHook(() => usePlanPreviewEvents(events));
|
||||
|
||||
// Events before first user message should be in first phase
|
||||
expect(result.current.size).toBe(1);
|
||||
expect(result.current.has("plan-obs-1")).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
describe("shouldShowPlanPreview", () => {
|
||||
it("should return true when event ID is in the set", () => {
|
||||
const planPreviewEventIds = new Set(["event-1", "event-2", "event-3"]);
|
||||
|
||||
expect(shouldShowPlanPreview("event-2", planPreviewEventIds)).toBe(true);
|
||||
});
|
||||
|
||||
it("should return false when event ID is not in the set", () => {
|
||||
const planPreviewEventIds = new Set(["event-1", "event-2"]);
|
||||
|
||||
expect(shouldShowPlanPreview("event-3", planPreviewEventIds)).toBe(false);
|
||||
});
|
||||
|
||||
it("should return false when set is empty", () => {
|
||||
const planPreviewEventIds = new Set<string>();
|
||||
|
||||
expect(shouldShowPlanPreview("event-1", planPreviewEventIds)).toBe(false);
|
||||
});
|
||||
});
|
||||
@@ -40,18 +40,6 @@ import { conversationWebSocketTestSetup } from "./helpers/msw-websocket-setup";
|
||||
import { useEventStore } from "#/stores/use-event-store";
|
||||
import { isV1Event } from "#/types/v1/type-guards";
|
||||
|
||||
// Mock useUserConversation to return V1 conversation data
|
||||
vi.mock("#/hooks/query/use-user-conversation", () => ({
|
||||
useUserConversation: vi.fn(() => ({
|
||||
data: {
|
||||
conversation_version: "V1",
|
||||
status: "RUNNING",
|
||||
},
|
||||
isLoading: false,
|
||||
error: null,
|
||||
})),
|
||||
}));
|
||||
|
||||
// MSW WebSocket mock setup
|
||||
const { wsLink, server: mswServer } = conversationWebSocketTestSetup();
|
||||
|
||||
@@ -679,16 +667,6 @@ describe("Conversation WebSocket Handler", () => {
|
||||
|
||||
// Set up MSW to mock both the HTTP API and WebSocket connection
|
||||
mswServer.use(
|
||||
// Mock events search for history preloading
|
||||
http.get(
|
||||
`http://localhost:3000/api/v1/conversation/${conversationId}/events/search`,
|
||||
async () => {
|
||||
await new Promise((resolve) => setTimeout(resolve, 10));
|
||||
return HttpResponse.json({
|
||||
items: mockHistoryEvents,
|
||||
});
|
||||
},
|
||||
),
|
||||
http.get(
|
||||
`http://localhost:3000/api/conversations/${conversationId}/events/count`,
|
||||
() => HttpResponse.json(expectedEventCount),
|
||||
@@ -725,6 +703,11 @@ describe("Conversation WebSocket Handler", () => {
|
||||
`http://localhost:3000/api/conversations/${conversationId}`,
|
||||
);
|
||||
|
||||
// Initially should be loading history
|
||||
expect(screen.getByTestId("is-loading-history")).toHaveTextContent(
|
||||
"true",
|
||||
);
|
||||
|
||||
// Wait for all events to be received
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId("events-received")).toHaveTextContent("3");
|
||||
@@ -743,14 +726,6 @@ describe("Conversation WebSocket Handler", () => {
|
||||
|
||||
// Set up MSW to mock both the HTTP API and WebSocket connection
|
||||
mswServer.use(
|
||||
// Mock empty events search
|
||||
http.get(
|
||||
`http://localhost:3000/api/v1/conversation/${conversationId}/events/search`,
|
||||
() =>
|
||||
HttpResponse.json({
|
||||
items: [],
|
||||
}),
|
||||
),
|
||||
http.get(
|
||||
`http://localhost:3000/api/conversations/${conversationId}/events/count`,
|
||||
() => HttpResponse.json(0),
|
||||
@@ -800,16 +775,6 @@ describe("Conversation WebSocket Handler", () => {
|
||||
|
||||
// Set up MSW to mock both the HTTP API and WebSocket connection
|
||||
mswServer.use(
|
||||
// Mock events search for history preloading (50 events)
|
||||
http.get(
|
||||
`http://localhost:3000/api/v1/conversation/${conversationId}/events/search`,
|
||||
async () => {
|
||||
await new Promise((resolve) => setTimeout(resolve, 10));
|
||||
return HttpResponse.json({
|
||||
items: mockHistoryEvents,
|
||||
});
|
||||
},
|
||||
),
|
||||
http.get(
|
||||
`http://localhost:3000/api/conversations/${conversationId}/events/count`,
|
||||
() => HttpResponse.json(expectedEventCount),
|
||||
@@ -845,6 +810,11 @@ describe("Conversation WebSocket Handler", () => {
|
||||
`http://localhost:3000/api/conversations/${conversationId}`,
|
||||
);
|
||||
|
||||
// Initially should be loading history
|
||||
expect(screen.getByTestId("is-loading-history")).toHaveTextContent(
|
||||
"true",
|
||||
);
|
||||
|
||||
// Wait for all events to be received
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId("events-received")).toHaveTextContent("50");
|
||||
|
||||
@@ -1,114 +0,0 @@
|
||||
import { describe, it, expect, afterEach, vi } from "vitest";
|
||||
import React from "react";
|
||||
import { renderHook, waitFor } from "@testing-library/react";
|
||||
import { QueryClient, QueryClientProvider } from "@tanstack/react-query";
|
||||
|
||||
import { useConversationHistory } from "#/hooks/query/use-conversation-history";
|
||||
import EventService from "#/api/event-service/event-service.api";
|
||||
import { useUserConversation } from "#/hooks/query/use-user-conversation";
|
||||
import type { Conversation } from "#/api/open-hands.types";
|
||||
import type { OpenHandsEvent } from "#/types/v1/core";
|
||||
|
||||
function makeConversation(version: "V0" | "V1"): Conversation {
|
||||
return {
|
||||
conversation_id: "conv-test",
|
||||
title: "Test Conversation",
|
||||
selected_repository: null,
|
||||
selected_branch: null,
|
||||
git_provider: null,
|
||||
last_updated_at: new Date().toISOString(),
|
||||
created_at: new Date().toISOString(),
|
||||
status: "RUNNING",
|
||||
runtime_status: null,
|
||||
url: null,
|
||||
session_api_key: null,
|
||||
conversation_version: version,
|
||||
};
|
||||
}
|
||||
|
||||
function makeEvent(): OpenHandsEvent {
|
||||
return {
|
||||
id: "evt-1",
|
||||
} as OpenHandsEvent;
|
||||
}
|
||||
|
||||
// --------------------
|
||||
// Mocks
|
||||
// --------------------
|
||||
vi.mock("#/api/open-hands-axios", () => ({
|
||||
openHands: {
|
||||
get: vi.fn(),
|
||||
},
|
||||
}));
|
||||
|
||||
vi.mock("#/api/event-service/event-service.api");
|
||||
vi.mock("#/hooks/query/use-user-conversation");
|
||||
|
||||
const queryClient = new QueryClient();
|
||||
|
||||
function wrapper({ children }: { children: React.ReactNode }) {
|
||||
return (
|
||||
<QueryClientProvider client={queryClient}>{children}</QueryClientProvider>
|
||||
);
|
||||
}
|
||||
|
||||
// --------------------
|
||||
// Tests
|
||||
// --------------------
|
||||
describe("useConversationHistory", () => {
|
||||
afterEach(() => {
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
it("calls V1 REST endpoint for V1 conversations", async () => {
|
||||
const v1SearchEventsSpy = vi.spyOn(EventService, "searchEventsV1");
|
||||
|
||||
vi.mocked(useUserConversation).mockReturnValue({
|
||||
data: makeConversation("V1"),
|
||||
isLoading: false,
|
||||
isPending: false,
|
||||
isError: false,
|
||||
error: null,
|
||||
refetch: vi.fn(),
|
||||
} as any);
|
||||
|
||||
v1SearchEventsSpy.mockResolvedValue([makeEvent()]);
|
||||
|
||||
const { result } = renderHook(() => useConversationHistory("conv-123"), {
|
||||
wrapper,
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(result.current.data).toBeDefined();
|
||||
});
|
||||
|
||||
expect(EventService.searchEventsV1).toHaveBeenCalledWith("conv-123");
|
||||
expect(EventService.searchEventsV0).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it("calls V0 REST endpoint for V0 conversations", async () => {
|
||||
const v0SearchEventsSpy = vi.spyOn(EventService, "searchEventsV0");
|
||||
|
||||
vi.mocked(useUserConversation).mockReturnValue({
|
||||
data: makeConversation("V0"),
|
||||
isLoading: false,
|
||||
isPending: false,
|
||||
isError: false,
|
||||
error: null,
|
||||
refetch: vi.fn(),
|
||||
} as any);
|
||||
|
||||
v0SearchEventsSpy.mockResolvedValue([makeEvent()]);
|
||||
|
||||
const { result } = renderHook(() => useConversationHistory("conv-456"), {
|
||||
wrapper,
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(result.current.data).toBeDefined();
|
||||
});
|
||||
|
||||
expect(EventService.searchEventsV0).toHaveBeenCalledWith("conv-456");
|
||||
expect(EventService.searchEventsV1).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
@@ -7,19 +7,14 @@ import LoginPage from "#/routes/login";
|
||||
import OptionService from "#/api/option-service/option-service.api";
|
||||
import AuthService from "#/api/auth-service/auth-service.api";
|
||||
|
||||
const { useEmailVerificationMock, resendEmailVerificationMock } = vi.hoisted(
|
||||
() => ({
|
||||
useEmailVerificationMock: vi.fn(() => ({
|
||||
emailVerified: false,
|
||||
hasDuplicatedEmail: false,
|
||||
emailVerificationModalOpen: false,
|
||||
setEmailVerificationModalOpen: vi.fn(),
|
||||
userId: null as string | null,
|
||||
resendEmailVerification: vi.fn(),
|
||||
})),
|
||||
resendEmailVerificationMock: vi.fn(),
|
||||
}),
|
||||
);
|
||||
const { useEmailVerificationMock } = vi.hoisted(() => ({
|
||||
useEmailVerificationMock: vi.fn(() => ({
|
||||
emailVerified: false,
|
||||
hasDuplicatedEmail: false,
|
||||
emailVerificationModalOpen: false,
|
||||
setEmailVerificationModalOpen: vi.fn(),
|
||||
})),
|
||||
}));
|
||||
|
||||
vi.mock("#/hooks/use-github-auth-url", () => ({
|
||||
useGitHubAuthUrl: () => "https://github.com/login/oauth/authorize",
|
||||
@@ -353,8 +348,6 @@ describe("LoginPage", () => {
|
||||
hasDuplicatedEmail: false,
|
||||
emailVerificationModalOpen: false,
|
||||
setEmailVerificationModalOpen: vi.fn(),
|
||||
userId: null,
|
||||
resendEmailVerification: resendEmailVerificationMock,
|
||||
});
|
||||
|
||||
render(<RouterStub initialEntries={["/login"]} />, {
|
||||
@@ -374,8 +367,6 @@ describe("LoginPage", () => {
|
||||
hasDuplicatedEmail: true,
|
||||
emailVerificationModalOpen: false,
|
||||
setEmailVerificationModalOpen: vi.fn(),
|
||||
userId: null,
|
||||
resendEmailVerification: resendEmailVerificationMock,
|
||||
});
|
||||
|
||||
render(<RouterStub initialEntries={["/login"]} />, {
|
||||
@@ -388,41 +379,6 @@ describe("LoginPage", () => {
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
it("should pass userId to EmailVerificationModal when userId is provided", async () => {
|
||||
const user = userEvent.setup();
|
||||
const testUserId = "test-user-id-123";
|
||||
const setEmailVerificationModalOpen = vi.fn();
|
||||
|
||||
useEmailVerificationMock.mockReturnValue({
|
||||
emailVerified: false,
|
||||
hasDuplicatedEmail: false,
|
||||
emailVerificationModalOpen: true,
|
||||
setEmailVerificationModalOpen,
|
||||
userId: testUserId,
|
||||
resendEmailVerification: resendEmailVerificationMock,
|
||||
});
|
||||
|
||||
render(<RouterStub initialEntries={["/login"]} />, {
|
||||
wrapper: createWrapper(),
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
screen.getByText("AUTH$PLEASE_CHECK_EMAIL_TO_VERIFY"),
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
|
||||
const resendButton = screen.getByRole("button", {
|
||||
name: /SETTINGS\$RESEND_VERIFICATION/i,
|
||||
});
|
||||
await user.click(resendButton);
|
||||
|
||||
expect(resendEmailVerificationMock).toHaveBeenCalledWith({
|
||||
userId: testUserId,
|
||||
isAuthFlow: true,
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe("Loading States", () => {
|
||||
@@ -459,15 +415,6 @@ describe("LoginPage", () => {
|
||||
|
||||
describe("Terms and Privacy", () => {
|
||||
it("should display Terms and Privacy notice", async () => {
|
||||
useEmailVerificationMock.mockReturnValue({
|
||||
emailVerified: false,
|
||||
hasDuplicatedEmail: false,
|
||||
emailVerificationModalOpen: false,
|
||||
setEmailVerificationModalOpen: vi.fn(),
|
||||
userId: null as string | null,
|
||||
resendEmailVerification: resendEmailVerificationMock,
|
||||
});
|
||||
|
||||
render(<RouterStub initialEntries={["/login"]} />, {
|
||||
wrapper: createWrapper(),
|
||||
});
|
||||
|
||||
@@ -48,7 +48,6 @@ function LoginStub() {
|
||||
searchParams.get("email_verification_required") === "true";
|
||||
const emailVerified = searchParams.get("email_verified") === "true";
|
||||
const emailVerificationText = "AUTH$PLEASE_CHECK_EMAIL_TO_VERIFY";
|
||||
const returnTo = searchParams.get("returnTo");
|
||||
|
||||
return (
|
||||
<div data-testid="login-page">
|
||||
@@ -59,7 +58,6 @@ function LoginStub() {
|
||||
{emailVerificationText}
|
||||
</div>
|
||||
)}
|
||||
{returnTo && <div data-testid="return-to-param">{returnTo}</div>}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
@@ -102,27 +100,6 @@ const RouterStubWithLogin = createRoutesStub([
|
||||
},
|
||||
]);
|
||||
|
||||
const RouterStubWithDeviceVerify = createRoutesStub([
|
||||
{
|
||||
Component: MainApp,
|
||||
path: "/",
|
||||
children: [
|
||||
{
|
||||
Component: () => <div data-testid="outlet-content" />,
|
||||
path: "/",
|
||||
},
|
||||
{
|
||||
Component: () => <div data-testid="device-verify-page" />,
|
||||
path: "/oauth/device/verify",
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
Component: LoginStub,
|
||||
path: "/login",
|
||||
},
|
||||
]);
|
||||
|
||||
const renderMainApp = (initialEntries: string[] = ["/"]) =>
|
||||
render(<RouterStub initialEntries={initialEntries} />, {
|
||||
wrapper: ({ children }) => (
|
||||
@@ -334,23 +311,5 @@ describe("MainApp", () => {
|
||||
{ timeout: 2000 },
|
||||
);
|
||||
});
|
||||
|
||||
it("should preserve query parameters in returnTo when redirecting to login", async () => {
|
||||
renderWithLoginStub(RouterStubWithDeviceVerify, [
|
||||
"/oauth/device/verify?user_code=F9XN6BKU",
|
||||
]);
|
||||
|
||||
await waitFor(
|
||||
() => {
|
||||
expect(screen.getByTestId("login-page")).toBeInTheDocument();
|
||||
const returnToElement = screen.getByTestId("return-to-param");
|
||||
expect(returnToElement).toBeInTheDocument();
|
||||
expect(returnToElement.textContent).toBe(
|
||||
"/oauth/device/verify?user_code=F9XN6BKU",
|
||||
);
|
||||
},
|
||||
{ timeout: 2000 },
|
||||
);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -138,72 +138,4 @@ describe("handleEventForUI", () => {
|
||||
anotherActionEvent,
|
||||
]);
|
||||
});
|
||||
|
||||
it("should NOT replace ThinkAction with ThinkObservation", () => {
|
||||
const mockThinkAction: ActionEvent = {
|
||||
id: "test-think-action-1",
|
||||
timestamp: Date.now().toString(),
|
||||
source: "agent",
|
||||
thought: [{ type: "text", text: "I am thinking..." }],
|
||||
thinking_blocks: [],
|
||||
action: {
|
||||
kind: "ThinkAction",
|
||||
thought: "I am thinking...",
|
||||
},
|
||||
tool_name: "think",
|
||||
tool_call_id: "call_think_1",
|
||||
tool_call: {
|
||||
id: "call_think_1",
|
||||
type: "function",
|
||||
function: {
|
||||
name: "think",
|
||||
arguments: "",
|
||||
},
|
||||
},
|
||||
llm_response_id: "response_think",
|
||||
security_risk: SecurityRisk.UNKNOWN,
|
||||
};
|
||||
|
||||
const mockThinkObservation: ObservationEvent = {
|
||||
id: "test-think-observation-1",
|
||||
timestamp: Date.now().toString(),
|
||||
source: "environment",
|
||||
tool_name: "think",
|
||||
tool_call_id: "call_think_1",
|
||||
observation: {
|
||||
kind: "ThinkObservation",
|
||||
content: [{ type: "text", text: "Your thought has been logged." }],
|
||||
},
|
||||
action_id: "test-think-action-1",
|
||||
};
|
||||
|
||||
const initialUiEvents = [mockMessageEvent, mockThinkAction];
|
||||
const result = handleEventForUI(mockThinkObservation, initialUiEvents);
|
||||
|
||||
// ThinkObservation should NOT be added - ThinkAction should remain
|
||||
expect(result).toEqual([mockMessageEvent, mockThinkAction]);
|
||||
expect(result).not.toBe(initialUiEvents);
|
||||
});
|
||||
|
||||
it("should NOT add ThinkObservation even when ThinkAction is not found", () => {
|
||||
const mockThinkObservation: ObservationEvent = {
|
||||
id: "test-think-observation-1",
|
||||
timestamp: Date.now().toString(),
|
||||
source: "environment",
|
||||
tool_name: "think",
|
||||
tool_call_id: "call_think_1",
|
||||
observation: {
|
||||
kind: "ThinkObservation",
|
||||
content: [{ type: "text", text: "Your thought has been logged." }],
|
||||
},
|
||||
action_id: "test-think-action-not-found",
|
||||
};
|
||||
|
||||
const initialUiEvents = [mockMessageEvent];
|
||||
const result = handleEventForUI(mockThinkObservation, initialUiEvents);
|
||||
|
||||
// ThinkObservation should never be added to uiEvents
|
||||
expect(result).toEqual([mockMessageEvent]);
|
||||
expect(result).not.toBe(initialUiEvents);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -103,7 +103,7 @@ export interface V1AppConversation {
|
||||
|
||||
export interface Skill {
|
||||
name: string;
|
||||
type: "repo" | "knowledge" | "agentskills";
|
||||
type: "repo" | "knowledge";
|
||||
content: string;
|
||||
triggers: string[];
|
||||
}
|
||||
|
||||
@@ -5,8 +5,6 @@ import type {
|
||||
ConfirmationResponseRequest,
|
||||
ConfirmationResponseResponse,
|
||||
} from "./event-service.types";
|
||||
import { openHands } from "../open-hands-axios";
|
||||
import { OpenHandsEvent } from "#/types/v1/core";
|
||||
|
||||
class EventService {
|
||||
/**
|
||||
@@ -63,27 +61,5 @@ class EventService {
|
||||
);
|
||||
return data;
|
||||
}
|
||||
|
||||
// V1 conversations — App Server REST endpoint
|
||||
static async searchEventsV1(conversationId: string, limit = 100) {
|
||||
const { data } = await openHands.get<{
|
||||
items: OpenHandsEvent[];
|
||||
}>(`/api/v1/conversation/${conversationId}/events/search`, {
|
||||
params: { limit },
|
||||
});
|
||||
|
||||
return data.items;
|
||||
}
|
||||
|
||||
// V0 conversations — Legacy REST endpoint
|
||||
static async searchEventsV0(conversationId: string, limit = 100) {
|
||||
const { data } = await openHands.get<{
|
||||
events: OpenHandsEvent[];
|
||||
}>(`/api/conversations/${conversationId}/events`, {
|
||||
params: { limit },
|
||||
});
|
||||
|
||||
return data.events;
|
||||
}
|
||||
}
|
||||
export default EventService;
|
||||
|
||||
@@ -110,7 +110,7 @@ export interface InputMetadata {
|
||||
|
||||
export interface Microagent {
|
||||
name: string;
|
||||
type: "repo" | "knowledge" | "agentskills";
|
||||
type: "repo" | "knowledge";
|
||||
content: string;
|
||||
triggers: string[];
|
||||
}
|
||||
|
||||
@@ -1,24 +1,22 @@
|
||||
import { useMemo } from "react";
|
||||
import { useTranslation } from "react-i18next";
|
||||
import { ArrowUpRight } from "lucide-react";
|
||||
import LessonPlanIcon from "#/icons/lesson-plan.svg?react";
|
||||
import { USE_PLANNING_AGENT } from "#/utils/feature-flags";
|
||||
import { Typography } from "#/ui/typography";
|
||||
import { I18nKey } from "#/i18n/declaration";
|
||||
import { MarkdownRenderer } from "#/components/features/markdown/markdown-renderer";
|
||||
|
||||
const MAX_CONTENT_LENGTH = 300;
|
||||
|
||||
interface PlanPreviewProps {
|
||||
/** Raw plan content from PLAN.md file */
|
||||
planContent?: string | null;
|
||||
title?: string;
|
||||
description?: string;
|
||||
onViewClick?: () => void;
|
||||
onBuildClick?: () => void;
|
||||
}
|
||||
|
||||
// TODO: Remove the hardcoded values and use the plan content from the conversation store
|
||||
/* eslint-disable i18next/no-literal-string */
|
||||
export function PlanPreview({
|
||||
planContent,
|
||||
title = "Improve Developer Onboarding and Examples",
|
||||
description = "Based on the analysis of Browser-Use's current documentation and examples, this plan addresses gaps in developer onboarding by creating a progressive learning path, troubleshooting resources, and practical examples that address real-world scenarios (like the LM Studio/local LLM integration issues encountered...",
|
||||
onViewClick,
|
||||
onBuildClick,
|
||||
}: PlanPreviewProps) {
|
||||
@@ -26,13 +24,6 @@ export function PlanPreview({
|
||||
|
||||
const shouldUsePlanningAgent = USE_PLANNING_AGENT();
|
||||
|
||||
// Truncate plan content for preview
|
||||
const truncatedContent = useMemo(() => {
|
||||
if (!planContent) return "";
|
||||
if (planContent.length <= MAX_CONTENT_LENGTH) return planContent;
|
||||
return `${planContent.slice(0, MAX_CONTENT_LENGTH)}...`;
|
||||
}, [planContent]);
|
||||
|
||||
if (!shouldUsePlanningAgent) {
|
||||
return null;
|
||||
}
|
||||
@@ -50,7 +41,6 @@ export function PlanPreview({
|
||||
type="button"
|
||||
onClick={onViewClick}
|
||||
className="flex items-center gap-1 hover:opacity-80 transition-opacity"
|
||||
data-testid="plan-preview-view-button"
|
||||
>
|
||||
<Typography.Text className="font-medium text-[11px] text-white tracking-[0.11px] leading-4">
|
||||
{t(I18nKey.COMMON$VIEW)}
|
||||
@@ -60,27 +50,16 @@ export function PlanPreview({
|
||||
</div>
|
||||
|
||||
{/* Content */}
|
||||
<div
|
||||
data-testid="plan-preview-content"
|
||||
className="flex flex-col gap-[10px] p-4 text-[15px] text-white leading-[29px]"
|
||||
>
|
||||
{truncatedContent && (
|
||||
<>
|
||||
<MarkdownRenderer includeStandard includeHeadings>
|
||||
{truncatedContent}
|
||||
</MarkdownRenderer>
|
||||
{planContent && planContent.length > MAX_CONTENT_LENGTH && (
|
||||
<button
|
||||
type="button"
|
||||
onClick={onViewClick}
|
||||
className="text-[#4a67bd] cursor-pointer hover:underline text-left"
|
||||
data-testid="plan-preview-read-more-button"
|
||||
>
|
||||
{t(I18nKey.COMMON$READ_MORE)}
|
||||
</button>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
<div className="flex flex-col gap-[10px] p-4">
|
||||
<h3 className="font-bold text-[19px] text-white leading-[29px]">
|
||||
{title}
|
||||
</h3>
|
||||
<p className="text-[15px] text-white leading-[29px]">
|
||||
{description}
|
||||
<Typography.Text className="text-[#4a67bd] cursor-pointer hover:underline ml-1">
|
||||
{t(I18nKey.COMMON$READ_MORE)}
|
||||
</Typography.Text>
|
||||
</p>
|
||||
</div>
|
||||
|
||||
{/* Footer */}
|
||||
@@ -89,7 +68,6 @@ export function PlanPreview({
|
||||
type="button"
|
||||
onClick={onBuildClick}
|
||||
className="bg-white flex items-center justify-center h-[26px] px-2 rounded-[4px] w-[93px] hover:opacity-90 transition-opacity cursor-pointer"
|
||||
data-testid="plan-preview-build-button"
|
||||
>
|
||||
<Typography.Text className="font-medium text-[14px] text-black leading-5">
|
||||
{t(I18nKey.COMMON$BUILD)}{" "}
|
||||
|
||||
@@ -11,15 +11,6 @@ interface SkillItemProps {
|
||||
}
|
||||
|
||||
export function SkillItem({ skill, isExpanded, onToggle }: SkillItemProps) {
|
||||
let skillTypeLabel: string;
|
||||
if (skill.type === "repo") {
|
||||
skillTypeLabel = "Repository";
|
||||
} else if (skill.type === "knowledge") {
|
||||
skillTypeLabel = "Knowledge";
|
||||
} else {
|
||||
skillTypeLabel = "AgentSkills";
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="rounded-md overflow-hidden">
|
||||
<button
|
||||
@@ -34,7 +25,7 @@ export function SkillItem({ skill, isExpanded, onToggle }: SkillItemProps) {
|
||||
</div>
|
||||
<div className="flex items-center">
|
||||
<Typography.Text className="px-2 py-1 text-xs rounded-full bg-gray-800 mr-2">
|
||||
{skillTypeLabel}
|
||||
{skill.type === "repo" ? "Repository" : "Knowledge"}
|
||||
</Typography.Text>
|
||||
<Typography.Text className="text-gray-300">
|
||||
{isExpanded ? (
|
||||
|
||||
@@ -87,6 +87,16 @@ export function GitBranchDropdown({
|
||||
[onBranchSelect],
|
||||
);
|
||||
|
||||
// Handle input value change
|
||||
const handleInputValueChange = useCallback(
|
||||
({ inputValue: newInputValue }: { inputValue?: string }) => {
|
||||
if (newInputValue !== undefined) {
|
||||
setInputValue(newInputValue);
|
||||
}
|
||||
},
|
||||
[],
|
||||
);
|
||||
|
||||
// Handle menu scroll for infinite loading
|
||||
const handleMenuScroll = useCallback(
|
||||
(event: React.UIEvent<HTMLUListElement>) => {
|
||||
@@ -118,14 +128,8 @@ export function GitBranchDropdown({
|
||||
onSelectedItemChange: ({ selectedItem: newSelectedItem }) => {
|
||||
handleBranchSelect(newSelectedItem || null);
|
||||
},
|
||||
onInputValueChange: handleInputValueChange,
|
||||
inputValue,
|
||||
// Override Downshift's default input-click behavior to avoid closing/reopening
|
||||
// the menu, which would reset scroll position and break search continuity.
|
||||
stateReducer: (state, actionAndChanges) =>
|
||||
actionAndChanges.type === useCombobox.stateChangeTypes.InputClick &&
|
||||
state.isOpen
|
||||
? { ...actionAndChanges.changes, isOpen: true }
|
||||
: actionAndChanges.changes,
|
||||
});
|
||||
|
||||
// Reset branch selection when repository changes
|
||||
@@ -172,12 +176,12 @@ export function GitBranchDropdown({
|
||||
|
||||
// Initialize input value when selectedBranch changes (but not when user is typing)
|
||||
useEffect(() => {
|
||||
if (selectedBranch && !isOpen) {
|
||||
if (selectedBranch && !isOpen && inputValue !== selectedBranch.name) {
|
||||
setInputValue(selectedBranch.name);
|
||||
} else if (!selectedBranch && !isOpen) {
|
||||
} else if (!selectedBranch && !isOpen && inputValue) {
|
||||
setInputValue("");
|
||||
}
|
||||
}, [selectedBranch, isOpen]);
|
||||
}, [selectedBranch, isOpen, inputValue]);
|
||||
|
||||
const isLoadingState = isLoading || isSearchLoading || isFetchingNextPage;
|
||||
|
||||
@@ -203,10 +207,6 @@ export function GitBranchDropdown({
|
||||
"disabled:bg-[#363636] disabled:cursor-not-allowed disabled:opacity-60",
|
||||
"pl-7 pr-16 text-sm font-normal leading-5", // Space for clear and toggle buttons
|
||||
),
|
||||
// Direct onChange for cursor position preservation
|
||||
onChange: (e: React.ChangeEvent<HTMLInputElement>) => {
|
||||
setInputValue(e.target.value);
|
||||
},
|
||||
})}
|
||||
data-testid="git-branch-dropdown-input"
|
||||
/>
|
||||
|
||||
@@ -184,6 +184,14 @@ export function GitRepoDropdown({
|
||||
setInputValue("");
|
||||
}, [handleSelectionChange]);
|
||||
|
||||
// Handle input value change
|
||||
const handleInputValueChange = useCallback(
|
||||
({ inputValue: newInputValue }: { inputValue?: string }) => {
|
||||
setInputValue(newInputValue || "");
|
||||
},
|
||||
[],
|
||||
);
|
||||
|
||||
// Handle scroll to bottom for pagination
|
||||
const handleMenuScroll = useCallback(
|
||||
(event: React.UIEvent<HTMLUListElement>) => {
|
||||
@@ -212,14 +220,8 @@ export function GitRepoDropdown({
|
||||
onSelectedItemChange: ({ selectedItem: newSelectedItem }) => {
|
||||
handleSelectionChange(newSelectedItem);
|
||||
},
|
||||
onInputValueChange: handleInputValueChange,
|
||||
inputValue,
|
||||
// Override Downshift's default input-click behavior to avoid closing/reopening
|
||||
// the menu, which would reset scroll position and break search continuity.
|
||||
stateReducer: (state, actionAndChanges) =>
|
||||
actionAndChanges.type === useCombobox.stateChangeTypes.InputClick &&
|
||||
state.isOpen
|
||||
? { ...actionAndChanges.changes, isOpen: true }
|
||||
: actionAndChanges.changes,
|
||||
});
|
||||
|
||||
// Sync localSelectedItem with external value prop
|
||||
@@ -235,8 +237,6 @@ export function GitRepoDropdown({
|
||||
useEffect(() => {
|
||||
if (selectedRepository && !isOpen) {
|
||||
setInputValue(selectedRepository.full_name);
|
||||
} else if (!selectedRepository && !isOpen) {
|
||||
setInputValue("");
|
||||
}
|
||||
}, [selectedRepository, isOpen]);
|
||||
|
||||
@@ -335,10 +335,6 @@ export function GitRepoDropdown({
|
||||
"disabled:bg-[#363636] disabled:cursor-not-allowed disabled:opacity-60",
|
||||
"pl-7 pr-16 text-sm font-normal leading-5", // Space for clear and toggle buttons
|
||||
),
|
||||
// Direct onChange for cursor position preservation
|
||||
onChange: (e: React.ChangeEvent<HTMLInputElement>) => {
|
||||
setInputValue(e.target.value);
|
||||
},
|
||||
})}
|
||||
data-testid="git-repo-dropdown"
|
||||
/>
|
||||
|
||||
@@ -27,11 +27,6 @@ export const shouldRenderEvent = (event: OpenHandsEvent) => {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Hide PlanningFileEditorAction - handled separately with PlanPreview component
|
||||
if (actionType === "PlanningFileEditorAction") {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
@@ -6,11 +6,9 @@ import {
|
||||
isObservationEvent,
|
||||
isAgentErrorEvent,
|
||||
isUserMessageEvent,
|
||||
isPlanningFileEditorObservationEvent,
|
||||
} from "#/types/v1/type-guards";
|
||||
import { MicroagentStatus } from "#/types/microagent-status";
|
||||
import { useConfig } from "#/hooks/query/use-config";
|
||||
import { useConversationStore } from "#/stores/conversation-store";
|
||||
// TODO: Implement V1 feedback functionality when API supports V1 event IDs
|
||||
// import { useFeedbackExists } from "#/hooks/query/use-feedback-exists";
|
||||
import {
|
||||
@@ -21,8 +19,6 @@ import {
|
||||
ThoughtEventMessage,
|
||||
} from "./event-message-components";
|
||||
import { createSkillReadyEvent } from "./event-content-helpers/create-skill-ready-event";
|
||||
import { PlanPreview } from "../../features/chat/plan-preview";
|
||||
import { shouldShowPlanPreview } from "./hooks/use-plan-preview-events";
|
||||
|
||||
interface EventMessageProps {
|
||||
event: OpenHandsEvent & { isFromPlanningAgent?: boolean };
|
||||
@@ -37,8 +33,6 @@ interface EventMessageProps {
|
||||
tooltip?: string;
|
||||
}>;
|
||||
isInLast10Actions: boolean;
|
||||
/** Set of event IDs that should render PlanPreview (one per user message phase) */
|
||||
planPreviewEventIds?: Set<string>;
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -149,10 +143,8 @@ export function EventMessage({
|
||||
microagentPRUrl,
|
||||
actions,
|
||||
isInLast10Actions,
|
||||
planPreviewEventIds,
|
||||
}: EventMessageProps) {
|
||||
const { data: config } = useConfig();
|
||||
const { planContent } = useConversationStore();
|
||||
|
||||
// V1 events use string IDs, but useFeedbackExists expects number
|
||||
// For now, we'll skip feedback functionality for V1 events
|
||||
@@ -206,21 +198,6 @@ export function EventMessage({
|
||||
|
||||
// Observation events - find the corresponding action and render thought + observation
|
||||
if (isObservationEvent(event)) {
|
||||
// Handle PlanningFileEditorObservation specially
|
||||
if (isPlanningFileEditorObservationEvent(event)) {
|
||||
// Only show PlanPreview if this event is marked as the one to display
|
||||
// (last PlanningFileEditorObservation in its phase)
|
||||
if (
|
||||
planPreviewEventIds &&
|
||||
shouldShowPlanPreview(event.id, planPreviewEventIds)
|
||||
) {
|
||||
return <PlanPreview planContent={planContent} />;
|
||||
}
|
||||
// Not the designated preview event for this phase - render nothing
|
||||
// This prevents duplicate previews within the same phase
|
||||
return null;
|
||||
}
|
||||
|
||||
// Find the action that this observation is responding to
|
||||
const correspondingAction = messages.find(
|
||||
(msg) => isActionEvent(msg) && msg.id === event.action_id,
|
||||
|
||||
@@ -1,114 +0,0 @@
|
||||
import { useMemo } from "react";
|
||||
import { OpenHandsEvent } from "#/types/v1/core";
|
||||
import {
|
||||
isUserMessageEvent,
|
||||
isPlanningFileEditorObservationEvent,
|
||||
} from "#/types/v1/type-guards";
|
||||
|
||||
/**
|
||||
* Groups events into phases based on user messages.
|
||||
* A phase starts with a user message and includes all subsequent events
|
||||
* until the next user message.
|
||||
*
|
||||
* @param events - The full list of events
|
||||
* @returns Array of phases, where each phase is an array of events
|
||||
*/
|
||||
function groupEventsByPhase(events: OpenHandsEvent[]): OpenHandsEvent[][] {
|
||||
const phases: OpenHandsEvent[][] = [];
|
||||
let currentPhase: OpenHandsEvent[] = [];
|
||||
|
||||
for (const event of events) {
|
||||
if (isUserMessageEvent(event)) {
|
||||
// Start a new phase with the user message
|
||||
if (currentPhase.length > 0) {
|
||||
phases.push(currentPhase);
|
||||
}
|
||||
currentPhase = [event];
|
||||
} else {
|
||||
// Add event to current phase
|
||||
currentPhase.push(event);
|
||||
}
|
||||
}
|
||||
|
||||
// Don't forget the last phase
|
||||
if (currentPhase.length > 0) {
|
||||
phases.push(currentPhase);
|
||||
}
|
||||
|
||||
return phases;
|
||||
}
|
||||
|
||||
/**
|
||||
* Finds the last PlanningFileEditorObservation in a phase.
|
||||
*
|
||||
* @param phase - Array of events in a phase
|
||||
* @returns The event ID of the last PlanningFileEditorObservation, or null
|
||||
*/
|
||||
function findLastPlanningObservationInPhase(
|
||||
phase: OpenHandsEvent[],
|
||||
): string | null {
|
||||
// Iterate backwards to find the last one
|
||||
for (let i = phase.length - 1; i >= 0; i -= 1) {
|
||||
const event = phase[i];
|
||||
if (isPlanningFileEditorObservationEvent(event)) {
|
||||
return event.id;
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
export interface PlanPreviewEventInfo {
|
||||
eventId: string;
|
||||
/** Index of this plan preview in the conversation (1st, 2nd, etc.) */
|
||||
phaseIndex: number;
|
||||
}
|
||||
|
||||
/**
|
||||
* Hook to determine which PlanningFileEditorObservation events should render PlanPreview.
|
||||
*
|
||||
* This hook implements phase-based grouping where:
|
||||
* - A phase starts with a user message and ends at the next user message
|
||||
* - Only the LAST PlanningFileEditorObservation in each phase shows PlanPreview
|
||||
* - This ensures only one preview per user request, even with multiple observations
|
||||
*
|
||||
* Scenario handling:
|
||||
* - Scenario 1 (Create plan): Multiple observations in one phase → 1 preview
|
||||
* - Scenario 2 (Create then update): Two user messages → two phases → 2 previews
|
||||
* - Scenario 3 (Create + update while processing): Two user messages → 2 previews
|
||||
*
|
||||
* @param allEvents - Full list of v1 events (for phase detection)
|
||||
* @returns Set of event IDs that should render PlanPreview
|
||||
*/
|
||||
export function usePlanPreviewEvents(allEvents: OpenHandsEvent[]): Set<string> {
|
||||
return useMemo(() => {
|
||||
const planPreviewEventIds = new Set<string>();
|
||||
|
||||
// Group events by phases (user message boundaries)
|
||||
const phases = groupEventsByPhase(allEvents);
|
||||
|
||||
// For each phase, find the last PlanningFileEditorObservation
|
||||
phases.forEach((phase) => {
|
||||
const lastPlanningObservationId =
|
||||
findLastPlanningObservationInPhase(phase);
|
||||
if (lastPlanningObservationId) {
|
||||
planPreviewEventIds.add(lastPlanningObservationId);
|
||||
}
|
||||
});
|
||||
|
||||
return planPreviewEventIds;
|
||||
}, [allEvents]);
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a specific event should render PlanPreview.
|
||||
*
|
||||
* @param eventId - The event ID to check
|
||||
* @param planPreviewEventIds - Set of event IDs that should render PlanPreview
|
||||
* @returns true if this event should render PlanPreview
|
||||
*/
|
||||
export function shouldShowPlanPreview(
|
||||
eventId: string,
|
||||
planPreviewEventIds: Set<string>,
|
||||
): boolean {
|
||||
return planPreviewEventIds.has(eventId);
|
||||
}
|
||||
@@ -3,7 +3,6 @@ import { OpenHandsEvent } from "#/types/v1/core";
|
||||
import { EventMessage } from "./event-message";
|
||||
import { ChatMessage } from "../../features/chat/chat-message";
|
||||
import { useOptimisticUserMessageStore } from "#/stores/optimistic-user-message-store";
|
||||
import { usePlanPreviewEvents } from "./hooks/use-plan-preview-events";
|
||||
// TODO: Implement microagent functionality for V1 when APIs support V1 event IDs
|
||||
// import { AgentState } from "#/types/agent-state";
|
||||
// import MemoryIcon from "#/icons/memory_icon.svg?react";
|
||||
@@ -19,10 +18,6 @@ export const Messages: React.FC<MessagesProps> = React.memo(
|
||||
|
||||
const optimisticUserMessage = getOptimisticUserMessage();
|
||||
|
||||
// Get the set of event IDs that should render PlanPreview
|
||||
// This ensures only one preview per user message "phase"
|
||||
const planPreviewEventIds = usePlanPreviewEvents(allEvents);
|
||||
|
||||
// TODO: Implement microagent functionality for V1 if needed
|
||||
// For now, we'll skip microagent features
|
||||
|
||||
@@ -35,7 +30,6 @@ export const Messages: React.FC<MessagesProps> = React.memo(
|
||||
messages={allEvents}
|
||||
isLastMessage={messages.length - 1 === index}
|
||||
isInLast10Actions={messages.length - 1 - index < 10}
|
||||
planPreviewEventIds={planPreviewEventIds}
|
||||
// Microagent props - not implemented yet for V1
|
||||
// microagentStatus={undefined}
|
||||
// microagentConversationId={undefined}
|
||||
|
||||
@@ -46,7 +46,6 @@ import { useTracking } from "#/hooks/use-tracking";
|
||||
import { useReadConversationFile } from "#/hooks/mutation/use-read-conversation-file";
|
||||
import useMetricsStore from "#/stores/metrics-store";
|
||||
import { I18nKey } from "#/i18n/declaration";
|
||||
import { useConversationHistory } from "#/hooks/query/use-conversation-history";
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/naming-convention
|
||||
export type V1_WebSocketConnectionState =
|
||||
@@ -307,21 +306,6 @@ export function ConversationWebSocketProvider({
|
||||
latestPlanningFileEventRef.current = null;
|
||||
}, [conversationId]);
|
||||
|
||||
const { data: preloadedEvents } = useConversationHistory(conversationId);
|
||||
|
||||
useEffect(() => {
|
||||
if (!preloadedEvents || preloadedEvents.length === 0) {
|
||||
setIsLoadingHistoryMain(false);
|
||||
return;
|
||||
}
|
||||
|
||||
for (const event of preloadedEvents) {
|
||||
addEvent(event);
|
||||
}
|
||||
|
||||
setIsLoadingHistoryMain(false);
|
||||
}, [preloadedEvents, addEvent]);
|
||||
|
||||
// Separate message handlers for each connection
|
||||
const handleMainMessage = useCallback(
|
||||
(messageEvent: MessageEvent) => {
|
||||
|
||||
@@ -1,22 +0,0 @@
|
||||
import { useQuery } from "@tanstack/react-query";
|
||||
import EventService from "#/api/event-service/event-service.api";
|
||||
import { useUserConversation } from "#/hooks/query/use-user-conversation";
|
||||
|
||||
export const useConversationHistory = (conversationId?: string) => {
|
||||
const { data: conversation } = useUserConversation(conversationId ?? null);
|
||||
|
||||
return useQuery({
|
||||
queryKey: ["conversation-history", conversationId, conversation],
|
||||
enabled: !!conversationId && !!conversation,
|
||||
queryFn: async () => {
|
||||
if (!conversationId || !conversation) return [];
|
||||
|
||||
if (conversation.conversation_version === "V1") {
|
||||
return EventService.searchEventsV1(conversationId);
|
||||
}
|
||||
|
||||
return EventService.searchEventsV0(conversationId);
|
||||
},
|
||||
staleTime: 30_000,
|
||||
});
|
||||
};
|
||||
@@ -20,7 +20,6 @@ export default function LoginPage() {
|
||||
recaptchaBlocked,
|
||||
emailVerificationModalOpen,
|
||||
setEmailVerificationModalOpen,
|
||||
userId,
|
||||
} = useEmailVerification();
|
||||
|
||||
const gitHubAuthUrl = useGitHubAuthUrl({
|
||||
@@ -78,7 +77,6 @@ export default function LoginPage() {
|
||||
onClose={() => {
|
||||
setEmailVerificationModalOpen(false);
|
||||
}}
|
||||
userId={userId}
|
||||
/>
|
||||
)}
|
||||
</>
|
||||
|
||||
@@ -5,7 +5,6 @@ import {
|
||||
Outlet,
|
||||
useNavigate,
|
||||
useLocation,
|
||||
useSearchParams,
|
||||
} from "react-router";
|
||||
import { useTranslation } from "react-i18next";
|
||||
import { I18nKey } from "#/i18n/declaration";
|
||||
@@ -68,7 +67,6 @@ export default function MainApp() {
|
||||
const appTitle = useAppTitle();
|
||||
const navigate = useNavigate();
|
||||
const { pathname } = useLocation();
|
||||
const [searchParams] = useSearchParams();
|
||||
const isOnTosPage = useIsOnTosPage();
|
||||
const { data: settings } = useSettings();
|
||||
const { migrateUserConsent } = useMigrateUserConsent();
|
||||
@@ -184,18 +182,13 @@ export default function MainApp() {
|
||||
|
||||
React.useEffect(() => {
|
||||
if (shouldRedirectToLogin) {
|
||||
// Include search params in returnTo to preserve query string (e.g., user_code for device OAuth)
|
||||
const searchString = searchParams.toString();
|
||||
let fullPath = "";
|
||||
if (pathname !== "/") {
|
||||
fullPath = searchString ? `${pathname}?${searchString}` : pathname;
|
||||
}
|
||||
const loginUrl = fullPath
|
||||
? `/login?returnTo=${encodeURIComponent(fullPath)}`
|
||||
const returnTo = pathname !== "/" ? pathname : "";
|
||||
const loginUrl = returnTo
|
||||
? `/login?returnTo=${encodeURIComponent(returnTo)}`
|
||||
: "/login";
|
||||
navigate(loginUrl, { replace: true });
|
||||
}
|
||||
}, [shouldRedirectToLogin, pathname, searchParams, navigate]);
|
||||
}, [shouldRedirectToLogin, pathname, navigate]);
|
||||
|
||||
if (shouldRedirectToLogin) {
|
||||
return (
|
||||
|
||||
@@ -213,37 +213,6 @@ export interface BrowserCloseTabAction extends ActionBase<"BrowserCloseTabAction
|
||||
tab_id: string;
|
||||
}
|
||||
|
||||
export interface PlanningFileEditorAction extends ActionBase<"PlanningFileEditorAction"> {
|
||||
/**
|
||||
* The commands to run. Allowed options are: `view`, `create`, `str_replace`, `insert`, `undo_edit`.
|
||||
*/
|
||||
command: "view" | "create" | "str_replace" | "insert" | "undo_edit";
|
||||
/**
|
||||
* Absolute path to file (typically /workspace/project/PLAN.md).
|
||||
*/
|
||||
path: string;
|
||||
/**
|
||||
* Required parameter of `create` command, with the content of the file to be created.
|
||||
*/
|
||||
file_text: string | null;
|
||||
/**
|
||||
* Required parameter of `str_replace` command containing the string in `path` to replace.
|
||||
*/
|
||||
old_str: string | null;
|
||||
/**
|
||||
* Optional parameter of `str_replace` command containing the new string (if not given, no string will be added). Required parameter of `insert` command containing the string to insert.
|
||||
*/
|
||||
new_str: string | null;
|
||||
/**
|
||||
* Required parameter of `insert` command. The `new_str` will be inserted AFTER the line `insert_line` of `path`. Must be >= 1.
|
||||
*/
|
||||
insert_line: number | null;
|
||||
/**
|
||||
* Optional parameter of `view` command when `path` points to a file. If none is given, the full file is shown.
|
||||
*/
|
||||
view_range: [number, number] | null;
|
||||
}
|
||||
|
||||
export type Action =
|
||||
| MCPToolAction
|
||||
| FinishAction
|
||||
@@ -253,7 +222,6 @@ export type Action =
|
||||
| FileEditorAction
|
||||
| StrReplaceEditorAction
|
||||
| TaskTrackerAction
|
||||
| PlanningFileEditorAction
|
||||
| BrowserNavigateAction
|
||||
| BrowserClickAction
|
||||
| BrowserTypeAction
|
||||
|
||||
@@ -4,7 +4,6 @@ import { isObservationEvent } from "#/types/v1/type-guards";
|
||||
/**
|
||||
* Handles adding an event to the UI events array
|
||||
* Replaces actions with observations when they arrive (so UI shows observation instead of action)
|
||||
* Exception: ThinkAction is NOT replaced because the thought content is in the action, not in the observation
|
||||
*/
|
||||
export const handleEventForUI = (
|
||||
event: OpenHandsEvent,
|
||||
@@ -13,17 +12,12 @@ export const handleEventForUI = (
|
||||
const newUiEvents = [...uiEvents];
|
||||
|
||||
if (isObservationEvent(event)) {
|
||||
// Don't add ThinkObservation at all - we keep the ThinkAction instead
|
||||
// The thought content is in the action, not the observation
|
||||
if (event.observation.kind === "ThinkObservation") {
|
||||
return newUiEvents;
|
||||
}
|
||||
|
||||
// Find and replace the corresponding action from uiEvents
|
||||
const actionIndex = newUiEvents.findIndex(
|
||||
(uiEvent) => uiEvent.id === event.action_id,
|
||||
);
|
||||
if (actionIndex !== -1) {
|
||||
// Replace the action with the observation
|
||||
newUiEvents[actionIndex] = event;
|
||||
} else {
|
||||
// Action not found in uiEvents, just add the observation
|
||||
|
||||
@@ -7,13 +7,6 @@ import { I18nextProvider, initReactI18next } from "react-i18next";
|
||||
import i18n from "i18next";
|
||||
import { vi } from "vitest";
|
||||
import { AxiosError } from "axios";
|
||||
import {
|
||||
ActionEvent,
|
||||
MessageEvent,
|
||||
ObservationEvent,
|
||||
PlanningFileEditorObservation,
|
||||
} from "#/types/v1/core";
|
||||
import { SecurityRisk } from "#/types/v1/core";
|
||||
|
||||
export const useParamsMock = vi.fn(() => ({
|
||||
conversationId: "test-conversation-id",
|
||||
@@ -105,100 +98,3 @@ export const createAxiosError = (
|
||||
config: {},
|
||||
},
|
||||
);
|
||||
|
||||
// Helper to create a PlanningFileEditorAction event
|
||||
export const createPlanningFileEditorActionEvent = (
|
||||
id: string,
|
||||
): ActionEvent => ({
|
||||
id,
|
||||
timestamp: new Date().toISOString(),
|
||||
source: "agent",
|
||||
thought: [{ type: "text", text: "Planning action" }],
|
||||
thinking_blocks: [],
|
||||
action: {
|
||||
kind: "PlanningFileEditorAction",
|
||||
command: "create",
|
||||
path: "/workspace/PLAN.md",
|
||||
file_text: "Plan content",
|
||||
old_str: null,
|
||||
new_str: null,
|
||||
insert_line: null,
|
||||
view_range: null,
|
||||
},
|
||||
tool_name: "planning_file_editor",
|
||||
tool_call_id: "call-1",
|
||||
tool_call: {
|
||||
id: "call-1",
|
||||
type: "function",
|
||||
function: {
|
||||
name: "planning_file_editor",
|
||||
arguments: '{"command": "create"}',
|
||||
},
|
||||
},
|
||||
llm_response_id: "response-1",
|
||||
security_risk: SecurityRisk.UNKNOWN,
|
||||
});
|
||||
|
||||
// Helper to create a non-planning action event
|
||||
export const createOtherActionEvent = (id: string): ActionEvent => ({
|
||||
id,
|
||||
timestamp: new Date().toISOString(),
|
||||
source: "agent",
|
||||
thought: [{ type: "text", text: "Other action" }],
|
||||
thinking_blocks: [],
|
||||
action: {
|
||||
kind: "ExecuteBashAction",
|
||||
command: "echo test",
|
||||
is_input: false,
|
||||
timeout: null,
|
||||
reset: false,
|
||||
},
|
||||
tool_name: "execute_bash",
|
||||
tool_call_id: "call-1",
|
||||
tool_call: {
|
||||
id: "call-1",
|
||||
type: "function",
|
||||
function: {
|
||||
name: "execute_bash",
|
||||
arguments: '{"command": "echo test"}',
|
||||
},
|
||||
},
|
||||
llm_response_id: "response-1",
|
||||
security_risk: SecurityRisk.UNKNOWN,
|
||||
});
|
||||
|
||||
// Helper to create a PlanningFileEditorObservation event
|
||||
export const createPlanningObservationEvent = (
|
||||
id: string,
|
||||
actionId: string = "action-1",
|
||||
): ObservationEvent<PlanningFileEditorObservation> => ({
|
||||
id,
|
||||
timestamp: new Date().toISOString(),
|
||||
source: "environment",
|
||||
tool_name: "planning_file_editor",
|
||||
tool_call_id: "call-1",
|
||||
action_id: actionId,
|
||||
observation: {
|
||||
kind: "PlanningFileEditorObservation",
|
||||
content: [{ type: "text", text: "Plan content" }],
|
||||
is_error: false,
|
||||
command: "create",
|
||||
path: "/workspace/PLAN.md",
|
||||
prev_exist: false,
|
||||
old_content: null,
|
||||
new_content: "Plan content",
|
||||
},
|
||||
});
|
||||
|
||||
// Helper to create a user message event
|
||||
export const createUserMessageEvent = (id: string): MessageEvent => ({
|
||||
id,
|
||||
timestamp: new Date().toISOString(),
|
||||
source: "user",
|
||||
llm_message: {
|
||||
role: "user",
|
||||
content: [{ type: "text", text: "User message" }],
|
||||
},
|
||||
activated_microagents: [],
|
||||
extended_content: [],
|
||||
});
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Literal
|
||||
from typing import Literal
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
@@ -14,7 +14,6 @@ from openhands.app_server.sandbox.sandbox_models import SandboxStatus
|
||||
from openhands.integrations.service_types import ProviderType
|
||||
from openhands.sdk.conversation.state import ConversationExecutionStatus
|
||||
from openhands.sdk.llm import MetricsSnapshot
|
||||
from openhands.sdk.plugin import PluginSource
|
||||
from openhands.storage.data_models.conversation_metadata import ConversationTrigger
|
||||
|
||||
|
||||
@@ -25,45 +24,6 @@ class AgentType(Enum):
|
||||
PLAN = 'plan'
|
||||
|
||||
|
||||
class PluginSpec(PluginSource):
|
||||
"""Specification for loading a plugin into a conversation.
|
||||
|
||||
Extends SDK's PluginSource with user-provided plugin configuration parameters.
|
||||
Inherits source, ref, and repo_path fields along with their validation.
|
||||
"""
|
||||
|
||||
parameters: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
description='User-provided values for plugin input parameters',
|
||||
)
|
||||
|
||||
@property
|
||||
def display_name(self) -> str:
|
||||
"""Extract a friendly display name from the plugin source.
|
||||
|
||||
Examples:
|
||||
- 'github:owner/repo' -> 'repo'
|
||||
- 'https://github.com/owner/repo.git' -> 'repo.git'
|
||||
- '/local/path' -> 'path'
|
||||
"""
|
||||
return self.source.split('/')[-1] if '/' in self.source else self.source
|
||||
|
||||
def format_params_as_text(self, indent: str = '') -> str | None:
|
||||
"""Format parameters as a readable text block for display.
|
||||
|
||||
Args:
|
||||
indent: Optional prefix to add before each parameter line.
|
||||
|
||||
Returns:
|
||||
Formatted parameters string, or None if no parameters.
|
||||
"""
|
||||
if not self.parameters:
|
||||
return None
|
||||
return '\n'.join(
|
||||
f'{indent}- {key}: {value}' for key, value in self.parameters.items()
|
||||
)
|
||||
|
||||
|
||||
class AppConversationInfo(BaseModel):
|
||||
"""Conversation info which does not contain status."""
|
||||
|
||||
@@ -158,15 +118,6 @@ class AppConversationStartRequest(OpenHandsModel):
|
||||
|
||||
public: bool | None = None
|
||||
|
||||
# Plugin parameters - for loading remote plugins into the conversation
|
||||
plugins: list[PluginSpec] | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
'List of plugins to load for this conversation. Plugins are loaded '
|
||||
'and their skills/MCP config are merged into the agent.'
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class AppConversationUpdateRequest(BaseModel):
|
||||
public: bool
|
||||
@@ -196,8 +147,7 @@ class AppConversationStartTask(OpenHandsModel):
|
||||
|
||||
Because starting an app conversation can be slow (And can involve starting a sandbox),
|
||||
we kick off a background task for it. Once the conversation is started, the app_conversation_id
|
||||
is populated.
|
||||
"""
|
||||
is populated."""
|
||||
|
||||
id: OpenHandsUUID = Field(default_factory=uuid4)
|
||||
created_by_user_id: str | None
|
||||
@@ -226,6 +176,6 @@ class SkillResponse(BaseModel):
|
||||
"""Response model for skills endpoint."""
|
||||
|
||||
name: str
|
||||
type: Literal['repo', 'knowledge', 'agentskills']
|
||||
type: Literal['repo', 'knowledge']
|
||||
content: str
|
||||
triggers: list[str] = []
|
||||
|
||||
@@ -503,6 +503,13 @@ async def get_conversation_skills(
|
||||
|
||||
agent_server_url = replace_localhost_hostname_for_docker(agent_server_url)
|
||||
|
||||
# Create remote workspace
|
||||
remote_workspace = AsyncRemoteWorkspace(
|
||||
host=agent_server_url,
|
||||
api_key=sandbox.session_api_key,
|
||||
working_dir=sandbox_spec.working_dir,
|
||||
)
|
||||
|
||||
# Load skills from all sources
|
||||
logger.info(f'Loading skills for conversation {conversation_id}')
|
||||
|
||||
@@ -511,9 +518,9 @@ async def get_conversation_skills(
|
||||
if isinstance(app_conversation_service, AppConversationServiceBase):
|
||||
all_skills = await app_conversation_service.load_and_merge_all_skills(
|
||||
sandbox,
|
||||
remote_workspace,
|
||||
conversation.selected_repository,
|
||||
sandbox_spec.working_dir,
|
||||
agent_server_url,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
@@ -524,11 +531,9 @@ async def get_conversation_skills(
|
||||
# Transform skills to response format
|
||||
skills_response = []
|
||||
for skill in all_skills:
|
||||
# Determine type based on AgentSkills format and trigger
|
||||
skill_type: Literal['repo', 'knowledge', 'agentskills']
|
||||
if skill.is_agentskills_format:
|
||||
skill_type = 'agentskills'
|
||||
elif skill.trigger is None:
|
||||
# Determine type based on trigger
|
||||
skill_type: Literal['repo', 'knowledge']
|
||||
if skill.trigger is None:
|
||||
skill_type = 'repo'
|
||||
else:
|
||||
skill_type = 'knowledge'
|
||||
@@ -621,6 +626,7 @@ async def _stream_app_conversation_start(
|
||||
user_context: UserContext,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Stream a json list, item by item."""
|
||||
|
||||
# Because the original dependencies are closed after the method returns, we need
|
||||
# a new dependency context which will continue intil the stream finishes.
|
||||
state = InjectorState()
|
||||
|
||||
@@ -95,7 +95,6 @@ class AppConversationService(ABC):
|
||||
task: AppConversationStartTask,
|
||||
sandbox: SandboxInfo,
|
||||
workspace: AsyncRemoteWorkspace,
|
||||
agent_server_url: str,
|
||||
) -> AsyncGenerator[AppConversationStartTask, None]:
|
||||
"""Run the setup scripts for the project and yield status updates"""
|
||||
yield task
|
||||
@@ -105,8 +104,7 @@ class AppConversationService(ABC):
|
||||
self, conversation_id: UUID, request: AppConversationUpdateRequest
|
||||
) -> AppConversation | None:
|
||||
"""Update an app conversation and return it. Return None if the conversation
|
||||
did not exist.
|
||||
"""
|
||||
did not exist."""
|
||||
|
||||
@abstractmethod
|
||||
async def delete_app_conversation(self, conversation_id: UUID) -> bool:
|
||||
|
||||
@@ -21,16 +21,18 @@ from openhands.app_server.app_conversation.app_conversation_service import (
|
||||
AppConversationService,
|
||||
)
|
||||
from openhands.app_server.app_conversation.skill_loader import (
|
||||
build_org_config,
|
||||
build_sandbox_config,
|
||||
load_skills_from_agent_server,
|
||||
load_global_skills,
|
||||
load_org_skills,
|
||||
load_repo_skills,
|
||||
load_sandbox_skills,
|
||||
merge_skills,
|
||||
)
|
||||
from openhands.app_server.sandbox.sandbox_models import SandboxInfo
|
||||
from openhands.app_server.user.user_context import UserContext
|
||||
from openhands.sdk import Agent
|
||||
from openhands.sdk.context.agent_context import AgentContext
|
||||
from openhands.sdk.context.condenser import LLMSummarizingCondenser
|
||||
from openhands.sdk.context.skills import Skill
|
||||
from openhands.sdk.context.skills import load_user_skills
|
||||
from openhands.sdk.llm import LLM
|
||||
from openhands.sdk.security.analyzer import SecurityAnalyzerBase
|
||||
from openhands.sdk.security.confirmation_policy import (
|
||||
@@ -51,8 +53,7 @@ PRE_COMMIT_LOCAL = '.git/hooks/pre-commit.local'
|
||||
class AppConversationServiceBase(AppConversationService, ABC):
|
||||
"""App Conversation service which adds git specific functionality.
|
||||
|
||||
Sets up repositories and installs hooks
|
||||
"""
|
||||
Sets up repositories and installs hooks"""
|
||||
|
||||
init_git_in_empty_workspace: bool
|
||||
user_context: UserContext
|
||||
@@ -60,74 +61,67 @@ class AppConversationServiceBase(AppConversationService, ABC):
|
||||
async def load_and_merge_all_skills(
|
||||
self,
|
||||
sandbox: SandboxInfo,
|
||||
remote_workspace: AsyncRemoteWorkspace,
|
||||
selected_repository: str | None,
|
||||
working_dir: str,
|
||||
agent_server_url: str,
|
||||
) -> list[Skill]:
|
||||
"""Load skills from all sources via the agent-server.
|
||||
) -> list:
|
||||
"""Load skills from all sources and merge them.
|
||||
|
||||
This method calls the agent-server's /api/skills endpoint to load and
|
||||
merge skills from all sources. The agent-server handles:
|
||||
- Public skills (from OpenHands/skills GitHub repo)
|
||||
- User skills (from ~/.openhands/skills/)
|
||||
- Organization skills (from {org}/.openhands repo)
|
||||
- Project/repo skills (from workspace .openhands/skills/)
|
||||
- Sandbox skills (from exposed URLs)
|
||||
This method handles all errors gracefully and will return an empty list
|
||||
if skill loading fails completely.
|
||||
|
||||
Args:
|
||||
sandbox: SandboxInfo containing exposed URLs and agent-server URL
|
||||
remote_workspace: AsyncRemoteWorkspace for loading repo skills
|
||||
selected_repository: Repository name or None
|
||||
working_dir: Working directory path
|
||||
agent_server_url: Agent-server URL (required)
|
||||
|
||||
Returns:
|
||||
List of merged Skill objects from all sources, or empty list on failure
|
||||
"""
|
||||
try:
|
||||
_logger.debug('Loading skills for V1 conversation via agent-server')
|
||||
_logger.debug('Loading skills for V1 conversation')
|
||||
|
||||
if not agent_server_url:
|
||||
_logger.warning('No agent-server URL available, cannot load skills')
|
||||
return []
|
||||
# Load skills from all sources
|
||||
sandbox_skills = load_sandbox_skills(sandbox)
|
||||
global_skills = load_global_skills()
|
||||
# Load user skills from ~/.openhands/skills/ directory
|
||||
# Uses the SDK's load_user_skills() function which handles loading from
|
||||
# ~/.openhands/skills/ and ~/.openhands/microagents/ (for backward compatibility)
|
||||
try:
|
||||
user_skills = load_user_skills()
|
||||
_logger.info(
|
||||
f'Loaded {len(user_skills)} user skills: {[s.name for s in user_skills]}'
|
||||
)
|
||||
except Exception as e:
|
||||
_logger.warning(f'Failed to load user skills: {str(e)}')
|
||||
user_skills = []
|
||||
|
||||
# Build org config (authentication handled by app-server)
|
||||
org_config = await build_org_config(selected_repository, self.user_context)
|
||||
# Load organization-level skills
|
||||
org_skills = await load_org_skills(
|
||||
remote_workspace, selected_repository, working_dir, self.user_context
|
||||
)
|
||||
|
||||
# Build sandbox config (exposed URLs)
|
||||
sandbox_config = build_sandbox_config(sandbox)
|
||||
repo_skills = await load_repo_skills(
|
||||
remote_workspace, selected_repository, working_dir
|
||||
)
|
||||
|
||||
# Determine project directory for project skills
|
||||
project_dir = working_dir
|
||||
if selected_repository:
|
||||
repo_name = selected_repository.split('/')[-1]
|
||||
project_dir = f'{working_dir}/{repo_name}'
|
||||
|
||||
# Single API call to agent-server for ALL skills
|
||||
all_skills = await load_skills_from_agent_server(
|
||||
agent_server_url=agent_server_url,
|
||||
session_api_key=sandbox.session_api_key,
|
||||
project_dir=project_dir,
|
||||
org_config=org_config,
|
||||
sandbox_config=sandbox_config,
|
||||
load_public=True,
|
||||
load_user=True,
|
||||
load_project=True,
|
||||
load_org=True,
|
||||
# Merge all skills (later lists override earlier ones)
|
||||
# Precedence: sandbox < global < user < org < repo
|
||||
all_skills = merge_skills(
|
||||
[sandbox_skills, global_skills, user_skills, org_skills, repo_skills]
|
||||
)
|
||||
|
||||
_logger.info(
|
||||
f'Loaded {len(all_skills)} total skills from agent-server: '
|
||||
f'{[s.name for s in all_skills]}'
|
||||
f'Loaded {len(all_skills)} total skills: {[s.name for s in all_skills]}'
|
||||
)
|
||||
|
||||
return all_skills
|
||||
|
||||
except Exception as e:
|
||||
_logger.warning(f'Failed to load skills: {e}', exc_info=True)
|
||||
# Return empty list on failure - skills will be loaded again later if needed
|
||||
return []
|
||||
|
||||
def _create_agent_with_skills(self, agent, skills: list[Skill]):
|
||||
def _create_agent_with_skills(self, agent, skills: list):
|
||||
"""Create or update agent with skills in its context.
|
||||
|
||||
Args:
|
||||
@@ -138,9 +132,9 @@ class AppConversationServiceBase(AppConversationService, ABC):
|
||||
Updated agent with skills in context
|
||||
"""
|
||||
if agent.agent_context:
|
||||
# Merge with existing context (new skills override existing ones)
|
||||
# Merge with existing context
|
||||
existing_skills = agent.agent_context.skills
|
||||
all_skills = self._merge_skills([existing_skills, skills])
|
||||
all_skills = merge_skills([skills, existing_skills])
|
||||
agent = agent.model_copy(
|
||||
update={
|
||||
'agent_context': agent.agent_context.model_copy(
|
||||
@@ -155,25 +149,6 @@ class AppConversationServiceBase(AppConversationService, ABC):
|
||||
|
||||
return agent
|
||||
|
||||
def _merge_skills(self, skill_lists: list[list[Skill]]) -> list[Skill]:
|
||||
"""Merge multiple skill lists, avoiding duplicates by name.
|
||||
|
||||
Later lists take precedence over earlier lists for duplicate names.
|
||||
|
||||
Args:
|
||||
skill_lists: List of skill lists to merge
|
||||
|
||||
Returns:
|
||||
Deduplicated list of skills with later lists overriding earlier ones
|
||||
"""
|
||||
skills_by_name: dict[str, Skill] = {}
|
||||
|
||||
for skill_list in skill_lists:
|
||||
for skill in skill_list:
|
||||
skills_by_name[skill.name] = skill
|
||||
|
||||
return list(skills_by_name.values())
|
||||
|
||||
async def _load_skills_and_update_agent(
|
||||
self,
|
||||
sandbox: SandboxInfo,
|
||||
@@ -194,10 +169,8 @@ class AppConversationServiceBase(AppConversationService, ABC):
|
||||
Updated agent with skills loaded into context
|
||||
"""
|
||||
# Load and merge all skills
|
||||
# Extract agent_server_url from remote_workspace host
|
||||
agent_server_url = remote_workspace.host
|
||||
all_skills = await self.load_and_merge_all_skills(
|
||||
sandbox, selected_repository, working_dir, agent_server_url
|
||||
sandbox, remote_workspace, selected_repository, working_dir
|
||||
)
|
||||
|
||||
# Update agent with skills
|
||||
@@ -210,7 +183,6 @@ class AppConversationServiceBase(AppConversationService, ABC):
|
||||
task: AppConversationStartTask,
|
||||
sandbox: SandboxInfo,
|
||||
workspace: AsyncRemoteWorkspace,
|
||||
agent_server_url: str,
|
||||
) -> AsyncGenerator[AppConversationStartTask, None]:
|
||||
task.status = AppConversationStartTaskStatus.PREPARING_REPOSITORY
|
||||
yield task
|
||||
@@ -228,9 +200,9 @@ class AppConversationServiceBase(AppConversationService, ABC):
|
||||
yield task
|
||||
await self.load_and_merge_all_skills(
|
||||
sandbox,
|
||||
workspace,
|
||||
task.request.selected_repository,
|
||||
workspace.working_dir,
|
||||
agent_server_url,
|
||||
)
|
||||
|
||||
async def _configure_git_user_settings(
|
||||
@@ -485,6 +457,7 @@ class AppConversationServiceBase(AppConversationService, ABC):
|
||||
security_analyzer_str: String value from settings
|
||||
httpx_client: HTTP client for making API requests
|
||||
"""
|
||||
|
||||
if session_api_key is None:
|
||||
return
|
||||
|
||||
|
||||
@@ -32,7 +32,6 @@ from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
AppConversationStartTask,
|
||||
AppConversationStartTaskStatus,
|
||||
AppConversationUpdateRequest,
|
||||
PluginSpec,
|
||||
)
|
||||
from openhands.app_server.app_conversation.app_conversation_service import (
|
||||
AppConversationService,
|
||||
@@ -80,7 +79,6 @@ from openhands.experiments.experiment_manager import ExperimentManagerImpl
|
||||
from openhands.integrations.provider import ProviderType
|
||||
from openhands.sdk import Agent, AgentContext, LocalWorkspace
|
||||
from openhands.sdk.llm import LLM
|
||||
from openhands.sdk.plugin import PluginSource
|
||||
from openhands.sdk.secret import LookupSecret, SecretValue, StaticSecret
|
||||
from openhands.sdk.utils.paging import page_iterator
|
||||
from openhands.sdk.workspace.remote.async_remote_workspace import AsyncRemoteWorkspace
|
||||
@@ -239,7 +237,7 @@ class LiveStatusAppConversationService(AppConversationServiceBase):
|
||||
working_dir=sandbox_spec.working_dir,
|
||||
)
|
||||
async for updated_task in self.run_setup_scripts(
|
||||
task, sandbox, remote_workspace, agent_server_url
|
||||
task, sandbox, remote_workspace
|
||||
):
|
||||
yield updated_task
|
||||
|
||||
@@ -256,7 +254,6 @@ class LiveStatusAppConversationService(AppConversationServiceBase):
|
||||
request.conversation_id,
|
||||
remote_workspace=remote_workspace,
|
||||
selected_repository=request.selected_repository,
|
||||
plugins=request.plugins,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -957,79 +954,6 @@ class LiveStatusAppConversationService(AppConversationServiceBase):
|
||||
return agent.model_copy(update=updates)
|
||||
return agent
|
||||
|
||||
def _construct_initial_message_with_plugin_params(
|
||||
self,
|
||||
initial_message: SendMessageRequest | None,
|
||||
plugins: list[PluginSpec] | None,
|
||||
) -> SendMessageRequest | None:
|
||||
"""Incorporate plugin parameters into the initial message if specified.
|
||||
|
||||
Plugin parameters are formatted and appended to the initial message so the
|
||||
agent has context about the user-provided configuration values.
|
||||
|
||||
Args:
|
||||
initial_message: The original initial message, if any
|
||||
plugins: List of plugin specifications with optional parameters
|
||||
|
||||
Returns:
|
||||
The initial message with plugin parameters incorporated, or the
|
||||
original message if no plugin parameters are specified
|
||||
"""
|
||||
from openhands.agent_server.models import TextContent
|
||||
|
||||
if not plugins:
|
||||
return initial_message
|
||||
|
||||
# Collect formatted parameters from plugins that have them
|
||||
plugins_with_params = [p for p in plugins if p.parameters]
|
||||
if not plugins_with_params:
|
||||
return initial_message
|
||||
|
||||
# Format parameters, grouped by plugin if multiple
|
||||
if len(plugins_with_params) == 1:
|
||||
params_text = plugins_with_params[0].format_params_as_text()
|
||||
plugin_params_message = (
|
||||
f'\n\nPlugin Configuration Parameters:\n{params_text}'
|
||||
)
|
||||
else:
|
||||
# Group by plugin name for clarity
|
||||
formatted_plugins = []
|
||||
for plugin in plugins_with_params:
|
||||
params_text = plugin.format_params_as_text(indent=' ')
|
||||
if params_text:
|
||||
formatted_plugins.append(f'{plugin.display_name}:\n{params_text}')
|
||||
|
||||
plugin_params_message = (
|
||||
'\n\nPlugin Configuration Parameters:\n' + '\n'.join(formatted_plugins)
|
||||
)
|
||||
|
||||
if initial_message is None:
|
||||
# Create a new message with just the plugin parameters
|
||||
return SendMessageRequest(
|
||||
content=[TextContent(text=plugin_params_message.strip())],
|
||||
run=True,
|
||||
)
|
||||
|
||||
# Append plugin parameters to existing message content
|
||||
new_content = list(initial_message.content)
|
||||
if new_content and isinstance(new_content[-1], TextContent):
|
||||
# Append to the last text content
|
||||
last_content = new_content[-1]
|
||||
new_content[-1] = TextContent(
|
||||
text=last_content.text + plugin_params_message,
|
||||
cache_prompt=last_content.cache_prompt,
|
||||
enable_truncation=last_content.enable_truncation,
|
||||
)
|
||||
else:
|
||||
# Add as new text content
|
||||
new_content.append(TextContent(text=plugin_params_message.strip()))
|
||||
|
||||
return SendMessageRequest(
|
||||
role=initial_message.role,
|
||||
content=new_content,
|
||||
run=initial_message.run,
|
||||
)
|
||||
|
||||
async def _finalize_conversation_request(
|
||||
self,
|
||||
agent: Agent,
|
||||
@@ -1042,7 +966,6 @@ class LiveStatusAppConversationService(AppConversationServiceBase):
|
||||
remote_workspace: AsyncRemoteWorkspace | None,
|
||||
selected_repository: str | None,
|
||||
working_dir: str,
|
||||
plugins: list[PluginSpec] | None = None,
|
||||
) -> StartConversationRequest:
|
||||
"""Finalize the conversation request with experiment variants and skills.
|
||||
|
||||
@@ -1057,7 +980,6 @@ class LiveStatusAppConversationService(AppConversationServiceBase):
|
||||
remote_workspace: Optional remote workspace for skills loading
|
||||
selected_repository: Optional repository name
|
||||
working_dir: Working directory path
|
||||
plugins: Optional list of plugin specifications to load
|
||||
|
||||
Returns:
|
||||
Complete StartConversationRequest ready for use
|
||||
@@ -1084,23 +1006,6 @@ class LiveStatusAppConversationService(AppConversationServiceBase):
|
||||
_logger.warning(f'Failed to load skills: {e}', exc_info=True)
|
||||
# Continue without skills - don't fail conversation startup
|
||||
|
||||
# Incorporate plugin parameters into initial message if specified
|
||||
final_initial_message = self._construct_initial_message_with_plugin_params(
|
||||
initial_message, plugins
|
||||
)
|
||||
|
||||
# Convert PluginSpec list to SDK PluginSource list for agent server
|
||||
sdk_plugins: list[PluginSource] | None = None
|
||||
if plugins:
|
||||
sdk_plugins = [
|
||||
PluginSource(
|
||||
source=p.source,
|
||||
ref=p.ref,
|
||||
repo_path=p.repo_path,
|
||||
)
|
||||
for p in plugins
|
||||
]
|
||||
|
||||
# Create and return the final request
|
||||
return StartConversationRequest(
|
||||
conversation_id=conversation_id,
|
||||
@@ -1109,9 +1014,8 @@ class LiveStatusAppConversationService(AppConversationServiceBase):
|
||||
confirmation_policy=self._select_confirmation_policy(
|
||||
bool(user.confirmation_mode), user.security_analyzer
|
||||
),
|
||||
initial_message=final_initial_message,
|
||||
initial_message=initial_message,
|
||||
secrets=secrets,
|
||||
plugins=sdk_plugins,
|
||||
)
|
||||
|
||||
async def _build_start_conversation_request_for_user(
|
||||
@@ -1126,7 +1030,6 @@ class LiveStatusAppConversationService(AppConversationServiceBase):
|
||||
conversation_id: UUID | None = None,
|
||||
remote_workspace: AsyncRemoteWorkspace | None = None,
|
||||
selected_repository: str | None = None,
|
||||
plugins: list[PluginSpec] | None = None,
|
||||
) -> StartConversationRequest:
|
||||
"""Build a complete conversation request for a user.
|
||||
|
||||
@@ -1135,7 +1038,6 @@ class LiveStatusAppConversationService(AppConversationServiceBase):
|
||||
2. Configuring LLM and MCP settings
|
||||
3. Creating an agent with appropriate context
|
||||
4. Finalizing the request with skills and experiment variants
|
||||
5. Passing plugins to the agent server for remote plugin loading
|
||||
"""
|
||||
user = await self.user_context.get_user_info()
|
||||
workspace = LocalWorkspace(working_dir=working_dir)
|
||||
@@ -1168,7 +1070,6 @@ class LiveStatusAppConversationService(AppConversationServiceBase):
|
||||
remote_workspace,
|
||||
selected_repository,
|
||||
working_dir,
|
||||
plugins=plugins,
|
||||
)
|
||||
|
||||
async def update_agent_server_conversation_title(
|
||||
@@ -1223,8 +1124,7 @@ class LiveStatusAppConversationService(AppConversationServiceBase):
|
||||
self, conversation_id: UUID, request: AppConversationUpdateRequest
|
||||
) -> AppConversation | None:
|
||||
"""Update an app conversation and return it. Return None if the conversation
|
||||
did not exist.
|
||||
"""
|
||||
did not exist."""
|
||||
info = await self.app_conversation_info_service.get_app_conversation_info(
|
||||
conversation_id
|
||||
)
|
||||
@@ -1395,7 +1295,7 @@ class LiveStatusAppConversationService(AppConversationServiceBase):
|
||||
# Get all events for this conversation
|
||||
i = 0
|
||||
async for event in page_iterator(
|
||||
self.event_service.search_events, conversation_id=conversation_id
|
||||
self.event_service.search_events, conversation_id__eq=conversation_id
|
||||
):
|
||||
event_filename = f'event_{i:06d}_{event.id}.json'
|
||||
event_path = os.path.join(temp_dir, event_filename)
|
||||
|
||||
@@ -1,37 +1,34 @@
|
||||
"""Utilities for loading skills for V1 conversations.
|
||||
|
||||
This module provides functions to load skills from the agent-server,
|
||||
which centralizes all skill loading logic. The app-server acts as a
|
||||
thin proxy that:
|
||||
1. Builds the org_config with authentication information
|
||||
2. Builds the sandbox_config with exposed URLs
|
||||
3. Calls the agent-server's /api/skills endpoint
|
||||
This module provides functions to load skills from various sources:
|
||||
- Global skills from OpenHands/skills/
|
||||
- User skills from ~/.openhands/skills/
|
||||
- Repository-level skills from the workspace
|
||||
|
||||
All source-specific skill loading is handled by the agent-server.
|
||||
All skills are used in V1 conversations.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel
|
||||
|
||||
import openhands
|
||||
from openhands.app_server.sandbox.sandbox_models import SandboxInfo
|
||||
from openhands.app_server.user.user_context import UserContext
|
||||
from openhands.integrations.provider import ProviderType
|
||||
from openhands.integrations.service_types import AuthenticationError
|
||||
from openhands.sdk.context.skills import Skill
|
||||
from openhands.sdk.context.skills.trigger import KeywordTrigger, TaskTrigger
|
||||
from openhands.sdk.workspace.remote.async_remote_workspace import AsyncRemoteWorkspace
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ExposedUrlConfig(BaseModel):
|
||||
"""Configuration for an exposed URL in sandbox config."""
|
||||
|
||||
name: str
|
||||
url: str
|
||||
port: int
|
||||
|
||||
# Path to global skills directory
|
||||
GLOBAL_SKILLS_DIR = os.path.join(
|
||||
os.path.dirname(os.path.dirname(openhands.__file__)),
|
||||
'skills',
|
||||
)
|
||||
WORK_HOSTS_SKILL = """The user has access to the following hosts for accessing a web application,
|
||||
each of which has a corresponding port:"""
|
||||
|
||||
WORK_HOSTS_SKILL_FOOTER = """
|
||||
When starting a web server, use the corresponding ports via environment variables:
|
||||
@@ -48,30 +45,96 @@ app.run(host='0.0.0.0', port=int(os.environ.get('WORKER_1', 12000)))
|
||||
```"""
|
||||
|
||||
|
||||
class SandboxConfig(BaseModel):
|
||||
"""Sandbox configuration for agent-server API request."""
|
||||
def _find_and_load_global_skill_files(skill_dir: Path) -> list[Skill]:
|
||||
"""Find and load all .md files from the global skills directory.
|
||||
|
||||
exposed_urls: list[ExposedUrlConfig]
|
||||
Args:
|
||||
skill_dir: Path to the global skills directory
|
||||
|
||||
Returns:
|
||||
List of Skill objects loaded from the files (excluding README.md)
|
||||
"""
|
||||
skills = []
|
||||
|
||||
try:
|
||||
# Find all .md files in the directory (excluding README.md)
|
||||
md_files = [f for f in skill_dir.glob('*.md') if f.name.lower() != 'readme.md']
|
||||
|
||||
# Load skills from the found files
|
||||
for file_path in md_files:
|
||||
try:
|
||||
skill = Skill.load(file_path, skill_dir)
|
||||
skills.append(skill)
|
||||
_logger.debug(f'Loaded global skill: {skill.name} from {file_path}')
|
||||
except Exception as e:
|
||||
_logger.warning(
|
||||
f'Failed to load global skill from {file_path}: {str(e)}'
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
_logger.debug(f'Failed to find global skill files: {str(e)}')
|
||||
|
||||
return skills
|
||||
|
||||
|
||||
class OrgConfig(BaseModel):
|
||||
"""Organization configuration for agent-server API request."""
|
||||
|
||||
repository: str
|
||||
provider: str
|
||||
org_repo_url: str
|
||||
org_name: str
|
||||
def load_sandbox_skills(sandbox: SandboxInfo) -> list[Skill]:
|
||||
"""Load skills specific to the sandbox, including exposed ports / urls."""
|
||||
if not sandbox.exposed_urls:
|
||||
return []
|
||||
urls = [url for url in sandbox.exposed_urls if url.name.startswith('WORKER_')]
|
||||
if not urls:
|
||||
return []
|
||||
content_list = [WORK_HOSTS_SKILL]
|
||||
for url in urls:
|
||||
content_list.append(f'* {url.url} (port {url.port})')
|
||||
content_list.append(WORK_HOSTS_SKILL_FOOTER)
|
||||
content = '\n'.join(content_list)
|
||||
return [Skill(name='work_hosts', content=content, trigger=None)]
|
||||
|
||||
|
||||
class SkillInfo(BaseModel):
|
||||
"""Skill information from agent-server API response."""
|
||||
def load_global_skills() -> list[Skill]:
|
||||
"""Load global skills from OpenHands/skills/ directory.
|
||||
|
||||
name: str
|
||||
content: str
|
||||
triggers: list[str] = []
|
||||
source: str | None = None
|
||||
description: str | None = None
|
||||
is_agentskills_format: bool = False
|
||||
Returns:
|
||||
List of Skill objects loaded from global skills directory.
|
||||
Returns empty list if directory doesn't exist or on errors.
|
||||
"""
|
||||
skill_dir = Path(GLOBAL_SKILLS_DIR)
|
||||
|
||||
# Check if directory exists
|
||||
if not skill_dir.exists():
|
||||
_logger.debug(f'Global skills directory does not exist: {skill_dir}')
|
||||
return []
|
||||
|
||||
try:
|
||||
_logger.info(f'Loading global skills from {skill_dir}')
|
||||
|
||||
# Find and load all .md files from the directory
|
||||
skills = _find_and_load_global_skill_files(skill_dir)
|
||||
|
||||
_logger.info(f'Loaded {len(skills)} global skills: {[s.name for s in skills]}')
|
||||
|
||||
return skills
|
||||
|
||||
except Exception as e:
|
||||
_logger.warning(f'Failed to load global skills: {str(e)}')
|
||||
return []
|
||||
|
||||
|
||||
def _determine_repo_root(working_dir: str, selected_repository: str | None) -> str:
|
||||
"""Determine the repository root directory.
|
||||
|
||||
Args:
|
||||
working_dir: Base working directory path
|
||||
selected_repository: Repository name (e.g., 'owner/repo') or None
|
||||
|
||||
Returns:
|
||||
Path to the repository root directory
|
||||
"""
|
||||
if selected_repository:
|
||||
repo_name = selected_repository.split('/')[-1]
|
||||
return f'{working_dir}/{repo_name}'
|
||||
return working_dir
|
||||
|
||||
|
||||
async def _is_gitlab_repository(repo_name: str, user_context: UserContext) -> bool:
|
||||
@@ -91,6 +154,8 @@ async def _is_gitlab_repository(repo_name: str, user_context: UserContext) -> bo
|
||||
)
|
||||
return repository.git_provider == ProviderType.GITLAB
|
||||
except Exception:
|
||||
# If we can't determine the provider, assume it's not GitLab
|
||||
# This is a safe fallback since we'll just use the default .openhands
|
||||
return False
|
||||
|
||||
|
||||
@@ -113,33 +178,10 @@ async def _is_azure_devops_repository(
|
||||
)
|
||||
return repository.git_provider == ProviderType.AZURE_DEVOPS
|
||||
except Exception:
|
||||
# If we can't determine the provider, assume it's not Azure DevOps
|
||||
return False
|
||||
|
||||
|
||||
async def _get_provider_type(
|
||||
selected_repository: str, user_context: UserContext
|
||||
) -> str:
|
||||
"""Determine the Git provider type for a repository.
|
||||
|
||||
Args:
|
||||
selected_repository: Repository name (e.g., 'owner/repo')
|
||||
user_context: UserContext to access provider handler
|
||||
|
||||
Returns:
|
||||
Provider type string: 'github', 'gitlab', 'azure', or 'bitbucket'
|
||||
"""
|
||||
is_gitlab = await _is_gitlab_repository(selected_repository, user_context)
|
||||
if is_gitlab:
|
||||
return 'gitlab'
|
||||
|
||||
is_azure = await _is_azure_devops_repository(selected_repository, user_context)
|
||||
if is_azure:
|
||||
return 'azure'
|
||||
|
||||
# Default to github (covers github and bitbucket)
|
||||
return 'github'
|
||||
|
||||
|
||||
async def _determine_org_repo_path(
|
||||
selected_repository: str, user_context: UserContext
|
||||
) -> tuple[str, str]:
|
||||
@@ -161,19 +203,27 @@ async def _determine_org_repo_path(
|
||||
"""
|
||||
repo_parts = selected_repository.split('/')
|
||||
|
||||
# Determine repository type
|
||||
is_azure_devops = await _is_azure_devops_repository(
|
||||
selected_repository, user_context
|
||||
)
|
||||
is_gitlab = await _is_gitlab_repository(selected_repository, user_context)
|
||||
|
||||
# Extract the org/user name
|
||||
# Azure DevOps format: org/project/repo (3 parts) - extract org (first part)
|
||||
# GitHub/GitLab/Bitbucket format: owner/repo (2 parts) - extract owner (first part)
|
||||
if is_azure_devops and len(repo_parts) >= 3:
|
||||
org_name = repo_parts[0]
|
||||
org_name = repo_parts[0] # Get org from org/project/repo
|
||||
else:
|
||||
org_name = repo_parts[-2]
|
||||
org_name = repo_parts[-2] # Get owner from owner/repo
|
||||
|
||||
# For GitLab and Azure DevOps, use openhands-config (since .openhands is not a valid repo name)
|
||||
# For other providers, use .openhands
|
||||
if is_gitlab:
|
||||
org_openhands_repo = f'{org_name}/openhands-config'
|
||||
elif is_azure_devops:
|
||||
# Azure DevOps format: org/project/repo
|
||||
# For org-level config, use: org/openhands-config/openhands-config
|
||||
org_openhands_repo = f'{org_name}/openhands-config/openhands-config'
|
||||
else:
|
||||
org_openhands_repo = f'{org_name}/.openhands'
|
||||
@@ -181,6 +231,227 @@ async def _determine_org_repo_path(
|
||||
return org_openhands_repo, org_name
|
||||
|
||||
|
||||
async def _read_file_from_workspace(
|
||||
workspace: AsyncRemoteWorkspace, file_path: str, working_dir: str
|
||||
) -> str | None:
|
||||
"""Read file content from remote workspace.
|
||||
|
||||
Args:
|
||||
workspace: AsyncRemoteWorkspace to execute commands
|
||||
file_path: Path to the file to read
|
||||
working_dir: Working directory for command execution
|
||||
|
||||
Returns:
|
||||
File content as string, or None if file doesn't exist or read fails
|
||||
"""
|
||||
try:
|
||||
result = await workspace.execute_command(
|
||||
f'cat {file_path}', cwd=working_dir, timeout=10.0
|
||||
)
|
||||
if result.exit_code == 0 and result.stdout.strip():
|
||||
return result.stdout
|
||||
return None
|
||||
except Exception as e:
|
||||
_logger.debug(f'Failed to read file {file_path}: {str(e)}')
|
||||
return None
|
||||
|
||||
|
||||
async def _load_special_files(
|
||||
workspace: AsyncRemoteWorkspace, repo_root: str, working_dir: str
|
||||
) -> list[Skill]:
|
||||
"""Load special skill files from repository root.
|
||||
|
||||
Loads: .cursorrules, agents.md, agent.md
|
||||
|
||||
Args:
|
||||
workspace: AsyncRemoteWorkspace to execute commands
|
||||
repo_root: Path to repository root directory
|
||||
working_dir: Working directory for command execution
|
||||
|
||||
Returns:
|
||||
List of Skill objects loaded from special files
|
||||
"""
|
||||
skills = []
|
||||
special_files = ['.cursorrules', 'agents.md', 'agent.md']
|
||||
|
||||
for filename in special_files:
|
||||
file_path = f'{repo_root}/{filename}'
|
||||
content = await _read_file_from_workspace(workspace, file_path, working_dir)
|
||||
|
||||
if content:
|
||||
try:
|
||||
# Use simple string path to avoid Path filesystem operations
|
||||
skill = Skill.load(path=filename, skill_dir=None, file_content=content)
|
||||
skills.append(skill)
|
||||
_logger.debug(f'Loaded special file skill: {skill.name}')
|
||||
except Exception as e:
|
||||
_logger.warning(f'Failed to create skill from {filename}: {str(e)}')
|
||||
|
||||
return skills
|
||||
|
||||
|
||||
async def _find_and_load_skill_md_files(
|
||||
workspace: AsyncRemoteWorkspace, skill_dir: str, working_dir: str
|
||||
) -> list[Skill]:
|
||||
"""Find and load all .md files from a skills directory in the workspace.
|
||||
|
||||
Args:
|
||||
workspace: AsyncRemoteWorkspace to execute commands
|
||||
skill_dir: Path to skills directory
|
||||
working_dir: Working directory for command execution
|
||||
|
||||
Returns:
|
||||
List of Skill objects loaded from the files (excluding README.md)
|
||||
"""
|
||||
skills = []
|
||||
|
||||
try:
|
||||
# Find all .md files in the directory
|
||||
result = await workspace.execute_command(
|
||||
f"find {skill_dir} -type f -name '*.md' 2>/dev/null || true",
|
||||
cwd=working_dir,
|
||||
timeout=10.0,
|
||||
)
|
||||
|
||||
if result.exit_code == 0 and result.stdout.strip():
|
||||
file_paths = [
|
||||
f.strip()
|
||||
for f in result.stdout.strip().split('\n')
|
||||
if f.strip() and 'README.md' not in f
|
||||
]
|
||||
|
||||
# Load skills from the found files
|
||||
for file_path in file_paths:
|
||||
content = await _read_file_from_workspace(
|
||||
workspace, file_path, working_dir
|
||||
)
|
||||
|
||||
if content:
|
||||
# Calculate relative path for skill name
|
||||
rel_path = file_path.replace(f'{skill_dir}/', '')
|
||||
try:
|
||||
# Use simple string path to avoid Path filesystem operations
|
||||
skill = Skill.load(
|
||||
path=rel_path, skill_dir=None, file_content=content
|
||||
)
|
||||
skills.append(skill)
|
||||
_logger.debug(f'Loaded repo skill: {skill.name}')
|
||||
except Exception as e:
|
||||
_logger.warning(
|
||||
f'Failed to create skill from {rel_path}: {str(e)}'
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
_logger.debug(f'Failed to find skill files in {skill_dir}: {str(e)}')
|
||||
|
||||
return skills
|
||||
|
||||
|
||||
def _merge_repo_skills_with_precedence(
|
||||
special_skills: list[Skill],
|
||||
skills_dir_skills: list[Skill],
|
||||
microagents_dir_skills: list[Skill],
|
||||
) -> list[Skill]:
|
||||
"""Merge repository skills with precedence order.
|
||||
|
||||
Precedence (highest to lowest):
|
||||
1. Special files (repo root)
|
||||
2. .openhands/skills/ directory
|
||||
3. .openhands/microagents/ directory (backward compatibility)
|
||||
|
||||
Args:
|
||||
special_skills: Skills from special files in repo root
|
||||
skills_dir_skills: Skills from .openhands/skills/ directory
|
||||
microagents_dir_skills: Skills from .openhands/microagents/ directory
|
||||
|
||||
Returns:
|
||||
Deduplicated list of skills with proper precedence
|
||||
"""
|
||||
# Use a dict to deduplicate by name, with earlier sources taking precedence
|
||||
skills_by_name = {}
|
||||
for skill in special_skills + skills_dir_skills + microagents_dir_skills:
|
||||
# Only add if not already present (earlier sources win)
|
||||
if skill.name not in skills_by_name:
|
||||
skills_by_name[skill.name] = skill
|
||||
|
||||
return list(skills_by_name.values())
|
||||
|
||||
|
||||
async def load_repo_skills(
|
||||
workspace: AsyncRemoteWorkspace,
|
||||
selected_repository: str | None,
|
||||
working_dir: str,
|
||||
) -> list[Skill]:
|
||||
"""Load repository-level skills from the workspace.
|
||||
|
||||
Loads skills from:
|
||||
1. Special files in repo root: .cursorrules, agents.md, agent.md
|
||||
2. .md files in .openhands/skills/ directory (preferred)
|
||||
3. .md files in .openhands/microagents/ directory (for backward compatibility)
|
||||
|
||||
Args:
|
||||
workspace: AsyncRemoteWorkspace to execute commands in the sandbox
|
||||
selected_repository: Repository name (e.g., 'owner/repo') or None
|
||||
working_dir: Working directory path
|
||||
|
||||
Returns:
|
||||
List of Skill objects loaded from repository.
|
||||
Returns empty list on errors.
|
||||
"""
|
||||
try:
|
||||
# Determine repository root directory
|
||||
repo_root = _determine_repo_root(working_dir, selected_repository)
|
||||
_logger.info(f'Loading repo skills from {repo_root}')
|
||||
|
||||
# Load special files from repo root
|
||||
special_skills = await _load_special_files(workspace, repo_root, working_dir)
|
||||
|
||||
# Load .md files from .openhands/skills/ directory (preferred)
|
||||
skills_dir = f'{repo_root}/.openhands/skills'
|
||||
skills_dir_skills = await _find_and_load_skill_md_files(
|
||||
workspace, skills_dir, working_dir
|
||||
)
|
||||
|
||||
# Load .md files from .openhands/microagents/ directory (backward compatibility)
|
||||
microagents_dir = f'{repo_root}/.openhands/microagents'
|
||||
microagents_dir_skills = await _find_and_load_skill_md_files(
|
||||
workspace, microagents_dir, working_dir
|
||||
)
|
||||
|
||||
# Merge all loaded skills with proper precedence
|
||||
all_skills = _merge_repo_skills_with_precedence(
|
||||
special_skills, skills_dir_skills, microagents_dir_skills
|
||||
)
|
||||
|
||||
_logger.info(
|
||||
f'Loaded {len(all_skills)} repo skills: {[s.name for s in all_skills]}'
|
||||
)
|
||||
|
||||
return all_skills
|
||||
|
||||
except Exception as e:
|
||||
_logger.warning(f'Failed to load repo skills: {str(e)}')
|
||||
return []
|
||||
|
||||
|
||||
def _validate_repository_for_org_skills(selected_repository: str) -> bool:
|
||||
"""Validate that the repository path has sufficient parts for org skills.
|
||||
|
||||
Args:
|
||||
selected_repository: Repository name (e.g., 'owner/repo')
|
||||
|
||||
Returns:
|
||||
True if repository is valid for org skills loading, False otherwise
|
||||
"""
|
||||
repo_parts = selected_repository.split('/')
|
||||
if len(repo_parts) < 2:
|
||||
_logger.warning(
|
||||
f'Repository path has insufficient parts ({len(repo_parts)} < 2), skipping org-level skills'
|
||||
)
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
async def _get_org_repository_url(
|
||||
org_openhands_repo: str, user_context: UserContext
|
||||
) -> str | None:
|
||||
@@ -210,193 +481,224 @@ async def _get_org_repository_url(
|
||||
return None
|
||||
|
||||
|
||||
async def build_org_config(
|
||||
selected_repository: str | None,
|
||||
user_context: UserContext,
|
||||
) -> OrgConfig | None:
|
||||
"""Build organization config for agent-server API request.
|
||||
async def _clone_org_repository(
|
||||
workspace: AsyncRemoteWorkspace,
|
||||
remote_url: str,
|
||||
org_repo_dir: str,
|
||||
working_dir: str,
|
||||
org_openhands_repo: str,
|
||||
) -> bool:
|
||||
"""Clone organization repository to temporary directory.
|
||||
|
||||
Args:
|
||||
selected_repository: Repository name (e.g., 'owner/repo') or None
|
||||
user_context: UserContext to access authentication and provider info
|
||||
workspace: AsyncRemoteWorkspace to execute commands
|
||||
remote_url: Authenticated Git URL
|
||||
org_repo_dir: Temporary directory path for cloning
|
||||
working_dir: Working directory for command execution
|
||||
org_openhands_repo: Organization repository path (for logging)
|
||||
|
||||
Returns:
|
||||
org_config dict if org repository exists and is accessible, None otherwise
|
||||
True if clone successful, False otherwise
|
||||
"""
|
||||
_logger.debug(f'Creating temporary directory for org repo: {org_repo_dir}')
|
||||
|
||||
# Clone the repo (shallow clone for efficiency)
|
||||
clone_cmd = f'GIT_TERMINAL_PROMPT=0 git clone --depth 1 {remote_url} {org_repo_dir}'
|
||||
_logger.info('Executing clone command for org-level repo')
|
||||
|
||||
result = await workspace.execute_command(clone_cmd, working_dir, timeout=120.0)
|
||||
|
||||
if result.exit_code != 0:
|
||||
_logger.info(
|
||||
f'No org-level skills found at {org_openhands_repo} (exit_code: {result.exit_code})'
|
||||
)
|
||||
_logger.debug(f'Clone command output: {result.stderr}')
|
||||
return False
|
||||
|
||||
_logger.info(f'Successfully cloned org-level skills from {org_openhands_repo}')
|
||||
return True
|
||||
|
||||
|
||||
async def _load_skills_from_org_directories(
|
||||
workspace: AsyncRemoteWorkspace, org_repo_dir: str, working_dir: str
|
||||
) -> tuple[list[Skill], list[Skill]]:
|
||||
"""Load skills from both skills/ and microagents/ directories in org repo.
|
||||
|
||||
Args:
|
||||
workspace: AsyncRemoteWorkspace to execute commands
|
||||
org_repo_dir: Path to cloned organization repository
|
||||
working_dir: Working directory for command execution
|
||||
|
||||
Returns:
|
||||
Tuple of (skills_dir_skills, microagents_dir_skills)
|
||||
"""
|
||||
skills_dir = f'{org_repo_dir}/skills'
|
||||
skills_dir_skills = await _find_and_load_skill_md_files(
|
||||
workspace, skills_dir, working_dir
|
||||
)
|
||||
|
||||
microagents_dir = f'{org_repo_dir}/microagents'
|
||||
microagents_dir_skills = await _find_and_load_skill_md_files(
|
||||
workspace, microagents_dir, working_dir
|
||||
)
|
||||
|
||||
return skills_dir_skills, microagents_dir_skills
|
||||
|
||||
|
||||
def _merge_org_skills_with_precedence(
|
||||
skills_dir_skills: list[Skill], microagents_dir_skills: list[Skill]
|
||||
) -> list[Skill]:
|
||||
"""Merge skills from skills/ and microagents/ with proper precedence.
|
||||
|
||||
Precedence: skills/ > microagents/ (skills/ overrides microagents/ for same name)
|
||||
|
||||
Args:
|
||||
skills_dir_skills: Skills loaded from skills/ directory
|
||||
microagents_dir_skills: Skills loaded from microagents/ directory
|
||||
|
||||
Returns:
|
||||
Merged list of skills with proper precedence applied
|
||||
"""
|
||||
skills_by_name = {}
|
||||
for skill in microagents_dir_skills + skills_dir_skills:
|
||||
# Later sources (skills/) override earlier ones (microagents/)
|
||||
if skill.name not in skills_by_name:
|
||||
skills_by_name[skill.name] = skill
|
||||
else:
|
||||
_logger.debug(
|
||||
f'Overriding org skill "{skill.name}" from microagents/ with skills/'
|
||||
)
|
||||
skills_by_name[skill.name] = skill
|
||||
|
||||
return list(skills_by_name.values())
|
||||
|
||||
|
||||
async def _cleanup_org_repository(
|
||||
workspace: AsyncRemoteWorkspace, org_repo_dir: str, working_dir: str
|
||||
) -> None:
|
||||
"""Clean up cloned organization repository directory.
|
||||
|
||||
Args:
|
||||
workspace: AsyncRemoteWorkspace to execute commands
|
||||
org_repo_dir: Path to cloned organization repository
|
||||
working_dir: Working directory for command execution
|
||||
"""
|
||||
cleanup_cmd = f'rm -rf {org_repo_dir}'
|
||||
await workspace.execute_command(cleanup_cmd, working_dir, timeout=10.0)
|
||||
|
||||
|
||||
async def load_org_skills(
|
||||
workspace: AsyncRemoteWorkspace,
|
||||
selected_repository: str | None,
|
||||
working_dir: str,
|
||||
user_context: UserContext,
|
||||
) -> list[Skill]:
|
||||
"""Load organization-level skills from the organization repository.
|
||||
|
||||
For example, if the repository is github.com/acme-co/api, this will check if
|
||||
github.com/acme-co/.openhands exists. If it does, it will clone it and load
|
||||
the skills from both the ./skills/ and ./microagents/ folders.
|
||||
|
||||
For GitLab repositories, it will use openhands-config instead of .openhands
|
||||
since GitLab doesn't support repository names starting with non-alphanumeric
|
||||
characters.
|
||||
|
||||
For Azure DevOps repositories, it will use org/openhands-config/openhands-config
|
||||
format to match Azure DevOps's three-part repository structure (org/project/repo).
|
||||
|
||||
Args:
|
||||
workspace: AsyncRemoteWorkspace to execute commands in the sandbox
|
||||
selected_repository: Repository name (e.g., 'owner/repo') or None
|
||||
working_dir: Working directory path
|
||||
user_context: UserContext to access provider handler and authentication
|
||||
|
||||
Returns:
|
||||
List of Skill objects loaded from organization repository.
|
||||
Returns empty list if no repository selected or on errors.
|
||||
"""
|
||||
if not selected_repository:
|
||||
return None
|
||||
|
||||
repo_parts = selected_repository.split('/')
|
||||
if len(repo_parts) < 2:
|
||||
_logger.warning(
|
||||
f'Repository path has insufficient parts ({len(repo_parts)} < 2), '
|
||||
f'skipping org-level skills'
|
||||
)
|
||||
return None
|
||||
return []
|
||||
|
||||
try:
|
||||
_logger.debug(
|
||||
f'Starting org-level skill loading for repository: {selected_repository}'
|
||||
)
|
||||
|
||||
# Validate repository path
|
||||
if not _validate_repository_for_org_skills(selected_repository):
|
||||
return []
|
||||
|
||||
# Determine organization repository path
|
||||
org_openhands_repo, org_name = await _determine_org_repo_path(
|
||||
selected_repository, user_context
|
||||
)
|
||||
|
||||
org_repo_url = await _get_org_repository_url(org_openhands_repo, user_context)
|
||||
if not org_repo_url:
|
||||
return None
|
||||
_logger.info(f'Checking for org-level skills at {org_openhands_repo}')
|
||||
|
||||
provider = await _get_provider_type(selected_repository, user_context)
|
||||
# Get authenticated URL for org repository
|
||||
remote_url = await _get_org_repository_url(org_openhands_repo, user_context)
|
||||
if not remote_url:
|
||||
return []
|
||||
|
||||
return OrgConfig(
|
||||
repository=selected_repository,
|
||||
provider=provider,
|
||||
org_repo_url=org_repo_url,
|
||||
org_name=org_name,
|
||||
# Clone the organization repository
|
||||
org_repo_dir = f'{working_dir}/_org_openhands_{org_name}'
|
||||
clone_success = await _clone_org_repository(
|
||||
workspace, remote_url, org_repo_dir, working_dir, org_openhands_repo
|
||||
)
|
||||
if not clone_success:
|
||||
return []
|
||||
|
||||
# Load skills from both skills/ and microagents/ directories
|
||||
(
|
||||
skills_dir_skills,
|
||||
microagents_dir_skills,
|
||||
) = await _load_skills_from_org_directories(
|
||||
workspace, org_repo_dir, working_dir
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
_logger.debug(f'Failed to build org config: {str(e)}')
|
||||
return None
|
||||
# Merge skills with proper precedence
|
||||
loaded_skills = _merge_org_skills_with_precedence(
|
||||
skills_dir_skills, microagents_dir_skills
|
||||
)
|
||||
|
||||
|
||||
def build_sandbox_config(sandbox: SandboxInfo) -> SandboxConfig | None:
|
||||
"""Build sandbox config for agent-server API request.
|
||||
|
||||
Args:
|
||||
sandbox: SandboxInfo containing exposed URLs
|
||||
|
||||
Returns:
|
||||
sandbox_config dict if there are exposed URLs, None otherwise
|
||||
"""
|
||||
if not sandbox.exposed_urls:
|
||||
return None
|
||||
|
||||
exposed_urls = [
|
||||
ExposedUrlConfig(name=url.name, url=url.url, port=url.port)
|
||||
for url in sandbox.exposed_urls
|
||||
]
|
||||
|
||||
return SandboxConfig(exposed_urls=exposed_urls)
|
||||
|
||||
|
||||
async def load_skills_from_agent_server(
|
||||
agent_server_url: str,
|
||||
session_api_key: str | None,
|
||||
project_dir: str,
|
||||
org_config: OrgConfig | None = None,
|
||||
sandbox_config: SandboxConfig | None = None,
|
||||
load_public: bool = True,
|
||||
load_user: bool = True,
|
||||
load_project: bool = True,
|
||||
load_org: bool = True,
|
||||
) -> list[Skill]:
|
||||
"""Load all skills from the agent-server.
|
||||
|
||||
This function makes a single API call to the agent-server's /api/skills
|
||||
endpoint to load and merge skills from all configured sources.
|
||||
|
||||
Args:
|
||||
agent_server_url: URL of the agent server (e.g., 'http://localhost:8000')
|
||||
session_api_key: Session API key for authentication (optional)
|
||||
project_dir: Workspace directory path for project skills
|
||||
org_config: Organization skills configuration (optional)
|
||||
sandbox_config: Sandbox skills configuration (optional)
|
||||
load_public: Whether to load public skills (default: True)
|
||||
load_user: Whether to load user skills (default: True)
|
||||
load_project: Whether to load project skills (default: True)
|
||||
load_org: Whether to load organization skills (default: True)
|
||||
|
||||
Returns:
|
||||
List of Skill objects merged from all sources.
|
||||
Returns empty list on error.
|
||||
"""
|
||||
try:
|
||||
# Build request payload
|
||||
payload = {
|
||||
'load_public': load_public,
|
||||
'load_user': load_user,
|
||||
'load_project': load_project,
|
||||
'load_org': load_org,
|
||||
'project_dir': project_dir,
|
||||
'org_config': org_config.model_dump() if org_config else None,
|
||||
'sandbox_config': sandbox_config.model_dump() if sandbox_config else None,
|
||||
}
|
||||
|
||||
# Build headers
|
||||
headers = {'Content-Type': 'application/json'}
|
||||
if session_api_key:
|
||||
headers['X-Session-API-Key'] = session_api_key
|
||||
|
||||
# Make API request
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
f'{agent_server_url}/api/skills',
|
||||
json=payload,
|
||||
headers=headers,
|
||||
timeout=60.0,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
data = response.json()
|
||||
|
||||
# Convert response to Skill objects
|
||||
skills: list[Skill] = []
|
||||
for skill_data_dict in data.get('skills', []):
|
||||
try:
|
||||
skill_info = SkillInfo.model_validate(skill_data_dict)
|
||||
skill = _convert_skill_info_to_skill(skill_info)
|
||||
skills.append(skill)
|
||||
except Exception as e:
|
||||
skill_name = (
|
||||
skill_data_dict.get('name', 'unknown')
|
||||
if isinstance(skill_data_dict, dict)
|
||||
else 'unknown'
|
||||
)
|
||||
_logger.warning(f'Failed to convert skill {skill_name}: {e}')
|
||||
|
||||
sources = data.get('sources', {})
|
||||
_logger.info(
|
||||
f'Loaded {len(skills)} skills from agent-server: '
|
||||
f'sources={sources}, names={[s.name for s in skills]}'
|
||||
f'Loaded {len(loaded_skills)} skills from org-level repository {org_openhands_repo}: {[s.name for s in loaded_skills]}'
|
||||
)
|
||||
|
||||
return skills
|
||||
# Clean up the org repo directory
|
||||
await _cleanup_org_repository(workspace, org_repo_dir, working_dir)
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
_logger.warning(
|
||||
f'Agent-server returned error status {e.response.status_code}: '
|
||||
f'{e.response.text}'
|
||||
)
|
||||
return []
|
||||
except httpx.RequestError as e:
|
||||
_logger.warning(f'Failed to connect to agent-server: {e}')
|
||||
return loaded_skills
|
||||
|
||||
except AuthenticationError as e:
|
||||
_logger.debug(f'org-level skill directory not found: {str(e)}')
|
||||
return []
|
||||
except Exception as e:
|
||||
_logger.warning(f'Failed to load skills from agent-server: {e}')
|
||||
_logger.warning(f'Failed to load org-level skills: {str(e)}')
|
||||
return []
|
||||
|
||||
|
||||
def _convert_skill_info_to_skill(skill_info: SkillInfo) -> Skill:
|
||||
"""Convert skill info from API response to Skill object.
|
||||
def merge_skills(skill_lists: list[list[Skill]]) -> list[Skill]:
|
||||
"""Merge multiple skill lists, avoiding duplicates by name.
|
||||
|
||||
Later lists take precedence over earlier lists for duplicate names.
|
||||
|
||||
Args:
|
||||
skill_info: SkillInfo model from API response
|
||||
skill_lists: List of skill lists to merge
|
||||
|
||||
Returns:
|
||||
Skill object
|
||||
Deduplicated list of skills with later lists overriding earlier ones
|
||||
"""
|
||||
trigger = None
|
||||
skills_by_name = {}
|
||||
|
||||
if skill_info.triggers:
|
||||
# Determine trigger type based on content
|
||||
if any(t.startswith('/') for t in skill_info.triggers):
|
||||
trigger = TaskTrigger(triggers=skill_info.triggers)
|
||||
else:
|
||||
trigger = KeywordTrigger(keywords=skill_info.triggers)
|
||||
for skill_list in skill_lists:
|
||||
for skill in skill_list:
|
||||
if skill.name in skills_by_name:
|
||||
_logger.debug(
|
||||
f'Overriding skill "{skill.name}" from earlier source with later source'
|
||||
)
|
||||
skills_by_name[skill.name] = skill
|
||||
|
||||
return Skill(
|
||||
name=skill_info.name,
|
||||
content=skill_info.content,
|
||||
trigger=trigger,
|
||||
source=skill_info.source,
|
||||
description=skill_info.description,
|
||||
is_agentskills_format=skill_info.is_agentskills_format,
|
||||
)
|
||||
result = list(skills_by_name.values())
|
||||
_logger.debug(f'Merged skills: {[s.name for s in result]}')
|
||||
return result
|
||||
|
||||
@@ -541,8 +541,7 @@ class SQLAppConversationInfoService(AppConversationInfoService):
|
||||
|
||||
def _fix_timezone(self, value: datetime) -> datetime:
|
||||
"""Sqlite does not stpre timezones - and since we can't update the existing models
|
||||
we assume UTC if the timezone is missing.
|
||||
"""
|
||||
we assume UTC if the timezone is missing."""
|
||||
if not value.tzinfo:
|
||||
value = value.replace(tzinfo=UTC)
|
||||
return value
|
||||
|
||||
@@ -68,8 +68,7 @@ class StoredAppConversationStartTask(Base): # type: ignore
|
||||
class SQLAppConversationStartTaskService(AppConversationStartTaskService):
|
||||
"""SQL implementation of AppConversationStartTaskService focused on db operations.
|
||||
|
||||
This allows storing and retrieving conversation start tasks from the database.
|
||||
"""
|
||||
This allows storing and retrieving conversation start tasks from the database."""
|
||||
|
||||
session: AsyncSession
|
||||
user_id: str | None = None
|
||||
|
||||
@@ -243,7 +243,20 @@ def config_from_env() -> AppServerConfig:
|
||||
config.sandbox_spec = DockerSandboxSpecServiceInjector()
|
||||
|
||||
if config.app_conversation_info is None:
|
||||
config.app_conversation_info = SQLAppConversationInfoServiceInjector()
|
||||
# Use enterprise injector if running in SAAS mode
|
||||
if 'saas' in (os.getenv('OPENHANDS_CONFIG_CLS') or '').lower():
|
||||
try:
|
||||
# Import enterprise injector dynamically
|
||||
from enterprise.server.utils.saas_app_conversation_info_injector import (
|
||||
SaasAppConversationInfoServiceInjector,
|
||||
)
|
||||
|
||||
config.app_conversation_info = SaasAppConversationInfoServiceInjector()
|
||||
except ImportError:
|
||||
# Fallback to OSS injector if enterprise module is not available
|
||||
config.app_conversation_info = SQLAppConversationInfoServiceInjector()
|
||||
else:
|
||||
config.app_conversation_info = SQLAppConversationInfoServiceInjector()
|
||||
|
||||
if config.app_conversation_start_task is None:
|
||||
config.app_conversation_start_task = (
|
||||
|
||||
@@ -13,7 +13,7 @@ from openhands.sdk.utils.models import DiscriminatedUnionMixin
|
||||
|
||||
# The version of the agent server to use for deployments.
|
||||
# Typically this will be the same as the values from the pyproject.toml
|
||||
AGENT_SERVER_IMAGE = 'ghcr.io/openhands/agent-server:c775ff6-python'
|
||||
AGENT_SERVER_IMAGE = 'ghcr.io/openhands/agent-server:10fff69-python'
|
||||
|
||||
|
||||
class SandboxSpecService(ABC):
|
||||
|
||||
+1
-16
@@ -136,7 +136,7 @@ class LLM(RetryMixin, DebugMixin):
|
||||
if self.config.model.startswith('openhands/'):
|
||||
model_name = self.config.model.removeprefix('openhands/')
|
||||
self.config.model = f'litellm_proxy/{model_name}'
|
||||
self.config.base_url = _get_openhands_llm_base_url()
|
||||
self.config.base_url = 'https://llm-proxy.app.all-hands.dev/'
|
||||
logger.debug(
|
||||
f'Rewrote openhands/{model_name} to {self.config.model} with base URL {self.config.base_url}'
|
||||
)
|
||||
@@ -851,18 +851,3 @@ class LLM(RetryMixin, DebugMixin):
|
||||
|
||||
# let pydantic handle the serialization
|
||||
return [message.model_dump() for message in messages]
|
||||
|
||||
|
||||
def _get_openhands_llm_base_url():
|
||||
# Get the API url if specified
|
||||
lite_llm_api_url = os.getenv('LITE_LLM_API_URL')
|
||||
if lite_llm_api_url:
|
||||
return lite_llm_api_url
|
||||
|
||||
# Fallback to using web_host.
|
||||
web_host = os.getenv('WEB_HOST')
|
||||
if web_host and ('.staging.' in web_host or web_host.startswith('staging')):
|
||||
return 'https://llm-proxy.staging.all-hands.dev/'
|
||||
|
||||
# Use the default
|
||||
return 'https://llm-proxy.app.all-hands.dev/'
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Annotated
|
||||
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
@@ -33,9 +31,7 @@ class Settings(BaseModel):
|
||||
user_version: int | None = None
|
||||
remote_runtime_resource_factor: int | None = None
|
||||
# Planned to be removed from settings
|
||||
secrets_store: Annotated[Secrets, Field(frozen=True)] = Field(
|
||||
default_factory=Secrets
|
||||
)
|
||||
secrets_store: Secrets = Field(default_factory=Secrets, frozen=True)
|
||||
enable_default_condenser: bool = True
|
||||
enable_sound_notifications: bool = False
|
||||
enable_proactive_conversation_starters: bool = True
|
||||
|
||||
Generated
+10
-10
@@ -7731,14 +7731,14 @@ llama = ["llama-index (>=0.12.29,<0.13.0)", "llama-index-core (>=0.12.29,<0.13.0
|
||||
|
||||
[[package]]
|
||||
name = "openhands-agent-server"
|
||||
version = "1.10.0"
|
||||
version = "1.8.2"
|
||||
description = "OpenHands Agent Server - REST/WebSocket interface for OpenHands AI Agent"
|
||||
optional = false
|
||||
python-versions = ">=3.12"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "openhands_agent_server-1.10.0-py3-none-any.whl", hash = "sha256:2e21076fff5e7cf9d03a3b011e2c90a6a3a46d2da3f18db9f7553ac413229c22"},
|
||||
{file = "openhands_agent_server-1.10.0.tar.gz", hash = "sha256:2062da2496a98a6c23201d086f124e02329d6c6d9d1b47be55921c084a29f55a"},
|
||||
{file = "openhands_agent_server-1.8.2-py3-none-any.whl", hash = "sha256:e9abb2e0fe970715537d0e0fc1aea3dd64bb9e8b531f70cb72b3d4e486aaa46a"},
|
||||
{file = "openhands_agent_server-1.8.2.tar.gz", hash = "sha256:43db2371ee84b100ac921396338dee74359fceeb5c9400c90530bcc5730144c3"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -7755,14 +7755,14 @@ wsproto = ">=1.2.0"
|
||||
|
||||
[[package]]
|
||||
name = "openhands-sdk"
|
||||
version = "1.10.0"
|
||||
version = "1.8.2"
|
||||
description = "OpenHands SDK - Core functionality for building AI agents"
|
||||
optional = false
|
||||
python-versions = ">=3.12"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "openhands_sdk-1.10.0-py3-none-any.whl", hash = "sha256:5c8875f2a07d7fabe3449914639572bef9003821207cb06aa237a239e964eed5"},
|
||||
{file = "openhands_sdk-1.10.0.tar.gz", hash = "sha256:93371b1af4532266ad2d225b9d7d3d711c745df31888efe643970673f62bdef9"},
|
||||
{file = "openhands_sdk-1.8.2-py3-none-any.whl", hash = "sha256:b4fad9581865ce222a3e6722384e4df56113db01bd34c2d2d408dfd9695365c0"},
|
||||
{file = "openhands_sdk-1.8.2.tar.gz", hash = "sha256:5bfb17c8b9515210d121249deb1f3d0dc407c3737edc55b5e73330b4571d61e3"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -7783,14 +7783,14 @@ boto3 = ["boto3 (>=1.35.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "openhands-tools"
|
||||
version = "1.10.0"
|
||||
version = "1.8.2"
|
||||
description = "OpenHands Tools - Runtime tools for AI agents"
|
||||
optional = false
|
||||
python-versions = ">=3.12"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "openhands_tools-1.10.0-py3-none-any.whl", hash = "sha256:1d5d2d1e34cc4ceb02c0ff1f008b06883ad48a8e7236ab8dd61ece64fbf8e2ed"},
|
||||
{file = "openhands_tools-1.10.0.tar.gz", hash = "sha256:7ed38cb13545ec2c4a35c26ece725d5b35788d30597db8b1904619c043ec1194"},
|
||||
{file = "openhands_tools-1.8.2-py3-none-any.whl", hash = "sha256:283f0c1fdd316914559cd16ade792383715478a8f5a73f7166daffc34bf9e5af"},
|
||||
{file = "openhands_tools-1.8.2.tar.gz", hash = "sha256:eae416e3867f7cb595129a33a4b9237886c4b8a075d2bc7618da55963f2747d5"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -17367,4 +17367,4 @@ third-party-runtimes = ["daytona", "e2b-code-interpreter", "modal", "runloop-api
|
||||
[metadata]
|
||||
lock-version = "2.1"
|
||||
python-versions = "^3.12,<3.14"
|
||||
content-hash = "f67478db2385eb258369313ac831b26582d744294c0996a35e786c3d7ced5db1"
|
||||
content-hash = "530cf5f60f1a38d69fea854eb3682a64a192f9f97beaa0dfc9dcedf239de3a58"
|
||||
|
||||
+9
-6
@@ -54,9 +54,9 @@ dependencies = [
|
||||
"numpy",
|
||||
"openai==2.8",
|
||||
"openhands-aci==0.3.2",
|
||||
"openhands-agent-server==1.10",
|
||||
"openhands-sdk==1.10",
|
||||
"openhands-tools==1.10",
|
||||
"openhands-agent-server==1.8.2",
|
||||
"openhands-sdk==1.8.2",
|
||||
"openhands-tools==1.8.2",
|
||||
"opentelemetry-api>=1.33.1",
|
||||
"opentelemetry-exporter-otlp-proto-grpc>=1.33.1",
|
||||
"pathspec>=0.12.1",
|
||||
@@ -280,9 +280,12 @@ e2b-code-interpreter = { version = "^2.0.0", optional = true }
|
||||
pybase62 = "^1.0.0"
|
||||
|
||||
# V1 dependencies
|
||||
openhands-sdk = "1.10"
|
||||
openhands-agent-server = "1.10"
|
||||
openhands-tools = "1.10"
|
||||
#openhands-agent-server = { git = "https://github.com/OpenHands/agent-sdk.git", subdirectory = "openhands-agent-server", rev = "15f565b8ac38876e40dc05c08e2b04ccaae4a66d" }
|
||||
#openhands-sdk = { git = "https://github.com/OpenHands/agent-sdk.git", subdirectory = "openhands-sdk", rev = "15f565b8ac38876e40dc05c08e2b04ccaae4a66d" }
|
||||
#openhands-tools = { git = "https://github.com/OpenHands/agent-sdk.git", subdirectory = "openhands-tools", rev = "15f565b8ac38876e40dc05c08e2b04ccaae4a66d" }
|
||||
openhands-sdk = "1.8.2"
|
||||
openhands-agent-server = "1.8.2"
|
||||
openhands-tools = "1.8.2"
|
||||
python-jose = { version = ">=3.3", extras = [ "cryptography" ] }
|
||||
sqlalchemy = { extras = [ "asyncio" ], version = "^2.0.40" }
|
||||
pg8000 = "^1.31.5"
|
||||
|
||||
@@ -17,7 +17,6 @@ from openhands.app_server.app_conversation.app_conversation_service_base import
|
||||
)
|
||||
from openhands.app_server.sandbox.sandbox_models import SandboxInfo
|
||||
from openhands.app_server.user.user_context import UserContext
|
||||
from openhands.sdk.context.skills import Skill
|
||||
|
||||
|
||||
class MockUserInfo:
|
||||
@@ -921,251 +920,347 @@ async def test_configure_git_user_settings_special_characters_in_name(mock_works
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Tests for load_and_merge_all_skills (updated to use agent-server)
|
||||
# Tests for load_and_merge_all_skills with org skills
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestMergeSkills:
|
||||
"""Test _merge_skills method."""
|
||||
class TestLoadAndMergeAllSkillsWithOrgSkills:
|
||||
"""Test load_and_merge_all_skills includes organization skills."""
|
||||
|
||||
def test_merges_skills_with_no_duplicates(self):
|
||||
"""Test merging skill lists with no duplicate names."""
|
||||
@pytest.mark.asyncio
|
||||
@patch(
|
||||
'openhands.app_server.app_conversation.app_conversation_service_base.load_sandbox_skills'
|
||||
)
|
||||
@patch(
|
||||
'openhands.app_server.app_conversation.app_conversation_service_base.load_global_skills'
|
||||
)
|
||||
@patch(
|
||||
'openhands.app_server.app_conversation.app_conversation_service_base.load_user_skills'
|
||||
)
|
||||
@patch(
|
||||
'openhands.app_server.app_conversation.app_conversation_service_base.load_org_skills'
|
||||
)
|
||||
@patch(
|
||||
'openhands.app_server.app_conversation.app_conversation_service_base.load_repo_skills'
|
||||
)
|
||||
async def test_load_and_merge_includes_org_skills(
|
||||
self,
|
||||
mock_load_repo,
|
||||
mock_load_org,
|
||||
mock_load_user,
|
||||
mock_load_global,
|
||||
mock_load_sandbox,
|
||||
):
|
||||
"""Test that load_and_merge_all_skills loads and merges org skills."""
|
||||
# Arrange
|
||||
mock_user_context = Mock(spec=UserContext)
|
||||
with patch.object(AppConversationServiceBase, '__abstractmethods__', set()):
|
||||
with patch.object(
|
||||
AppConversationServiceBase,
|
||||
'__abstractmethods__',
|
||||
set(),
|
||||
):
|
||||
service = AppConversationServiceBase(
|
||||
init_git_in_empty_workspace=True, user_context=mock_user_context
|
||||
init_git_in_empty_workspace=True,
|
||||
user_context=mock_user_context,
|
||||
)
|
||||
|
||||
skill1 = Mock(spec=Skill)
|
||||
skill1.name = 'skill1'
|
||||
skill2 = Mock(spec=Skill)
|
||||
skill2.name = 'skill2'
|
||||
skill3 = Mock(spec=Skill)
|
||||
skill3.name = 'skill3'
|
||||
sandbox = Mock(spec=SandboxInfo)
|
||||
sandbox.exposed_urls = []
|
||||
remote_workspace = AsyncMock()
|
||||
|
||||
skill_lists = [[skill1], [skill2], [skill3]]
|
||||
# Create distinct mock skills for each source
|
||||
sandbox_skill = Mock()
|
||||
sandbox_skill.name = 'sandbox_skill'
|
||||
global_skill = Mock()
|
||||
global_skill.name = 'global_skill'
|
||||
user_skill = Mock()
|
||||
user_skill.name = 'user_skill'
|
||||
org_skill = Mock()
|
||||
org_skill.name = 'org_skill'
|
||||
repo_skill = Mock()
|
||||
repo_skill.name = 'repo_skill'
|
||||
|
||||
mock_load_sandbox.return_value = [sandbox_skill]
|
||||
mock_load_global.return_value = [global_skill]
|
||||
mock_load_user.return_value = [user_skill]
|
||||
mock_load_org.return_value = [org_skill]
|
||||
mock_load_repo.return_value = [repo_skill]
|
||||
|
||||
# Act
|
||||
result = service._merge_skills(skill_lists)
|
||||
result = await service.load_and_merge_all_skills(
|
||||
sandbox, remote_workspace, 'owner/repo', '/workspace'
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(result) == 3
|
||||
assert len(result) == 5
|
||||
names = {s.name for s in result}
|
||||
assert names == {'skill1', 'skill2', 'skill3'}
|
||||
|
||||
def test_merges_skills_with_duplicates_later_wins(self):
|
||||
"""Test that later skill lists override earlier ones for duplicate names."""
|
||||
# Arrange
|
||||
mock_user_context = Mock(spec=UserContext)
|
||||
with patch.object(AppConversationServiceBase, '__abstractmethods__', set()):
|
||||
service = AppConversationServiceBase(
|
||||
init_git_in_empty_workspace=True, user_context=mock_user_context
|
||||
assert names == {
|
||||
'sandbox_skill',
|
||||
'global_skill',
|
||||
'user_skill',
|
||||
'org_skill',
|
||||
'repo_skill',
|
||||
}
|
||||
mock_load_org.assert_called_once_with(
|
||||
remote_workspace, 'owner/repo', '/workspace', mock_user_context
|
||||
)
|
||||
|
||||
skill1_v1 = Mock(spec=Skill)
|
||||
skill1_v1.name = 'skill1'
|
||||
skill1_v1.version = 'v1'
|
||||
@pytest.mark.asyncio
|
||||
@patch(
|
||||
'openhands.app_server.app_conversation.app_conversation_service_base.load_sandbox_skills'
|
||||
)
|
||||
@patch(
|
||||
'openhands.app_server.app_conversation.app_conversation_service_base.load_global_skills'
|
||||
)
|
||||
@patch(
|
||||
'openhands.app_server.app_conversation.app_conversation_service_base.load_user_skills'
|
||||
)
|
||||
@patch(
|
||||
'openhands.app_server.app_conversation.app_conversation_service_base.load_org_skills'
|
||||
)
|
||||
@patch(
|
||||
'openhands.app_server.app_conversation.app_conversation_service_base.load_repo_skills'
|
||||
)
|
||||
async def test_load_and_merge_org_skills_precedence(
|
||||
self,
|
||||
mock_load_repo,
|
||||
mock_load_org,
|
||||
mock_load_user,
|
||||
mock_load_global,
|
||||
mock_load_sandbox,
|
||||
):
|
||||
"""Test that org skills have correct precedence (higher than user, lower than repo)."""
|
||||
# Arrange
|
||||
mock_user_context = Mock(spec=UserContext)
|
||||
with patch.object(
|
||||
AppConversationServiceBase,
|
||||
'__abstractmethods__',
|
||||
set(),
|
||||
):
|
||||
service = AppConversationServiceBase(
|
||||
init_git_in_empty_workspace=True,
|
||||
user_context=mock_user_context,
|
||||
)
|
||||
|
||||
skill1_v2 = Mock(spec=Skill)
|
||||
skill1_v2.name = 'skill1'
|
||||
skill1_v2.version = 'v2'
|
||||
sandbox = Mock(spec=SandboxInfo)
|
||||
sandbox.exposed_urls = []
|
||||
remote_workspace = AsyncMock()
|
||||
|
||||
skill2 = Mock(spec=Skill)
|
||||
skill2.name = 'skill2'
|
||||
# Create skills with same name but different sources
|
||||
user_skill = Mock()
|
||||
user_skill.name = 'common_skill'
|
||||
user_skill.source = 'user'
|
||||
|
||||
skill_lists = [[skill1_v1], [skill1_v2, skill2]]
|
||||
org_skill = Mock()
|
||||
org_skill.name = 'common_skill'
|
||||
org_skill.source = 'org'
|
||||
|
||||
repo_skill = Mock()
|
||||
repo_skill.name = 'common_skill'
|
||||
repo_skill.source = 'repo'
|
||||
|
||||
mock_load_sandbox.return_value = []
|
||||
mock_load_global.return_value = []
|
||||
mock_load_user.return_value = [user_skill]
|
||||
mock_load_org.return_value = [org_skill]
|
||||
mock_load_repo.return_value = [repo_skill]
|
||||
|
||||
# Act
|
||||
result = service._merge_skills(skill_lists)
|
||||
result = await service.load_and_merge_all_skills(
|
||||
sandbox, remote_workspace, 'owner/repo', '/workspace'
|
||||
)
|
||||
|
||||
# Assert
|
||||
# Should have only one skill with repo source (highest precedence)
|
||||
assert len(result) == 1
|
||||
assert result[0].source == 'repo'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch(
|
||||
'openhands.app_server.app_conversation.app_conversation_service_base.load_sandbox_skills'
|
||||
)
|
||||
@patch(
|
||||
'openhands.app_server.app_conversation.app_conversation_service_base.load_global_skills'
|
||||
)
|
||||
@patch(
|
||||
'openhands.app_server.app_conversation.app_conversation_service_base.load_user_skills'
|
||||
)
|
||||
@patch(
|
||||
'openhands.app_server.app_conversation.app_conversation_service_base.load_org_skills'
|
||||
)
|
||||
@patch(
|
||||
'openhands.app_server.app_conversation.app_conversation_service_base.load_repo_skills'
|
||||
)
|
||||
async def test_load_and_merge_org_skills_override_user_skills(
|
||||
self,
|
||||
mock_load_repo,
|
||||
mock_load_org,
|
||||
mock_load_user,
|
||||
mock_load_global,
|
||||
mock_load_sandbox,
|
||||
):
|
||||
"""Test that org skills override user skills for same name."""
|
||||
# Arrange
|
||||
mock_user_context = Mock(spec=UserContext)
|
||||
with patch.object(
|
||||
AppConversationServiceBase,
|
||||
'__abstractmethods__',
|
||||
set(),
|
||||
):
|
||||
service = AppConversationServiceBase(
|
||||
init_git_in_empty_workspace=True,
|
||||
user_context=mock_user_context,
|
||||
)
|
||||
|
||||
sandbox = Mock(spec=SandboxInfo)
|
||||
sandbox.exposed_urls = []
|
||||
remote_workspace = AsyncMock()
|
||||
|
||||
# Create skills with same name
|
||||
user_skill = Mock()
|
||||
user_skill.name = 'shared_skill'
|
||||
user_skill.priority = 'low'
|
||||
|
||||
org_skill = Mock()
|
||||
org_skill.name = 'shared_skill'
|
||||
org_skill.priority = 'high'
|
||||
|
||||
mock_load_sandbox.return_value = []
|
||||
mock_load_global.return_value = []
|
||||
mock_load_user.return_value = [user_skill]
|
||||
mock_load_org.return_value = [org_skill]
|
||||
mock_load_repo.return_value = []
|
||||
|
||||
# Act
|
||||
result = await service.load_and_merge_all_skills(
|
||||
sandbox, remote_workspace, 'owner/repo', '/workspace'
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(result) == 1
|
||||
assert result[0].priority == 'high' # Org skill should win
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch(
|
||||
'openhands.app_server.app_conversation.app_conversation_service_base.load_sandbox_skills'
|
||||
)
|
||||
@patch(
|
||||
'openhands.app_server.app_conversation.app_conversation_service_base.load_global_skills'
|
||||
)
|
||||
@patch(
|
||||
'openhands.app_server.app_conversation.app_conversation_service_base.load_user_skills'
|
||||
)
|
||||
@patch(
|
||||
'openhands.app_server.app_conversation.app_conversation_service_base.load_org_skills'
|
||||
)
|
||||
@patch(
|
||||
'openhands.app_server.app_conversation.app_conversation_service_base.load_repo_skills'
|
||||
)
|
||||
async def test_load_and_merge_handles_org_skills_failure(
|
||||
self,
|
||||
mock_load_repo,
|
||||
mock_load_org,
|
||||
mock_load_user,
|
||||
mock_load_global,
|
||||
mock_load_sandbox,
|
||||
):
|
||||
"""Test that failure to load org skills doesn't break the overall process."""
|
||||
# Arrange
|
||||
mock_user_context = Mock(spec=UserContext)
|
||||
with patch.object(
|
||||
AppConversationServiceBase,
|
||||
'__abstractmethods__',
|
||||
set(),
|
||||
):
|
||||
service = AppConversationServiceBase(
|
||||
init_git_in_empty_workspace=True,
|
||||
user_context=mock_user_context,
|
||||
)
|
||||
|
||||
sandbox = Mock(spec=SandboxInfo)
|
||||
sandbox.exposed_urls = []
|
||||
remote_workspace = AsyncMock()
|
||||
|
||||
global_skill = Mock()
|
||||
global_skill.name = 'global_skill'
|
||||
repo_skill = Mock()
|
||||
repo_skill.name = 'repo_skill'
|
||||
|
||||
mock_load_sandbox.return_value = []
|
||||
mock_load_global.return_value = [global_skill]
|
||||
mock_load_user.return_value = []
|
||||
mock_load_org.return_value = [] # Org skills failed/empty
|
||||
mock_load_repo.return_value = [repo_skill]
|
||||
|
||||
# Act
|
||||
result = await service.load_and_merge_all_skills(
|
||||
sandbox, remote_workspace, 'owner/repo', '/workspace'
|
||||
)
|
||||
|
||||
# Assert
|
||||
# Should still have skills from other sources
|
||||
assert len(result) == 2
|
||||
skill1_result = next(s for s in result if s.name == 'skill1')
|
||||
assert skill1_result.version == 'v2'
|
||||
|
||||
|
||||
class TestLoadAndMergeAllSkills:
|
||||
"""Test load_and_merge_all_skills method (updated to use agent-server)."""
|
||||
names = {s.name for s in result}
|
||||
assert names == {'global_skill', 'repo_skill'}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch(
|
||||
'openhands.app_server.app_conversation.app_conversation_service_base.load_skills_from_agent_server'
|
||||
'openhands.app_server.app_conversation.app_conversation_service_base.load_sandbox_skills'
|
||||
)
|
||||
@patch(
|
||||
'openhands.app_server.app_conversation.app_conversation_service_base.build_org_config'
|
||||
'openhands.app_server.app_conversation.app_conversation_service_base.load_global_skills'
|
||||
)
|
||||
@patch(
|
||||
'openhands.app_server.app_conversation.app_conversation_service_base.build_sandbox_config'
|
||||
'openhands.app_server.app_conversation.app_conversation_service_base.load_user_skills'
|
||||
)
|
||||
async def test_loads_skills_successfully(
|
||||
@patch(
|
||||
'openhands.app_server.app_conversation.app_conversation_service_base.load_org_skills'
|
||||
)
|
||||
@patch(
|
||||
'openhands.app_server.app_conversation.app_conversation_service_base.load_repo_skills'
|
||||
)
|
||||
async def test_load_and_merge_no_selected_repository(
|
||||
self,
|
||||
mock_build_sandbox_config,
|
||||
mock_build_org_config,
|
||||
mock_load_skills,
|
||||
mock_load_repo,
|
||||
mock_load_org,
|
||||
mock_load_user,
|
||||
mock_load_global,
|
||||
mock_load_sandbox,
|
||||
):
|
||||
"""Test successfully loading skills from agent-server."""
|
||||
"""Test skill loading when no repository is selected."""
|
||||
# Arrange
|
||||
mock_user_context = Mock(spec=UserContext)
|
||||
with patch.object(AppConversationServiceBase, '__abstractmethods__', set()):
|
||||
with patch.object(
|
||||
AppConversationServiceBase,
|
||||
'__abstractmethods__',
|
||||
set(),
|
||||
):
|
||||
service = AppConversationServiceBase(
|
||||
init_git_in_empty_workspace=True, user_context=mock_user_context
|
||||
init_git_in_empty_workspace=True,
|
||||
user_context=mock_user_context,
|
||||
)
|
||||
|
||||
mock_workspace = AsyncMock()
|
||||
mock_workspace.working_dir = '/workspace'
|
||||
|
||||
from openhands.app_server.sandbox.sandbox_models import ExposedUrl
|
||||
|
||||
sandbox = Mock(spec=SandboxInfo)
|
||||
exposed_url = ExposedUrl(
|
||||
name='AGENT_SERVER', url='http://localhost:8000', port=8000
|
||||
)
|
||||
sandbox.exposed_urls = [exposed_url]
|
||||
sandbox.session_api_key = 'test-api-key'
|
||||
sandbox.exposed_urls = []
|
||||
remote_workspace = AsyncMock()
|
||||
|
||||
skill1 = Mock(spec=Skill)
|
||||
skill1.name = 'skill1'
|
||||
skill2 = Mock(spec=Skill)
|
||||
skill2.name = 'skill2'
|
||||
global_skill = Mock()
|
||||
global_skill.name = 'global_skill'
|
||||
|
||||
mock_load_skills.return_value = [skill1, skill2]
|
||||
mock_build_org_config.return_value = {'repository': 'owner/repo'}
|
||||
mock_build_sandbox_config.return_value = {'exposed_urls': []}
|
||||
mock_load_sandbox.return_value = []
|
||||
mock_load_global.return_value = [global_skill]
|
||||
mock_load_user.return_value = []
|
||||
mock_load_org.return_value = []
|
||||
mock_load_repo.return_value = []
|
||||
|
||||
# Act
|
||||
result = await service.load_and_merge_all_skills(
|
||||
sandbox, 'owner/repo', '/workspace', 'http://localhost:8000'
|
||||
sandbox, remote_workspace, None, '/workspace'
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(result) == 2
|
||||
assert result[0].name == 'skill1'
|
||||
assert result[1].name == 'skill2'
|
||||
mock_load_skills.assert_called_once()
|
||||
call_kwargs = mock_load_skills.call_args[1]
|
||||
assert call_kwargs['agent_server_url'] == 'http://localhost:8000'
|
||||
assert call_kwargs['session_api_key'] == 'test-api-key'
|
||||
assert call_kwargs['project_dir'] == '/workspace/repo'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch(
|
||||
'openhands.app_server.app_conversation.app_conversation_service_base.load_skills_from_agent_server'
|
||||
)
|
||||
async def test_returns_empty_list_when_no_agent_server_url(self, mock_load_skills):
|
||||
"""Test returns empty list when agent-server URL is not available."""
|
||||
# Arrange
|
||||
mock_user_context = Mock(spec=UserContext)
|
||||
with patch.object(AppConversationServiceBase, '__abstractmethods__', set()):
|
||||
service = AppConversationServiceBase(
|
||||
init_git_in_empty_workspace=True, user_context=mock_user_context
|
||||
assert len(result) == 1
|
||||
# Org skills should be called even with None repository
|
||||
mock_load_org.assert_called_once_with(
|
||||
remote_workspace, None, '/workspace', mock_user_context
|
||||
)
|
||||
|
||||
AsyncMock()
|
||||
from openhands.app_server.sandbox.sandbox_models import ExposedUrl
|
||||
|
||||
sandbox = Mock(spec=SandboxInfo)
|
||||
exposed_url = ExposedUrl(
|
||||
name='VSCODE', url='http://localhost:8080', port=8080
|
||||
)
|
||||
sandbox.exposed_urls = [exposed_url]
|
||||
|
||||
# Act - pass empty string to simulate no agent server URL
|
||||
# This should still call load_skills_from_agent_server but it will fail
|
||||
result = await service.load_and_merge_all_skills(
|
||||
sandbox, 'owner/repo', '/workspace', ''
|
||||
)
|
||||
|
||||
# Assert - should return empty list when agent_server_url is empty
|
||||
assert result == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch(
|
||||
'openhands.app_server.app_conversation.app_conversation_service_base.load_skills_from_agent_server'
|
||||
)
|
||||
@patch(
|
||||
'openhands.app_server.app_conversation.app_conversation_service_base.build_org_config'
|
||||
)
|
||||
@patch(
|
||||
'openhands.app_server.app_conversation.app_conversation_service_base.build_sandbox_config'
|
||||
)
|
||||
async def test_uses_working_dir_when_no_repository(
|
||||
self,
|
||||
mock_build_sandbox_config,
|
||||
mock_build_org_config,
|
||||
mock_load_skills,
|
||||
):
|
||||
"""Test uses working_dir as project_dir when no repository is selected."""
|
||||
# Arrange
|
||||
mock_user_context = Mock(spec=UserContext)
|
||||
with patch.object(AppConversationServiceBase, '__abstractmethods__', set()):
|
||||
service = AppConversationServiceBase(
|
||||
init_git_in_empty_workspace=True, user_context=mock_user_context
|
||||
)
|
||||
|
||||
AsyncMock()
|
||||
from openhands.app_server.sandbox.sandbox_models import ExposedUrl
|
||||
|
||||
sandbox = Mock(spec=SandboxInfo)
|
||||
exposed_url = ExposedUrl(
|
||||
name='AGENT_SERVER', url='http://localhost:8000', port=8000
|
||||
)
|
||||
sandbox.exposed_urls = [exposed_url]
|
||||
sandbox.session_api_key = 'test-key'
|
||||
|
||||
mock_load_skills.return_value = []
|
||||
mock_build_org_config.return_value = None
|
||||
mock_build_sandbox_config.return_value = None
|
||||
|
||||
# Act
|
||||
await service.load_and_merge_all_skills(
|
||||
sandbox, None, '/workspace', 'http://localhost:8000'
|
||||
)
|
||||
|
||||
# Assert
|
||||
call_kwargs = mock_load_skills.call_args[1]
|
||||
assert call_kwargs['project_dir'] == '/workspace'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch(
|
||||
'openhands.app_server.app_conversation.app_conversation_service_base.load_skills_from_agent_server'
|
||||
)
|
||||
@patch(
|
||||
'openhands.app_server.app_conversation.app_conversation_service_base.build_org_config'
|
||||
)
|
||||
@patch(
|
||||
'openhands.app_server.app_conversation.app_conversation_service_base.build_sandbox_config'
|
||||
)
|
||||
async def test_handles_exception_gracefully(
|
||||
self,
|
||||
mock_build_sandbox_config,
|
||||
mock_build_org_config,
|
||||
mock_load_skills,
|
||||
):
|
||||
"""Test handles exceptions during skill loading."""
|
||||
# Arrange
|
||||
mock_user_context = Mock(spec=UserContext)
|
||||
with patch.object(AppConversationServiceBase, '__abstractmethods__', set()):
|
||||
service = AppConversationServiceBase(
|
||||
init_git_in_empty_workspace=True, user_context=mock_user_context
|
||||
)
|
||||
|
||||
AsyncMock()
|
||||
from openhands.app_server.sandbox.sandbox_models import ExposedUrl
|
||||
|
||||
sandbox = Mock(spec=SandboxInfo)
|
||||
exposed_url = ExposedUrl(
|
||||
name='AGENT_SERVER', url='http://localhost:8000', port=8000
|
||||
)
|
||||
sandbox.exposed_urls = [exposed_url]
|
||||
sandbox.session_api_key = 'test-key'
|
||||
|
||||
mock_load_skills.side_effect = Exception('Network error')
|
||||
|
||||
# Act
|
||||
result = await service.load_and_merge_all_skills(
|
||||
sandbox, 'owner/repo', '/workspace', 'http://localhost:8000'
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == []
|
||||
|
||||
@@ -1165,50 +1165,6 @@ class TestLiveStatusAppConversationService:
|
||||
)
|
||||
self.mock_event_service.search_events.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_export_conversation_calls_search_events_with_correct_parameter_name(
|
||||
self,
|
||||
):
|
||||
"""Test that export_conversation calls search_events with 'conversation_id' parameter, not 'conversation_id__eq'.
|
||||
|
||||
This test verifies the fix for a bug where page_iterator was called with
|
||||
conversation_id__eq instead of conversation_id, causing a TypeError since
|
||||
the search_events method expects conversation_id as its parameter name.
|
||||
"""
|
||||
# Arrange
|
||||
conversation_id = uuid4()
|
||||
|
||||
# Mock conversation info
|
||||
mock_conversation_info = Mock(spec=AppConversationInfo)
|
||||
mock_conversation_info.id = conversation_id
|
||||
mock_conversation_info.model_dump_json = Mock(return_value='{}')
|
||||
|
||||
self.mock_app_conversation_info_service.get_app_conversation_info = AsyncMock(
|
||||
return_value=mock_conversation_info
|
||||
)
|
||||
|
||||
# Mock empty event page to simplify test
|
||||
mock_event_page = Mock()
|
||||
mock_event_page.items = []
|
||||
mock_event_page.next_page_id = None
|
||||
|
||||
self.mock_event_service.search_events = AsyncMock(return_value=mock_event_page)
|
||||
|
||||
# Act
|
||||
await self.service.export_conversation(conversation_id)
|
||||
|
||||
# Assert - Verify search_events was called with 'conversation_id', not 'conversation_id__eq'
|
||||
self.mock_event_service.search_events.assert_called()
|
||||
call_kwargs = self.mock_event_service.search_events.call_args[1]
|
||||
|
||||
assert 'conversation_id' in call_kwargs, (
|
||||
"search_events should be called with 'conversation_id' parameter"
|
||||
)
|
||||
assert 'conversation_id__eq' not in call_kwargs, (
|
||||
"search_events should NOT be called with 'conversation_id__eq' parameter"
|
||||
)
|
||||
assert call_kwargs['conversation_id'] == conversation_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_export_conversation_large_pagination(self):
|
||||
"""Test download with multiple pages of events."""
|
||||
@@ -1332,7 +1288,7 @@ class TestLiveStatusAppConversationService:
|
||||
task.sandbox_id = self.mock_sandbox.id
|
||||
yield task
|
||||
|
||||
async def mock_run_setup_scripts(task, sandbox, workspace, agent_server_url):
|
||||
async def mock_run_setup_scripts(task, sandbox, workspace):
|
||||
yield task
|
||||
|
||||
self.service._wait_for_sandbox_start = mock_wait_for_sandbox
|
||||
@@ -1786,855 +1742,3 @@ class TestLiveStatusAppConversationService:
|
||||
stdio_server = mcp_servers['stdio-server']
|
||||
assert stdio_server['command'] == 'npx'
|
||||
assert stdio_server['env'] == {'TOKEN': 'value'}
|
||||
|
||||
|
||||
class TestPluginHandling:
|
||||
"""Test cases for plugin-related functionality in LiveStatusAppConversationService."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Set up test fixtures."""
|
||||
# Create mock dependencies
|
||||
self.mock_user_context = Mock(spec=UserContext)
|
||||
self.mock_user_auth = Mock()
|
||||
self.mock_user_context.user_auth = self.mock_user_auth
|
||||
self.mock_jwt_service = Mock()
|
||||
self.mock_sandbox_service = Mock()
|
||||
self.mock_sandbox_spec_service = Mock()
|
||||
self.mock_app_conversation_info_service = Mock()
|
||||
self.mock_app_conversation_start_task_service = Mock()
|
||||
self.mock_event_callback_service = Mock()
|
||||
self.mock_event_service = Mock()
|
||||
self.mock_httpx_client = Mock()
|
||||
|
||||
# Create service instance
|
||||
self.service = LiveStatusAppConversationService(
|
||||
init_git_in_empty_workspace=True,
|
||||
user_context=self.mock_user_context,
|
||||
app_conversation_info_service=self.mock_app_conversation_info_service,
|
||||
app_conversation_start_task_service=self.mock_app_conversation_start_task_service,
|
||||
event_callback_service=self.mock_event_callback_service,
|
||||
event_service=self.mock_event_service,
|
||||
sandbox_service=self.mock_sandbox_service,
|
||||
sandbox_spec_service=self.mock_sandbox_spec_service,
|
||||
jwt_service=self.mock_jwt_service,
|
||||
sandbox_startup_timeout=30,
|
||||
sandbox_startup_poll_frequency=1,
|
||||
httpx_client=self.mock_httpx_client,
|
||||
web_url='https://test.example.com',
|
||||
openhands_provider_base_url='https://provider.example.com',
|
||||
access_token_hard_timeout=None,
|
||||
app_mode='test',
|
||||
)
|
||||
|
||||
# Mock user info
|
||||
self.mock_user = Mock()
|
||||
self.mock_user.id = 'test_user_123'
|
||||
self.mock_user.llm_model = 'gpt-4'
|
||||
self.mock_user.llm_base_url = 'https://api.openai.com/v1'
|
||||
self.mock_user.llm_api_key = 'test_api_key'
|
||||
self.mock_user.confirmation_mode = False
|
||||
self.mock_user.search_api_key = None
|
||||
self.mock_user.condenser_max_size = None
|
||||
self.mock_user.mcp_config = None
|
||||
self.mock_user.security_analyzer = None
|
||||
|
||||
# Mock sandbox
|
||||
self.mock_sandbox = Mock(spec=SandboxInfo)
|
||||
self.mock_sandbox.id = uuid4()
|
||||
self.mock_sandbox.status = SandboxStatus.RUNNING
|
||||
|
||||
def test_construct_initial_message_with_plugin_params_no_plugins(self):
|
||||
"""Test _construct_initial_message_with_plugin_params with no plugins returns original message."""
|
||||
from openhands.agent_server.models import SendMessageRequest, TextContent
|
||||
|
||||
# Test with None initial message and None plugins
|
||||
result = self.service._construct_initial_message_with_plugin_params(None, None)
|
||||
assert result is None
|
||||
|
||||
# Test with None initial message and empty plugins list
|
||||
result = self.service._construct_initial_message_with_plugin_params(None, [])
|
||||
assert result is None
|
||||
|
||||
# Test with initial message but None plugins
|
||||
initial_msg = SendMessageRequest(content=[TextContent(text='Hello world')])
|
||||
result = self.service._construct_initial_message_with_plugin_params(
|
||||
initial_msg, None
|
||||
)
|
||||
assert result is initial_msg
|
||||
|
||||
def test_construct_initial_message_with_plugin_params_no_params(self):
|
||||
"""Test _construct_initial_message_with_plugin_params with plugins but no parameters."""
|
||||
from openhands.agent_server.models import SendMessageRequest, TextContent
|
||||
from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
PluginSpec,
|
||||
)
|
||||
|
||||
# Plugin with no parameters
|
||||
plugins = [PluginSpec(source='github:owner/repo')]
|
||||
|
||||
# Test with None initial message
|
||||
result = self.service._construct_initial_message_with_plugin_params(
|
||||
None, plugins
|
||||
)
|
||||
assert result is None
|
||||
|
||||
# Test with initial message
|
||||
initial_msg = SendMessageRequest(content=[TextContent(text='Hello world')])
|
||||
result = self.service._construct_initial_message_with_plugin_params(
|
||||
initial_msg, plugins
|
||||
)
|
||||
assert result is initial_msg
|
||||
|
||||
def test_construct_initial_message_with_plugin_params_creates_new_message(self):
|
||||
"""Test _construct_initial_message_with_plugin_params creates message when no initial message."""
|
||||
from openhands.agent_server.models import TextContent
|
||||
from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
PluginSpec,
|
||||
)
|
||||
|
||||
plugins = [
|
||||
PluginSpec(
|
||||
source='github:owner/repo',
|
||||
parameters={'api_key': 'test123', 'debug': True},
|
||||
)
|
||||
]
|
||||
|
||||
result = self.service._construct_initial_message_with_plugin_params(
|
||||
None, plugins
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert len(result.content) == 1
|
||||
assert isinstance(result.content[0], TextContent)
|
||||
assert 'Plugin Configuration Parameters:' in result.content[0].text
|
||||
assert '- api_key: test123' in result.content[0].text
|
||||
assert '- debug: True' in result.content[0].text
|
||||
assert result.run is True
|
||||
|
||||
def test_construct_initial_message_with_plugin_params_appends_to_message(self):
|
||||
"""Test _construct_initial_message_with_plugin_params appends to existing message."""
|
||||
from openhands.agent_server.models import SendMessageRequest, TextContent
|
||||
from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
PluginSpec,
|
||||
)
|
||||
|
||||
initial_msg = SendMessageRequest(
|
||||
content=[TextContent(text='Please analyze this codebase')],
|
||||
run=False,
|
||||
)
|
||||
plugins = [
|
||||
PluginSpec(
|
||||
source='github:owner/repo',
|
||||
ref='v1.0.0',
|
||||
parameters={'target_dir': '/src', 'verbose': True},
|
||||
)
|
||||
]
|
||||
|
||||
result = self.service._construct_initial_message_with_plugin_params(
|
||||
initial_msg, plugins
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert len(result.content) == 1
|
||||
text = result.content[0].text
|
||||
assert text.startswith('Please analyze this codebase')
|
||||
assert 'Plugin Configuration Parameters:' in text
|
||||
assert '- target_dir: /src' in text
|
||||
assert '- verbose: True' in text
|
||||
assert result.run is False
|
||||
|
||||
def test_construct_initial_message_with_plugin_params_preserves_role(self):
|
||||
"""Test _construct_initial_message_with_plugin_params preserves message role."""
|
||||
from openhands.agent_server.models import SendMessageRequest, TextContent
|
||||
from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
PluginSpec,
|
||||
)
|
||||
|
||||
initial_msg = SendMessageRequest(
|
||||
role='system',
|
||||
content=[TextContent(text='System message')],
|
||||
)
|
||||
plugins = [PluginSpec(source='github:owner/repo', parameters={'key': 'value'})]
|
||||
|
||||
result = self.service._construct_initial_message_with_plugin_params(
|
||||
initial_msg, plugins
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result.role == 'system'
|
||||
|
||||
def test_construct_initial_message_with_plugin_params_empty_content(self):
|
||||
"""Test _construct_initial_message_with_plugin_params handles empty content list."""
|
||||
from openhands.agent_server.models import SendMessageRequest, TextContent
|
||||
from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
PluginSpec,
|
||||
)
|
||||
|
||||
initial_msg = SendMessageRequest(content=[])
|
||||
plugins = [PluginSpec(source='github:owner/repo', parameters={'key': 'value'})]
|
||||
|
||||
result = self.service._construct_initial_message_with_plugin_params(
|
||||
initial_msg, plugins
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert len(result.content) == 1
|
||||
assert isinstance(result.content[0], TextContent)
|
||||
assert 'Plugin Configuration Parameters:' in result.content[0].text
|
||||
|
||||
def test_construct_initial_message_with_multiple_plugins(self):
|
||||
"""Test _construct_initial_message_with_plugin_params handles multiple plugins."""
|
||||
from openhands.agent_server.models import TextContent
|
||||
from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
PluginSpec,
|
||||
)
|
||||
|
||||
plugins = [
|
||||
PluginSpec(
|
||||
source='github:owner/plugin1',
|
||||
parameters={'key1': 'value1'},
|
||||
),
|
||||
PluginSpec(
|
||||
source='github:owner/plugin2',
|
||||
parameters={'key2': 'value2'},
|
||||
),
|
||||
]
|
||||
|
||||
result = self.service._construct_initial_message_with_plugin_params(
|
||||
None, plugins
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert len(result.content) == 1
|
||||
assert isinstance(result.content[0], TextContent)
|
||||
text = result.content[0].text
|
||||
assert 'Plugin Configuration Parameters:' in text
|
||||
# Multiple plugins should show grouped by plugin name
|
||||
assert 'plugin1' in text
|
||||
assert 'plugin2' in text
|
||||
assert 'key1: value1' in text
|
||||
assert 'key2: value2' in text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch(
|
||||
'openhands.app_server.app_conversation.live_status_app_conversation_service.ExperimentManagerImpl'
|
||||
)
|
||||
async def test_finalize_conversation_request_with_plugins(
|
||||
self, mock_experiment_manager
|
||||
):
|
||||
"""Test _finalize_conversation_request passes plugins list to StartConversationRequest."""
|
||||
from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
PluginSpec,
|
||||
)
|
||||
|
||||
# Arrange
|
||||
mock_agent = Mock(spec=Agent)
|
||||
mock_llm = Mock(spec=LLM)
|
||||
mock_llm.model = 'gpt-4'
|
||||
mock_llm.usage_id = 'agent'
|
||||
|
||||
mock_updated_agent = Mock(spec=Agent)
|
||||
mock_updated_agent.llm = mock_llm
|
||||
mock_updated_agent.condenser = None
|
||||
mock_experiment_manager.run_agent_variant_tests__v1.return_value = (
|
||||
mock_updated_agent
|
||||
)
|
||||
|
||||
workspace = LocalWorkspace(working_dir='/test')
|
||||
secrets = {'test': StaticSecret(value='secret')}
|
||||
|
||||
plugins = [
|
||||
PluginSpec(
|
||||
source='github:owner/my-plugin',
|
||||
ref='v1.0.0',
|
||||
parameters={'api_key': 'test123'},
|
||||
)
|
||||
]
|
||||
|
||||
# Act
|
||||
result = await self.service._finalize_conversation_request(
|
||||
mock_agent,
|
||||
None,
|
||||
self.mock_user,
|
||||
workspace,
|
||||
None,
|
||||
secrets,
|
||||
self.mock_sandbox,
|
||||
None,
|
||||
None,
|
||||
'/test/dir',
|
||||
plugins=plugins,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, StartConversationRequest)
|
||||
assert result.plugins is not None
|
||||
assert len(result.plugins) == 1
|
||||
assert result.plugins[0].source == 'github:owner/my-plugin'
|
||||
assert result.plugins[0].ref == 'v1.0.0'
|
||||
# Also verify initial message contains plugin params
|
||||
assert result.initial_message is not None
|
||||
assert (
|
||||
'Plugin Configuration Parameters:' in result.initial_message.content[0].text
|
||||
)
|
||||
assert '- api_key: test123' in result.initial_message.content[0].text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch(
|
||||
'openhands.app_server.app_conversation.live_status_app_conversation_service.ExperimentManagerImpl'
|
||||
)
|
||||
async def test_finalize_conversation_request_without_plugins(
|
||||
self, mock_experiment_manager
|
||||
):
|
||||
"""Test _finalize_conversation_request without plugins sets plugins to None."""
|
||||
# Arrange
|
||||
mock_agent = Mock(spec=Agent)
|
||||
mock_llm = Mock(spec=LLM)
|
||||
mock_llm.model = 'gpt-4'
|
||||
mock_llm.usage_id = 'agent'
|
||||
|
||||
mock_updated_agent = Mock(spec=Agent)
|
||||
mock_updated_agent.llm = mock_llm
|
||||
mock_updated_agent.condenser = None
|
||||
mock_experiment_manager.run_agent_variant_tests__v1.return_value = (
|
||||
mock_updated_agent
|
||||
)
|
||||
|
||||
workspace = LocalWorkspace(working_dir='/test')
|
||||
secrets = {}
|
||||
|
||||
# Act
|
||||
result = await self.service._finalize_conversation_request(
|
||||
mock_agent,
|
||||
None,
|
||||
self.mock_user,
|
||||
workspace,
|
||||
None,
|
||||
secrets,
|
||||
self.mock_sandbox,
|
||||
None,
|
||||
None,
|
||||
'/test/dir',
|
||||
plugins=None,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, StartConversationRequest)
|
||||
assert result.plugins is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch(
|
||||
'openhands.app_server.app_conversation.live_status_app_conversation_service.ExperimentManagerImpl'
|
||||
)
|
||||
async def test_finalize_conversation_request_plugin_without_ref(
|
||||
self, mock_experiment_manager
|
||||
):
|
||||
"""Test _finalize_conversation_request with plugin that has no ref."""
|
||||
from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
PluginSpec,
|
||||
)
|
||||
|
||||
# Arrange
|
||||
mock_agent = Mock(spec=Agent)
|
||||
mock_llm = Mock(spec=LLM)
|
||||
mock_llm.model = 'gpt-4'
|
||||
mock_llm.usage_id = 'agent'
|
||||
|
||||
mock_updated_agent = Mock(spec=Agent)
|
||||
mock_updated_agent.llm = mock_llm
|
||||
mock_updated_agent.condenser = None
|
||||
mock_experiment_manager.run_agent_variant_tests__v1.return_value = (
|
||||
mock_updated_agent
|
||||
)
|
||||
|
||||
workspace = LocalWorkspace(working_dir='/test')
|
||||
secrets = {}
|
||||
|
||||
# Plugin without ref or parameters
|
||||
plugins = [PluginSpec(source='github:owner/my-plugin')]
|
||||
|
||||
# Act
|
||||
result = await self.service._finalize_conversation_request(
|
||||
mock_agent,
|
||||
None,
|
||||
self.mock_user,
|
||||
workspace,
|
||||
None,
|
||||
secrets,
|
||||
self.mock_sandbox,
|
||||
None,
|
||||
None,
|
||||
'/test/dir',
|
||||
plugins=plugins,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, StartConversationRequest)
|
||||
assert result.plugins is not None
|
||||
assert len(result.plugins) == 1
|
||||
assert result.plugins[0].source == 'github:owner/my-plugin'
|
||||
assert result.plugins[0].ref is None
|
||||
# No parameters, so initial message should be None
|
||||
assert result.initial_message is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch(
|
||||
'openhands.app_server.app_conversation.live_status_app_conversation_service.ExperimentManagerImpl'
|
||||
)
|
||||
async def test_finalize_conversation_request_plugin_with_repo_path(
|
||||
self, mock_experiment_manager
|
||||
):
|
||||
"""Test _finalize_conversation_request passes repo_path to PluginSource."""
|
||||
from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
PluginSpec,
|
||||
)
|
||||
|
||||
# Arrange
|
||||
mock_agent = Mock(spec=Agent)
|
||||
mock_llm = Mock(spec=LLM)
|
||||
mock_llm.model = 'gpt-4'
|
||||
mock_llm.usage_id = 'agent'
|
||||
|
||||
mock_updated_agent = Mock(spec=Agent)
|
||||
mock_updated_agent.llm = mock_llm
|
||||
mock_updated_agent.condenser = None
|
||||
mock_experiment_manager.run_agent_variant_tests__v1.return_value = (
|
||||
mock_updated_agent
|
||||
)
|
||||
|
||||
workspace = LocalWorkspace(working_dir='/test')
|
||||
secrets = {}
|
||||
|
||||
# Plugin with repo_path (for marketplace repos containing multiple plugins)
|
||||
plugins = [
|
||||
PluginSpec(
|
||||
source='github:owner/marketplace-repo',
|
||||
ref='main',
|
||||
repo_path='plugins/city-weather',
|
||||
)
|
||||
]
|
||||
|
||||
# Act
|
||||
result = await self.service._finalize_conversation_request(
|
||||
mock_agent,
|
||||
None,
|
||||
self.mock_user,
|
||||
workspace,
|
||||
None,
|
||||
secrets,
|
||||
self.mock_sandbox,
|
||||
None,
|
||||
None,
|
||||
'/test/dir',
|
||||
plugins=plugins,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, StartConversationRequest)
|
||||
assert result.plugins is not None
|
||||
assert len(result.plugins) == 1
|
||||
assert result.plugins[0].source == 'github:owner/marketplace-repo'
|
||||
assert result.plugins[0].ref == 'main'
|
||||
assert result.plugins[0].repo_path == 'plugins/city-weather'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch(
|
||||
'openhands.app_server.app_conversation.live_status_app_conversation_service.ExperimentManagerImpl'
|
||||
)
|
||||
async def test_finalize_conversation_request_multiple_plugins(
|
||||
self, mock_experiment_manager
|
||||
):
|
||||
"""Test _finalize_conversation_request with multiple plugins."""
|
||||
from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
PluginSpec,
|
||||
)
|
||||
|
||||
# Arrange
|
||||
mock_agent = Mock(spec=Agent)
|
||||
mock_llm = Mock(spec=LLM)
|
||||
mock_llm.model = 'gpt-4'
|
||||
mock_llm.usage_id = 'agent'
|
||||
|
||||
mock_updated_agent = Mock(spec=Agent)
|
||||
mock_updated_agent.llm = mock_llm
|
||||
mock_updated_agent.condenser = None
|
||||
mock_experiment_manager.run_agent_variant_tests__v1.return_value = (
|
||||
mock_updated_agent
|
||||
)
|
||||
|
||||
workspace = LocalWorkspace(working_dir='/test')
|
||||
secrets = {}
|
||||
|
||||
# Multiple plugins
|
||||
plugins = [
|
||||
PluginSpec(source='github:owner/security-plugin', ref='v2.0.0'),
|
||||
PluginSpec(
|
||||
source='github:owner/monorepo',
|
||||
repo_path='plugins/logging',
|
||||
),
|
||||
PluginSpec(source='/local/path/to/plugin'),
|
||||
]
|
||||
|
||||
# Act
|
||||
result = await self.service._finalize_conversation_request(
|
||||
mock_agent,
|
||||
None,
|
||||
self.mock_user,
|
||||
workspace,
|
||||
None,
|
||||
secrets,
|
||||
self.mock_sandbox,
|
||||
None,
|
||||
None,
|
||||
'/test/dir',
|
||||
plugins=plugins,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, StartConversationRequest)
|
||||
assert result.plugins is not None
|
||||
assert len(result.plugins) == 3
|
||||
assert result.plugins[0].source == 'github:owner/security-plugin'
|
||||
assert result.plugins[0].ref == 'v2.0.0'
|
||||
assert result.plugins[1].source == 'github:owner/monorepo'
|
||||
assert result.plugins[1].repo_path == 'plugins/logging'
|
||||
assert result.plugins[2].source == '/local/path/to/plugin'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_start_conversation_request_for_user_with_plugins(self):
|
||||
"""Test _build_start_conversation_request_for_user passes plugins to finalize method."""
|
||||
from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
PluginSpec,
|
||||
)
|
||||
|
||||
# Arrange
|
||||
self.mock_user_context.get_user_info.return_value = self.mock_user
|
||||
self.mock_user_context.get_secrets.return_value = {}
|
||||
self.mock_user_context.get_provider_tokens = AsyncMock(return_value=None)
|
||||
self.mock_user_context.get_mcp_api_key.return_value = None
|
||||
|
||||
plugins = [
|
||||
PluginSpec(
|
||||
source='https://github.com/org/plugin.git',
|
||||
ref='main',
|
||||
parameters={'config_file': 'custom.yaml'},
|
||||
)
|
||||
]
|
||||
|
||||
# Mock _finalize_conversation_request to capture the call
|
||||
mock_finalize = AsyncMock(return_value=Mock(spec=StartConversationRequest))
|
||||
self.service._finalize_conversation_request = mock_finalize
|
||||
|
||||
# Act
|
||||
await self.service._build_start_conversation_request_for_user(
|
||||
self.mock_sandbox,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
'/workspace',
|
||||
plugins=plugins,
|
||||
)
|
||||
|
||||
# Assert
|
||||
mock_finalize.assert_called_once()
|
||||
call_kwargs = mock_finalize.call_args.kwargs
|
||||
assert call_kwargs['plugins'] == plugins
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_start_conversation_request_for_user_without_plugins(self):
|
||||
"""Test _build_start_conversation_request_for_user works without plugins."""
|
||||
# Arrange
|
||||
self.mock_user_context.get_user_info.return_value = self.mock_user
|
||||
self.mock_user_context.get_secrets.return_value = {}
|
||||
self.mock_user_context.get_provider_tokens = AsyncMock(return_value=None)
|
||||
self.mock_user_context.get_mcp_api_key.return_value = None
|
||||
|
||||
# Mock _finalize_conversation_request
|
||||
mock_finalize = AsyncMock(return_value=Mock(spec=StartConversationRequest))
|
||||
self.service._finalize_conversation_request = mock_finalize
|
||||
|
||||
# Act
|
||||
await self.service._build_start_conversation_request_for_user(
|
||||
self.mock_sandbox,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
'/workspace',
|
||||
)
|
||||
|
||||
# Assert
|
||||
mock_finalize.assert_called_once()
|
||||
call_kwargs = mock_finalize.call_args.kwargs
|
||||
assert call_kwargs.get('plugins') is None
|
||||
|
||||
|
||||
class TestPluginSpecModel:
|
||||
"""Test cases for the PluginSpec model."""
|
||||
|
||||
def test_plugin_spec_with_all_fields(self):
|
||||
"""Test PluginSpec with all fields provided."""
|
||||
from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
PluginSpec,
|
||||
)
|
||||
|
||||
plugin = PluginSpec(
|
||||
source='github:owner/repo',
|
||||
ref='v1.0.0',
|
||||
repo_path='plugins/my-plugin',
|
||||
parameters={'key1': 'value1', 'key2': 123, 'key3': True},
|
||||
)
|
||||
|
||||
assert plugin.source == 'github:owner/repo'
|
||||
assert plugin.ref == 'v1.0.0'
|
||||
assert plugin.repo_path == 'plugins/my-plugin'
|
||||
assert plugin.parameters == {'key1': 'value1', 'key2': 123, 'key3': True}
|
||||
|
||||
def test_plugin_spec_with_only_source(self):
|
||||
"""Test PluginSpec with only source provided."""
|
||||
from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
PluginSpec,
|
||||
)
|
||||
|
||||
plugin = PluginSpec(source='https://github.com/owner/repo.git')
|
||||
|
||||
assert plugin.source == 'https://github.com/owner/repo.git'
|
||||
assert plugin.ref is None
|
||||
assert plugin.repo_path is None
|
||||
assert plugin.parameters is None
|
||||
|
||||
def test_plugin_spec_serialization(self):
|
||||
"""Test PluginSpec serialization to JSON."""
|
||||
from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
PluginSpec,
|
||||
)
|
||||
|
||||
plugin = PluginSpec(
|
||||
source='github:owner/repo',
|
||||
ref='main',
|
||||
repo_path='plugins/my-plugin',
|
||||
parameters={'debug': True},
|
||||
)
|
||||
|
||||
json_data = plugin.model_dump()
|
||||
assert json_data == {
|
||||
'source': 'github:owner/repo',
|
||||
'ref': 'main',
|
||||
'repo_path': 'plugins/my-plugin',
|
||||
'parameters': {'debug': True},
|
||||
}
|
||||
|
||||
def test_plugin_spec_deserialization(self):
|
||||
"""Test PluginSpec deserialization from dict."""
|
||||
from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
PluginSpec,
|
||||
)
|
||||
|
||||
data = {
|
||||
'source': 'github:owner/repo',
|
||||
'ref': 'v2.0.0',
|
||||
'repo_path': 'plugins/weather',
|
||||
'parameters': {'timeout': 30},
|
||||
}
|
||||
|
||||
plugin = PluginSpec.model_validate(data)
|
||||
|
||||
assert plugin.source == 'github:owner/repo'
|
||||
assert plugin.ref == 'v2.0.0'
|
||||
assert plugin.repo_path == 'plugins/weather'
|
||||
assert plugin.parameters == {'timeout': 30}
|
||||
|
||||
def test_plugin_spec_display_name_github_format(self):
|
||||
"""Test display_name extracts repo name from github:owner/repo format."""
|
||||
from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
PluginSpec,
|
||||
)
|
||||
|
||||
plugin = PluginSpec(source='github:owner/my-plugin')
|
||||
assert plugin.display_name == 'my-plugin'
|
||||
|
||||
def test_plugin_spec_display_name_git_url(self):
|
||||
"""Test display_name extracts repo name from git URL."""
|
||||
from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
PluginSpec,
|
||||
)
|
||||
|
||||
plugin = PluginSpec(source='https://github.com/owner/repo.git')
|
||||
assert plugin.display_name == 'repo.git'
|
||||
|
||||
def test_plugin_spec_display_name_local_path(self):
|
||||
"""Test display_name extracts directory name from local path."""
|
||||
from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
PluginSpec,
|
||||
)
|
||||
|
||||
plugin = PluginSpec(source='/local/path/to/plugin')
|
||||
assert plugin.display_name == 'plugin'
|
||||
|
||||
def test_plugin_spec_display_name_no_slash(self):
|
||||
"""Test display_name returns source as-is when no slash present."""
|
||||
from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
PluginSpec,
|
||||
)
|
||||
|
||||
plugin = PluginSpec(source='local-plugin')
|
||||
assert plugin.display_name == 'local-plugin'
|
||||
|
||||
def test_plugin_spec_format_params_as_text(self):
|
||||
"""Test format_params_as_text formats parameters as text."""
|
||||
from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
PluginSpec,
|
||||
)
|
||||
|
||||
plugin = PluginSpec(
|
||||
source='github:owner/repo',
|
||||
parameters={'key1': 'value1', 'key2': 123},
|
||||
)
|
||||
|
||||
result = plugin.format_params_as_text()
|
||||
assert result == '- key1: value1\n- key2: 123'
|
||||
|
||||
def test_plugin_spec_format_params_as_text_with_indent(self):
|
||||
"""Test format_params_as_text with custom indent."""
|
||||
from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
PluginSpec,
|
||||
)
|
||||
|
||||
plugin = PluginSpec(
|
||||
source='github:owner/repo',
|
||||
parameters={'debug': True},
|
||||
)
|
||||
|
||||
result = plugin.format_params_as_text(indent=' ')
|
||||
assert result == ' - debug: True'
|
||||
|
||||
def test_plugin_spec_format_params_as_text_no_params(self):
|
||||
"""Test format_params_as_text returns None when no parameters."""
|
||||
from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
PluginSpec,
|
||||
)
|
||||
|
||||
plugin = PluginSpec(source='github:owner/repo')
|
||||
assert plugin.format_params_as_text() is None
|
||||
|
||||
def test_plugin_spec_inherits_repo_path_validation(self):
|
||||
"""Test PluginSpec inherits validation from SDK's PluginSource."""
|
||||
import pytest
|
||||
|
||||
from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
PluginSpec,
|
||||
)
|
||||
|
||||
# Should reject absolute paths
|
||||
with pytest.raises(ValueError, match='must be relative'):
|
||||
PluginSpec(source='github:owner/repo', repo_path='/absolute/path')
|
||||
|
||||
# Should reject parent traversal
|
||||
with pytest.raises(ValueError, match="cannot contain '..'"):
|
||||
PluginSpec(source='github:owner/repo', repo_path='../parent/path')
|
||||
|
||||
|
||||
class TestAppConversationStartRequestWithPlugins:
|
||||
"""Test cases for AppConversationStartRequest with plugins field."""
|
||||
|
||||
def test_start_request_with_plugins(self):
|
||||
"""Test AppConversationStartRequest with plugins field."""
|
||||
from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
AppConversationStartRequest,
|
||||
PluginSpec,
|
||||
)
|
||||
|
||||
plugins = [
|
||||
PluginSpec(
|
||||
source='github:owner/my-plugin',
|
||||
ref='v1.0.0',
|
||||
parameters={'api_key': 'test'},
|
||||
)
|
||||
]
|
||||
|
||||
request = AppConversationStartRequest(
|
||||
title='Test conversation',
|
||||
plugins=plugins,
|
||||
)
|
||||
|
||||
assert request.plugins is not None
|
||||
assert len(request.plugins) == 1
|
||||
assert request.plugins[0].source == 'github:owner/my-plugin'
|
||||
assert request.plugins[0].ref == 'v1.0.0'
|
||||
assert request.plugins[0].parameters == {'api_key': 'test'}
|
||||
|
||||
def test_start_request_without_plugins(self):
|
||||
"""Test AppConversationStartRequest without plugins field (backwards compatible)."""
|
||||
from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
AppConversationStartRequest,
|
||||
)
|
||||
|
||||
request = AppConversationStartRequest(
|
||||
title='Test conversation',
|
||||
)
|
||||
|
||||
assert request.plugins is None
|
||||
|
||||
def test_start_request_serialization_with_plugins(self):
|
||||
"""Test AppConversationStartRequest serialization includes plugins."""
|
||||
from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
AppConversationStartRequest,
|
||||
PluginSpec,
|
||||
)
|
||||
|
||||
plugins = [PluginSpec(source='github:owner/repo')]
|
||||
request = AppConversationStartRequest(plugins=plugins)
|
||||
|
||||
json_data = request.model_dump()
|
||||
|
||||
assert 'plugins' in json_data
|
||||
assert len(json_data['plugins']) == 1
|
||||
assert json_data['plugins'][0]['source'] == 'github:owner/repo'
|
||||
|
||||
def test_start_request_deserialization_with_plugins(self):
|
||||
"""Test AppConversationStartRequest deserialization from JSON with plugins."""
|
||||
from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
AppConversationStartRequest,
|
||||
)
|
||||
|
||||
data = {
|
||||
'title': 'Test',
|
||||
'plugins': [
|
||||
{
|
||||
'source': 'github:owner/plugin',
|
||||
'ref': 'main',
|
||||
'parameters': {'key': 'value'},
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
request = AppConversationStartRequest.model_validate(data)
|
||||
|
||||
assert request.plugins is not None
|
||||
assert len(request.plugins) == 1
|
||||
assert request.plugins[0].source == 'github:owner/plugin'
|
||||
assert request.plugins[0].ref == 'main'
|
||||
assert request.plugins[0].parameters == {'key': 'value'}
|
||||
|
||||
def test_start_request_with_multiple_plugins(self):
|
||||
"""Test AppConversationStartRequest with multiple plugins."""
|
||||
from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
AppConversationStartRequest,
|
||||
PluginSpec,
|
||||
)
|
||||
|
||||
plugins = [
|
||||
PluginSpec(source='github:owner/plugin1', ref='v1.0.0'),
|
||||
PluginSpec(source='github:owner/plugin2', repo_path='plugins/sub'),
|
||||
PluginSpec(source='/local/path'),
|
||||
]
|
||||
|
||||
request = AppConversationStartRequest(
|
||||
title='Test conversation',
|
||||
plugins=plugins,
|
||||
)
|
||||
|
||||
assert request.plugins is not None
|
||||
assert len(request.plugins) == 3
|
||||
assert request.plugins[0].source == 'github:owner/plugin1'
|
||||
assert request.plugins[1].repo_path == 'plugins/sub'
|
||||
assert request.plugins[2].source == '/local/path'
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user