Compare commits

..

1 Commits

Author SHA1 Message Date
mamoodi
c97d66131d Release 1.2.0 2026-01-15 10:08:32 -05:00
336 changed files with 20604 additions and 32704 deletions

View File

@@ -39,7 +39,8 @@ jobs:
run: |
if [[ "$GITHUB_EVENT_NAME" == "pull_request" ]]; then
json=$(jq -n -c '[
{ image: "nikolaik/python-nodejs:python3.12-nodejs22", tag: "nikolaik" }
{ image: "nikolaik/python-nodejs:python3.12-nodejs22", tag: "nikolaik" },
{ image: "ubuntu:24.04", tag: "ubuntu" }
]')
else
json=$(jq -n -c '[
@@ -148,9 +149,6 @@ jobs:
push: true
tags: ${{ env.DOCKER_TAGS }}
platforms: ${{ env.DOCKER_PLATFORM }}
# Caching directives to boost performance
cache-from: type=registry,ref=ghcr.io/${{ env.REPO_OWNER }}/runtime:buildcache-${{ matrix.base_image.tag }}
cache-to: type=registry,ref=ghcr.io/${{ env.REPO_OWNER }}/runtime:buildcache-${{ matrix.base_image.tag }},mode=max
build-args: ${{ env.DOCKER_BUILD_ARGS }}
context: containers/runtime
provenance: false

View File

@@ -12,8 +12,7 @@ services:
- SANDBOX_API_HOSTNAME=host.docker.internal
- DOCKER_HOST_ADDR=host.docker.internal
#
- AGENT_SERVER_IMAGE_REPOSITORY=${AGENT_SERVER_IMAGE_REPOSITORY:-ghcr.io/openhands/runtime}
- AGENT_SERVER_IMAGE_TAG=${AGENT_SERVER_IMAGE_TAG:-1.2-nikolaik}
- SANDBOX_RUNTIME_CONTAINER_IMAGE=${SANDBOX_RUNTIME_CONTAINER_IMAGE:-ghcr.io/openhands/runtime:1.2-nikolaik}
- SANDBOX_USER_ID=${SANDBOX_USER_ID:-1234}
- WORKSPACE_MOUNT_PATH=${WORKSPACE_BASE:-$PWD/workspace}
ports:

View File

@@ -7,8 +7,7 @@ services:
image: openhands:latest
container_name: openhands-app-${DATE:-}
environment:
- AGENT_SERVER_IMAGE_REPOSITORY=${AGENT_SERVER_IMAGE_REPOSITORY:-ghcr.io/openhands/agent-server}
- AGENT_SERVER_IMAGE_TAG=${AGENT_SERVER_IMAGE_TAG:-31536c8-python}
- SANDBOX_RUNTIME_CONTAINER_IMAGE=${SANDBOX_RUNTIME_CONTAINER_IMAGE:-docker.openhands.dev/openhands/runtime:1.2-nikolaik}
#- SANDBOX_USER_ID=${SANDBOX_USER_ID:-1234} # enable this only if you want a specific non-root sandbox user but you will have to manually adjust permissions of ~/.openhands for this user
- WORKSPACE_MOUNT_PATH=${WORKSPACE_BASE:-$PWD/workspace}
ports:

View File

@@ -2,7 +2,7 @@ BACKEND_HOST ?= "127.0.0.1"
BACKEND_PORT = 3000
BACKEND_HOST_PORT = "$(BACKEND_HOST):$(BACKEND_PORT)"
FRONTEND_PORT = 3001
OPENHANDS_PATH ?= ".."
OPENHANDS_PATH ?= "../../OpenHands"
OPENHANDS := $(OPENHANDS_PATH)
OPENHANDS_FRONTEND_PATH = $(OPENHANDS)/frontend/build

View File

@@ -26,14 +26,12 @@ from integrations.utils import (
from integrations.v1_utils import get_saas_user_auth
from jinja2 import Environment, FileSystemLoader
from pydantic import SecretStr
from server.auth.auth_error import ExpiredError
from server.auth.constants import GITHUB_APP_CLIENT_ID, GITHUB_APP_PRIVATE_KEY
from server.auth.token_manager import TokenManager
from server.utils.conversation_callback_utils import register_callback_processor
from openhands.core.logger import openhands_logger as logger
from openhands.integrations.provider import ProviderToken, ProviderType
from openhands.integrations.service_types import AuthenticationError
from openhands.server.types import (
LLMAuthenticationError,
MissingSettingsError,
@@ -349,7 +347,7 @@ class GithubManager(Manager):
msg_info = f'@{user_info.username} please set a valid LLM API key in [OpenHands Cloud]({HOST_URL}) before starting a job.'
except (AuthenticationError, ExpiredError, SessionExpiredError) as e:
except SessionExpiredError as e:
logger.warning(
f'[GitHub] Session expired for user {user_info.username}: {str(e)}'
)

View File

@@ -1,6 +1,6 @@
import asyncio
from integrations.store_repo_utils import store_repositories_in_db
from integrations.utils import store_repositories_in_db
from pydantic import SecretStr
from server.auth.token_manager import TokenManager

View File

@@ -25,9 +25,9 @@ from server.auth.constants import GITHUB_APP_CLIENT_ID, GITHUB_APP_PRIVATE_KEY
from server.auth.token_manager import TokenManager
from server.config import get_config
from storage.database import session_maker
from storage.org_store import OrgStore
from storage.proactive_conversation_store import ProactiveConversationStore
from storage.saas_secrets_store import SaasSecretsStore
from storage.saas_settings_store import SaasSettingsStore
from openhands.agent_server.models import SendMessageRequest
from openhands.app_server.app_conversation.app_conversation_models import (
@@ -78,17 +78,19 @@ async def get_user_proactive_conversation_setting(user_id: str | None) -> bool:
if not user_id:
return False
# Check global setting first - if disabled globally, return False
if not ENABLE_PROACTIVE_CONVERSATION_STARTERS:
config = get_config()
settings_store = SaasSettingsStore(
user_id=user_id, session_maker=session_maker, config=config
)
settings = await call_sync_from_async(
settings_store.get_user_settings_by_keycloak_id, user_id
)
if not settings or settings.enable_proactive_conversation_starters is None:
return False
def _get_setting():
org = OrgStore.get_current_org_from_keycloak_user_id(user_id)
if not org:
return False
return bool(org.enable_proactive_conversation_starters)
return await call_sync_from_async(_get_setting)
return settings.enable_proactive_conversation_starters
# =================================================
@@ -149,7 +151,6 @@ class GithubIssue(ResolverViewInterface):
issue_body=self.description,
previous_comments=self.previous_comments,
)
return user_instructions, conversation_instructions
async def _get_user_secrets(self):
@@ -188,7 +189,6 @@ class GithubIssue(ResolverViewInterface):
conversation_trigger=ConversationTrigger.RESOLVER,
git_provider=ProviderType.GITHUB,
)
self.conversation_id = conversation_metadata.conversation_id
return conversation_metadata
@@ -327,6 +327,7 @@ class GithubIssueComment(GithubIssue):
conversation_instructions_template = jinja_env.get_template(
'issue_conversation_instructions.j2'
)
conversation_instructions = conversation_instructions_template.render(
issue_number=self.issue_number,
issue_title=self.title,
@@ -396,6 +397,7 @@ class GithubInlinePRComment(GithubPRComment):
conversation_instructions_template = jinja_env.get_template(
'pr_update_conversation_instructions.j2'
)
conversation_instructions = conversation_instructions_template.render(
pr_number=self.issue_number,
pr_title=self.title,

View File

@@ -1,7 +1,7 @@
import asyncio
from integrations.store_repo_utils import store_repositories_in_db
from integrations.types import GitLabResourceType
from integrations.utils import store_repositories_in_db
from pydantic import SecretStr
from server.auth.token_manager import TokenManager
from storage.gitlab_webhook import GitlabWebhook, WebhookStatus

View File

@@ -1,37 +1,22 @@
"""Jira integration manager.
This module orchestrates the processing of Jira webhook events:
1. Parse webhook payload (via JiraPayloadParser)
2. Validate workspace
3. Authenticate user
4. Create view with repository selection (via JiraFactory)
5. Start conversation job
The manager delegates payload parsing to JiraPayloadParser and view creation
to JiraFactory, keeping the orchestration logic clean and traceable.
"""
import hashlib
import hmac
from typing import Dict, Optional, Tuple
from urllib.parse import urlparse
import httpx
from integrations.jira.jira_payload import (
JiraPayloadError,
JiraPayloadParser,
JiraPayloadSkipped,
JiraPayloadSuccess,
JiraWebhookPayload,
from fastapi import Request
from integrations.jira.jira_types import JiraViewInterface
from integrations.jira.jira_view import (
JiraExistingConversationView,
JiraFactory,
JiraNewConversationView,
)
from integrations.jira.jira_types import (
JiraViewInterface,
RepositoryNotFoundError,
StartingConvoException,
)
from integrations.jira.jira_view import JiraFactory, JiraNewConversationView
from integrations.manager import Manager
from integrations.models import Message
from integrations.models import JobContext, Message
from integrations.utils import (
HOST,
HOST_URL,
OPENHANDS_RESOLVER_TEMPLATES_DIR,
get_oh_labels,
filter_potential_repos_by_user_msg,
get_session_expired_message,
)
from jinja2 import Environment, FileSystemLoader
@@ -43,6 +28,9 @@ from storage.jira_user import JiraUser
from storage.jira_workspace import JiraWorkspace
from openhands.core.logger import openhands_logger as logger
from openhands.integrations.provider import ProviderHandler
from openhands.integrations.service_types import Repository
from openhands.server.shared import server_config
from openhands.server.types import (
LLMAuthenticationError,
MissingSettingsError,
@@ -53,211 +41,303 @@ from openhands.utils.http_session import httpx_verify_option
JIRA_CLOUD_API_URL = 'https://api.atlassian.com/ex/jira'
# Get OH labels for this environment
OH_LABEL, INLINE_OH_LABEL = get_oh_labels(HOST)
class JiraManager(Manager):
"""Manager for processing Jira webhook events.
This class orchestrates the flow from webhook receipt to conversation creation,
delegating parsing to JiraPayloadParser and view creation to JiraFactory.
"""
def __init__(self, token_manager: TokenManager):
self.token_manager = token_manager
self.integration_store = JiraIntegrationStore.get_instance()
self.jinja_env = Environment(
loader=FileSystemLoader(OPENHANDS_RESOLVER_TEMPLATES_DIR + 'jira')
)
self.payload_parser = JiraPayloadParser(
oh_label=OH_LABEL,
inline_oh_label=INLINE_OH_LABEL,
)
async def receive_message(self, message: Message):
"""Process incoming Jira webhook message.
Flow:
1. Parse webhook payload
2. Validate workspace exists and is active
3. Authenticate user
4. Create view (includes fetching issue details and selecting repository)
5. Start job
Each step has clear logging for traceability.
"""
raw_payload = message.message.get('payload', {})
# Step 1: Parse webhook payload
logger.info(
'[Jira] Received webhook',
extra={'raw_payload': raw_payload},
)
parse_result = self.payload_parser.parse(raw_payload)
if isinstance(parse_result, JiraPayloadSkipped):
logger.info(
'[Jira] Webhook skipped', extra={'reason': parse_result.skip_reason}
)
return
if isinstance(parse_result, JiraPayloadError):
logger.warning(
'[Jira] Webhook parse failed', extra={'error': parse_result.error}
)
return
payload = parse_result.payload
logger.info(
'[Jira] Processing webhook',
extra={
'event_type': payload.event_type.value,
'issue_key': payload.issue_key,
'user_email': payload.user_email,
},
)
# Step 2: Validate workspace
workspace = await self._get_active_workspace(payload)
if not workspace:
return
# Step 3: Authenticate user
jira_user, saas_user_auth = await self._authenticate_user(payload, workspace)
if not jira_user or not saas_user_auth:
return
# Step 4: Create view (includes issue details fetch and repo selection)
decrypted_api_key = self.token_manager.decrypt_text(workspace.svc_acc_api_key)
try:
view = await JiraFactory.create_view(
payload=payload,
workspace=workspace,
user=jira_user,
user_auth=saas_user_auth,
decrypted_api_key=decrypted_api_key,
)
except RepositoryNotFoundError as e:
logger.warning(
'[Jira] Repository not found',
extra={'issue_key': payload.issue_key, 'error': str(e)},
)
await self._send_error_from_payload(payload, workspace, str(e))
return
except StartingConvoException as e:
logger.warning(
'[Jira] View creation failed',
extra={'issue_key': payload.issue_key, 'error': str(e)},
)
await self._send_error_from_payload(payload, workspace, str(e))
return
except Exception as e:
logger.error(
'[Jira] Unexpected error creating view',
extra={'issue_key': payload.issue_key, 'error': str(e)},
exc_info=True,
)
await self._send_error_from_payload(
payload,
workspace,
'Failed to initialize conversation. Please try again.',
)
return
# Step 5: Start job
await self.start_job(view)
async def _get_active_workspace(
self, payload: JiraWebhookPayload
) -> JiraWorkspace | None:
"""Validate and return the workspace for the webhook.
Returns None if:
- Workspace not found
- Workspace is inactive
- Request is from service account (to prevent recursion)
"""
workspace = await self.integration_store.get_workspace_by_name(
payload.workspace_name
)
if not workspace:
logger.warning(
'[Jira] Workspace not found',
extra={'workspace_name': payload.workspace_name},
)
# Can't send error without workspace credentials
return None
# Prevent recursive triggers from service account
if payload.user_email == workspace.svc_acc_email:
logger.debug(
'[Jira] Ignoring service account trigger',
extra={'workspace_name': payload.workspace_name},
)
return None
if workspace.status != 'active':
logger.warning(
'[Jira] Workspace inactive',
extra={'workspace_id': workspace.id, 'status': workspace.status},
)
await self._send_error_from_payload(
payload, workspace, 'Jira integration is not active for your workspace.'
)
return None
return workspace
async def _authenticate_user(
self, payload: JiraWebhookPayload, workspace: JiraWorkspace
async def authenticate_user(
self, jira_user_id: str, workspace_id: int
) -> tuple[JiraUser | None, UserAuth | None]:
"""Authenticate the Jira user and get OpenHands auth."""
"""Authenticate Jira user and get their OpenHands user auth."""
# Find active Jira user by Keycloak user ID and workspace ID
jira_user = await self.integration_store.get_active_user(
payload.account_id, workspace.id
jira_user_id, workspace_id
)
if not jira_user:
logger.warning(
'[Jira] User not found or inactive',
extra={
'account_id': payload.account_id,
'user_email': payload.user_email,
'workspace_id': workspace.id,
},
)
await self._send_error_from_payload(
payload,
workspace,
f'User {payload.user_email} is not authenticated or active in the Jira integration.',
f'[Jira] No active Jira user found for {jira_user_id} in workspace {workspace_id}'
)
return None, None
saas_user_auth = await get_user_auth_from_keycloak_id(
jira_user.keycloak_user_id
)
if not saas_user_auth:
logger.warning(
'[Jira] Failed to get OpenHands auth',
extra={
'keycloak_user_id': jira_user.keycloak_user_id,
'user_email': payload.user_email,
},
)
await self._send_error_from_payload(
payload,
workspace,
f'User {payload.user_email} is not authenticated with OpenHands.',
)
return None, None
return jira_user, saas_user_auth
async def start_job(self, view: JiraViewInterface):
async def _get_repositories(self, user_auth: UserAuth) -> list[Repository]:
"""Get repositories that the user has access to."""
provider_tokens = await user_auth.get_provider_tokens()
if provider_tokens is None:
return []
access_token = await user_auth.get_access_token()
user_id = await user_auth.get_user_id()
client = ProviderHandler(
provider_tokens=provider_tokens,
external_auth_token=access_token,
external_auth_id=user_id,
)
repos: list[Repository] = await client.get_repositories(
'pushed', server_config.app_mode, None, None, None, None
)
return repos
async def validate_request(
self, request: Request
) -> Tuple[bool, Optional[str], Optional[Dict]]:
"""Verify Jira webhook signature."""
signature_header = request.headers.get('x-hub-signature')
signature = signature_header.split('=')[1] if signature_header else None
body = await request.body()
payload = await request.json()
workspace_name = ''
if payload.get('webhookEvent') == 'comment_created':
selfUrl = payload.get('comment', {}).get('author', {}).get('self')
elif payload.get('webhookEvent') == 'jira:issue_updated':
selfUrl = payload.get('user', {}).get('self')
else:
workspace_name = ''
parsedUrl = urlparse(selfUrl)
if parsedUrl.hostname:
workspace_name = parsedUrl.hostname
if not workspace_name:
logger.warning('[Jira] No workspace name found in webhook payload')
return False, None, None
if not signature:
logger.warning('[Jira] No signature found in webhook headers')
return False, None, None
workspace = await self.integration_store.get_workspace_by_name(workspace_name)
if not workspace:
logger.warning('[Jira] Could not identify workspace for webhook')
return False, None, None
if workspace.status != 'active':
logger.warning(f'[Jira] Workspace {workspace.id} is not active')
return False, None, None
webhook_secret = self.token_manager.decrypt_text(workspace.webhook_secret)
digest = hmac.new(webhook_secret.encode(), body, hashlib.sha256).hexdigest()
if hmac.compare_digest(signature, digest):
logger.info('[Jira] Webhook signature verified successfully')
return True, signature, payload
return False, None, None
def parse_webhook(self, payload: Dict) -> JobContext | None:
event_type = payload.get('webhookEvent')
if event_type == 'comment_created':
comment_data = payload.get('comment', {})
comment = comment_data.get('body', '')
if '@openhands' not in comment:
return None
issue_data = payload.get('issue', {})
issue_id = issue_data.get('id')
issue_key = issue_data.get('key')
base_api_url = issue_data.get('self', '').split('/rest/')[0]
user_data = comment_data.get('author', {})
user_email = user_data.get('emailAddress')
display_name = user_data.get('displayName')
account_id = user_data.get('accountId')
elif event_type == 'jira:issue_updated':
changelog = payload.get('changelog', {})
items = changelog.get('items', [])
labels = [
item.get('toString', '')
for item in items
if item.get('field') == 'labels' and 'toString' in item
]
if 'openhands' not in labels:
return None
issue_data = payload.get('issue', {})
issue_id = issue_data.get('id')
issue_key = issue_data.get('key')
base_api_url = issue_data.get('self', '').split('/rest/')[0]
user_data = payload.get('user', {})
user_email = user_data.get('emailAddress')
display_name = user_data.get('displayName')
account_id = user_data.get('accountId')
comment = ''
else:
return None
workspace_name = ''
parsedUrl = urlparse(base_api_url)
if parsedUrl.hostname:
workspace_name = parsedUrl.hostname
if not all(
[
issue_id,
issue_key,
user_email,
display_name,
account_id,
workspace_name,
base_api_url,
]
):
return None
return JobContext(
issue_id=issue_id,
issue_key=issue_key,
user_msg=comment,
user_email=user_email,
display_name=display_name,
platform_user_id=account_id,
workspace_name=workspace_name,
base_api_url=base_api_url,
)
async def receive_message(self, message: Message):
"""Process incoming Jira webhook message."""
payload = message.message.get('payload', {})
job_context = self.parse_webhook(payload)
if not job_context:
logger.info('[Jira] Webhook does not match trigger conditions')
return
# Get workspace by user email domain
workspace = await self.integration_store.get_workspace_by_name(
job_context.workspace_name
)
if not workspace:
logger.warning(
f'[Jira] No workspace found for email domain: {job_context.user_email}'
)
await self._send_error_comment(
job_context,
'Your workspace is not configured with Jira integration.',
None,
)
return
# Prevent any recursive triggers from the service account
if job_context.user_email == workspace.svc_acc_email:
return
if workspace.status != 'active':
logger.warning(f'[Jira] Workspace {workspace.id} is not active')
await self._send_error_comment(
job_context,
'Jira integration is not active for your workspace.',
workspace,
)
return
# Authenticate user
jira_user, saas_user_auth = await self.authenticate_user(
job_context.platform_user_id, workspace.id
)
if not jira_user or not saas_user_auth:
logger.warning(
f'[Jira] User authentication failed for {job_context.user_email}'
)
await self._send_error_comment(
job_context,
f'User {job_context.user_email} is not authenticated or active in the Jira integration.',
workspace,
)
return
# Get issue details
try:
api_key = self.token_manager.decrypt_text(workspace.svc_acc_api_key)
issue_title, issue_description = await self.get_issue_details(
job_context, workspace.jira_cloud_id, workspace.svc_acc_email, api_key
)
job_context.issue_title = issue_title
job_context.issue_description = issue_description
except Exception as e:
logger.error(f'[Jira] Failed to get issue context: {str(e)}')
await self._send_error_comment(
job_context,
'Failed to retrieve issue details. Please check the issue key and try again.',
workspace,
)
return
try:
# Create Jira view
jira_view = await JiraFactory.create_jira_view_from_payload(
job_context,
saas_user_auth,
jira_user,
workspace,
)
except Exception as e:
logger.error(f'[Jira] Failed to create jira view: {str(e)}', exc_info=True)
await self._send_error_comment(
job_context,
'Failed to initialize conversation. Please try again.',
workspace,
)
return
if not await self.is_job_requested(message, jira_view):
return
await self.start_job(jira_view)
async def is_job_requested(
self, message: Message, jira_view: JiraViewInterface
) -> bool:
"""
Check if a job is requested and handle repository selection.
"""
if isinstance(jira_view, JiraExistingConversationView):
return True
try:
# Get user repositories
user_repos: list[Repository] = await self._get_repositories(
jira_view.saas_user_auth
)
target_str = f'{jira_view.job_context.issue_description}\n{jira_view.job_context.user_msg}'
# Try to infer repository from issue description
match, repos = filter_potential_repos_by_user_msg(target_str, user_repos)
if match:
# Found exact repository match
jira_view.selected_repo = repos[0].full_name
logger.info(f'[Jira] Inferred repository: {repos[0].full_name}')
return True
else:
# No clear match - send repository selection comment
await self._send_repo_selection_comment(jira_view)
return False
except Exception as e:
logger.error(f'[Jira] Error in is_job_requested: {str(e)}')
return False
async def start_job(self, jira_view: JiraViewInterface):
"""Start a Jira job/conversation."""
# Import here to prevent circular import
from server.conversation_callback_processor.jira_callback_processor import (
@@ -265,79 +345,101 @@ class JiraManager(Manager):
)
try:
user_info: JiraUser = jira_view.jira_user
logger.info(
'[Jira] Starting job',
extra={
'issue_key': view.payload.issue_key,
'user_id': view.jira_user.keycloak_user_id,
'selected_repo': view.selected_repo,
},
f'[Jira] Starting job for user {user_info.keycloak_user_id} '
f'issue {jira_view.job_context.issue_key}',
)
# Create conversation
conversation_id = await view.create_or_update_conversation(self.jinja_env)
conversation_id = await jira_view.create_or_update_conversation(
self.jinja_env
)
logger.info(
'[Jira] Conversation created',
extra={
'conversation_id': conversation_id,
'issue_key': view.payload.issue_key,
},
f'[Jira] Created/Updated conversation {conversation_id} for issue {jira_view.job_context.issue_key}'
)
# Register callback processor for updates
if isinstance(view, JiraNewConversationView):
if isinstance(jira_view, JiraNewConversationView):
processor = JiraCallbackProcessor(
issue_key=view.payload.issue_key,
workspace_name=view.jira_workspace.name,
)
register_callback_processor(conversation_id, processor)
logger.info(
'[Jira] Callback processor registered',
extra={'conversation_id': conversation_id},
issue_key=jira_view.job_context.issue_key,
workspace_name=jira_view.jira_workspace.name,
)
# Send success response
msg_info = view.get_response_msg()
# Register the callback processor
register_callback_processor(conversation_id, processor)
logger.info(
f'[Jira] Created callback processor for conversation {conversation_id}'
)
# Send initial response
msg_info = jira_view.get_response_msg()
except MissingSettingsError as e:
logger.warning(
'[Jira] Missing settings error',
extra={'issue_key': view.payload.issue_key, 'error': str(e)},
)
logger.warning(f'[Jira] Missing settings error: {str(e)}')
msg_info = f'Please re-login into [OpenHands Cloud]({HOST_URL}) before starting a job.'
except LLMAuthenticationError as e:
logger.warning(
'[Jira] LLM authentication error',
extra={'issue_key': view.payload.issue_key, 'error': str(e)},
)
logger.warning(f'[Jira] LLM authentication error: {str(e)}')
msg_info = f'Please set a valid LLM API key in [OpenHands Cloud]({HOST_URL}) before starting a job.'
except SessionExpiredError as e:
logger.warning(
'[Jira] Session expired',
extra={'issue_key': view.payload.issue_key, 'error': str(e)},
)
logger.warning(f'[Jira] Session expired: {str(e)}')
msg_info = get_session_expired_message()
except StartingConvoException as e:
logger.warning(
'[Jira] Conversation start failed',
extra={'issue_key': view.payload.issue_key, 'error': str(e)},
)
msg_info = str(e)
except Exception as e:
logger.error(
'[Jira] Unexpected error starting job',
extra={'issue_key': view.payload.issue_key, 'error': str(e)},
exc_info=True,
f'[Jira] Unexpected error starting job: {str(e)}', exc_info=True
)
msg_info = 'Sorry, there was an unexpected error starting the job. Please try again.'
# Send response comment
await self._send_comment(view, msg_info)
try:
api_key = self.token_manager.decrypt_text(
jira_view.jira_workspace.svc_acc_api_key
)
await self.send_message(
self.create_outgoing_message(msg=msg_info),
issue_key=jira_view.job_context.issue_key,
jira_cloud_id=jira_view.jira_workspace.jira_cloud_id,
svc_acc_email=jira_view.jira_workspace.svc_acc_email,
svc_acc_api_key=api_key,
)
except Exception as e:
logger.error(f'[Jira] Failed to send response message: {str(e)}')
async def get_issue_details(
self,
job_context: JobContext,
jira_cloud_id: str,
svc_acc_email: str,
svc_acc_api_key: str,
) -> Tuple[str, str]:
url = f'{JIRA_CLOUD_API_URL}/{jira_cloud_id}/rest/api/2/issue/{job_context.issue_key}'
async with httpx.AsyncClient(verify=httpx_verify_option()) as client:
response = await client.get(url, auth=(svc_acc_email, svc_acc_api_key))
response.raise_for_status()
issue_payload = response.json()
if not issue_payload:
raise ValueError(f'Issue with key {job_context.issue_key} not found.')
title = issue_payload.get('fields', {}).get('summary', '')
description = issue_payload.get('fields', {}).get('description', '')
if not title:
raise ValueError(
f'Issue with key {job_context.issue_key} does not have a title.'
)
if not description:
raise ValueError(
f'Issue with key {job_context.issue_key} does not have a description.'
)
return title, description
async def send_message(
self,
@@ -347,7 +449,6 @@ class JiraManager(Manager):
svc_acc_email: str,
svc_acc_api_key: str,
):
"""Send a comment to a Jira issue."""
url = (
f'{JIRA_CLOUD_API_URL}/{jira_cloud_id}/rest/api/2/issue/{issue_key}/comment'
)
@@ -359,53 +460,54 @@ class JiraManager(Manager):
response.raise_for_status()
return response.json()
async def _send_comment(self, view: JiraViewInterface, msg: str):
"""Send a comment using credentials from the view."""
try:
api_key = self.token_manager.decrypt_text(
view.jira_workspace.svc_acc_api_key
)
await self.send_message(
self.create_outgoing_message(msg=msg),
issue_key=view.payload.issue_key,
jira_cloud_id=view.jira_workspace.jira_cloud_id,
svc_acc_email=view.jira_workspace.svc_acc_email,
svc_acc_api_key=api_key,
)
except Exception as e:
logger.error(
'[Jira] Failed to send comment',
extra={'issue_key': view.payload.issue_key, 'error': str(e)},
)
async def _send_error_from_payload(
async def _send_error_comment(
self,
payload: JiraWebhookPayload,
workspace: JiraWorkspace,
job_context: JobContext,
error_msg: str,
workspace: JiraWorkspace | None,
):
"""Send error comment before view is created (using payload directly)."""
"""Send error comment to Jira issue."""
if not workspace:
logger.error('[Jira] Cannot send error comment - no workspace available')
return
try:
api_key = self.token_manager.decrypt_text(workspace.svc_acc_api_key)
await self.send_message(
self.create_outgoing_message(msg=error_msg),
issue_key=payload.issue_key,
issue_key=job_context.issue_key,
jira_cloud_id=workspace.jira_cloud_id,
svc_acc_email=workspace.svc_acc_email,
svc_acc_api_key=api_key,
)
except Exception as e:
logger.error(
'[Jira] Failed to send error comment',
extra={'issue_key': payload.issue_key, 'error': str(e)},
logger.error(f'[Jira] Failed to send error comment: {str(e)}')
async def _send_repo_selection_comment(self, jira_view: JiraViewInterface):
"""Send a comment with repository options for the user to choose."""
try:
comment_msg = (
'I need to know which repository to work with. '
'Please add it to your issue description or send a followup comment.'
)
def get_workspace_name_from_payload(self, payload: dict) -> str | None:
"""Extract workspace name from Jira webhook payload.
api_key = self.token_manager.decrypt_text(
jira_view.jira_workspace.svc_acc_api_key
)
This method is used by the route for signature verification.
"""
parse_result = self.payload_parser.parse(payload)
if isinstance(parse_result, JiraPayloadSuccess):
return parse_result.payload.workspace_name
return None
await self.send_message(
self.create_outgoing_message(msg=comment_msg),
issue_key=jira_view.job_context.issue_key,
jira_cloud_id=jira_view.jira_workspace.jira_cloud_id,
svc_acc_email=jira_view.jira_workspace.svc_acc_email,
svc_acc_api_key=api_key,
)
logger.info(
f'[Jira] Sent repository selection comment for issue {jira_view.job_context.issue_key}'
)
except Exception as e:
logger.error(
f'[Jira] Failed to send repository selection comment: {str(e)}'
)

View File

@@ -1,267 +0,0 @@
"""Centralized payload parsing for Jira webhooks.
This module provides a single source of truth for parsing and validating
Jira webhook payloads, replacing scattered parsing logic throughout the codebase.
"""
from dataclasses import dataclass
from enum import Enum
from urllib.parse import urlparse
from openhands.core.logger import openhands_logger as logger
class JiraEventType(Enum):
"""Types of Jira events we handle."""
LABELED_TICKET = 'labeled_ticket'
COMMENT_MENTION = 'comment_mention'
@dataclass(frozen=True)
class JiraWebhookPayload:
"""Normalized, validated representation of a Jira webhook payload.
This immutable dataclass replaces JobContext and provides a single
source of truth for all webhook data. All parsing happens in
JiraPayloadParser, ensuring consistent validation.
"""
event_type: JiraEventType
raw_event: str # Original webhookEvent value
# Issue data
issue_id: str
issue_key: str
# User data
user_email: str
display_name: str
account_id: str
# Workspace data (derived from issue self URL)
workspace_name: str
base_api_url: str
# Event-specific data
comment_body: str = '' # For comment events
@property
def user_msg(self) -> str:
"""Alias for comment_body for backward compatibility."""
return self.comment_body
class JiraPayloadParseError(Exception):
"""Raised when payload parsing fails."""
def __init__(self, reason: str, event_type: str | None = None):
self.reason = reason
self.event_type = event_type
super().__init__(reason)
@dataclass(frozen=True)
class JiraPayloadSuccess:
"""Result when parsing succeeds."""
payload: JiraWebhookPayload
@dataclass(frozen=True)
class JiraPayloadSkipped:
"""Result when event is intentionally skipped."""
skip_reason: str
@dataclass(frozen=True)
class JiraPayloadError:
"""Result when parsing fails due to invalid data."""
error: str
JiraPayloadParseResult = JiraPayloadSuccess | JiraPayloadSkipped | JiraPayloadError
class JiraPayloadParser:
"""Centralized parser for Jira webhook payloads.
This class provides a single entry point for parsing webhooks,
determining event types, and extracting all necessary fields.
Replaces scattered parsing in JiraFactory and JiraManager.
"""
def __init__(self, oh_label: str, inline_oh_label: str):
"""Initialize parser with OpenHands label configuration.
Args:
oh_label: Label that triggers OpenHands (e.g., 'openhands')
inline_oh_label: Mention that triggers OpenHands (e.g., '@openhands')
"""
self.oh_label = oh_label
self.inline_oh_label = inline_oh_label
def parse(self, raw_payload: dict) -> JiraPayloadParseResult:
"""Parse a raw webhook payload into a normalized JiraWebhookPayload.
Args:
raw_payload: The raw webhook payload dict from Jira
Returns:
One of:
- JiraPayloadSuccess: Valid, actionable event with payload
- JiraPayloadSkipped: Event we intentionally don't process
- JiraPayloadError: Malformed payload we expected to process
"""
webhook_event = raw_payload.get('webhookEvent', '')
logger.debug(
'[Jira] Parsing webhook payload', extra={'webhook_event': webhook_event}
)
if webhook_event == 'jira:issue_updated':
return self._parse_label_event(raw_payload, webhook_event)
elif webhook_event == 'comment_created':
return self._parse_comment_event(raw_payload, webhook_event)
else:
return JiraPayloadSkipped(f'Unhandled webhook event type: {webhook_event}')
def _parse_label_event(
self, payload: dict, webhook_event: str
) -> JiraPayloadParseResult:
"""Parse an issue_updated event for label changes."""
changelog = payload.get('changelog', {})
items = changelog.get('items', [])
# Extract labels that were added
labels = [
item.get('toString', '')
for item in items
if item.get('field') == 'labels' and 'toString' in item
]
if self.oh_label not in labels:
return JiraPayloadSkipped(
f"Label event does not contain '{self.oh_label}' label"
)
# For label events, user data comes from 'user' field
user_data = payload.get('user', {})
return self._extract_and_validate(
payload=payload,
user_data=user_data,
event_type=JiraEventType.LABELED_TICKET,
webhook_event=webhook_event,
comment_body='',
)
def _parse_comment_event(
self, payload: dict, webhook_event: str
) -> JiraPayloadParseResult:
"""Parse a comment_created event."""
comment_data = payload.get('comment', {})
comment_body = comment_data.get('body', '')
if not self._has_mention(comment_body):
return JiraPayloadSkipped(
f"Comment does not mention '{self.inline_oh_label}'"
)
# For comment events, user data comes from 'comment.author'
user_data = comment_data.get('author', {})
return self._extract_and_validate(
payload=payload,
user_data=user_data,
event_type=JiraEventType.COMMENT_MENTION,
webhook_event=webhook_event,
comment_body=comment_body,
)
def _has_mention(self, text: str) -> bool:
"""Check if text contains an exact mention of OpenHands."""
from integrations.utils import has_exact_mention
return has_exact_mention(text, self.inline_oh_label)
def _extract_and_validate(
self,
payload: dict,
user_data: dict,
event_type: JiraEventType,
webhook_event: str,
comment_body: str,
) -> JiraPayloadParseResult:
"""Extract common fields and validate required data is present."""
issue_data = payload.get('issue', {})
# Extract all fields with empty string defaults (makes them str type)
issue_id = issue_data.get('id', '')
issue_key = issue_data.get('key', '')
user_email = user_data.get('emailAddress', '')
display_name = user_data.get('displayName', '')
account_id = user_data.get('accountId', '')
base_api_url, workspace_name = self._extract_workspace_from_url(
issue_data.get('self', '')
)
# Validate required fields
missing: list[str] = []
if not issue_id:
missing.append('issue.id')
if not issue_key:
missing.append('issue.key')
if not user_email:
missing.append('user.emailAddress')
if not display_name:
missing.append('user.displayName')
if not account_id:
missing.append('user.accountId')
if not workspace_name:
missing.append('workspace_name (derived from issue.self)')
if not base_api_url:
missing.append('base_api_url (derived from issue.self)')
if missing:
return JiraPayloadError(f"Missing required fields: {', '.join(missing)}")
return JiraPayloadSuccess(
JiraWebhookPayload(
event_type=event_type,
raw_event=webhook_event,
issue_id=issue_id,
issue_key=issue_key,
user_email=user_email,
display_name=display_name,
account_id=account_id,
workspace_name=workspace_name,
base_api_url=base_api_url,
comment_body=comment_body,
)
)
def _extract_workspace_from_url(self, self_url: str) -> tuple[str, str]:
"""Extract base API URL and workspace name from issue self URL.
Args:
self_url: The 'self' URL from the issue data
Returns:
Tuple of (base_api_url, workspace_name)
"""
if not self_url:
return '', ''
# Extract base URL (everything before /rest/)
if '/rest/' in self_url:
base_api_url = self_url.split('/rest/')[0]
else:
parsed = urlparse(self_url)
base_api_url = f'{parsed.scheme}://{parsed.netloc}'
# Extract workspace name (hostname)
parsed = urlparse(base_api_url)
workspace_name = parsed.hostname or ''
return base_api_url, workspace_name

View File

@@ -1,42 +1,26 @@
"""Type definitions and interfaces for Jira integration."""
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING
from integrations.models import JobContext
from jinja2 import Environment
from storage.jira_user import JiraUser
from storage.jira_workspace import JiraWorkspace
from openhands.server.user_auth.user_auth import UserAuth
if TYPE_CHECKING:
from integrations.jira.jira_payload import JiraWebhookPayload
class JiraViewInterface(ABC):
"""Interface for Jira views that handle different types of Jira interactions.
"""Interface for Jira views that handle different types of Jira interactions."""
Views hold the webhook payload directly rather than duplicating fields,
and fetch issue details lazily when needed.
"""
# Core data - view holds these references
payload: 'JiraWebhookPayload'
job_context: JobContext
saas_user_auth: UserAuth
jira_user: JiraUser
jira_workspace: JiraWorkspace
# Mutable state set during processing
selected_repo: str | None
conversation_id: str
@abstractmethod
async def get_issue_details(self) -> tuple[str, str]:
"""Fetch and cache issue title and description from Jira API.
Returns:
Tuple of (issue_title, issue_description)
"""
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
"""Get initial instructions for the conversation."""
pass
@abstractmethod
@@ -51,21 +35,6 @@ class JiraViewInterface(ABC):
class StartingConvoException(Exception):
"""Exception raised when starting a conversation fails.
This provides user-friendly error messages that can be sent back to Jira.
"""
pass
class RepositoryNotFoundError(Exception):
"""Raised when a repository cannot be determined from the issue.
This is a separate error domain from StartingConvoException - it represents
a precondition failure (no repo configured/found) rather than a conversation
creation failure. The manager catches this and converts it to a user-friendly
message.
"""
"""Exception raised when starting a conversation fails."""
pass

View File

@@ -1,21 +1,8 @@
"""Jira view implementations and factory.
from dataclasses import dataclass
Views are responsible for:
- Holding the webhook payload and auth context
- Lazy-loading issue details from Jira API when needed
- Creating conversations with the selected repository
"""
from dataclasses import dataclass, field
import httpx
from integrations.jira.jira_payload import JiraWebhookPayload
from integrations.jira.jira_types import (
JiraViewInterface,
RepositoryNotFoundError,
StartingConvoException,
)
from integrations.utils import CONVERSATION_URL, infer_repo_from_message
from integrations.jira.jira_types import JiraViewInterface, StartingConvoException
from integrations.models import JobContext
from integrations.utils import CONVERSATION_URL, get_final_agent_observation
from jinja2 import Environment
from storage.jira_conversation import JiraConversation
from storage.jira_integration_store import JiraIntegrationStore
@@ -23,147 +10,55 @@ from storage.jira_user import JiraUser
from storage.jira_workspace import JiraWorkspace
from openhands.core.logger import openhands_logger as logger
from openhands.integrations.provider import ProviderHandler
from openhands.server.services.conversation_service import create_new_conversation
from openhands.core.schema.agent import AgentState
from openhands.events.action import MessageAction
from openhands.events.serialization.event import event_to_dict
from openhands.server.services.conversation_service import (
create_new_conversation,
setup_init_conversation_settings,
)
from openhands.server.shared import ConversationStoreImpl, config, conversation_manager
from openhands.server.user_auth.user_auth import UserAuth
from openhands.storage.data_models.conversation_metadata import ConversationTrigger
from openhands.utils.http_session import httpx_verify_option
JIRA_CLOUD_API_URL = 'https://api.atlassian.com/ex/jira'
integration_store = JiraIntegrationStore.get_instance()
@dataclass
class JiraNewConversationView(JiraViewInterface):
"""View for creating a new Jira conversation.
This view holds the webhook payload directly and lazily fetches
issue details when needed for rendering templates.
"""
payload: JiraWebhookPayload
job_context: JobContext
saas_user_auth: UserAuth
jira_user: JiraUser
jira_workspace: JiraWorkspace
selected_repo: str | None = None
conversation_id: str = ''
selected_repo: str | None
conversation_id: str
# Lazy-loaded issue details (cached after first fetch)
_issue_title: str | None = field(default=None, repr=False)
_issue_description: str | None = field(default=None, repr=False)
# Decrypted API key (set by factory)
_decrypted_api_key: str = field(default='', repr=False)
async def get_issue_details(self) -> tuple[str, str]:
"""Fetch issue details from Jira API (cached after first call).
Returns:
Tuple of (issue_title, issue_description)
Raises:
StartingConvoException: If issue details cannot be fetched
"""
if self._issue_title is not None and self._issue_description is not None:
return self._issue_title, self._issue_description
try:
url = f'{JIRA_CLOUD_API_URL}/{self.jira_workspace.jira_cloud_id}/rest/api/2/issue/{self.payload.issue_key}'
async with httpx.AsyncClient(verify=httpx_verify_option()) as client:
response = await client.get(
url,
auth=(
self.jira_workspace.svc_acc_email,
self._decrypted_api_key,
),
)
response.raise_for_status()
issue_payload = response.json()
if not issue_payload:
raise StartingConvoException(
f'Issue {self.payload.issue_key} not found.'
)
self._issue_title = issue_payload.get('fields', {}).get('summary', '')
self._issue_description = (
issue_payload.get('fields', {}).get('description', '') or ''
)
if not self._issue_title:
raise StartingConvoException(
f'Issue {self.payload.issue_key} does not have a title.'
)
logger.info(
'[Jira] Fetched issue details',
extra={
'issue_key': self.payload.issue_key,
'has_description': bool(self._issue_description),
},
)
return self._issue_title, self._issue_description
except httpx.HTTPStatusError as e:
logger.error(
'[Jira] Failed to fetch issue details',
extra={
'issue_key': self.payload.issue_key,
'status': e.response.status_code,
},
)
raise StartingConvoException(
f'Failed to fetch issue details: HTTP {e.response.status_code}'
)
except Exception as e:
if isinstance(e, StartingConvoException):
raise
logger.error(
'[Jira] Failed to fetch issue details',
extra={'issue_key': self.payload.issue_key, 'error': str(e)},
)
raise StartingConvoException(f'Failed to fetch issue details: {str(e)}')
async def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
"""Get instructions for the conversation.
This fetches issue details if not already cached.
Returns:
Tuple of (system_instructions, user_message)
"""
issue_title, issue_description = await self.get_issue_details()
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
"""Instructions passed when conversation is first initialized"""
instructions_template = jinja_env.get_template('jira_instructions.j2')
instructions = instructions_template.render()
user_msg_template = jinja_env.get_template('jira_new_conversation.j2')
user_msg = user_msg_template.render(
issue_key=self.payload.issue_key,
issue_title=issue_title,
issue_description=issue_description,
user_message=self.payload.user_msg,
issue_key=self.job_context.issue_key,
issue_title=self.job_context.issue_title,
issue_description=self.job_context.issue_description,
user_message=self.job_context.user_msg or '',
)
return instructions, user_msg
async def create_or_update_conversation(self, jinja_env: Environment) -> str:
"""Create a new Jira conversation.
"""Create a new Jira conversation"""
Returns:
The conversation ID
Raises:
StartingConvoException: If conversation creation fails
"""
if not self.selected_repo:
raise StartingConvoException('No repository selected for this conversation')
provider_tokens = await self.saas_user_auth.get_provider_tokens()
user_secrets = await self.saas_user_auth.get_secrets()
instructions, user_msg = await self._get_instructions(jinja_env)
instructions, user_msg = self._get_instructions(jinja_env)
try:
agent_loop_info = await create_new_conversation(
@@ -181,259 +76,149 @@ class JiraNewConversationView(JiraViewInterface):
self.conversation_id = agent_loop_info.conversation_id
logger.info(
'[Jira] Created conversation',
extra={
'conversation_id': self.conversation_id,
'issue_key': self.payload.issue_key,
'selected_repo': self.selected_repo,
},
)
logger.info(f'[Jira] Created conversation {self.conversation_id}')
# Store Jira conversation mapping
jira_conversation = JiraConversation(
conversation_id=self.conversation_id,
issue_id=self.payload.issue_id,
issue_key=self.payload.issue_key,
issue_id=self.job_context.issue_id,
issue_key=self.job_context.issue_key,
jira_user_id=self.jira_user.id,
)
await integration_store.create_conversation(jira_conversation)
return self.conversation_id
except Exception as e:
if isinstance(e, StartingConvoException):
raise
logger.error(
'[Jira] Failed to create conversation',
extra={'issue_key': self.payload.issue_key, 'error': str(e)},
exc_info=True,
f'[Jira] Failed to create conversation: {str(e)}', exc_info=True
)
raise StartingConvoException(f'Failed to create conversation: {str(e)}')
def get_response_msg(self) -> str:
"""Get the response message to send back to Jira."""
"""Get the response message to send back to Jira"""
conversation_link = CONVERSATION_URL.format(self.conversation_id)
return f"I'm on it! {self.payload.display_name} can [track my progress here|{conversation_link}]."
return f"I'm on it! {self.job_context.display_name} can [track my progress here|{conversation_link}]."
@dataclass
class JiraExistingConversationView(JiraViewInterface):
job_context: JobContext
saas_user_auth: UserAuth
jira_user: JiraUser
jira_workspace: JiraWorkspace
selected_repo: str | None
conversation_id: str
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
"""Instructions passed when conversation is first initialized"""
user_msg_template = jinja_env.get_template('jira_existing_conversation.j2')
user_msg = user_msg_template.render(
issue_key=self.job_context.issue_key,
user_message=self.job_context.user_msg or '',
issue_title=self.job_context.issue_title,
issue_description=self.job_context.issue_description,
)
return '', user_msg
async def create_or_update_conversation(self, jinja_env: Environment) -> str:
"""Update an existing Jira conversation"""
user_id = self.jira_user.keycloak_user_id
try:
conversation_store = await ConversationStoreImpl.get_instance(
config, user_id
)
try:
await conversation_store.get_metadata(self.conversation_id)
except FileNotFoundError:
raise StartingConvoException('Conversation no longer exists.')
provider_tokens = await self.saas_user_auth.get_provider_tokens()
# Should we raise here if there are no providers?
providers_set = list(provider_tokens.keys()) if provider_tokens else []
conversation_init_data = await setup_init_conversation_settings(
user_id, self.conversation_id, providers_set
)
# Either join ongoing conversation, or restart the conversation
agent_loop_info = await conversation_manager.maybe_start_agent_loop(
self.conversation_id, conversation_init_data, user_id
)
final_agent_observation = get_final_agent_observation(
agent_loop_info.event_store
)
agent_state = (
None
if len(final_agent_observation) == 0
else final_agent_observation[0].agent_state
)
if not agent_state or agent_state == AgentState.LOADING:
raise StartingConvoException('Conversation is still starting')
_, user_msg = self._get_instructions(jinja_env)
user_message_event = MessageAction(content=user_msg)
await conversation_manager.send_event_to_conversation(
self.conversation_id, event_to_dict(user_message_event)
)
return self.conversation_id
except Exception as e:
logger.error(
f'[Jira] Failed to create conversation: {str(e)}', exc_info=True
)
raise StartingConvoException(f'Failed to create conversation: {str(e)}')
def get_response_msg(self) -> str:
"""Get the response message to send back to Jira"""
conversation_link = CONVERSATION_URL.format(self.conversation_id)
return f"I'm on it! {self.job_context.display_name} can [continue tracking my progress here|{conversation_link}]."
class JiraFactory:
"""Factory for creating Jira views.
The factory is responsible for:
- Creating the appropriate view type
- Inferring and selecting the repository
- Validating all required data is available
Repository selection happens here so that view creation either
succeeds with a valid repo or fails with a clear error.
"""
"""Factory for creating Jira views based on message content"""
@staticmethod
async def _create_provider_handler(user_auth: UserAuth) -> ProviderHandler | None:
"""Create a ProviderHandler for the user."""
provider_tokens = await user_auth.get_provider_tokens()
if provider_tokens is None:
return None
access_token = await user_auth.get_access_token()
user_id = await user_auth.get_user_id()
return ProviderHandler(
provider_tokens=provider_tokens,
external_auth_token=access_token,
external_auth_id=user_id,
)
@staticmethod
def _extract_potential_repos(
issue_key: str,
issue_title: str,
issue_description: str,
user_msg: str,
) -> list[str]:
"""Extract potential repository names from issue content.
Raises:
RepositoryNotFoundError: If no potential repos found in text.
"""
search_text = f'{issue_title}\n{issue_description}\n{user_msg}'
potential_repos = infer_repo_from_message(search_text)
if not potential_repos:
raise RepositoryNotFoundError(
'Could not determine which repository to use. '
'Please mention the repository (e.g., owner/repo) in the issue description or comment.'
)
logger.info(
'[Jira] Found potential repositories in issue content',
extra={'issue_key': issue_key, 'potential_repos': potential_repos},
)
return potential_repos
@staticmethod
async def _verify_repos(
issue_key: str,
potential_repos: list[str],
provider_handler: ProviderHandler,
) -> list[str]:
"""Verify which repos the user has access to."""
verified_repos: list[str] = []
for repo_name in potential_repos:
try:
repository = await provider_handler.verify_repo_provider(repo_name)
verified_repos.append(repository.full_name)
logger.debug(
'[Jira] Repository verification succeeded',
extra={'issue_key': issue_key, 'repository': repository.full_name},
)
except Exception as e:
logger.debug(
'[Jira] Repository verification failed',
extra={
'issue_key': issue_key,
'repo_name': repo_name,
'error': str(e),
},
)
return verified_repos
@staticmethod
def _select_single_repo(
issue_key: str,
potential_repos: list[str],
verified_repos: list[str],
) -> str:
"""Select exactly one repo from verified repos.
Raises:
RepositoryNotFoundError: If zero or multiple repos verified.
"""
if len(verified_repos) == 0:
raise RepositoryNotFoundError(
f'Could not access any of the mentioned repositories: {", ".join(potential_repos)}. '
'Please ensure you have access to the repository and it exists.'
)
if len(verified_repos) > 1:
raise RepositoryNotFoundError(
f'Multiple repositories found: {", ".join(verified_repos)}. '
'Please specify exactly one repository in the issue description or comment.'
)
logger.info(
'[Jira] Verified repository access',
extra={'issue_key': issue_key, 'repository': verified_repos[0]},
)
return verified_repos[0]
@staticmethod
async def _infer_repository(
payload: JiraWebhookPayload,
user_auth: UserAuth,
issue_title: str,
issue_description: str,
) -> str:
"""Infer and verify the repository from issue content.
Raises:
RepositoryNotFoundError: If no valid repository can be determined.
"""
provider_handler = await JiraFactory._create_provider_handler(user_auth)
if not provider_handler:
raise RepositoryNotFoundError(
'No Git provider connected. Please connect a Git provider in OpenHands settings.'
)
potential_repos = JiraFactory._extract_potential_repos(
payload.issue_key, issue_title, issue_description, payload.user_msg
)
verified_repos = await JiraFactory._verify_repos(
payload.issue_key, potential_repos, provider_handler
)
return JiraFactory._select_single_repo(
payload.issue_key, potential_repos, verified_repos
)
@staticmethod
async def create_view(
payload: JiraWebhookPayload,
workspace: JiraWorkspace,
user: JiraUser,
user_auth: UserAuth,
decrypted_api_key: str,
async def create_jira_view_from_payload(
job_context: JobContext,
saas_user_auth: UserAuth,
jira_user: JiraUser,
jira_workspace: JiraWorkspace,
) -> JiraViewInterface:
"""Create a Jira view with repository already selected.
"""Create appropriate Jira view based on the message and user state"""
This factory method:
1. Creates the view with payload and auth context
2. Fetches issue details (needed for repo inference)
3. Infers and selects the repository
if not jira_user or not saas_user_auth or not jira_workspace:
raise StartingConvoException('User not authenticated with Jira integration')
If any step fails, an appropriate exception is raised with
a user-friendly message.
Args:
payload: Parsed webhook payload
workspace: The Jira workspace
user: The Jira user
user_auth: OpenHands user authentication
decrypted_api_key: Decrypted service account API key
Returns:
A JiraViewInterface with selected_repo populated
Raises:
StartingConvoException: If view creation fails
RepositoryNotFoundError: If repository cannot be determined
"""
logger.info(
'[Jira] Creating view',
extra={
'issue_key': payload.issue_key,
'event_type': payload.event_type.value,
},
conversation = await integration_store.get_user_conversations_by_issue_id(
job_context.issue_id, jira_user.id
)
# Create the view
view = JiraNewConversationView(
payload=payload,
saas_user_auth=user_auth,
jira_user=user,
jira_workspace=workspace,
_decrypted_api_key=decrypted_api_key,
if conversation:
logger.info(
f'[Jira] Found existing conversation for issue {job_context.issue_id}'
)
return JiraExistingConversationView(
job_context=job_context,
saas_user_auth=saas_user_auth,
jira_user=jira_user,
jira_workspace=jira_workspace,
selected_repo=None,
conversation_id=conversation.conversation_id,
)
return JiraNewConversationView(
job_context=job_context,
saas_user_auth=saas_user_auth,
jira_user=jira_user,
jira_workspace=jira_workspace,
selected_repo=None, # Will be set later after repo inference
conversation_id='', # Will be set when conversation is created
)
# Fetch issue details (needed for repo inference)
try:
issue_title, issue_description = await view.get_issue_details()
except StartingConvoException:
raise # Re-raise with original message
except Exception as e:
raise StartingConvoException(f'Failed to fetch issue details: {str(e)}')
# Infer and select repository
selected_repo = await JiraFactory._infer_repository(
payload=payload,
user_auth=user_auth,
issue_title=issue_title,
issue_description=issue_description,
)
view.selected_repo = selected_repo
logger.info(
'[Jira] View created successfully',
extra={
'issue_key': payload.issue_key,
'selected_repo': selected_repo,
},
)
return view

View File

@@ -16,6 +16,11 @@ class Manager(ABC):
"Send message to integration from Openhands server"
raise NotImplementedError
@abstractmethod
async def is_job_requested(self, message: Message) -> bool:
"Confirm that a job is being requested"
raise NotImplementedError
@abstractmethod
def start_job(self):
"Kick off a job with openhands agent"

View File

@@ -29,20 +29,17 @@ class ResolverUserContext(UserContext):
return UserInfo(id=user_id)
async def get_authenticated_git_url(
self, repository: str, is_optional: bool = False
) -> str:
async def get_authenticated_git_url(self, repository: str) -> str:
# This would need to be implemented based on the git provider tokens
# For now, return a basic HTTPS URL
return f'https://github.com/{repository}.git'
async def get_latest_token(self, provider_type: ProviderType) -> str | None:
# Return the appropriate token string from git_provider_tokens
# Return the appropriate token from git_provider_tokens
provider_tokens = await self.saas_user_auth.get_provider_tokens()
if provider_tokens:
provider_token = provider_tokens.get(provider_type)
if provider_token and provider_token.token:
return provider_token.token.get_secret_value()
return provider_tokens.get(provider_type)
return None
async def get_provider_tokens(self) -> PROVIDER_TOKEN_TYPE | None:

View File

@@ -189,7 +189,6 @@ class SlackNewConversationView(SlackViewInterface):
'channel_id': self.channel_id,
'conversation_id': self.conversation_id,
'keycloak_user_id': user_info.keycloak_user_id,
'org_id': user_info.org_id,
'parent_id': self.thread_ts or self.message_ts,
'v1_enabled': v1_enabled,
},
@@ -198,7 +197,6 @@ class SlackNewConversationView(SlackViewInterface):
conversation_id=self.conversation_id,
channel_id=self.channel_id,
keycloak_user_id=user_info.keycloak_user_id,
org_id=user_info.org_id,
parent_id=self.thread_ts
or self.message_ts, # conversations can start in a thread reply as well; we should always references the parent's (root level msg's) message ID
v1_enabled=v1_enabled,
@@ -401,10 +399,10 @@ class SlackUpdateExistingConversationView(SlackNewConversationView):
if not agent_state or agent_state == AgentState.LOADING:
raise StartingConvoException('Conversation is still starting')
instructions, _ = self._get_instructions(jinja)
user_msg = MessageAction(content=instructions)
user_msg, _ = self._get_instructions(jinja)
user_msg_action = MessageAction(content=user_msg)
await conversation_manager.send_event_to_conversation(
self.conversation_id, event_to_dict(user_msg)
self.conversation_id, event_to_dict(user_msg_action)
)
async def send_message_to_v1_conversation(self, jinja: Environment):

View File

@@ -1,53 +0,0 @@
from storage.repository_store import RepositoryStore
from storage.stored_repository import StoredRepository
from storage.user_repo_map import UserRepositoryMap
from storage.user_repo_map_store import UserRepositoryMapStore
from openhands.core.config.openhands_config import OpenHandsConfig
from openhands.core.logger import openhands_logger as logger
from openhands.integrations.service_types import Repository
async def store_repositories_in_db(repos: list[Repository], user_id: str) -> None:
"""
Store repositories in DB and create user-repository mappings
Args:
repos: List of Repository objects to store
user_id: User ID associated with these repositories
"""
# Convert Repository objects to StoredRepository objects
# Convert Repository objects to UserRepositoryMap objects
stored_repos = []
user_repos = []
for repo in repos:
repo_id = f'{repo.git_provider.value}##{str(repo.id)}'
stored_repo = StoredRepository(
repo_name=repo.full_name,
repo_id=repo_id,
is_public=repo.is_public,
# Optional fields set to None by default
has_microagent=None,
has_setup_script=None,
)
stored_repos.append(stored_repo)
user_repo_map = UserRepositoryMap(user_id=user_id, repo_id=repo_id, admin=None)
user_repos.append(user_repo_map)
# Get config instance
config = OpenHandsConfig()
try:
# Store repositories in the repos table
repo_store = RepositoryStore.get_instance(config)
repo_store.store_projects(stored_repos)
# Store user-repository mappings in the user-repos table
user_repo_store = UserRepositoryMapStore.get_instance(config)
user_repo_store.store_user_repo_mappings(user_repos)
logger.info(f'Saved repos for user {user_id}')
except Exception:
logger.warning('Failed to save repos', exc_info=True)

View File

@@ -1,24 +1,19 @@
from uuid import UUID
import stripe
from server.auth.token_manager import TokenManager
from server.constants import STRIPE_API_KEY
from server.logger import logger
from sqlalchemy.orm import Session
from storage.database import session_maker
from storage.org import Org
from storage.org_store import OrgStore
from storage.stripe_customer import StripeCustomer
from openhands.utils.async_utils import call_sync_from_async
stripe.api_key = STRIPE_API_KEY
async def find_customer_id_by_org_id(org_id: UUID) -> str | None:
async def find_customer_id_by_user_id(user_id: str) -> str | None:
# First search our own DB...
with session_maker() as session:
stripe_customer = (
session.query(StripeCustomer)
.filter(StripeCustomer.org_id == org_id)
.filter(StripeCustomer.keycloak_user_id == user_id)
.first()
)
if stripe_customer:
@@ -26,76 +21,46 @@ async def find_customer_id_by_org_id(org_id: UUID) -> str | None:
# If that fails, fallback to stripe
search_result = await stripe.Customer.search_async(
query=f"metadata['org_id']:'{str(org_id)}'",
query=f"metadata['user_id']:'{user_id}'",
)
data = search_result.data
if not data:
logger.info(
'no_customer_for_org_id',
extra={'org_id': str(org_id)},
)
logger.info('no_customer_for_user_id', extra={'user_id': user_id})
return None
return data[0].id # type: ignore [attr-defined]
async def find_customer_id_by_user_id(user_id: str) -> str | None:
# First search our own DB...
org = await call_sync_from_async(
OrgStore.get_current_org_from_keycloak_user_id, user_id
)
if not org:
logger.warning(f'Org not found for user {user_id}')
return None
customer_id = await find_customer_id_by_org_id(org.id)
return customer_id
async def find_or_create_customer_by_user_id(user_id: str) -> dict | None:
# Get the current org for the user
org = await call_sync_from_async(
OrgStore.get_current_org_from_keycloak_user_id, user_id
)
if not org:
logger.warning(f'Org not found for user {user_id}')
return None
customer_id = await find_customer_id_by_org_id(org.id)
async def find_or_create_customer(user_id: str) -> str:
customer_id = await find_customer_id_by_user_id(user_id)
if customer_id:
return {'customer_id': customer_id, 'org_id': str(org.id)}
logger.info(
'creating_customer',
extra={'user_id': user_id, 'org_id': str(org.id)},
)
return customer_id
logger.info('creating_customer', extra={'user_id': user_id})
# Get the user info from keycloak
token_manager = TokenManager()
user_info = await token_manager.get_user_info_from_user_id(user_id) or {}
# Create the customer in stripe
customer = await stripe.Customer.create_async(
email=org.contact_email,
metadata={'org_id': str(org.id)},
email=str(user_info.get('email', '')),
metadata={'user_id': user_id},
)
# Save the stripe customer in the local db
with session_maker() as session:
session.add(
StripeCustomer(
keycloak_user_id=user_id,
org_id=org.id,
stripe_customer_id=customer.id,
)
StripeCustomer(keycloak_user_id=user_id, stripe_customer_id=customer.id)
)
session.commit()
logger.info(
'created_customer',
extra={
'user_id': user_id,
'org_id': str(org.id),
'stripe_customer_id': customer.id,
},
extra={'user_id': user_id, 'stripe_customer_id': customer.id},
)
return {'customer_id': customer.id, 'org_id': str(org.id)}
return customer.id
async def has_payment_method_by_user_id(user_id: str) -> bool:
async def has_payment_method(user_id: str) -> bool:
customer_id = await find_customer_id_by_user_id(user_id)
if customer_id is None:
return False
@@ -106,28 +71,3 @@ async def has_payment_method_by_user_id(user_id: str) -> bool:
f'has_payment_method:{user_id}:{customer_id}:{bool(payment_methods.data)}'
)
return bool(payment_methods.data)
async def migrate_customer(session: Session, user_id: str, org: Org):
stripe_customer = (
session.query(StripeCustomer)
.filter(StripeCustomer.keycloak_user_id == user_id)
.first()
)
if stripe_customer is None:
return
stripe_customer.org_id = org.id
customer = await stripe.Customer.modify_async(
id=stripe_customer.stripe_customer_id,
email=org.contact_email,
metadata={'user_id': '', 'org_id': str(org.id)},
)
logger.info(
'migrated_customer',
extra={
'user_id': user_id,
'org_id': str(org.id),
'stripe_customer_id': customer.id,
},
)

View File

@@ -6,9 +6,15 @@ import re
from typing import TYPE_CHECKING
from jinja2 import Environment, FileSystemLoader
from server.config import get_config
from server.constants import WEB_HOST
from storage.org_store import OrgStore
from storage.database import session_maker
from storage.repository_store import RepositoryStore
from storage.stored_repository import StoredRepository
from storage.user_repo_map import UserRepositoryMap
from storage.user_repo_map_store import UserRepositoryMapStore
from openhands.core.config.openhands_config import OpenHandsConfig
from openhands.core.logger import openhands_logger as logger
from openhands.core.schema.agent import AgentState
from openhands.events import Event, EventSource
@@ -119,17 +125,26 @@ async def get_user_v1_enabled_setting(user_id: str | None) -> bool:
Returns:
True if V1 conversations are enabled for this user, False otherwise
"""
# If no user ID is provided, we can't check user settings
if not user_id:
return False
org = await call_sync_from_async(
OrgStore.get_current_org_from_keycloak_user_id, user_id
from storage.saas_settings_store import SaasSettingsStore
config = get_config()
settings_store = SaasSettingsStore(
user_id=user_id, session_maker=session_maker, config=config
)
if not org or org.v1_enabled is None:
settings = await call_sync_from_async(
settings_store.get_user_settings_by_keycloak_id, user_id
)
if not settings or settings.v1_enabled is None:
return False
return org.v1_enabled
return settings.v1_enabled
def has_exact_mention(text: str, mention: str) -> bool:
@@ -394,46 +409,102 @@ def append_conversation_footer(message: str, conversation_id: str) -> str:
return message + footer
async def store_repositories_in_db(repos: list[Repository], user_id: str) -> None:
"""
Store repositories in DB and create user-repository mappings
Args:
repos: List of Repository objects to store
user_id: User ID associated with these repositories
"""
# Convert Repository objects to StoredRepository objects
# Convert Repository objects to UserRepositoryMap objects
stored_repos = []
user_repos = []
for repo in repos:
repo_id = f'{repo.git_provider.value}##{str(repo.id)}'
stored_repo = StoredRepository(
repo_name=repo.full_name,
repo_id=repo_id,
is_public=repo.is_public,
# Optional fields set to None by default
has_microagent=None,
has_setup_script=None,
)
stored_repos.append(stored_repo)
user_repo_map = UserRepositoryMap(user_id=user_id, repo_id=repo_id, admin=None)
user_repos.append(user_repo_map)
# Get config instance
config = OpenHandsConfig()
try:
# Store repositories in the repos table
repo_store = RepositoryStore.get_instance(config)
repo_store.store_projects(stored_repos)
# Store user-repository mappings in the user-repos table
user_repo_store = UserRepositoryMapStore.get_instance(config)
user_repo_store.store_user_repo_mappings(user_repos)
logger.info(f'Saved repos for user {user_id}')
except Exception:
logger.warning('Failed to save repos', exc_info=True)
def infer_repo_from_message(user_msg: str) -> list[str]:
"""
Extract all repository names in the format 'owner/repo' from various Git provider URLs
and direct mentions in text. Supports GitHub, GitLab, and BitBucket.
Args:
user_msg: Input message that may contain repository references
Returns:
List of repository names in 'owner/repo' format, empty list if none found
"""
# Normalize the message by removing extra whitespace and newlines
normalized_msg = re.sub(r'\s+', ' ', user_msg.strip())
git_url_pattern = (
r'https?://(?:github\.com|gitlab\.com|bitbucket\.org)/'
r'([a-zA-Z0-9_.-]+)/([a-zA-Z0-9_.-]+?)(?:\.git)?'
r'(?:[/?#].*?)?(?=\s|$|[^\w.-])'
)
# Pattern to match Git URLs from GitHub, GitLab, and BitBucket
# Captures: protocol, domain, owner, repo (with optional .git extension)
git_url_pattern = r'https?://(?:github\.com|gitlab\.com|bitbucket\.org)/([a-zA-Z0-9_.-]+)/([a-zA-Z0-9_.-]+?)(?:\.git)?(?:[/?#].*?)?(?=\s|$|[^\w.-])'
# UPDATED: allow {{ owner/repo }} in addition to existing boundaries
# Pattern to match direct owner/repo mentions (e.g., "OpenHands/OpenHands")
# Must be surrounded by word boundaries or specific characters to avoid false positives
direct_pattern = (
r'(?:^|\s|{{|[\[\(\'":`])' # left boundary
r'([a-zA-Z0-9_.-]+)/([a-zA-Z0-9_.-]+)'
r'(?=\s|$|}}|[\]\)\'",.:`])' # right boundary
r'(?:^|\s|[\[\(\'"])([a-zA-Z0-9_.-]+)/([a-zA-Z0-9_.-]+)(?=\s|$|[\]\)\'",.])'
)
matches: list[str] = []
matches = []
# Git URLs first (highest priority)
for owner, repo in re.findall(git_url_pattern, normalized_msg):
# First, find all Git URLs (highest priority)
git_matches = re.findall(git_url_pattern, normalized_msg)
for owner, repo in git_matches:
# Remove .git extension if present
repo = re.sub(r'\.git$', '', repo)
matches.append(f'{owner}/{repo}')
# Direct mentions
for owner, repo in re.findall(direct_pattern, normalized_msg):
# Second, find all direct owner/repo mentions
direct_matches = re.findall(direct_pattern, normalized_msg)
for owner, repo in direct_matches:
full_match = f'{owner}/{repo}'
# Skip if it looks like a version number, date, or file path
if (
re.match(r'^\d+\.\d+/\d+\.\d+$', full_match)
or re.match(r'^\d{1,2}/\d{1,2}$', full_match)
or re.match(r'^[A-Z]/[A-Z]$', full_match)
or repo.endswith(('.txt', '.md', '.py', '.js'))
or ('.' in repo and len(repo.split('.')) > 2)
):
re.match(r'^\d+\.\d+/\d+\.\d+$', full_match) # version numbers
or re.match(r'^\d{1,2}/\d{1,2}$', full_match) # dates
or re.match(r'^[A-Z]/[A-Z]$', full_match) # single letters
or repo.endswith('.txt')
or repo.endswith('.md') # file extensions
or repo.endswith('.py')
or repo.endswith('.js')
or '.' in repo
and len(repo.split('.')) > 2
): # complex file paths
continue
# Avoid duplicates from Git URLs already found
if full_match not in matches:
matches.append(full_match)

View File

@@ -20,8 +20,6 @@ down_revision = '059'
branch_labels = None
depends_on = None
# TODO: decide whether to modify this for orgs or users
def upgrade():
"""
@@ -30,10 +28,8 @@ def upgrade():
This replaces the functionality of the removed admin maintenance endpoint.
"""
# Hardcoded value to prevent migration failures when constant is removed from codebase
# This migration has already run in production, so we use the value that was current at the time
CURRENT_USER_SETTINGS_VERSION = 4
# Import here to avoid circular imports
from server.constants import CURRENT_USER_SETTINGS_VERSION
# Create a connection and bind it to a session
connection = op.get_bind()

View File

@@ -1,272 +0,0 @@
"""create org tables from pgerd schema
Revision ID: 089
Revises: 088
Create Date: 2025-01-07 00:00:00.000000
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision: str = '089'
down_revision: Union[str, None] = '088'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# Remove current settings table
op.execute('DROP TABLE IF EXISTS settings')
# Add already_migrated column to user_settings table
op.add_column(
'user_settings',
sa.Column(
'already_migrated',
sa.Boolean,
nullable=True,
server_default=sa.text('false'),
),
)
# Create role table
op.create_table(
'role',
sa.Column('id', sa.Integer, sa.Identity(), primary_key=True),
sa.Column('name', sa.String, nullable=False),
sa.Column('rank', sa.Integer, nullable=False),
sa.UniqueConstraint('name', name='role_name_unique'),
)
# 1. Create default roles
op.execute(
sa.text("""
INSERT INTO role (name, rank) VALUES ('owner', 10), ('admin', 20), ('user', 1000)
ON CONFLICT (name) DO NOTHING;
""")
)
# Create org table with settings fields
op.create_table(
'org',
sa.Column(
'id',
postgresql.UUID(as_uuid=True),
primary_key=True,
),
sa.Column('name', sa.String, nullable=False),
sa.Column('contact_name', sa.String, nullable=True),
sa.Column('contact_email', sa.String, nullable=True),
sa.Column('conversation_expiration', sa.Integer, nullable=True),
# Settings fields moved to org table
sa.Column('agent', sa.String, nullable=True),
sa.Column('default_max_iterations', sa.Integer, nullable=True),
sa.Column('security_analyzer', sa.String, nullable=True),
sa.Column(
'confirmation_mode',
sa.Boolean,
nullable=True,
server_default=sa.text('false'),
),
sa.Column('default_llm_model', sa.String, nullable=True),
sa.Column('default_llm_base_url', sa.String, nullable=True),
sa.Column('remote_runtime_resource_factor', sa.Integer, nullable=True),
sa.Column(
'enable_default_condenser',
sa.Boolean,
nullable=False,
server_default=sa.text('true'),
),
sa.Column('billing_margin', sa.Float, nullable=True),
sa.Column(
'enable_proactive_conversation_starters',
sa.Boolean,
nullable=False,
server_default=sa.text('true'),
),
sa.Column('sandbox_base_container_image', sa.String, nullable=True),
sa.Column('sandbox_runtime_container_image', sa.String, nullable=True),
sa.Column(
'org_version', sa.Integer, nullable=False, server_default=sa.text('0')
),
sa.Column('mcp_config', sa.JSON, nullable=True),
sa.Column('_search_api_key', sa.String, nullable=True),
sa.Column('_sandbox_api_key', sa.String, nullable=True),
sa.Column('max_budget_per_task', sa.Float, nullable=True),
sa.Column(
'enable_solvability_analysis',
sa.Boolean,
nullable=True,
server_default=sa.text('false'),
),
sa.Column('v1_enabled', sa.Boolean, nullable=True),
sa.Column('condenser_max_size', sa.Integer, nullable=True),
sa.UniqueConstraint('name', name='org_name_unique'),
)
# Create user table with user-specific settings fields
op.create_table(
'user',
sa.Column(
'id',
postgresql.UUID(as_uuid=True),
primary_key=True,
),
sa.Column('current_org_id', postgresql.UUID(as_uuid=True), nullable=False),
sa.Column('role_id', sa.Integer, nullable=True),
sa.Column('accepted_tos', sa.DateTime, nullable=True),
sa.Column(
'enable_sound_notifications',
sa.Boolean,
nullable=True,
server_default=sa.text('false'),
),
sa.Column('language', sa.String, nullable=True),
sa.Column('user_consents_to_analytics', sa.Boolean, nullable=True),
sa.Column('email', sa.String, nullable=True),
sa.Column('email_verified', sa.Boolean, nullable=True),
sa.ForeignKeyConstraint(
['current_org_id'], ['org.id'], name='current_org_fkey'
),
sa.ForeignKeyConstraint(['role_id'], ['role.id'], name='user_role_fkey'),
)
# Create org_member table (junction table for many-to-many relationship)
op.create_table(
'org_member',
sa.Column('org_id', postgresql.UUID(as_uuid=True), nullable=False),
sa.Column('user_id', postgresql.UUID(as_uuid=True), nullable=False),
sa.Column('role_id', sa.Integer, nullable=False),
sa.Column('_llm_api_key', sa.String, nullable=False),
sa.Column('max_iterations', sa.Integer, nullable=True),
sa.Column('llm_model', sa.String, nullable=True),
sa.Column('_llm_api_key_for_byor', sa.String, nullable=True),
sa.Column('llm_base_url', sa.String, nullable=True),
sa.Column('status', sa.String, nullable=True),
sa.ForeignKeyConstraint(['org_id'], ['org.id'], name='om_org_fkey'),
sa.ForeignKeyConstraint(['user_id'], ['user.id'], name='om_user_fkey'),
sa.ForeignKeyConstraint(['role_id'], ['role.id'], name='om_role_fkey'),
sa.PrimaryKeyConstraint('org_id', 'user_id'),
)
# Add org_id column to existing tables
# billing_sessions
op.add_column(
'billing_sessions',
sa.Column('org_id', postgresql.UUID(as_uuid=True), nullable=True),
)
op.create_foreign_key(
'billing_sessions_org_fkey', 'billing_sessions', 'org', ['org_id'], ['id']
)
# Create conversation_metadata_saas table
op.create_table(
'conversation_metadata_saas',
sa.Column('conversation_id', sa.String(), nullable=False),
sa.Column('user_id', postgresql.UUID(as_uuid=True), nullable=False),
sa.Column('org_id', postgresql.UUID(as_uuid=True), nullable=False),
sa.ForeignKeyConstraint(
['user_id'], ['user.id'], name='conversation_metadata_saas_user_fkey'
),
sa.ForeignKeyConstraint(
['org_id'], ['org.id'], name='conversation_metadata_saas_org_fkey'
),
sa.PrimaryKeyConstraint('conversation_id'),
)
# custom_secrets
op.add_column(
'custom_secrets',
sa.Column('org_id', postgresql.UUID(as_uuid=True), nullable=True),
)
op.create_foreign_key(
'custom_secrets_org_fkey', 'custom_secrets', 'org', ['org_id'], ['id']
)
# api_keys
op.add_column(
'api_keys', sa.Column('org_id', postgresql.UUID(as_uuid=True), nullable=True)
)
op.create_foreign_key('api_keys_org_fkey', 'api_keys', 'org', ['org_id'], ['id'])
# slack_conversation
op.add_column(
'slack_conversation',
sa.Column('org_id', postgresql.UUID(as_uuid=True), nullable=True),
)
op.create_foreign_key(
'slack_conversation_org_fkey', 'slack_conversation', 'org', ['org_id'], ['id']
)
# slack_users
op.add_column(
'slack_users', sa.Column('org_id', postgresql.UUID(as_uuid=True), nullable=True)
)
op.create_foreign_key(
'slack_users_org_fkey', 'slack_users', 'org', ['org_id'], ['id']
)
# stripe_customers
op.alter_column(
'stripe_customers',
'keycloak_user_id',
existing_type=sa.String(),
nullable=True,
)
op.add_column(
'stripe_customers',
sa.Column('org_id', postgresql.UUID(as_uuid=True), nullable=True),
)
op.create_foreign_key(
'stripe_customers_org_fkey', 'stripe_customers', 'org', ['org_id'], ['id']
)
def downgrade() -> None:
# Drop already_migrated column from user_settings table
op.drop_column('user_settings', 'already_migrated')
# Drop foreign keys and columns added to existing tables
op.drop_constraint(
'stripe_customers_org_fkey', 'stripe_customers', type_='foreignkey'
)
op.drop_column('stripe_customers', 'org_id')
op.alter_column(
'stripe_customers',
'keycloak_user_id',
existing_type=sa.String(),
nullable=False,
)
op.drop_constraint('slack_users_org_fkey', 'slack_users', type_='foreignkey')
op.drop_column('slack_users', 'org_id')
op.drop_constraint(
'slack_conversation_org_fkey', 'slack_conversation', type_='foreignkey'
)
op.drop_column('slack_conversation', 'org_id')
op.drop_constraint('api_keys_org_fkey', 'api_keys', type_='foreignkey')
op.drop_column('api_keys', 'org_id')
op.drop_constraint('custom_secrets_org_fkey', 'custom_secrets', type_='foreignkey')
op.drop_column('custom_secrets', 'org_id')
# Drop conversation_metadata_saas table
op.drop_table('conversation_metadata_saas')
op.drop_constraint(
'billing_sessions_org_fkey', 'billing_sessions', type_='foreignkey'
)
op.drop_column('billing_sessions', 'org_id')
# Drop tables in reverse order due to foreign key constraints
op.drop_table('org_member')
op.drop_table('user')
op.drop_table('org')
op.drop_table('role')

View File

@@ -1,28 +0,0 @@
"""Add git_user_name and git_user_email columns to user table.
Revision ID: 090
Revises: 089
Create Date: 2025-01-22
"""
import sqlalchemy as sa
from alembic import op
revision = '090'
down_revision = '089'
def upgrade() -> None:
op.add_column(
'user',
sa.Column('git_user_name', sa.String, nullable=True),
)
op.add_column(
'user',
sa.Column('git_user_email', sa.String, nullable=True),
)
def downgrade() -> None:
op.drop_column('user', 'git_user_email')
op.drop_column('user', 'git_user_name')

11661
enterprise/poetry.lock generated

File diff suppressed because one or more lines are too long

View File

@@ -4,10 +4,6 @@ from dotenv import load_dotenv
load_dotenv()
# Ensure SAAS configuration is used
if not os.getenv('OPENHANDS_CONFIG_CLS'):
os.environ['OPENHANDS_CONFIG_CLS'] = 'server.config.SaaSServerConfig'
import socketio # noqa: E402
from fastapi import Request, status # noqa: E402
from fastapi.middleware.cors import CORSMiddleware # noqa: E402
@@ -38,7 +34,6 @@ from server.routes.integration.linear import linear_integration_router # noqa:
from server.routes.integration.slack import slack_router # noqa: E402
from server.routes.mcp_patch import patch_mcp_server # noqa: E402
from server.routes.oauth_device import oauth_device_router # noqa: E402
from server.routes.orgs import org_router # noqa: E402
from server.routes.readiness import readiness_router # noqa: E402
from server.routes.user import saas_user_router # noqa: E402
from server.sharing.shared_conversation_router import ( # noqa: E402
@@ -91,7 +86,6 @@ if GITLAB_APP_CLIENT_ID:
base_app.include_router(gitlab_integration_router)
base_app.include_router(api_keys_router) # Add routes for API key management
base_app.include_router(org_router) # Add routes for organization management
add_github_proxy_routes(base_app)
add_debugging_routes(
base_app

View File

@@ -39,8 +39,6 @@ ROLE_CHECK_ENABLED = os.getenv('ROLE_CHECK_ENABLED', 'false').lower() in (
'on',
)
DUPLICATE_EMAIL_CHECK = os.getenv('DUPLICATE_EMAIL_CHECK', 'true') in ('1', 'true')
# reCAPTCHA Enterprise
RECAPTCHA_PROJECT_ID = os.getenv('RECAPTCHA_PROJECT_ID', '').strip()
RECAPTCHA_SITE_KEY = os.getenv('RECAPTCHA_SITE_KEY', '').strip()

View File

@@ -77,15 +77,6 @@ class SaasUserAuth(UserAuth):
self.access_token = SecretStr(tokens['access_token'])
self.refresh_token = SecretStr(tokens['refresh_token'])
self.refreshed = True
if not self.email or not self.email_verified or not self.user_id:
# We don't need to verify the signature here because we just refreshed
# this token from the IDP via token_manager.refresh()
access_token_payload = jwt.decode(
tokens['access_token'], options={'verify_signature': False}
)
self.user_id = access_token_payload['sub']
self.email = access_token_payload['email']
self.email_verified = access_token_payload['email_verified']
def _is_token_expired(self, token: SecretStr):
logger.debug('saas_user_auth_is_token_expired')
@@ -112,6 +103,7 @@ class SaasUserAuth(UserAuth):
return settings
settings_store = await self.get_user_settings_store()
settings = await settings_store.load()
# If load() returned None, should settings be created?
if settings:
settings.email = self.email
settings.email_verified = self.email_verified
@@ -282,13 +274,11 @@ async def saas_user_auth_from_bearer(request: Request) -> SaasUserAuth | None:
if not user_id:
return None
offline_token = await token_manager.load_offline_token(user_id)
saas_user_auth = SaasUserAuth(
return SaasUserAuth(
user_id=user_id,
refresh_token=SecretStr(offline_token),
auth_type=AuthType.BEARER,
)
await saas_user_auth.refresh()
return saas_user_auth
except Exception as exc:
raise BearerTokenError from exc

View File

@@ -16,11 +16,9 @@ from keycloak.exceptions import (
KeycloakError,
KeycloakPostError,
)
from server.auth.auth_error import ExpiredError
from server.auth.constants import (
BITBUCKET_APP_CLIENT_ID,
BITBUCKET_APP_CLIENT_SECRET,
DUPLICATE_EMAIL_CHECK,
GITHUB_APP_CLIENT_ID,
GITHUB_APP_CLIENT_SECRET,
GITLAB_APP_CLIENT_ID,
@@ -427,8 +425,6 @@ class TokenManager:
access_token = data.get('access_token')
refresh_token = data.get('refresh_token')
if not access_token or not refresh_token:
if data.get('error') == 'bad_refresh_token':
raise ExpiredError()
raise ValueError(
'Failed to refresh token: missing access_token or refresh_token in response.'
)
@@ -650,10 +646,6 @@ class TokenManager:
if not email:
return False
# We have the option to skip the duplicate email check in test environments
if not DUPLICATE_EMAIL_CHECK:
return False
base_email = extract_base_email(email)
if not base_email:
logger.warning(f'Could not extract base email from: {email}')

View File

@@ -8,7 +8,7 @@ import socketio
from server.logger import logger
from server.utils.conversation_callback_utils import invoke_conversation_callbacks
from storage.database import session_maker
from storage.stored_conversation_metadata_saas import StoredConversationMetadataSaas
from storage.stored_conversation_metadata import StoredConversationMetadata
from openhands.core.config import LLMConfig
from openhands.core.config.openhands_config import OpenHandsConfig
@@ -524,18 +524,16 @@ class ClusteredConversationManager(StandaloneConversationManager):
)
# Look up the user_id from the database
with session_maker() as session:
conversation_metadata_saas = (
session.query(StoredConversationMetadataSaas)
conversation_metadata = (
session.query(StoredConversationMetadata)
.filter(
StoredConversationMetadataSaas.conversation_id
StoredConversationMetadata.conversation_id
== conversation_id
)
.first()
)
user_id = (
str(conversation_metadata_saas.user_id)
if conversation_metadata_saas
else None
conversation_metadata.user_id if conversation_metadata else None
)
# Handle the stopped conversation asynchronously
asyncio.create_task(

View File

@@ -19,8 +19,8 @@ IS_LOCAL_ENV = bool(HOST == 'localhost')
DEFAULT_BILLING_MARGIN = float(os.environ.get('DEFAULT_BILLING_MARGIN', '1.0'))
# Map of user settings versions to their corresponding default LLM models
# This ensures that PERSONAL_WORKSPACE_VERSION_TO_MODEL and LITELLM_DEFAULT_MODEL stay in sync
PERSONAL_WORKSPACE_VERSION_TO_MODEL = {
# This ensures that CURRENT_USER_SETTINGS_VERSION and LITELLM_DEFAULT_MODEL stay in sync
USER_SETTINGS_VERSION_TO_MODEL = {
1: 'claude-3-5-sonnet-20241022',
2: 'claude-3-7-sonnet-20250219',
3: 'claude-sonnet-4-20250514',
@@ -31,8 +31,7 @@ PERSONAL_WORKSPACE_VERSION_TO_MODEL = {
LITELLM_DEFAULT_MODEL = os.getenv('LITELLM_DEFAULT_MODEL')
# Current user settings version - this should be the latest key in USER_SETTINGS_VERSION_TO_MODEL
ORG_SETTINGS_VERSION = max(PERSONAL_WORKSPACE_VERSION_TO_MODEL.keys())
PERSONAL_WORKSPACE_VERSION = max(PERSONAL_WORKSPACE_VERSION_TO_MODEL.keys())
CURRENT_USER_SETTINGS_VERSION = max(USER_SETTINGS_VERSION_TO_MODEL.keys())
LITE_LLM_API_URL = os.environ.get(
'LITE_LLM_API_URL', 'https://llm-proxy.app.all-hands.dev'
@@ -56,6 +55,7 @@ SUBSCRIPTION_PRICE_DATA = {
DEFAULT_INITIAL_BUDGET = float(os.environ.get('DEFAULT_INITIAL_BUDGET', '10'))
STRIPE_API_KEY = os.environ.get('STRIPE_API_KEY', None)
STRIPE_WEBHOOK_SECRET = os.environ.get('STRIPE_WEBHOOK_SECRET', None)
REQUIRE_PAYMENT = os.environ.get('REQUIRE_PAYMENT', '0') in ('1', 'true')
SLACK_CLIENT_ID = os.environ.get('SLACK_CLIENT_ID', None)
@@ -91,5 +91,5 @@ def get_default_litellm_model():
"""Construct proxy for litellm model based on user settings if not set explicitly."""
if LITELLM_DEFAULT_MODEL:
return LITELLM_DEFAULT_MODEL
model = PERSONAL_WORKSPACE_VERSION_TO_MODEL[PERSONAL_WORKSPACE_VERSION]
model = USER_SETTINGS_VERSION_TO_MODEL[CURRENT_USER_SETTINGS_VERSION]
return build_litellm_proxy_model_path(model)

View File

@@ -1,68 +0,0 @@
"""
Email domain validation utilities for enterprise endpoints.
"""
from fastapi import Depends, HTTPException, Request, status
from openhands.core.logger import openhands_logger as logger
from openhands.server.user_auth import get_user_auth, get_user_id
async def get_admin_user_id(
request: Request, user_id: str | None = Depends(get_user_id)
) -> str:
"""
Dependency that validates user has @openhands.dev email domain.
This dependency can be used in place of get_user_id for endpoints that
should only be accessible to admin users. Currently, this is implemented
by checking for @openhands.dev email domain.
TODO: In the future, this should be replaced with an explicit is_admin flag
in user/org settings instead of relying on email domain validation.
Args:
request: FastAPI request object
user_id: User ID from get_user_id dependency
Returns:
str: User ID if email domain is valid
Raises:
HTTPException: 403 if email domain is not @openhands.dev
HTTPException: 401 if user is not authenticated
Example:
@router.post('/endpoint')
async def create_resource(
user_id: str = Depends(get_admin_user_id),
):
# Only admin users can access this endpoint
pass
"""
if not user_id:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail='User not authenticated',
)
user_auth = await get_user_auth(request)
user_email = await user_auth.get_user_email()
if not user_email:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail='User email not available',
)
if not user_email.endswith('@openhands.dev'):
logger.warning(
'Access denied - invalid email domain',
extra={'user_id': user_id, 'email_domain': user_email.split('@')[-1]},
)
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail='Access restricted to @openhands.dev users',
)
return user_id

View File

@@ -44,13 +44,11 @@ class MyProcessor(MaintenanceTaskProcessor):
### UserVersionUpgradeProcessor
Located in `user_version_upgrade_processor.py`, this processor:
- Handles up to 100 user IDs per task
- Upgrades users with `user_version < ORG_SETTINGS_VERSION`
- Upgrades users with `user_version < CURRENT_USER_SETTINGS_VERSION`
- Uses `SaasSettingsStore.create_default_settings()` for upgrades
**Usage:**
```python
from server.maintenance_task_processor.user_version_upgrade_processor import UserVersionUpgradeProcessor
@@ -146,26 +144,22 @@ task = create_maintenance_task(
## Best Practices
### Processor Design
- Keep tasks short-running (under 1 minute)
- Handle errors gracefully and return meaningful error information
- Use batch processing for large datasets
- Include progress information in the return dict
### Error Handling
- Always wrap your processor logic in try-catch blocks
- Return structured error information
- Log important events for debugging
### Performance
- Limit batch sizes to avoid long-running tasks
- Use database sessions efficiently
- Consider memory usage for large datasets
### Testing
- Create unit tests for your processors
- Test error conditions
- Verify the processor serialization/deserialization works correctly
@@ -173,7 +167,6 @@ task = create_maintenance_task(
## Database Patterns
The maintenance task system follows the repository's established patterns:
- Uses `session_maker()` for database operations
- Wraps sync database operations in `call_sync_from_async` for async routes
- Follows proper SQLAlchemy query patterns
@@ -181,18 +174,15 @@ The maintenance task system follows the repository's established patterns:
## Integration with Existing Systems
### User Management
- Integrates with the existing `UserSettings` model
- Uses the current user versioning system (`ORG_SETTINGS_VERSION`)
- Uses the current user versioning system (`CURRENT_USER_SETTINGS_VERSION`)
- Maintains compatibility with existing user management workflows
### Authentication
- Admin endpoints use the existing SaaS authentication system
- Requires users to have `admin = True` in their UserSettings
### Monitoring
- Tasks are logged with structured information
- Status updates are tracked in the database
- Error information is preserved for debugging
@@ -216,7 +206,6 @@ The maintenance task system follows the repository's established patterns:
## Future Enhancements
Potential improvements that could be added:
- Task dependencies and scheduling
- Retry mechanisms for failed tasks
- Real-time progress updates

View File

@@ -0,0 +1,155 @@
from __future__ import annotations
from typing import List
from server.constants import CURRENT_USER_SETTINGS_VERSION
from server.logger import logger
from storage.database import session_maker
from storage.maintenance_task import MaintenanceTask, MaintenanceTaskProcessor
from storage.saas_settings_store import SaasSettingsStore
from storage.user_settings import UserSettings
from openhands.core.config import load_openhands_config
class UserVersionUpgradeProcessor(MaintenanceTaskProcessor):
"""
Processor for upgrading user settings to the current version.
This processor takes a list of user IDs and upgrades any users
whose user_version is less than CURRENT_USER_SETTINGS_VERSION.
"""
user_ids: List[str]
async def __call__(self, task: MaintenanceTask) -> dict:
"""
Process user version upgrades for the specified user IDs.
Args:
task: The maintenance task being processed
Returns:
dict: Results containing successful and failed user IDs
"""
logger.info(
'user_version_upgrade_processor:start',
extra={
'task_id': task.id,
'user_count': len(self.user_ids),
'current_version': CURRENT_USER_SETTINGS_VERSION,
},
)
if len(self.user_ids) > 100:
raise ValueError(
f'Too many user IDs: {len(self.user_ids)}. Maximum is 100.'
)
config = load_openhands_config()
# Track results
successful_upgrades = []
failed_upgrades = []
users_already_current = []
# Find users that need upgrading
with session_maker() as session:
users_to_upgrade = (
session.query(UserSettings)
.filter(
UserSettings.keycloak_user_id.in_(self.user_ids),
UserSettings.user_version < CURRENT_USER_SETTINGS_VERSION,
)
.all()
)
# Track users that are already current
users_needing_upgrade_ids = {u.keycloak_user_id for u in users_to_upgrade}
users_already_current = [
uid for uid in self.user_ids if uid not in users_needing_upgrade_ids
]
logger.info(
'user_version_upgrade_processor:found_users',
extra={
'task_id': task.id,
'users_to_upgrade': len(users_to_upgrade),
'users_already_current': len(users_already_current),
'total_requested': len(self.user_ids),
},
)
# Process each user that needs upgrading
for user_settings in users_to_upgrade:
user_id = user_settings.keycloak_user_id
old_version = user_settings.user_version
try:
logger.info(
'user_version_upgrade_processor:upgrading_user',
extra={
'task_id': task.id,
'user_id': user_id,
'old_version': old_version,
'new_version': CURRENT_USER_SETTINGS_VERSION,
},
)
# Create SaasSettingsStore instance and upgrade
settings_store = await SaasSettingsStore.get_instance(config, user_id)
await settings_store.create_default_settings(user_settings)
successful_upgrades.append(
{
'user_id': user_id,
'old_version': old_version,
'new_version': CURRENT_USER_SETTINGS_VERSION,
}
)
logger.info(
'user_version_upgrade_processor:user_upgraded',
extra={
'task_id': task.id,
'user_id': user_id,
'old_version': old_version,
'new_version': CURRENT_USER_SETTINGS_VERSION,
},
)
except Exception as e:
failed_upgrades.append(
{'user_id': user_id, 'old_version': old_version, 'error': str(e)}
)
logger.error(
'user_version_upgrade_processor:user_upgrade_failed',
extra={
'task_id': task.id,
'user_id': user_id,
'old_version': old_version,
'error': str(e),
},
)
# Create result summary
result = {
'total_users': len(self.user_ids),
'users_already_current': users_already_current,
'successful_upgrades': successful_upgrades,
'failed_upgrades': failed_upgrades,
'summary': (
f'Processed {len(self.user_ids)} users: '
f'{len(successful_upgrades)} upgraded, '
f'{len(users_already_current)} already current, '
f'{len(failed_upgrades)} errors'
),
}
logger.info(
'user_version_upgrade_processor:completed',
extra={'task_id': task.id, 'result': result},
)
return result

View File

@@ -1,5 +1,7 @@
from typing import TYPE_CHECKING
from storage.api_key_store import ApiKeyStore
if TYPE_CHECKING:
from openhands.core.config.openhands_config import OpenHandsConfig
@@ -34,7 +36,6 @@ class SaaSOpenHandsMCPConfig(OpenHandsMCPConfig):
Returns:
A tuple containing the default SSE server configuration and a list of MCP stdio server configurations
"""
from storage.api_key_store import ApiKeyStore
api_key_store = ApiKeyStore.get_instance()
if user_id:

View File

@@ -144,7 +144,7 @@ class SetAuthCookieMiddleware:
# "if accepted_tos is not None" as there should not be any users with
# accepted_tos equal to "None"
if accepted_tos is False and request.url.path != '/api/accept_tos':
logger.warning('User has not accepted the terms of service')
logger.error('User has not accepted the terms of service')
raise TosNotAcceptedError
def _should_attach(self, request: Request) -> bool:
@@ -162,8 +162,6 @@ class SetAuthCookieMiddleware:
'/api/email/resend',
'/oauth/device/authorize',
'/oauth/device/token',
'/api/v1/web-client/config',
'/api/v1/webhooks/secrets',
)
if path in ignore_paths:
return False

View File

@@ -1,95 +1,113 @@
from datetime import UTC, datetime
import httpx
from fastapi import APIRouter, Depends, HTTPException, status
from pydantic import BaseModel, field_validator
from server.config import get_config
from server.constants import (
BYOR_KEY_VERIFICATION_TIMEOUT,
LITE_LLM_API_KEY,
LITE_LLM_API_URL,
)
from storage.api_key_store import ApiKeyStore
from storage.lite_llm_manager import LiteLlmManager
from storage.org_member import OrgMember
from storage.org_member_store import OrgMemberStore
from storage.user_store import UserStore
from storage.database import session_maker
from storage.saas_settings_store import SaasSettingsStore
from openhands.core.logger import openhands_logger as logger
from openhands.server.user_auth import get_user_id
from openhands.utils.async_utils import call_sync_from_async
from openhands.utils.http_session import httpx_verify_option
# Helper functions for BYOR API key management
async def get_byor_key_from_db(user_id: str) -> str | None:
"""Get the BYOR key from the database for a user."""
config = get_config()
settings_store = SaasSettingsStore(
user_id=user_id, session_maker=session_maker, config=config
)
def _get_byor_key():
user = UserStore.get_user_by_id(user_id)
if not user:
return None
current_org_id = user.current_org_id
current_org_member: OrgMember = None
for org_member in user.org_members:
if org_member.org_id == current_org_id:
current_org_member = org_member
break
if not current_org_member:
return None
if current_org_member.llm_api_key_for_byor:
return current_org_member.llm_api_key_for_byor.get_secret_value()
return None
return await call_sync_from_async(_get_byor_key)
user_db_settings = await call_sync_from_async(
settings_store.get_user_settings_by_keycloak_id, user_id
)
if user_db_settings and user_db_settings.llm_api_key_for_byor:
return user_db_settings.llm_api_key_for_byor
return None
async def store_byor_key_in_db(user_id: str, key: str) -> None:
"""Store the BYOR key in the database for a user."""
config = get_config()
settings_store = SaasSettingsStore(
user_id=user_id, session_maker=session_maker, config=config
)
def _update_user_settings():
user = UserStore.get_user_by_id(user_id)
if not user:
return None
current_org_id = user.current_org_id
current_org_member: OrgMember = None
for org_member in user.org_members:
if org_member.org_id == current_org_id:
current_org_member = org_member
break
if not current_org_member:
return None
current_org_member.llm_api_key_for_byor = key
OrgMemberStore.update_org_member(current_org_member)
with session_maker() as session:
user_db_settings = settings_store.get_user_settings_by_keycloak_id(
user_id, session
)
if user_db_settings:
user_db_settings.llm_api_key_for_byor = key
session.commit()
logger.info(
'Successfully stored BYOR key in user settings',
extra={'user_id': user_id},
)
else:
logger.warning(
'User settings not found when trying to store BYOR key',
extra={'user_id': user_id},
)
await call_sync_from_async(_update_user_settings)
async def generate_byor_key(user_id: str) -> str | None:
"""Generate a new BYOR key for a user."""
if not (LITE_LLM_API_KEY and LITE_LLM_API_URL):
logger.warning(
'LiteLLM API configuration not found', extra={'user_id': user_id}
)
return None
try:
user = await UserStore.get_user_by_id_async(user_id)
if not user:
return None
current_org_id = str(user.current_org_id)
key = await LiteLlmManager.generate_key(
user_id,
current_org_id,
f'BYOR Key - user {user_id}, org {current_org_id}',
{'type': 'byor'},
)
if key:
logger.info(
'Successfully generated new BYOR key',
extra={
async with httpx.AsyncClient(
verify=httpx_verify_option(),
headers={
'x-goog-api-key': LITE_LLM_API_KEY,
},
) as client:
response = await client.post(
f'{LITE_LLM_API_URL}/key/generate',
json={
'user_id': user_id,
'key_length': len(key) if key else 0,
'key_prefix': key[:10] + '...' if key and len(key) > 10 else key,
'metadata': {'type': 'byor'},
'key_alias': f'BYOR Key - user {user_id}',
},
)
return key
else:
logger.error(
'Failed to generate BYOR LLM API key - no key in response',
extra={'user_id': user_id},
)
return None
response.raise_for_status()
response_json = response.json()
key = response_json.get('key')
if key:
logger.info(
'Successfully generated new BYOR key',
extra={
'user_id': user_id,
'key_length': len(key) if key else 0,
'key_prefix': key[:10] + '...'
if key and len(key) > 10
else key,
},
)
return key
else:
logger.error(
'Failed to generate BYOR LLM API key - no key in response',
extra={'user_id': user_id, 'response_json': response_json},
)
return None
except Exception as e:
logger.exception(
'Error generating BYOR key',
@@ -98,25 +116,96 @@ async def generate_byor_key(user_id: str) -> str | None:
return None
async def delete_byor_key_from_litellm(user_id: str, byor_key: str) -> bool:
"""Delete the BYOR key from LiteLLM using the key directly.
async def verify_byor_key_in_litellm(byor_key: str, user_id: str) -> bool:
"""Verify that a BYOR key is valid in LiteLLM by making a lightweight API call.
Also attempts to delete by key alias if the key is not found,
to clean up orphaned aliases that could block key regeneration.
Args:
byor_key: The BYOR key to verify
user_id: The user ID for logging purposes
Returns:
True if the key is verified as valid, False if verification fails or key is invalid.
Returns False on network errors/timeouts to ensure we don't return potentially invalid keys.
"""
try:
# Get user to construct the key alias
user = await UserStore.get_user_by_id_async(user_id)
key_alias = None
if user and user.current_org_id:
key_alias = f'BYOR Key - user {user_id}, org {user.current_org_id}'
if not (LITE_LLM_API_URL and byor_key):
return False
await LiteLlmManager.delete_key(byor_key, key_alias=key_alias)
logger.info(
'Successfully deleted BYOR key from LiteLLM',
extra={'user_id': user_id},
try:
async with httpx.AsyncClient(
verify=httpx_verify_option(),
timeout=BYOR_KEY_VERIFICATION_TIMEOUT,
) as client:
# Make a lightweight request to verify the key
# Using /v1/models endpoint as it's lightweight and requires authentication
response = await client.get(
f'{LITE_LLM_API_URL}/v1/models',
headers={
'Authorization': f'Bearer {byor_key}',
},
)
# Only 200 status code indicates valid key
if response.status_code == 200:
logger.debug(
'BYOR key verification successful',
extra={'user_id': user_id},
)
return True
# All other status codes (401, 403, 500, etc.) are treated as invalid
# This includes authentication errors and server errors
logger.warning(
'BYOR key verification failed - treating as invalid',
extra={
'user_id': user_id,
'status_code': response.status_code,
'key_prefix': byor_key[:10] + '...'
if len(byor_key) > 10
else byor_key,
},
)
return False
except (httpx.TimeoutException, Exception) as e:
# Any exception (timeout, network error, etc.) means we can't verify
# Return False to trigger regeneration rather than returning potentially invalid key
logger.warning(
'BYOR key verification error - treating as invalid to ensure key validity',
extra={
'user_id': user_id,
'error': str(e),
'error_type': type(e).__name__,
},
)
return True
return False
async def delete_byor_key_from_litellm(user_id: str, byor_key: str) -> bool:
"""Delete the BYOR key from LiteLLM using the key directly."""
if not (LITE_LLM_API_KEY and LITE_LLM_API_URL):
logger.warning(
'LiteLLM API configuration not found', extra={'user_id': user_id}
)
return False
try:
async with httpx.AsyncClient(
verify=httpx_verify_option(),
headers={
'x-goog-api-key': LITE_LLM_API_KEY,
},
) as client:
# Delete the key directly using the key value
delete_url = f'{LITE_LLM_API_URL}/key/delete'
delete_payload = {'keys': [byor_key]}
delete_response = await client.post(delete_url, json=delete_payload)
delete_response.raise_for_status()
logger.info(
'Successfully deleted BYOR key from LiteLLM',
extra={'user_id': user_id},
)
return True
except Exception as e:
logger.exception(
'Error deleting BYOR key from LiteLLM',
@@ -268,7 +357,7 @@ async def get_llm_api_key_for_byor(user_id: str = Depends(get_user_id)):
byor_key = await get_byor_key_from_db(user_id)
if byor_key:
# Validate that the key is actually registered in LiteLLM
is_valid = await LiteLlmManager.verify_key(byor_key, user_id)
is_valid = await verify_byor_key_in_litellm(byor_key, user_id)
if is_valid:
return {'key': byor_key}
else:
@@ -323,6 +412,15 @@ async def refresh_llm_api_key_for_byor(user_id: str = Depends(get_user_id)):
logger.info('Starting BYOR LLM API key refresh', extra={'user_id': user_id})
try:
if not (LITE_LLM_API_KEY and LITE_LLM_API_URL):
logger.warning(
'LiteLLM API configuration not found', extra={'user_id': user_id}
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail='LiteLLM API configuration not found',
)
# Get the existing BYOR key from the database
existing_byor_key = await get_byor_key_from_db(user_id)

View File

@@ -1,6 +1,5 @@
import base64
import json
import uuid
import warnings
from datetime import datetime, timezone
from typing import Annotated, Literal, Optional
@@ -23,12 +22,12 @@ from server.auth.gitlab_sync import schedule_gitlab_repo_sync
from server.auth.recaptcha_service import recaptcha_service
from server.auth.saas_user_auth import SaasUserAuth
from server.auth.token_manager import TokenManager
from server.config import sign_token
from server.config import get_config, sign_token
from server.constants import IS_FEATURE_ENV
from server.routes.event_webhook import _get_session_api_key, _get_user_id
from storage.database import session_maker
from storage.user import User
from storage.user_store import UserStore
from storage.saas_settings_store import SaasSettingsStore
from storage.user_settings import UserSettings
from openhands.core.logger import openhands_logger as logger
from openhands.integrations.provider import ProviderHandler
@@ -88,8 +87,7 @@ def get_cookie_domain(request: Request) -> str | None:
# for now just use the full hostname except for staging stacks.
return (
None
if not request.url.hostname
or request.url.hostname.endswith('staging.all-hands.dev')
if (request.url.hostname or '').endswith('staging.all-hand.dev')
else request.url.hostname
)
@@ -176,20 +174,6 @@ async def keycloak_callback(
email = user_info.get('email')
user_id = user_info['sub']
user = await UserStore.get_user_by_id_async(user_id)
if not user:
user = await UserStore.create_user(user_id, user_info)
if not user:
logger.error(f'Failed to authenticate user {user_info["preferred_username"]}')
return JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
content={
'error': f'Failed to authenticate user {user_info["preferred_username"]}'
},
)
logger.info(f'Logging in user {str(user.id)} in org {user.current_org_id}')
# reCAPTCHA verification with Account Defender
if RECAPTCHA_SITE_KEY:
@@ -376,7 +360,15 @@ async def keycloak_callback(
f'&state={state}'
)
has_accepted_tos = user.accepted_tos is not None
config = get_config()
settings_store = SaasSettingsStore(
user_id=user_id, session_maker=session_maker, config=config
)
user_settings = settings_store.get_user_settings_by_keycloak_id(user_id)
has_accepted_tos = (
user_settings is not None and user_settings.accepted_tos is not None
)
# If the user hasn't accepted the TOS, redirect to the TOS page
if not has_accepted_tos:
encoded_redirect_url = quote(redirect_url, safe='')
@@ -494,20 +486,28 @@ async def accept_tos(request: Request):
redirect_url = body.get('redirect_url', str(request.base_url))
# Update user settings with TOS acceptance
accepted_tos: datetime = datetime.now(timezone.utc)
with session_maker() as session:
user = session.query(User).filter(User.id == uuid.UUID(user_id)).first()
if not user:
session.rollback()
logger.error('User for {user_id} not found.')
return JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
content={'error': 'User does not exist'},
user_settings = (
session.query(UserSettings)
.filter(UserSettings.keycloak_user_id == user_id)
.first()
)
if user_settings:
user_settings.accepted_tos = datetime.now(timezone.utc)
session.merge(user_settings)
else:
# Create user settings if they don't exist
user_settings = UserSettings(
keycloak_user_id=user_id,
accepted_tos=datetime.now(timezone.utc),
user_version=0, # This will trigger a migration to the latest version on next load
)
user.accepted_tos = accepted_tos
session.add(user_settings)
session.commit()
logger.info(f'User {user_id} accepted TOS')
logger.info(f'User {user_id} accepted TOS')
response = JSONResponse(
status_code=status.HTTP_200_OK, content={'redirect_url': redirect_url}

View File

@@ -4,42 +4,63 @@ from datetime import UTC, datetime
from decimal import Decimal
from enum import Enum
import httpx
import stripe
from dateutil.relativedelta import relativedelta # type: ignore
from fastapi import APIRouter, Depends, HTTPException, Request, status
from fastapi.responses import RedirectResponse
from fastapi.responses import JSONResponse, RedirectResponse
from integrations import stripe_service
from pydantic import BaseModel
from server.config import get_config
from server.constants import (
LITE_LLM_API_KEY,
LITE_LLM_API_URL,
STRIPE_API_KEY,
STRIPE_WEBHOOK_SECRET,
SUBSCRIPTION_PRICE_DATA,
get_default_litellm_model,
)
from server.logger import logger
from starlette.datastructures import URL
from storage.billing_session import BillingSession
from storage.database import session_maker
from storage.lite_llm_manager import LiteLlmManager
from storage.saas_settings_store import SaasSettingsStore
from storage.subscription_access import SubscriptionAccess
from storage.user_store import UserStore
from openhands.app_server.config import get_global_config
from openhands.server.user_auth import get_user_id
from openhands.utils.http_session import httpx_verify_option
stripe.api_key = STRIPE_API_KEY
billing_router = APIRouter(prefix='/api/billing')
async def validate_billing_enabled() -> None:
# TODO: Add a new app_mode named "ON_PREM" to support self-hosted customers instead of doing this
# and members should comment out the "validate_saas_environment" function if they are developing and testing locally.
def is_all_hands_saas_environment(request: Request) -> bool:
"""Check if the current domain is an All Hands SaaS environment.
Args:
request: FastAPI Request object
Returns:
True if the current domain contains "all-hands.dev" or "openhands.dev" postfix
"""
Validate that the billing feature flag is enabled
hostname = request.url.hostname or ''
return hostname.endswith('all-hands.dev') or hostname.endswith('openhands.dev')
def validate_saas_environment(request: Request) -> None:
"""Validate that the request is coming from an All Hands SaaS environment.
Args:
request: FastAPI Request object
Raises:
HTTPException: If the request is not from an All Hands SaaS environment
"""
config = get_global_config()
web_client_config = await config.web_client.get_web_client_config()
if not web_client_config.feature_flags.enable_billing:
if not is_all_hands_saas_environment(request):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=(
'Billing is disabled in this environment. '
'Please set OH_WEB_CLIENT_FEATURE_FLAGS_ENABLE_BILLING to enable billing.'
),
detail='Checkout sessions are only available for All Hands SaaS environments',
)
@@ -85,20 +106,31 @@ def calculate_credits(user_info: LiteLlmUserInfo) -> float:
return max(max_budget - spend, 0.0)
# Endpoint to retrieve the current organization's credit balance
# Endpoint to retrieve user's current credit balance
@billing_router.get('/credits')
async def get_credits(user_id: str = Depends(get_user_id)) -> GetCreditsResponse:
if not stripe_service.STRIPE_API_KEY:
return GetCreditsResponse()
user = await UserStore.get_user_by_id_async(user_id)
user_team_info = await LiteLlmManager.get_user_team_info(
user_id, str(user.current_org_id)
)
# Update to use calculate_credits
spend = user_team_info.get('spend', 0)
max_budget = (user_team_info.get('litellm_budget_table') or {}).get('max_budget', 0)
credits = max(max_budget - spend, 0)
return GetCreditsResponse(credits=Decimal('{:.2f}'.format(credits)))
try:
async with httpx.AsyncClient(
verify=httpx_verify_option(), timeout=15.0
) as client:
user_json = await _get_litellm_user(client, user_id)
credits = calculate_credits(user_json['user_info'])
return GetCreditsResponse(credits=Decimal('{:.2f}'.format(credits)))
except httpx.HTTPStatusError as e:
logger.error(
f'litellm_get_user_failed: {type(e).__name__}: {e}',
extra={
'user_id': user_id,
'status_code': e.response.status_code,
},
exc_info=True,
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail='Failed to retrieve credit balance from billing service',
)
# Endpoint to retrieve user's current subscription access
@@ -133,7 +165,79 @@ async def get_subscription_access(
async def has_payment_method(user_id: str = Depends(get_user_id)) -> bool:
if not user_id:
raise HTTPException(status.HTTP_401_UNAUTHORIZED)
return await stripe_service.has_payment_method_by_user_id(user_id)
return await stripe_service.has_payment_method(user_id)
# Endpoint to cancel user's subscription
@billing_router.post('/cancel-subscription')
async def cancel_subscription(user_id: str = Depends(get_user_id)) -> JSONResponse:
"""Cancel user's active subscription at the end of the current billing period."""
if not user_id:
raise HTTPException(status.HTTP_401_UNAUTHORIZED)
with session_maker() as session:
# Find the user's active subscription
now = datetime.now(UTC)
subscription_access = (
session.query(SubscriptionAccess)
.filter(SubscriptionAccess.status == 'ACTIVE')
.filter(SubscriptionAccess.user_id == user_id)
.filter(SubscriptionAccess.start_at <= now)
.filter(SubscriptionAccess.end_at >= now)
.filter(SubscriptionAccess.cancelled_at.is_(None)) # Not already cancelled
.first()
)
if not subscription_access:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail='No active subscription found',
)
if not subscription_access.stripe_subscription_id:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail='Cannot cancel subscription: missing Stripe subscription ID',
)
try:
# Cancel the subscription in Stripe at period end
await stripe.Subscription.modify_async(
subscription_access.stripe_subscription_id, cancel_at_period_end=True
)
# Update local database
subscription_access.cancelled_at = datetime.now(UTC)
session.merge(subscription_access)
session.commit()
logger.info(
'subscription_cancelled',
extra={
'user_id': user_id,
'stripe_subscription_id': subscription_access.stripe_subscription_id,
'subscription_access_id': subscription_access.id,
'end_at': subscription_access.end_at,
},
)
return JSONResponse(
{'status': 'success', 'message': 'Subscription cancelled successfully'}
)
except stripe.StripeError as e:
logger.error(
'stripe_cancellation_failed',
extra={
'user_id': user_id,
'stripe_subscription_id': subscription_access.stripe_subscription_id,
'error': str(e),
},
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f'Failed to cancel subscription: {str(e)}',
)
# Endpoint to create a new setup intent in stripe
@@ -141,17 +245,17 @@ async def has_payment_method(user_id: str = Depends(get_user_id)) -> bool:
async def create_customer_setup_session(
request: Request, user_id: str = Depends(get_user_id)
) -> CreateBillingSessionResponse:
await validate_billing_enabled()
customer_info = await stripe_service.find_or_create_customer_by_user_id(user_id)
base_url = _get_base_url(request)
validate_saas_environment(request)
customer_id = await stripe_service.find_or_create_customer(user_id)
checkout_session = await stripe.checkout.Session.create_async(
customer=customer_info['customer_id'],
customer=customer_id,
mode='setup',
payment_method_types=['card'],
success_url=f'{base_url}?free_credits=success',
cancel_url=f'{base_url}',
success_url=f'{request.base_url}?free_credits=success',
cancel_url=f'{request.base_url}',
)
return CreateBillingSessionResponse(redirect_url=checkout_session.url)
return CreateBillingSessionResponse(redirect_url=checkout_session.url) # type: ignore[arg-type]
# Endpoint to create a new Stripe checkout session for credit purchase
@@ -161,11 +265,11 @@ async def create_checkout_session(
request: Request,
user_id: str = Depends(get_user_id),
) -> CreateBillingSessionResponse:
await validate_billing_enabled()
base_url = _get_base_url(request)
customer_info = await stripe_service.find_or_create_customer_by_user_id(user_id)
validate_saas_environment(request)
customer_id = await stripe_service.find_or_create_customer(user_id)
checkout_session = await stripe.checkout.Session.create_async(
customer=customer_info['customer_id'],
customer=customer_id,
line_items=[
{
'price_data': {
@@ -178,22 +282,21 @@ async def create_checkout_session(
'tax_behavior': 'exclusive',
},
'quantity': 1,
},
}
],
mode='payment',
payment_method_types=['card'],
saved_payment_method_options={
'payment_method_save': 'enabled',
},
success_url=f'{base_url}api/billing/success?session_id={{CHECKOUT_SESSION_ID}}',
cancel_url=f'{base_url}api/billing/cancel?session_id={{CHECKOUT_SESSION_ID}}',
success_url=f'{request.base_url}api/billing/success?session_id={{CHECKOUT_SESSION_ID}}',
cancel_url=f'{request.base_url}api/billing/cancel?session_id={{CHECKOUT_SESSION_ID}}',
)
logger.info(
'created_stripe_checkout_session',
extra={
'stripe_customer_id': customer_info['customer_id'],
'stripe_customer_id': customer_id,
'user_id': user_id,
'org_id': customer_info['org_id'],
'amount': body.amount,
'checkout_session_id': checkout_session.id,
},
@@ -202,14 +305,105 @@ async def create_checkout_session(
billing_session = BillingSession(
id=checkout_session.id,
user_id=user_id,
org_id=customer_info['org_id'],
price=body.amount,
price_code='NA',
billing_session_type=BillingSessionType.DIRECT_PAYMENT.value,
)
session.add(billing_session)
session.commit()
return CreateBillingSessionResponse(redirect_url=checkout_session.url)
return CreateBillingSessionResponse(redirect_url=checkout_session.url) # type: ignore[arg-type]
@billing_router.post('/subscription-checkout-session')
async def create_subscription_checkout_session(
request: Request,
billing_session_type: BillingSessionType = BillingSessionType.MONTHLY_SUBSCRIPTION,
user_id: str = Depends(get_user_id),
) -> CreateBillingSessionResponse:
validate_saas_environment(request)
# Prevent duplicate subscriptions for the same user
with session_maker() as session:
now = datetime.now(UTC)
existing_active_subscription = (
session.query(SubscriptionAccess)
.filter(SubscriptionAccess.status == 'ACTIVE')
.filter(SubscriptionAccess.user_id == user_id)
.filter(SubscriptionAccess.start_at <= now)
.filter(SubscriptionAccess.end_at >= now)
.filter(SubscriptionAccess.cancelled_at.is_(None)) # Not cancelled
.first()
)
if existing_active_subscription:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail='Cannot create subscription: User already has an active subscription that has not been cancelled',
)
customer_id = await stripe_service.find_or_create_customer(user_id)
subscription_price_data = SUBSCRIPTION_PRICE_DATA[billing_session_type.value]
checkout_session = await stripe.checkout.Session.create_async(
customer=customer_id,
line_items=[
{
'price_data': subscription_price_data,
'quantity': 1,
}
],
mode='subscription',
payment_method_types=['card'],
saved_payment_method_options={
'payment_method_save': 'enabled',
},
success_url=f'{request.base_url}api/billing/success?session_id={{CHECKOUT_SESSION_ID}}',
cancel_url=f'{request.base_url}api/billing/cancel?session_id={{CHECKOUT_SESSION_ID}}',
subscription_data={
'metadata': {
'user_id': user_id,
'billing_session_type': billing_session_type.value,
}
},
)
logger.info(
'created_stripe_subscription_checkout_session',
extra={
'stripe_customer_id': customer_id,
'user_id': user_id,
'checkout_session_id': checkout_session.id,
'billing_session_type': billing_session_type.value,
},
)
with session_maker() as session:
billing_session = BillingSession(
id=checkout_session.id,
user_id=user_id,
price=subscription_price_data['unit_amount'],
price_code='NA',
billing_session_type=billing_session_type.value,
)
session.add(billing_session)
session.commit()
return CreateBillingSessionResponse(
redirect_url=typing.cast(str, checkout_session.url)
)
@billing_router.get('/create-subscription-checkout-session')
async def create_subscription_checkout_session_via_get(
request: Request,
billing_session_type: BillingSessionType = BillingSessionType.MONTHLY_SUBSCRIPTION,
user_id: str = Depends(get_user_id),
) -> RedirectResponse:
"""Create a subscription checkout session using a GET request (For easier copy / paste to URL bar)."""
validate_saas_environment(request)
response = await create_subscription_checkout_session(
request, billing_session_type, user_id
)
return RedirectResponse(response.redirect_url)
# Callback endpoint for successful Stripe payments - updates user credits and billing session status
@@ -231,6 +425,15 @@ async def success_callback(session_id: str, request: Request):
)
raise HTTPException(status.HTTP_400_BAD_REQUEST)
# Any non direct payment (Subscription) is processed in the invoice_payment.paid by the webhook
if (
billing_session.billing_session_type
!= BillingSessionType.DIRECT_PAYMENT.value
):
return RedirectResponse(
f'{request.base_url}settings?checkout=success', status_code=302
)
stripe_session = stripe.checkout.Session.retrieve(session_id)
if stripe_session.status != 'complete':
# Hopefully this never happens - we get a redirect from stripe where the payment is not yet complete
@@ -244,40 +447,34 @@ async def success_callback(session_id: str, request: Request):
)
raise HTTPException(status.HTTP_400_BAD_REQUEST)
user = await UserStore.get_user_by_id_async(billing_session.user_id)
user_team_info = await LiteLlmManager.get_user_team_info(
billing_session.user_id, str(user.current_org_id)
)
amount_subtotal = stripe_session.amount_subtotal or 0
add_credits = amount_subtotal / 100
max_budget = (user_team_info.get('litellm_budget_table') or {}).get(
'max_budget', 0
)
new_max_budget = max_budget + add_credits
async with httpx.AsyncClient(verify=httpx_verify_option()) as client:
# Update max budget in litellm
user_json = await _get_litellm_user(client, billing_session.user_id)
amount_subtotal = stripe_session.amount_subtotal or 0
add_credits = amount_subtotal / 100
new_max_budget = (
(user_json.get('user_info') or {}).get('max_budget') or 0
) + add_credits
await _upsert_litellm_user(client, billing_session.user_id, new_max_budget)
await LiteLlmManager.update_team_and_users_budget(
str(user.current_org_id), new_max_budget
)
# Store transaction status
billing_session.status = 'completed'
billing_session.price = add_credits
billing_session.updated_at = datetime.now(UTC)
session.merge(billing_session)
logger.info(
'stripe_checkout_success',
extra={
'amount_subtotal': stripe_session.amount_subtotal,
'user_id': billing_session.user_id,
'org_id': str(user.current_org_id),
'checkout_session_id': billing_session.id,
'stripe_customer_id': stripe_session.customer,
},
)
session.commit()
# Store transaction status
billing_session.status = 'completed'
billing_session.price = amount_subtotal
billing_session.updated_at = datetime.now(UTC)
session.merge(billing_session)
logger.info(
'stripe_checkout_success',
extra={
'amount_subtotal': stripe_session.amount_subtotal,
'user_id': billing_session.user_id,
'checkout_session_id': billing_session.id,
'stripe_customer_id': stripe_session.customer,
},
)
session.commit()
return RedirectResponse(
f'{_get_base_url(request)}settings/billing?checkout=success', status_code=302
f'{request.base_url}settings/billing?checkout=success', status_code=302
)
@@ -304,14 +501,206 @@ async def cancel_callback(session_id: str, request: Request):
session.merge(billing_session)
session.commit()
# Redirect credit purchases to billing screen, subscriptions to LLM settings
if (
billing_session.billing_session_type
== BillingSessionType.DIRECT_PAYMENT.value
):
return RedirectResponse(
f'{request.base_url}settings/billing?checkout=cancel',
status_code=302,
)
else:
return RedirectResponse(
f'{request.base_url}settings?checkout=cancel', status_code=302
)
# If no billing session found, default to LLM settings (subscription flow)
return RedirectResponse(
f'{_get_base_url(request)}settings/billing?checkout=cancel', status_code=302
f'{request.base_url}settings?checkout=cancel', status_code=302
)
def _get_base_url(request: Request) -> URL:
# Never send any part of the credit card process over a non secure connection
base_url = request.base_url
if base_url.hostname != 'localhost':
base_url = base_url.replace(scheme='https')
return base_url
@billing_router.post('/stripe-webhook')
async def stripe_webhook(request: Request) -> JSONResponse:
"""Endpoint for stripe webhooks."""
payload = await request.body()
sig_header = request.headers.get('stripe-signature')
try:
event = stripe.Webhook.construct_event(
payload, sig_header, STRIPE_WEBHOOK_SECRET
)
except ValueError as e:
# Invalid payload
raise HTTPException(status_code=400, detail=f'Invalid payload: {e}')
except stripe.SignatureVerificationError as e:
# Invalid signature
raise HTTPException(status_code=400, detail=f'Invalid signature: {e}')
# Handle the event
logger.info('stripe_webhook_event', extra={'event': event})
event_type = event['type']
if event_type == 'invoice.paid':
invoice = event['data']['object']
amount_paid = invoice.amount_paid
metadata = invoice.parent.subscription_details.metadata # type: ignore
billing_session_type = metadata.billing_session_type
assert (
amount_paid == SUBSCRIPTION_PRICE_DATA[billing_session_type]['unit_amount']
)
user_id = metadata.user_id
start_at = datetime.now(UTC)
if billing_session_type == BillingSessionType.MONTHLY_SUBSCRIPTION.value:
end_at = start_at + relativedelta(months=1)
else:
raise ValueError(f'unknown_billing_session_type:{billing_session_type}')
with session_maker() as session:
subscription_access = SubscriptionAccess(
status='ACTIVE',
user_id=user_id,
start_at=start_at,
end_at=end_at,
amount_paid=amount_paid,
stripe_invoice_payment_id=invoice.payment_intent,
stripe_subscription_id=invoice.subscription, # Store Stripe subscription ID
)
session.add(subscription_access)
session.commit()
elif event_type == 'customer.subscription.updated':
subscription = event['data']['object']
subscription_id = subscription['id']
# Handle subscription cancellation
if subscription.get('cancel_at_period_end') is True:
with session_maker() as session:
subscription_access = (
session.query(SubscriptionAccess)
.filter(
SubscriptionAccess.stripe_subscription_id == subscription_id
)
.filter(SubscriptionAccess.status == 'ACTIVE')
.first()
)
if subscription_access and not subscription_access.cancelled_at:
subscription_access.cancelled_at = datetime.now(UTC)
session.merge(subscription_access)
session.commit()
logger.info(
'subscription_cancelled_via_webhook',
extra={
'stripe_subscription_id': subscription_id,
'user_id': subscription_access.user_id,
'subscription_access_id': subscription_access.id,
},
)
elif event_type == 'customer.subscription.deleted':
subscription = event['data']['object']
subscription_id = subscription['id']
with session_maker() as session:
subscription_access = (
session.query(SubscriptionAccess)
.filter(SubscriptionAccess.stripe_subscription_id == subscription_id)
.filter(SubscriptionAccess.status == 'ACTIVE')
.first()
)
if subscription_access:
subscription_access.status = 'DISABLED'
subscription_access.updated_at = datetime.now(UTC)
session.merge(subscription_access)
session.commit()
# Reset user settings to free tier defaults
reset_user_to_free_tier_settings(subscription_access.user_id)
logger.info(
'subscription_expired_reset_to_free_tier',
extra={
'stripe_subscription_id': subscription_id,
'user_id': subscription_access.user_id,
'subscription_access_id': subscription_access.id,
},
)
else:
logger.info('stripe_webhook_unhandled_event_type', extra={'type': event_type})
return JSONResponse({'status': 'success'})
def reset_user_to_free_tier_settings(user_id: str) -> None:
"""Reset user settings to free tier defaults when subscription ends."""
config = get_config()
settings_store = SaasSettingsStore(
user_id=user_id, session_maker=session_maker, config=config
)
with session_maker() as session:
user_settings = settings_store.get_user_settings_by_keycloak_id(
user_id, session
)
if user_settings:
user_settings.llm_model = get_default_litellm_model()
user_settings.llm_api_key = None
user_settings.llm_api_key_for_byor = None
user_settings.llm_base_url = LITE_LLM_API_URL
user_settings.max_budget_per_task = None
user_settings.confirmation_mode = False
user_settings.enable_solvability_analysis = False
user_settings.security_analyzer = 'llm'
user_settings.agent = 'CodeActAgent'
user_settings.language = 'en'
user_settings.enable_default_condenser = True
user_settings.enable_sound_notifications = False
user_settings.enable_proactive_conversation_starters = True
user_settings.user_consents_to_analytics = False
session.merge(user_settings)
session.commit()
logger.info(
'user_settings_reset_to_free_tier',
extra={
'user_id': user_id,
'reset_timestamp': datetime.now(UTC).isoformat(),
},
)
async def _get_litellm_user(client: httpx.AsyncClient, user_id: str) -> dict:
"""Get a user from litellm with the id matching that given.
If no such user exists, returns a dummy user in the format:
`{'user_id': '<USER_ID>', 'user_info': {'spend': 0}, 'keys': [], 'teams': []}`
"""
response = await client.get(
f'{LITE_LLM_API_URL}/user/info?user_id={user_id}',
headers={
'x-goog-api-key': LITE_LLM_API_KEY,
},
)
response.raise_for_status()
return response.json()
async def _upsert_litellm_user(
client: httpx.AsyncClient, user_id: str, max_budget: float
):
"""Insert / Update a user in litellm."""
response = await client.post(
f'{LITE_LLM_API_URL}/user/update',
headers={
'x-goog-api-key': LITE_LLM_API_KEY,
},
json={
'user_id': user_id,
'max_budget': max_budget,
},
)
response.raise_for_status()

View File

@@ -5,8 +5,8 @@ from threading import Thread
from fastapi import APIRouter, FastAPI
from sqlalchemy import func, select
from storage.database import a_session_maker, get_engine, session_maker
from storage.user import User
from storage.database import a_session_maker, engine, session_maker
from storage.user_settings import UserSettings
from openhands.core.logger import openhands_logger as logger
from openhands.utils.async_utils import wait_all
@@ -47,7 +47,6 @@ def add_debugging_routes(api: FastAPI):
- checked_out: Number of connections currently in use
- overflow: Number of overflow connections created beyond pool_size
"""
engine = get_engine()
return {
'checked_in': engine.pool.checkedin(),
'checked_out': engine.pool.checkedout(),
@@ -128,9 +127,8 @@ def _db_check(delay: int):
delay: Number of seconds to hold the database connection
"""
with session_maker() as session:
num_users = session.query(User).count()
num_users = session.query(UserSettings).count()
time.sleep(delay)
engine = get_engine()
logger.info(
'check',
extra={
@@ -157,7 +155,7 @@ async def _a_db_check(delay: int):
delay: Number of seconds to hold the database connection
"""
async with a_session_maker() as a_session:
stmt = select(func.count(User.id))
stmt = select(func.count(UserSettings.id))
num_users = await a_session.execute(stmt)
await asyncio.sleep(delay)
logger.info(f'a_num_users:{num_users.scalar_one()}')

View File

@@ -21,7 +21,7 @@ from server.utils.conversation_callback_utils import (
update_conversation_stats,
)
from storage.database import session_maker
from storage.stored_conversation_metadata_saas import StoredConversationMetadataSaas
from storage.stored_conversation_metadata import StoredConversationMetadata
from openhands.server.shared import conversation_manager
@@ -226,12 +226,12 @@ def _parse_conversation_id_and_subpath(path: str) -> Tuple[str, str]:
def _get_user_id(conversation_id: str) -> str:
with session_maker() as session:
conversation_metadata_saas = (
session.query(StoredConversationMetadataSaas)
.filter(StoredConversationMetadataSaas.conversation_id == conversation_id)
conversation_metadata = (
session.query(StoredConversationMetadata)
.filter(StoredConversationMetadata.conversation_id == conversation_id)
.first()
)
return str(conversation_metadata_saas.user_id)
return conversation_metadata.user_id
async def _get_session_api_key(user_id: str, conversation_id: str) -> str | None:

View File

@@ -5,7 +5,7 @@ from pydantic import BaseModel, Field
from sqlalchemy.future import select
from storage.database import session_maker
from storage.feedback import ConversationFeedback
from storage.stored_conversation_metadata_saas import StoredConversationMetadataSaas
from storage.stored_conversation_metadata import StoredConversationMetadata
from openhands.events.event_store import EventStore
from openhands.server.shared import file_store
@@ -33,10 +33,10 @@ async def get_event_ids(conversation_id: str, user_id: str) -> List[int]:
def _verify_conversation():
with session_maker() as session:
metadata = (
session.query(StoredConversationMetadataSaas)
session.query(StoredConversationMetadata)
.filter(
StoredConversationMetadataSaas.conversation_id == conversation_id,
StoredConversationMetadataSaas.user_id == user_id,
StoredConversationMetadata.conversation_id == conversation_id,
StoredConversationMetadata.user_id == user_id,
)
.first()
)

View File

@@ -1,7 +1,6 @@
import hashlib
import json
import os
import zlib
from base64 import b64decode, b64encode
from urllib.parse import parse_qs, urlencode, urlparse
@@ -52,11 +51,7 @@ def add_github_proxy_routes(app: FastAPI):
state_payload = json.dumps(
[query_params['state'][0], query_params['redirect_uri'][0]]
)
# Compress before encrypting to reduce URL length
# This is critical for feature deployments where reCAPTCHA tokens in state
# can cause "URL too long" errors from GitHub
compressed_payload = zlib.compress(state_payload.encode())
state = b64encode(_fernet().encrypt(compressed_payload)).decode()
state = b64encode(_fernet().encrypt(state_payload.encode())).decode()
query_params['state'] = [state]
query_params['redirect_uri'] = [
f'https://{request.url.netloc}/github-proxy/callback'
@@ -72,9 +67,7 @@ def add_github_proxy_routes(app: FastAPI):
parsed_url = urlparse(str(request.url))
query_params = parse_qs(parsed_url.query)
state = query_params['state'][0]
# Decrypt and decompress (reverse of github_proxy_start)
decrypted_payload = _fernet().decrypt(b64decode(state.encode()))
decrypted_state = zlib.decompress(decrypted_payload).decode()
decrypted_state = _fernet().decrypt(b64decode(state.encode())).decode()
# Build query Params
state, redirect_uri = json.loads(decrypted_state)

View File

@@ -1,5 +1,3 @@
import hashlib
import hmac
import json
import os
import re
@@ -7,16 +5,15 @@ import uuid
from urllib.parse import urlparse
import requests
from fastapi import APIRouter, BackgroundTasks, Header, HTTPException, Request, status
from fastapi import APIRouter, BackgroundTasks, HTTPException, Request, status
from fastapi.responses import JSONResponse, RedirectResponse
from integrations.jira.jira_manager import JiraManager
from integrations.models import Message, SourceType
from integrations.utils import HOST_URL
from pydantic import BaseModel, Field, field_validator
from server.auth.constants import JIRA_CLIENT_ID, JIRA_CLIENT_SECRET
from server.auth.saas_user_auth import SaasUserAuth
from server.auth.token_manager import TokenManager
from storage.jira_workspace import JiraWorkspace
from server.constants import WEB_HOST
from storage.redis import create_redis_client
from openhands.core.logger import openhands_logger as logger
@@ -27,7 +24,7 @@ JIRA_WEBHOOKS_ENABLED = os.environ.get('JIRA_WEBHOOKS_ENABLED', '0') in (
'1',
'true',
)
JIRA_REDIRECT_URI = f'{HOST_URL}/integration/jira/callback'
JIRA_REDIRECT_URI = f'https://{WEB_HOST}/integration/jira/callback'
JIRA_SCOPES = 'read:me read:jira-user read:jira-work'
JIRA_AUTH_URL = 'https://auth.atlassian.com/authorize'
JIRA_TOKEN_URL = 'https://auth.atlassian.com/oauth/token'
@@ -125,63 +122,6 @@ jira_manager = JiraManager(token_manager)
redis_client = create_redis_client()
async def verify_jira_signature(body: bytes, signature: str, payload: dict):
"""
Verify Jira webhook signature.
Args:
body: Raw request body bytes
signature: Signature from x-hub-signature header (format: "sha256=<hash>")
payload: Parsed JSON payload from webhook
Raises:
HTTPException: 403 if signature verification fails or workspace is invalid
Returns:
None (raises exception on failure)
"""
if not signature:
raise HTTPException(
status_code=403, detail='x-hub-signature header is missing!'
)
workspace_name = jira_manager.get_workspace_name_from_payload(payload)
if workspace_name is None:
logger.warning('[Jira] No workspace name found in webhook payload')
raise HTTPException(
status_code=403, detail='Workspace name not found in payload'
)
workspace: (
JiraWorkspace | None
) = await jira_manager.integration_store.get_workspace_by_name(workspace_name)
if workspace is None:
logger.warning(f'[Jira] Could not identify workspace {workspace_name}')
raise HTTPException(status_code=403, detail='Unidentified workspace')
if workspace.status != 'active':
logger.warning(
'[Jira] Workspace is inactive',
extra={
'jira_workspace_id': workspace.id,
'parsed_workspace_name': workspace.name,
'status': workspace.status,
},
)
raise HTTPException(status_code=403, detail='Workspace is inactive')
webhook_secret = token_manager.decrypt_text(workspace.webhook_secret)
expected_signature = hmac.new(
webhook_secret.encode(), body, hashlib.sha256
).hexdigest()
if not hmac.compare_digest(expected_signature, signature):
raise HTTPException(status_code=403, detail="Request signatures didn't match!")
async def _handle_workspace_link_creation(
user_id: str, jira_user_id: str, target_workspace: str
):
@@ -276,7 +216,6 @@ async def _validate_workspace_update_permissions(user_id: str, target_workspace:
async def jira_events(
request: Request,
background_tasks: BackgroundTasks,
x_hub_signature: str = Header(None),
):
"""Handle Jira webhook events."""
# Check if Jira webhooks are enabled
@@ -288,15 +227,13 @@ async def jira_events(
)
try:
parts = x_hub_signature.split('=', 1)
if not (len(parts) == 2 and parts[1]):
raise HTTPException(status_code=403, detail='Malformed x-hub-signature!')
signature_valid, signature, payload = await jira_manager.validate_request(
request
)
signature = parts[1]
body = await request.body()
payload = await request.json()
await verify_jira_signature(body, signature, payload)
if not signature_valid:
logger.warning('[Jira] Invalid webhook signature')
raise HTTPException(status_code=403, detail='Invalid webhook signature!')
# Check for duplicate requests using Redis
key = f'jira:{signature}'

View File

@@ -15,6 +15,7 @@ from integrations.slack.slack_manager import SlackManager
from integrations.utils import (
HOST_URL,
)
from pydantic import SecretStr
from server.auth.constants import (
KEYCLOAK_CLIENT_ID,
KEYCLOAK_REALM_NAME,
@@ -34,7 +35,6 @@ from slack_sdk.web.async_client import AsyncWebClient
from storage.database import session_maker
from storage.slack_team_store import SlackTeamStore
from storage.slack_user import SlackUser
from storage.user_store import UserStore
from openhands.integrations.service_types import ProviderType
from openhands.server.shared import config, sio
@@ -79,14 +79,6 @@ async def install_callback(
status_code=400,
)
if not config.jwt_secret:
logger.error('slack_install_callback_error JWT not configured.')
return _html_response(
title='Error',
description=html.escape('JWT not configured'),
status_code=500,
)
try:
client = AsyncWebClient() # no prepared token needed for this
# Complete the installation by calling oauth.v2.access API method
@@ -102,17 +94,16 @@ async def install_callback(
# Create a state variable for keycloak oauth
payload = {}
jwt_secret: SecretStr = config.jwt_secret # type: ignore[assignment]
if state:
payload = jwt.decode(
state, config.jwt_secret.get_secret_value(), algorithms=['HS256']
state, jwt_secret.get_secret_value(), algorithms=['HS256']
)
payload['slack_user_id'] = authed_user.get('id')
payload['bot_access_token'] = bot_access_token
payload['team_id'] = team_id
state = jwt.encode(
payload, config.jwt_secret.get_secret_value(), algorithm='HS256'
)
state = jwt.encode(payload, jwt_secret.get_secret_value(), algorithm='HS256')
# Redirect into keycloak
scope = quote('openid email profile offline_access')
@@ -158,16 +149,9 @@ async def keycloak_callback(
status_code=400,
)
if not config.jwt_secret:
logger.error('problem_retrieving_keycloak_tokens JWT not configured.')
return _html_response(
title='Error',
description=html.escape('JWT not configured'),
status_code=500,
)
jwt_secret: SecretStr = config.jwt_secret # type: ignore[assignment]
payload: dict[str, str] = jwt.decode(
state, config.jwt_secret.get_secret_value(), algorithms=['HS256']
state, jwt_secret.get_secret_value(), algorithms=['HS256']
)
slack_user_id = payload['slack_user_id']
bot_access_token = payload['bot_access_token']
@@ -196,13 +180,6 @@ async def keycloak_callback(
user_info = await token_manager.get_user_info(keycloak_access_token)
keycloak_user_id = user_info['sub']
user = await UserStore.get_user_by_id_async(keycloak_user_id)
if not user:
return _html_response(
title='Failed to authenticate.',
description=f'Please re-login into <a href="{HOST_URL}" style="color:#ecedee;text-decoration:underline;">OpenHands Cloud</a>. Then try <a href="https://docs.all-hands.dev/usage/cloud/slack-installation" style="color:#ecedee;text-decoration:underline;">installing the OpenHands Slack App</a> again',
status_code=400,
)
# These tokens are offline access tokens - store them!
await token_manager.store_offline_token(keycloak_user_id, keycloak_refresh_token)
@@ -234,7 +211,6 @@ async def keycloak_callback(
slack_display_name = slack_user_info.data['user']['profile']['display_name']
slack_user = SlackUser(
keycloak_user_id=keycloak_user_id,
org_id=user.current_org_id,
slack_user_id=slack_user_id,
slack_display_name=slack_display_name,
)
@@ -329,7 +305,7 @@ async def on_form_interaction(request: Request, background_tasks: BackgroundTask
body = await request.body()
form = await request.form()
payload = json.loads(form.get('payload'))
payload = json.loads(form.get('payload')) # type: ignore[arg-type]
logger.info('slack_on_form_interaction', extra={'payload': payload})

View File

@@ -21,7 +21,7 @@ DEVICE_CODE_EXPIRES_IN = 600 # 10 minutes
DEVICE_TOKEN_POLL_INTERVAL = 5 # seconds
API_KEY_NAME = 'Device Link Access Key'
KEY_EXPIRATION_TIME = timedelta(days=7) # Key expires in a week
KEY_EXPIRATION_TIME = timedelta(days=1) # Key expires in 24 hours
# ---------------------------------------------------------------------------
# Models

View File

@@ -1,171 +0,0 @@
from pydantic import BaseModel, EmailStr, Field
from storage.org import Org
class OrgCreationError(Exception):
"""Base exception for organization creation errors."""
pass
class OrgNameExistsError(OrgCreationError):
"""Raised when an organization name already exists."""
def __init__(self, name: str):
self.name = name
super().__init__(f'Organization with name "{name}" already exists')
class LiteLLMIntegrationError(OrgCreationError):
"""Raised when LiteLLM integration fails."""
pass
class OrgDatabaseError(OrgCreationError):
"""Raised when database operations fail."""
pass
class OrgDeletionError(Exception):
"""Base exception for organization deletion errors."""
pass
class OrgAuthorizationError(OrgDeletionError):
"""Raised when user is not authorized to delete organization."""
def __init__(self, message: str = 'Not authorized to delete organization'):
super().__init__(message)
class OrgNotFoundError(Exception):
"""Raised when organization is not found or user doesn't have access."""
def __init__(self, org_id: str):
self.org_id = org_id
super().__init__(f'Organization with id "{org_id}" not found')
class OrgCreate(BaseModel):
"""Request model for creating a new organization."""
# Required fields
name: str = Field(min_length=1, max_length=255, strip_whitespace=True)
contact_name: str
contact_email: EmailStr = Field(strip_whitespace=True)
class OrgResponse(BaseModel):
"""Response model for organization."""
id: str
name: str
contact_name: str
contact_email: str
conversation_expiration: int | None = None
agent: str | None = None
default_max_iterations: int | None = None
security_analyzer: str | None = None
confirmation_mode: bool | None = None
default_llm_model: str | None = None
default_llm_api_key_for_byor: str | None = None
default_llm_base_url: str | None = None
remote_runtime_resource_factor: int | None = None
enable_default_condenser: bool = True
billing_margin: float | None = None
enable_proactive_conversation_starters: bool = True
sandbox_base_container_image: str | None = None
sandbox_runtime_container_image: str | None = None
org_version: int = 0
mcp_config: dict | None = None
search_api_key: str | None = None
sandbox_api_key: str | None = None
max_budget_per_task: float | None = None
enable_solvability_analysis: bool | None = None
v1_enabled: bool | None = None
credits: float | None = None
@classmethod
def from_org(cls, org: Org, credits: float | None = None) -> 'OrgResponse':
"""Create an OrgResponse from an Org entity.
Args:
org: The organization entity to convert
credits: Optional credits value (defaults to None)
Returns:
OrgResponse: The response model instance
"""
return cls(
id=str(org.id),
name=org.name,
contact_name=org.contact_name,
contact_email=org.contact_email,
conversation_expiration=org.conversation_expiration,
agent=org.agent,
default_max_iterations=org.default_max_iterations,
security_analyzer=org.security_analyzer,
confirmation_mode=org.confirmation_mode,
default_llm_model=org.default_llm_model,
default_llm_api_key_for_byor=None,
default_llm_base_url=org.default_llm_base_url,
remote_runtime_resource_factor=org.remote_runtime_resource_factor,
enable_default_condenser=org.enable_default_condenser
if org.enable_default_condenser is not None
else True,
billing_margin=org.billing_margin,
enable_proactive_conversation_starters=org.enable_proactive_conversation_starters
if org.enable_proactive_conversation_starters is not None
else True,
sandbox_base_container_image=org.sandbox_base_container_image,
sandbox_runtime_container_image=org.sandbox_runtime_container_image,
org_version=org.org_version if org.org_version is not None else 0,
mcp_config=org.mcp_config,
search_api_key=None,
sandbox_api_key=None,
max_budget_per_task=org.max_budget_per_task,
enable_solvability_analysis=org.enable_solvability_analysis,
v1_enabled=org.v1_enabled,
credits=credits,
)
class OrgPage(BaseModel):
"""Paginated response model for organization list."""
items: list[OrgResponse]
next_page_id: str | None = None
class OrgUpdate(BaseModel):
"""Request model for updating an organization."""
# Basic organization information (any authenticated user can update)
contact_name: str | None = None
contact_email: EmailStr | None = Field(default=None, strip_whitespace=True)
conversation_expiration: int | None = None
default_max_iterations: int | None = Field(default=None, gt=0)
remote_runtime_resource_factor: int | None = Field(default=None, gt=0)
billing_margin: float | None = Field(default=None, ge=0, le=1)
enable_proactive_conversation_starters: bool | None = None
sandbox_base_container_image: str | None = None
sandbox_runtime_container_image: str | None = None
mcp_config: dict | None = None
sandbox_api_key: str | None = None
max_budget_per_task: float | None = Field(default=None, gt=0)
enable_solvability_analysis: bool | None = None
v1_enabled: bool | None = None
# LLM settings (require admin/owner role)
default_llm_model: str | None = None
default_llm_api_key_for_byor: str | None = None
default_llm_base_url: str | None = None
search_api_key: str | None = None
security_analyzer: str | None = None
agent: str | None = None
confirmation_mode: bool | None = None
enable_default_condenser: bool | None = None
condenser_max_size: int | None = Field(default=None, ge=20)

View File

@@ -1,402 +0,0 @@
from typing import Annotated
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, Query, status
from server.email_validation import get_admin_user_id
from server.routes.org_models import (
LiteLLMIntegrationError,
OrgAuthorizationError,
OrgCreate,
OrgDatabaseError,
OrgNameExistsError,
OrgNotFoundError,
OrgPage,
OrgResponse,
OrgUpdate,
)
from storage.org_service import OrgService
from openhands.core.logger import openhands_logger as logger
from openhands.server.user_auth import get_user_id
# Initialize API router
org_router = APIRouter(prefix='/api/organizations')
@org_router.get('', response_model=OrgPage)
async def list_user_orgs(
page_id: Annotated[
str | None,
Query(title='Optional next_page_id from the previously returned page'),
] = None,
limit: Annotated[
int,
Query(title='The max number of results in the page', gt=0, lte=100),
] = 100,
user_id: str = Depends(get_user_id),
) -> OrgPage:
"""List organizations for the authenticated user.
This endpoint returns a paginated list of all organizations that the
authenticated user is a member of.
Args:
page_id: Optional page ID (offset) for pagination
limit: Maximum number of organizations to return (1-100, default 100)
user_id: Authenticated user ID (injected by dependency)
Returns:
OrgPage: Paginated list of organizations
Raises:
HTTPException: 500 if retrieval fails
"""
logger.info(
'Listing organizations for user',
extra={
'user_id': user_id,
'page_id': page_id,
'limit': limit,
},
)
try:
# Fetch organizations from service layer
orgs, next_page_id = OrgService.get_user_orgs_paginated(
user_id=user_id,
page_id=page_id,
limit=limit,
)
# Convert Org entities to OrgResponse objects
org_responses = [OrgResponse.from_org(org, credits=None) for org in orgs]
logger.info(
'Successfully retrieved organizations',
extra={
'user_id': user_id,
'org_count': len(org_responses),
'has_more': next_page_id is not None,
},
)
return OrgPage(items=org_responses, next_page_id=next_page_id)
except Exception as e:
logger.exception(
'Unexpected error listing organizations',
extra={'user_id': user_id, 'error': str(e)},
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail='Failed to retrieve organizations',
)
@org_router.post('', response_model=OrgResponse, status_code=status.HTTP_201_CREATED)
async def create_org(
org_data: OrgCreate,
user_id: str = Depends(get_admin_user_id),
) -> OrgResponse:
"""Create a new organization.
This endpoint allows authenticated users with @openhands.dev email to create
a new organization. The user who creates the organization automatically becomes
its owner.
Args:
org_data: Organization creation data
user_id: Authenticated user ID (injected by dependency)
Returns:
OrgResponse: The created organization details
Raises:
HTTPException: 403 if user email domain is not @openhands.dev
HTTPException: 409 if organization name already exists
HTTPException: 500 if creation fails
"""
logger.info(
'Creating new organization',
extra={
'user_id': user_id,
'org_name': org_data.name,
},
)
try:
# Use service layer to create organization
org = await OrgService.create_org_with_owner(
name=org_data.name,
contact_name=org_data.contact_name,
contact_email=org_data.contact_email,
user_id=user_id,
)
# Retrieve credits from LiteLLM
credits = await OrgService.get_org_credits(user_id, org.id)
return OrgResponse.from_org(org, credits=credits)
except OrgNameExistsError as e:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail=str(e),
)
except LiteLLMIntegrationError as e:
logger.error(
'LiteLLM integration failed',
extra={'user_id': user_id, 'error': str(e)},
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail='Failed to create LiteLLM integration',
)
except OrgDatabaseError as e:
logger.error(
'Database operation failed',
extra={'user_id': user_id, 'error': str(e)},
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail='Failed to create organization',
)
except Exception as e:
logger.exception(
'Unexpected error creating organization',
extra={'user_id': user_id, 'error': str(e)},
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail='An unexpected error occurred',
)
@org_router.get('/{org_id}', response_model=OrgResponse, status_code=status.HTTP_200_OK)
async def get_org(
org_id: UUID,
user_id: str = Depends(get_user_id),
) -> OrgResponse:
"""Get organization details by ID.
This endpoint allows authenticated users who are members of an organization
to retrieve its details. Only members of the organization can access this endpoint.
Args:
org_id: Organization ID (UUID)
user_id: Authenticated user ID (injected by dependency)
Returns:
OrgResponse: The organization details
Raises:
HTTPException: 422 if org_id is not a valid UUID (handled by FastAPI)
HTTPException: 404 if organization not found or user is not a member
HTTPException: 500 if retrieval fails
"""
logger.info(
'Retrieving organization details',
extra={
'user_id': user_id,
'org_id': str(org_id),
},
)
try:
# Use service layer to get organization with membership validation
org = await OrgService.get_org_by_id(
org_id=org_id,
user_id=user_id,
)
# Retrieve credits from LiteLLM
credits = await OrgService.get_org_credits(user_id, org.id)
return OrgResponse.from_org(org, credits=credits)
except OrgNotFoundError as e:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=str(e),
)
except Exception as e:
logger.exception(
'Unexpected error retrieving organization',
extra={'user_id': user_id, 'org_id': str(org_id), 'error': str(e)},
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail='An unexpected error occurred',
)
@org_router.delete('/{org_id}', status_code=status.HTTP_200_OK)
async def delete_org(
org_id: UUID,
user_id: str = Depends(get_admin_user_id),
) -> dict:
"""Delete an organization.
This endpoint allows authenticated organization owners to delete their organization.
All associated data including organization members, conversations, billing data,
and external LiteLLM team resources will be permanently removed.
Args:
org_id: Organization ID to delete
user_id: Authenticated user ID (injected by dependency)
Returns:
dict: Confirmation message with deleted organization details
Raises:
HTTPException: 403 if user is not the organization owner
HTTPException: 404 if organization not found
HTTPException: 500 if deletion fails
"""
logger.info(
'Organization deletion requested',
extra={
'user_id': user_id,
'org_id': str(org_id),
},
)
try:
# Use service layer to delete organization with cleanup
deleted_org = await OrgService.delete_org_with_cleanup(
user_id=user_id,
org_id=org_id,
)
logger.info(
'Organization deletion completed successfully',
extra={
'user_id': user_id,
'org_id': str(org_id),
'org_name': deleted_org.name,
},
)
return {
'message': 'Organization deleted successfully',
'organization': {
'id': str(deleted_org.id),
'name': deleted_org.name,
'contact_name': deleted_org.contact_name,
'contact_email': deleted_org.contact_email,
},
}
except OrgNotFoundError as e:
logger.warning(
'Organization not found for deletion',
extra={'user_id': user_id, 'org_id': str(org_id)},
)
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=str(e),
)
except OrgAuthorizationError as e:
logger.warning(
'User not authorized to delete organization',
extra={'user_id': user_id, 'org_id': str(org_id)},
)
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=str(e),
)
except OrgDatabaseError as e:
logger.error(
'Database error during organization deletion',
extra={'user_id': user_id, 'org_id': str(org_id), 'error': str(e)},
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail='Failed to delete organization',
)
except Exception as e:
logger.exception(
'Unexpected error during organization deletion',
extra={'user_id': user_id, 'org_id': str(org_id), 'error': str(e)},
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail='An unexpected error occurred',
)
@org_router.patch('/{org_id}', response_model=OrgResponse)
async def update_org(
org_id: UUID,
update_data: OrgUpdate,
user_id: str = Depends(get_user_id),
) -> OrgResponse:
"""Update an existing organization.
This endpoint allows authenticated users to update organization settings.
LLM-related settings require admin or owner role in the organization.
Args:
org_id: Organization ID to update (UUID validated by FastAPI)
update_data: Organization update data
user_id: Authenticated user ID (injected by dependency)
Returns:
OrgResponse: The updated organization details
Raises:
HTTPException: 400 if org_id is invalid UUID format (handled by FastAPI)
HTTPException: 403 if user lacks permission for LLM settings
HTTPException: 404 if organization not found
HTTPException: 422 if validation errors occur (handled by FastAPI)
HTTPException: 500 if update fails
"""
logger.info(
'Updating organization',
extra={
'user_id': user_id,
'org_id': str(org_id),
},
)
try:
# Use service layer to update organization with permission checks
updated_org = await OrgService.update_org_with_permissions(
org_id=org_id,
update_data=update_data,
user_id=user_id,
)
# Retrieve credits from LiteLLM (following same pattern as create endpoint)
credits = await OrgService.get_org_credits(user_id, updated_org.id)
return OrgResponse.from_org(updated_org, credits=credits)
except ValueError as e:
# Organization not found
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=str(e),
)
except PermissionError as e:
# User lacks permission for LLM settings
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=str(e),
)
except OrgDatabaseError as e:
logger.error(
'Database operation failed',
extra={'user_id': user_id, 'org_id': str(org_id), 'error': str(e)},
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail='Failed to update organization',
)
except Exception as e:
logger.exception(
'Unexpected error updating organization',
extra={'user_id': user_id, 'org_id': str(org_id), 'error': str(e)},
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail='An unexpected error occurred',
)

View File

@@ -23,7 +23,6 @@ from sqlalchemy import orm
from storage.api_key_store import ApiKeyStore
from storage.database import session_maker
from storage.stored_conversation_metadata import StoredConversationMetadata
from storage.stored_conversation_metadata_saas import StoredConversationMetadataSaas
from openhands.controller.agent import Agent
from openhands.core.config import LLMConfig, OpenHandsConfig
@@ -647,18 +646,16 @@ class SaasNestedConversationManager(ConversationManager):
"""
with session_maker() as session:
conversation_metadata_saas = (
session.query(StoredConversationMetadataSaas)
.filter(
StoredConversationMetadataSaas.conversation_id == conversation_id
)
conversation_metadata = (
session.query(StoredConversationMetadata)
.filter(StoredConversationMetadata.conversation_id == conversation_id)
.first()
)
if not conversation_metadata_saas:
if not conversation_metadata:
raise ValueError(f'No conversation found {conversation_id}')
return str(conversation_metadata_saas.user_id)
return conversation_metadata.user_id
async def _get_runtime_status_from_nested_runtime(
self, session_api_key: Any | None, nested_url: str, conversation_id: str
@@ -997,17 +994,9 @@ class SaasNestedConversationManager(ConversationManager):
with session_maker() as session:
# Only include conversations updated in the past week
one_week_ago = datetime.now(UTC) - timedelta(days=7)
query = (
session.query(StoredConversationMetadata.conversation_id)
.join(
StoredConversationMetadataSaas,
StoredConversationMetadata.conversation_id
== StoredConversationMetadataSaas.conversation_id,
)
.filter(
StoredConversationMetadataSaas.user_id == user_id,
StoredConversationMetadata.last_updated_at >= one_week_ago,
)
query = session.query(StoredConversationMetadata.conversation_id).filter(
StoredConversationMetadata.user_id == user_id,
StoredConversationMetadata.last_updated_at >= one_week_ago,
)
user_conversation_ids = set(query)
return user_conversation_ids
@@ -1081,16 +1070,11 @@ class SaasNestedConversationManager(ConversationManager):
.filter(StoredConversationMetadata.conversation_id == conversation_id)
.first()
)
conversation_metadata_saas = (
session.query(StoredConversationMetadataSaas)
.filter(StoredConversationMetadataSaas.conversation_id == conversation_id)
.first()
)
if conversation_metadata is None or conversation_metadata_saas is None:
if conversation_metadata is None:
# Conversation is running in different server
return
user_id = conversation_metadata_saas.user_id
user_id = conversation_metadata.user_id
# Get the id of the next event which is not present
events_dir = get_conversation_events_dir(

View File

@@ -26,7 +26,6 @@ from server.sharing.shared_conversation_models import (
)
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from storage.stored_conversation_metadata_saas import StoredConversationMetadataSaas
from openhands.app_server.app_conversation.sql_app_conversation_info_service import (
StoredConversationMetadata,
@@ -58,7 +57,7 @@ class SQLSharedConversationInfoService(SharedConversationInfoService):
include_sub_conversations: bool = False,
) -> SharedConversationPage:
"""Search for shared conversations."""
query = self._public_select_with_saas_metadata()
query = self._public_select()
# Conditionally exclude sub-conversations based on the parameter
if not include_sub_conversations:
@@ -105,17 +104,14 @@ class SQLSharedConversationInfoService(SharedConversationInfoService):
query = query.limit(limit + 1)
result = await self.db_session.execute(query)
rows = result.all()
rows = result.scalars().all()
# Check if there are more results
has_more = len(rows) > limit
if has_more:
rows = rows[:limit]
items = [
self._to_shared_conversation(stored, saas_metadata=saas_metadata)
for stored, saas_metadata in rows
]
items = [self._to_shared_conversation(row) for row in rows]
# Calculate next page ID
next_page_id = None
@@ -156,18 +152,17 @@ class SQLSharedConversationInfoService(SharedConversationInfoService):
self, conversation_id: UUID
) -> SharedConversation | None:
"""Get a single public conversation info, returning None if missing or not shared."""
query = self._public_select_with_saas_metadata().where(
query = self._public_select().where(
StoredConversationMetadata.conversation_id == str(conversation_id)
)
result = await self.db_session.execute(query)
row = result.first()
stored = result.scalar_one_or_none()
if row is None:
if stored is None:
return None
stored, saas_metadata = row
return self._to_shared_conversation(stored, saas_metadata=saas_metadata)
return self._to_shared_conversation(stored)
def _public_select(self):
"""Create a select query that only returns public conversations."""
@@ -178,25 +173,6 @@ class SQLSharedConversationInfoService(SharedConversationInfoService):
query = query.where(StoredConversationMetadata.public == True) # noqa: E712
return query
def _public_select_with_saas_metadata(self):
"""Create a select query that returns public conversations with SAAS metadata.
This joins with conversation_metadata_saas to retrieve the user_id needed
for constructing the correct event storage path. Uses LEFT OUTER JOIN to
support conversations that may not have SAAS metadata (e.g., in tests).
"""
query = (
select(StoredConversationMetadata, StoredConversationMetadataSaas)
.outerjoin(
StoredConversationMetadataSaas,
StoredConversationMetadata.conversation_id
== StoredConversationMetadataSaas.conversation_id,
)
.where(StoredConversationMetadata.conversation_version == 'V1')
.where(StoredConversationMetadata.public == True) # noqa: E712
)
return query
def _apply_filters(
self,
query,
@@ -235,16 +211,9 @@ class SQLSharedConversationInfoService(SharedConversationInfoService):
def _to_shared_conversation(
self,
stored: StoredConversationMetadata,
saas_metadata: StoredConversationMetadataSaas | None = None,
sub_conversation_ids: list[UUID] | None = None,
) -> SharedConversation:
"""Convert StoredConversationMetadata to SharedConversation.
Args:
stored: The base conversation metadata from conversation_metadata table.
saas_metadata: Optional SAAS metadata containing user_id and org_id.
sub_conversation_ids: Optional list of sub-conversation IDs.
"""
"""Convert StoredConversationMetadata to SharedConversation."""
# V1 conversations should always have a sandbox_id
sandbox_id = stored.sandbox_id
assert sandbox_id is not None
@@ -270,16 +239,9 @@ class SQLSharedConversationInfoService(SharedConversationInfoService):
created_at = self._fix_timezone(stored.created_at)
updated_at = self._fix_timezone(stored.last_updated_at)
# Get user_id from SAAS metadata if available
created_by_user_id = (
str(saas_metadata.user_id)
if saas_metadata and saas_metadata.user_id
else None
)
return SharedConversation(
id=UUID(stored.conversation_id),
created_by_user_id=created_by_user_id,
created_by_user_id=stored.user_id if stored.user_id else None,
sandbox_id=stored.sandbox_id,
selected_repository=stored.selected_repository,
selected_branch=stored.selected_branch,

View File

@@ -1,350 +0,0 @@
"""Enterprise injector for SQLAppConversationInfoService with SAAS filtering."""
from datetime import datetime
from typing import AsyncGenerator
from uuid import UUID
from fastapi import Request
from sqlalchemy import func, select
from storage.stored_conversation_metadata import StoredConversationMetadata
from storage.stored_conversation_metadata_saas import StoredConversationMetadataSaas
from storage.user import User
from openhands.app_server.app_conversation.app_conversation_info_service import (
AppConversationInfoService,
AppConversationInfoServiceInjector,
)
from openhands.app_server.app_conversation.app_conversation_models import (
AppConversationInfo,
AppConversationInfoPage,
AppConversationSortOrder,
)
from openhands.app_server.app_conversation.sql_app_conversation_info_service import (
SQLAppConversationInfoService,
)
from openhands.app_server.services.injector import InjectorState
class SaasSQLAppConversationInfoService(SQLAppConversationInfoService):
"""Extended SQLAppConversationInfoService with user-based filtering and SAAS metadata handling."""
async def _secure_select(self):
query = (
select(StoredConversationMetadata)
.join(
StoredConversationMetadataSaas,
StoredConversationMetadata.conversation_id
== StoredConversationMetadataSaas.conversation_id,
)
.where(StoredConversationMetadata.conversation_version == 'V1')
)
user_id_str = await self.user_context.get_user_id()
if user_id_str:
user_id_uuid = UUID(user_id_str)
query = query.where(StoredConversationMetadataSaas.user_id == user_id_uuid)
return query
async def _secure_select_with_saas_metadata(self):
"""Select query that includes SAAS metadata for retrieving user_id."""
query = (
select(StoredConversationMetadata, StoredConversationMetadataSaas)
.join(
StoredConversationMetadataSaas,
StoredConversationMetadata.conversation_id
== StoredConversationMetadataSaas.conversation_id,
)
.where(StoredConversationMetadata.conversation_version == 'V1')
)
user_id_str = await self.user_context.get_user_id()
if user_id_str:
user_id_uuid = UUID(user_id_str)
query = query.where(StoredConversationMetadataSaas.user_id == user_id_uuid)
return query
async def search_app_conversation_info(
self,
title__contains: str | None = None,
created_at__gte: datetime | None = None,
created_at__lt: datetime | None = None,
updated_at__gte: datetime | None = None,
updated_at__lt: datetime | None = None,
sort_order: AppConversationSortOrder = AppConversationSortOrder.CREATED_AT_DESC,
page_id: str | None = None,
limit: int = 100,
include_sub_conversations: bool = False,
) -> AppConversationInfoPage:
"""Search for conversations with user_id from SAAS metadata."""
query = await self._secure_select_with_saas_metadata()
# Conditionally exclude sub-conversations based on the parameter
if not include_sub_conversations:
# Exclude sub-conversations (only include top-level conversations)
query = query.where(
StoredConversationMetadata.parent_conversation_id.is_(None)
)
query = self._apply_filters_with_saas_metadata(
query=query,
title__contains=title__contains,
created_at__gte=created_at__gte,
created_at__lt=created_at__lt,
updated_at__gte=updated_at__gte,
updated_at__lt=updated_at__lt,
)
# Add sort order
if sort_order == AppConversationSortOrder.CREATED_AT:
query = query.order_by(StoredConversationMetadata.created_at)
elif sort_order == AppConversationSortOrder.CREATED_AT_DESC:
query = query.order_by(StoredConversationMetadata.created_at.desc())
elif sort_order == AppConversationSortOrder.UPDATED_AT:
query = query.order_by(StoredConversationMetadata.last_updated_at)
elif sort_order == AppConversationSortOrder.UPDATED_AT_DESC:
query = query.order_by(StoredConversationMetadata.last_updated_at.desc())
elif sort_order == AppConversationSortOrder.TITLE:
query = query.order_by(StoredConversationMetadata.title)
elif sort_order == AppConversationSortOrder.TITLE_DESC:
query = query.order_by(StoredConversationMetadata.title.desc())
# Apply pagination
if page_id is not None:
try:
offset = int(page_id)
query = query.offset(offset)
except ValueError:
# If page_id is not a valid integer, start from beginning
offset = 0
else:
offset = 0
# Apply limit and get one extra to check if there are more results
query = query.limit(limit + 1)
result = await self.db_session.execute(query)
rows = result.all()
# Check if there are more results
has_more = len(rows) > limit
if has_more:
rows = rows[:limit]
items = [
self._to_info_with_user_id(stored_metadata, saas_metadata)
for stored_metadata, saas_metadata in rows
]
# Calculate next page ID
next_page_id = None
if has_more:
next_page_id = str(offset + limit)
return AppConversationInfoPage(items=items, next_page_id=next_page_id)
async def count_app_conversation_info(
self,
title__contains: str | None = None,
created_at__gte: datetime | None = None,
created_at__lt: datetime | None = None,
updated_at__gte: datetime | None = None,
updated_at__lt: datetime | None = None,
) -> int:
"""Count conversations matching the given filters with SAAS metadata."""
query = (
select(func.count(StoredConversationMetadata.conversation_id))
.select_from(
StoredConversationMetadata.join(
StoredConversationMetadataSaas,
StoredConversationMetadata.conversation_id
== StoredConversationMetadataSaas.conversation_id,
)
)
.where(StoredConversationMetadata.conversation_version == 'V1')
)
# Apply user filtering
user_id_str = await self.user_context.get_user_id()
if user_id_str:
user_id_uuid = UUID(user_id_str)
query = query.where(StoredConversationMetadataSaas.user_id == user_id_uuid)
query = self._apply_filters_with_saas_metadata(
query=query,
title__contains=title__contains,
created_at__gte=created_at__gte,
created_at__lt=created_at__lt,
updated_at__gte=updated_at__gte,
updated_at__lt=updated_at__lt,
)
result = await self.db_session.execute(query)
count = result.scalar()
return count or 0
def _apply_filters_with_saas_metadata(
self,
query,
title__contains: str | None = None,
created_at__gte: datetime | None = None,
created_at__lt: datetime | None = None,
updated_at__gte: datetime | None = None,
updated_at__lt: datetime | None = None,
):
"""Apply filters to query that includes SAAS metadata."""
# Apply the same filters as the base class
conditions = []
if title__contains is not None:
conditions.append(
StoredConversationMetadata.title.like(f'%{title__contains}%')
)
if created_at__gte is not None:
conditions.append(StoredConversationMetadata.created_at >= created_at__gte)
if created_at__lt is not None:
conditions.append(StoredConversationMetadata.created_at < created_at__lt)
if updated_at__gte is not None:
conditions.append(
StoredConversationMetadata.last_updated_at >= updated_at__gte
)
if updated_at__lt is not None:
conditions.append(
StoredConversationMetadata.last_updated_at < updated_at__lt
)
if conditions:
query = query.where(*conditions)
return query
async def get_app_conversation_info(
self, conversation_id: UUID
) -> AppConversationInfo | None:
"""Get conversation info with user_id from SAAS metadata."""
query = await self._secure_select_with_saas_metadata()
query = query.where(
StoredConversationMetadata.conversation_id == str(conversation_id)
)
result_set = await self.db_session.execute(query)
result = result_set.first()
if result:
stored_metadata, saas_metadata = result
return self._to_info_with_user_id(stored_metadata, saas_metadata)
return None
async def batch_get_app_conversation_info(
self, conversation_ids: list[UUID]
) -> list[AppConversationInfo | None]:
"""Batch get conversation info with user_id from SAAS metadata."""
conversation_id_strs = [
str(conversation_id) for conversation_id in conversation_ids
]
query = await self._secure_select_with_saas_metadata()
query = query.where(
StoredConversationMetadata.conversation_id.in_(conversation_id_strs)
)
result = await self.db_session.execute(query)
rows = result.all()
# Create a mapping of conversation_id to (metadata, saas_metadata)
info_by_id = {}
for stored_metadata, saas_metadata in rows:
info_by_id[stored_metadata.conversation_id] = (
stored_metadata,
saas_metadata,
)
results: list[AppConversationInfo | None] = []
for conversation_id in conversation_id_strs:
if conversation_id in info_by_id:
stored_metadata, saas_metadata = info_by_id[conversation_id]
results.append(
self._to_info_with_user_id(stored_metadata, saas_metadata)
)
else:
results.append(None)
return results
async def save_app_conversation_info(
self, info: AppConversationInfo
) -> AppConversationInfo:
"""Save conversation info and create/update SAAS metadata with user_id and org_id."""
# Save the base conversation metadata
await super().save_app_conversation_info(info)
# Get current user_id for SAAS metadata
user_id_str = await self.user_context.get_user_id()
if user_id_str:
# Convert string user_id to UUID
user_id_uuid = UUID(user_id_str)
user_query = select(User).where(User.id == user_id_uuid)
result = await self.db_session.execute(user_query)
user = result.scalar_one_or_none()
assert user
# Check if SAAS metadata already exists
saas_query = select(StoredConversationMetadataSaas).where(
StoredConversationMetadataSaas.conversation_id == str(info.id)
)
result = await self.db_session.execute(saas_query)
existing_saas_metadata = result.scalar_one_or_none()
assert existing_saas_metadata is None or (
existing_saas_metadata.user_id == user_id_uuid
and existing_saas_metadata.org_id == user.current_org_id
)
if not existing_saas_metadata:
# Create new SAAS metadata
# Set org_id to user_id as specified in requirements
saas_metadata = StoredConversationMetadataSaas(
conversation_id=str(info.id),
user_id=user_id_uuid,
org_id=user.current_org_id,
)
self.db_session.add(saas_metadata)
await self.db_session.commit()
return info
def _to_info_with_user_id(
self,
stored: StoredConversationMetadata,
saas_metadata: StoredConversationMetadataSaas,
) -> AppConversationInfo:
"""Convert stored metadata to AppConversationInfo with user_id from SAAS metadata."""
# Use the base _to_info method to get the basic info
info = self._to_info(stored)
# Override the created_by_user_id with the user_id from SAAS metadata
info.created_by_user_id = (
str(saas_metadata.user_id) if saas_metadata.user_id else None
)
return info
class SaasAppConversationInfoServiceInjector(AppConversationInfoServiceInjector):
"""Enterprise injector for SQLAppConversationInfoService with SAAS filtering."""
async def inject(
self, state: InjectorState, request: Request | None = None
) -> AsyncGenerator[AppConversationInfoService, None]:
from openhands.app_server.config import (
get_db_session,
get_user_context,
)
async with (
get_user_context(state, request) as user_context,
get_db_session(state, request) as db_session,
):
service = SaasSQLAppConversationInfoService(
db_session=db_session, user_context=user_context
)
yield service

View File

@@ -1,85 +0,0 @@
from storage.api_key import ApiKey
from storage.auth_tokens import AuthTokens
from storage.billing_session import BillingSession
from storage.billing_session_type import BillingSessionType
from storage.conversation_callback import CallbackStatus, ConversationCallback
from storage.conversation_work import ConversationWork
from storage.experiment_assignment import ExperimentAssignment
from storage.feedback import ConversationFeedback, Feedback
from storage.github_app_installation import GithubAppInstallation
from storage.gitlab_webhook import GitlabWebhook, WebhookStatus
from storage.jira_conversation import JiraConversation
from storage.jira_dc_conversation import JiraDcConversation
from storage.jira_dc_user import JiraDcUser
from storage.jira_dc_workspace import JiraDcWorkspace
from storage.jira_user import JiraUser
from storage.jira_workspace import JiraWorkspace
from storage.linear_conversation import LinearConversation
from storage.linear_user import LinearUser
from storage.linear_workspace import LinearWorkspace
from storage.maintenance_task import MaintenanceTask, MaintenanceTaskStatus
from storage.openhands_pr import OpenhandsPR
from storage.org import Org
from storage.org_member import OrgMember
from storage.proactive_convos import ProactiveConversation
from storage.role import Role
from storage.slack_conversation import SlackConversation
from storage.slack_team import SlackTeam
from storage.slack_user import SlackUser
from storage.stored_conversation_metadata import StoredConversationMetadata
from storage.stored_conversation_metadata_saas import StoredConversationMetadataSaas
from storage.stored_custom_secrets import StoredCustomSecrets
from storage.stored_offline_token import StoredOfflineToken
from storage.stored_repository import StoredRepository
from storage.stripe_customer import StripeCustomer
from storage.subscription_access import SubscriptionAccess
from storage.subscription_access_status import SubscriptionAccessStatus
from storage.user import User
from storage.user_repo_map import UserRepositoryMap
from storage.user_settings import UserSettings
__all__ = [
'ApiKey',
'AuthTokens',
'BillingSession',
'BillingSessionType',
'CallbackStatus',
'ConversationCallback',
'ConversationFeedback',
'StoredConversationMetadataSaas',
'ConversationWork',
'ExperimentAssignment',
'Feedback',
'GithubAppInstallation',
'GitlabWebhook',
'JiraConversation',
'JiraDcConversation',
'JiraDcUser',
'JiraDcWorkspace',
'JiraUser',
'JiraWorkspace',
'LinearConversation',
'LinearUser',
'LinearWorkspace',
'MaintenanceTask',
'MaintenanceTaskStatus',
'OpenhandsPR',
'Org',
'OrgMember',
'ProactiveConversation',
'Role',
'SlackConversation',
'SlackTeam',
'SlackUser',
'StoredConversationMetadata',
'StoredOfflineToken',
'StoredRepository',
'StoredCustomSecrets',
'StripeCustomer',
'SubscriptionAccess',
'SubscriptionAccessStatus',
'User',
'UserRepositoryMap',
'UserSettings',
'WebhookStatus',
]

View File

@@ -1,6 +1,4 @@
from sqlalchemy import Column, DateTime, ForeignKey, Integer, String, text
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import relationship
from sqlalchemy import Column, DateTime, Integer, String, text
from storage.base import Base
@@ -13,13 +11,9 @@ class ApiKey(Base):
id = Column(Integer, primary_key=True, autoincrement=True)
key = Column(String(255), nullable=False, unique=True, index=True)
user_id = Column(String(255), nullable=False, index=True)
org_id = Column(UUID(as_uuid=True), ForeignKey('org.id'), nullable=True)
name = Column(String(255), nullable=True)
created_at = Column(
DateTime, server_default=text('CURRENT_TIMESTAMP'), nullable=False
)
last_used_at = Column(DateTime, nullable=True)
expires_at = Column(DateTime, nullable=True)
# Relationships
org = relationship('Org', back_populates='api_keys')

View File

@@ -9,7 +9,6 @@ from sqlalchemy import update
from sqlalchemy.orm import sessionmaker
from storage.api_key import ApiKey
from storage.database import session_maker
from storage.user_store import UserStore
from openhands.core.logger import openhands_logger as logger
@@ -40,15 +39,10 @@ class ApiKeyStore:
The generated API key
"""
api_key = self.generate_api_key()
user = UserStore.get_user_by_id(user_id)
org_id = user.current_org_id
with self.session_maker() as session:
key_record = ApiKey(
key=api_key,
user_id=user_id,
org_id=org_id,
name=name,
expires_at=expires_at,
key=api_key, user_id=user_id, name=name, expires_at=expires_at
)
session.add(key_record)
session.commit()
@@ -114,15 +108,8 @@ class ApiKeyStore:
def list_api_keys(self, user_id: str) -> list[dict]:
"""List all API keys for a user."""
user = UserStore.get_user_by_id(user_id)
org_id = user.current_org_id
with self.session_maker() as session:
keys = (
session.query(ApiKey)
.filter(ApiKey.user_id == user_id)
.filter(ApiKey.org_id == org_id)
.all()
)
keys = session.query(ApiKey).filter(ApiKey.user_id == user_id).all()
return [
{
@@ -137,14 +124,9 @@ class ApiKeyStore:
]
def retrieve_mcp_api_key(self, user_id: str) -> str | None:
user = UserStore.get_user_by_id(user_id)
org_id = user.current_org_id
with self.session_maker() as session:
keys: list[ApiKey] = (
session.query(ApiKey)
.filter(ApiKey.user_id == user_id)
.filter(ApiKey.org_id == org_id)
.all()
session.query(ApiKey).filter(ApiKey.user_id == user_id).all()
)
for key in keys:
if key.name == 'MCP_API_KEY':

View File

@@ -1,8 +1,6 @@
from datetime import UTC, datetime
from sqlalchemy import DECIMAL, Column, DateTime, Enum, ForeignKey, String
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import relationship
from sqlalchemy import DECIMAL, Column, DateTime, Enum, String
from storage.base import Base
@@ -13,9 +11,9 @@ class BillingSession(Base): # type: ignore
"""
__tablename__ = 'billing_sessions'
id = Column(String, primary_key=True)
user_id = Column(String, nullable=False)
org_id = Column(UUID(as_uuid=True), ForeignKey('org.id'), nullable=True)
status = Column(
Enum(
'in_progress',
@@ -26,6 +24,15 @@ class BillingSession(Base): # type: ignore
),
default='in_progress',
)
billing_session_type = Column(
Enum(
'DIRECT_PAYMENT',
'MONTHLY_SUBSCRIPTION',
name='billing_session_type_enum',
),
nullable=False,
default='DIRECT_PAYMENT',
)
price = Column(DECIMAL(19, 4), nullable=False)
price_code = Column(String, nullable=False)
created_at = Column(
@@ -36,6 +43,3 @@ class BillingSession(Base): # type: ignore
DateTime(timezone=True),
default=lambda: datetime.now(UTC), # type: ignore[attr-defined]
)
# Relationships
org = relationship('Org', back_populates='billing_sessions')

View File

@@ -1,10 +1,5 @@
from __future__ import annotations
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from openhands.events.observation.agent import AgentStateChangedObservation
from abc import ABC, abstractmethod
from datetime import datetime
from enum import Enum
@@ -15,6 +10,7 @@ from sqlalchemy import Column, DateTime, ForeignKey, Integer, String, Text, text
from sqlalchemy import Enum as SQLEnum
from storage.base import Base
from openhands.events.observation.agent import AgentStateChangedObservation
from openhands.utils.import_utils import get_impl
@@ -37,7 +33,7 @@ class ConversationCallbackProcessor(BaseModel, ABC):
async def __call__(
self,
callback: ConversationCallback,
observation: 'AgentStateChangedObservation',
observation: AgentStateChangedObservation,
) -> None:
"""
Process a conversation event.

View File

@@ -1,38 +1,122 @@
"""
Database connection module for enterprise storage.
import asyncio
import os
This is for backwards compatibility with V0.
from sqlalchemy import create_engine
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import NullPool
from sqlalchemy.util import await_only
This module provides database engines and session makers by delegating to the
centralized DbSessionInjector from app_server/config.py. This ensures a single
source of truth for database connection configuration.
"""
DB_HOST = os.environ.get('DB_HOST', 'localhost') # for non-GCP environments
DB_PORT = os.environ.get('DB_PORT', '5432') # for non-GCP environments
DB_USER = os.environ.get('DB_USER', 'postgres')
DB_PASS = os.environ.get('DB_PASS', 'postgres').strip()
DB_NAME = os.environ.get('DB_NAME', 'openhands')
import contextlib
GCP_DB_INSTANCE = os.environ.get('GCP_DB_INSTANCE') # for GCP environments
GCP_PROJECT = os.environ.get('GCP_PROJECT')
GCP_REGION = os.environ.get('GCP_REGION')
POOL_SIZE = int(os.environ.get('DB_POOL_SIZE', '25'))
MAX_OVERFLOW = int(os.environ.get('DB_MAX_OVERFLOW', '10'))
POOL_RECYCLE = int(os.environ.get('DB_POOL_RECYCLE', '1800'))
# Initialize Cloud SQL Connector once at module level for GCP environments.
_connector = None
def _get_db_session_injector():
from openhands.app_server.config import get_global_config
def _get_db_engine():
if GCP_DB_INSTANCE: # GCP environments
_config = get_global_config()
return _config.db_session
def get_db_connection():
global _connector
from google.cloud.sql.connector import Connector
if not _connector:
_connector = Connector()
instance_string = f'{GCP_PROJECT}:{GCP_REGION}:{GCP_DB_INSTANCE}'
return _connector.connect(
instance_string, 'pg8000', user=DB_USER, password=DB_PASS, db=DB_NAME
)
return create_engine(
'postgresql+pg8000://',
creator=get_db_connection,
pool_size=POOL_SIZE,
max_overflow=MAX_OVERFLOW,
pool_recycle=POOL_RECYCLE,
pool_pre_ping=True,
)
else:
host_string = (
f'postgresql+pg8000://{DB_USER}:{DB_PASS}@{DB_HOST}:{DB_PORT}/{DB_NAME}'
)
return create_engine(
host_string,
pool_size=POOL_SIZE,
max_overflow=MAX_OVERFLOW,
pool_recycle=POOL_RECYCLE,
pool_pre_ping=True,
)
def session_maker():
db_session_injector = _get_db_session_injector()
session_maker = db_session_injector.get_session_maker()
return session_maker()
async def async_creator():
from google.cloud.sql.connector import Connector
loop = asyncio.get_running_loop()
async with Connector(loop=loop) as connector:
conn = await connector.connect_async(
f'{GCP_PROJECT}:{GCP_REGION}:{GCP_DB_INSTANCE}', # Cloud SQL instance connection name"
'asyncpg',
user=DB_USER,
password=DB_PASS,
db=DB_NAME,
)
return conn
@contextlib.asynccontextmanager
async def a_session_maker():
db_session_injector = _get_db_session_injector()
a_session_maker = await db_session_injector.get_async_session_maker()
async with a_session_maker() as session:
yield session
def _get_async_db_engine():
if GCP_DB_INSTANCE: # GCP environments
def adapted_creator():
dbapi = engine.dialect.dbapi
from sqlalchemy.dialects.postgresql.asyncpg import (
AsyncAdapt_asyncpg_connection,
)
return AsyncAdapt_asyncpg_connection(
dbapi,
await_only(async_creator()),
prepared_statement_cache_size=100,
)
# create async connection pool with wrapped creator
return create_async_engine(
'postgresql+asyncpg://',
creator=adapted_creator,
# Use NullPool to disable connection pooling and avoid event loop issues
poolclass=NullPool,
)
else:
host_string = (
f'postgresql+asyncpg://{DB_USER}:{DB_PASS}@{DB_HOST}:{DB_PORT}/{DB_NAME}'
)
return create_async_engine(
host_string,
# Use NullPool to disable connection pooling and avoid event loop issues
poolclass=NullPool,
)
def get_engine():
db_session_injector = _get_db_session_injector()
engine = db_session_injector.get_db_engine()
return engine
engine = _get_db_engine()
session_maker = sessionmaker(bind=engine)
a_engine = _get_async_db_engine()
a_session_maker = sessionmaker(
bind=a_engine,
class_=AsyncSession,
expire_on_commit=False,
# Configure the session to use the same connection for all operations in a transaction
# This helps prevent the "Task got Future attached to a different loop" error
future=True,
)

View File

@@ -1,114 +0,0 @@
import binascii
import hashlib
from base64 import b64decode, b64encode
from cryptography.fernet import Fernet, InvalidToken
from pydantic import SecretStr
from server.config import get_config
_jwt_service = None
_fernet = None
def encrypt_model(encrypt_keys: list, model_instance) -> dict:
return encrypt_kwargs(encrypt_keys, model_to_kwargs(model_instance))
def decrypt_model(decrypt_keys: list, model_instance) -> dict:
return decrypt_kwargs(decrypt_keys, model_to_kwargs(model_instance))
def encrypt_kwargs(encrypt_keys: list, kwargs: dict) -> dict:
for key, value in kwargs.items():
if value is None:
continue
if isinstance(value, dict):
encrypt_kwargs(encrypt_keys, value)
continue
if key in encrypt_keys:
value = encrypt_value(value)
kwargs[key] = value
return kwargs
def decrypt_kwargs(encrypt_keys: list, kwargs: dict) -> dict:
for key, value in kwargs.items():
try:
if value is None:
continue
if key in encrypt_keys:
value = decrypt_value(value)
kwargs[key] = value
except binascii.Error:
pass # Key is in legacy format...
return kwargs
def encrypt_value(value: str | SecretStr) -> str:
return get_jwt_service().create_jwe_token(
{'v': value.get_secret_value() if isinstance(value, SecretStr) else value}
)
def decrypt_value(value: str | SecretStr) -> str:
token = get_jwt_service().decrypt_jwe_token(
value.get_secret_value() if isinstance(value, SecretStr) else value
)
return token['v']
def get_jwt_service():
from openhands.app_server.config import get_global_config
global _jwt_service
if _jwt_service is None:
jwt_service_injector = get_global_config().jwt
assert jwt_service_injector is not None
_jwt_service = jwt_service_injector.get_jwt_service()
return _jwt_service
def decrypt_legacy_model(decrypt_keys: list, model_instance) -> dict:
return decrypt_legacy_kwargs(decrypt_keys, model_to_kwargs(model_instance))
def decrypt_legacy_kwargs(encrypt_keys: list, kwargs: dict) -> dict:
for key, value in kwargs.items():
try:
if value is None:
continue
if key in encrypt_keys:
value = decrypt_legacy_value(value)
kwargs[key] = value
except binascii.Error:
pass # Key is in legacy format...
except InvalidToken:
pass # Key not encrypted...
return kwargs
def decrypt_legacy_value(value: str | SecretStr) -> str:
if isinstance(value, SecretStr):
return (
get_fernet().decrypt(b64decode(value.get_secret_value().encode())).decode()
)
else:
return get_fernet().decrypt(b64decode(value.encode())).decode()
def get_fernet():
global _fernet
if _fernet is None:
jwt_secret = get_config().jwt_secret.get_secret_value()
fernet_key = b64encode(hashlib.sha256(jwt_secret.encode()).digest())
_fernet = Fernet(fernet_key)
return _fernet
def model_to_kwargs(model_instance):
return {
column.name: getattr(model_instance, column.name)
for column in model_instance.__table__.columns
}

View File

@@ -1,16 +1,7 @@
import sys
from enum import IntEnum
from sqlalchemy import (
ARRAY,
Boolean,
Column,
DateTime,
Integer,
String,
Text,
text,
)
from sqlalchemy import ARRAY, Boolean, Column, DateTime, Integer, String, Text, text
from storage.base import Base

View File

@@ -118,7 +118,9 @@ class JiraIntegrationStore:
.first()
)
async def get_workspace_by_name(self, workspace_name: str) -> JiraWorkspace | None:
async def get_workspace_by_name(
self, workspace_name: str
) -> Optional[JiraWorkspace]:
"""Retrieve workspace by name."""
with session_maker() as session:
return (

View File

@@ -1,887 +0,0 @@
"""
Store class for managing organizational settings.
"""
import functools
import os
from typing import Any, Awaitable, Callable
import httpx
from pydantic import SecretStr
from server.auth.token_manager import TokenManager
from server.constants import (
DEFAULT_INITIAL_BUDGET,
LITE_LLM_API_KEY,
LITE_LLM_API_URL,
LITE_LLM_TEAM_ID,
ORG_SETTINGS_VERSION,
get_default_litellm_model,
)
from server.logger import logger
from storage.user_settings import UserSettings
from openhands.server.settings import Settings
from openhands.utils.http_session import httpx_verify_option
# Timeout in seconds for BYOR key verification requests to LiteLLM
BYOR_KEY_VERIFICATION_TIMEOUT = 5.0
# A very large number to represent "unlimited" until LiteLLM fixes their unlimited update bug.
UNLIMITED_BUDGET_SETTING = 1000000000.0
class LiteLlmManager:
"""Manage LiteLLM interactions."""
@staticmethod
async def create_entries(
org_id: str,
keycloak_user_id: str,
oss_settings: Settings,
create_user: bool,
) -> Settings | None:
logger.info(
'SettingsStore:update_settings_with_litellm_default:start',
extra={'org_id': org_id, 'user_id': keycloak_user_id},
)
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
logger.warning('LiteLLM API configuration not found')
return None
local_deploy = os.environ.get('LOCAL_DEPLOYMENT', None)
key = LITE_LLM_API_KEY
if not local_deploy:
# Get user info to add to litellm
token_manager = TokenManager()
keycloak_user_info = (
await token_manager.get_user_info_from_user_id(keycloak_user_id) or {}
)
async with httpx.AsyncClient(
headers={
'x-goog-api-key': LITE_LLM_API_KEY,
}
) as client:
await LiteLlmManager._create_team(
client, keycloak_user_id, org_id, DEFAULT_INITIAL_BUDGET
)
if create_user:
await LiteLlmManager._create_user(
client, keycloak_user_info.get('email'), keycloak_user_id
)
await LiteLlmManager._add_user_to_team(
client, keycloak_user_id, org_id, DEFAULT_INITIAL_BUDGET
)
key = await LiteLlmManager._generate_key(
client,
keycloak_user_id,
org_id,
f'OpenHands Cloud - user {keycloak_user_id} - org {org_id}',
None,
)
oss_settings.agent = 'CodeActAgent'
# Use the model corresponding to the current user settings version
oss_settings.llm_model = get_default_litellm_model()
oss_settings.llm_api_key = SecretStr(key)
oss_settings.llm_base_url = LITE_LLM_API_URL
return oss_settings
@staticmethod
async def migrate_entries(
org_id: str,
keycloak_user_id: str,
user_settings: UserSettings,
) -> UserSettings | None:
logger.info(
'LiteLlmManager:migrate_lite_llm_entries:start',
extra={'org_id': org_id, 'user_id': keycloak_user_id},
)
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
logger.warning('LiteLLM API configuration not found')
return None
local_deploy = os.environ.get('LOCAL_DEPLOYMENT', None)
if not local_deploy:
# Get user info to add to litellm
async with httpx.AsyncClient(
headers={
'x-goog-api-key': LITE_LLM_API_KEY,
}
) as client:
user_json = await LiteLlmManager._get_user(client, keycloak_user_id)
if not user_json:
return None
user_info = user_json['user_info']
max_budget = user_info.get('max_budget', 0.0)
spend = user_info.get('spend', 0.0)
# In upgrade to V4, we no longer use billing margin, but instead apply this directly
# in litellm. The default billing marign was 2 before this (hence the magic numbers below)
if (
user_settings
and user_settings.user_version < 4
and user_settings.billing_margin
and user_settings.billing_margin != 1.0
):
billing_margin = user_settings.billing_margin
logger.info(
'user_settings_v4_budget_upgrade',
extra={
'max_budget': max_budget,
'billing_margin': billing_margin,
'spend': spend,
},
)
max_budget *= billing_margin
spend *= billing_margin
if not max_budget:
# if max_budget is None, then we've already migrated the User
return None
credits = max(max_budget - spend, 0.0)
logger.info(
'LiteLlmManager:migrate_lite_llm_entries:create_team',
extra={'org_id': org_id, 'user_id': keycloak_user_id},
)
await LiteLlmManager._create_team(
client, keycloak_user_id, org_id, credits
)
logger.info(
'LiteLlmManager:migrate_lite_llm_entries:update_user',
extra={'org_id': org_id, 'user_id': keycloak_user_id},
)
await LiteLlmManager._update_user(
client, keycloak_user_id, max_budget=UNLIMITED_BUDGET_SETTING
)
logger.info(
'LiteLlmManager:migrate_lite_llm_entries:add_user_to_team',
extra={'org_id': org_id, 'user_id': keycloak_user_id},
)
await LiteLlmManager._add_user_to_team(
client, keycloak_user_id, org_id, credits
)
if user_settings.llm_api_key:
logger.info(
'LiteLlmManager:migrate_lite_llm_entries:update_key',
extra={'org_id': org_id, 'user_id': keycloak_user_id},
)
await LiteLlmManager._update_key(
client,
keycloak_user_id,
user_settings.llm_api_key,
team_id=org_id,
)
if user_settings.llm_api_key_for_byor:
logger.info(
'LiteLlmManager:migrate_lite_llm_entries:update_byor_key',
extra={'org_id': org_id, 'user_id': keycloak_user_id},
)
await LiteLlmManager._update_key(
client,
keycloak_user_id,
user_settings.llm_api_key_for_byor,
team_id=org_id,
)
logger.info(
'LiteLlmManager:migrate_lite_llm_entries:end',
extra={'org_id': org_id, 'user_id': keycloak_user_id},
)
return user_settings
@staticmethod
async def update_team_and_users_budget(
team_id: str,
max_budget: float,
):
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
logger.warning('LiteLLM API configuration not found')
return
async with httpx.AsyncClient(
headers={
'x-goog-api-key': LITE_LLM_API_KEY,
}
) as client:
await LiteLlmManager._update_team(client, team_id, None, max_budget)
team_info = await LiteLlmManager._get_team(client, team_id)
if not team_info:
return None
# TODO: change to use bulk update endpoint
for membership in team_info.get('team_memberships', []):
user_id = membership.get('user_id')
if not user_id:
continue
await LiteLlmManager._update_user_in_team(
client, user_id, team_id, max_budget
)
@staticmethod
async def _create_team(
client: httpx.AsyncClient,
team_alias: str,
team_id: str,
max_budget: float,
):
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
logger.warning('LiteLLM API configuration not found')
return
response = await client.post(
f'{LITE_LLM_API_URL}/team/new',
json={
'team_id': team_id,
'team_alias': team_alias,
'models': [],
'max_budget': max_budget,
'spend': 0,
'metadata': {
'version': ORG_SETTINGS_VERSION,
'model': get_default_litellm_model(),
},
},
)
# Team failed to create in litellm - this is an unforseen error state...
if not response.is_success:
if (
response.status_code == 400
and 'already exists. Please use a different team id' in response.text
):
# team already exists, so update, then return
await LiteLlmManager._update_team(
client, team_id, team_alias, max_budget
)
return
logger.error(
'error_creating_litellm_team',
extra={
'status_code': response.status_code,
'text': response.text,
'team_id': team_id,
'max_budget': max_budget,
},
)
response.raise_for_status()
@staticmethod
async def _get_team(client: httpx.AsyncClient, team_id: str) -> dict | None:
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
logger.warning('LiteLLM API configuration not found')
return None
"""Get a team from litellm with the id matching that given."""
response = await client.get(
f'{LITE_LLM_API_URL}/team/info?team_id={team_id}',
)
response.raise_for_status()
return response.json()
@staticmethod
async def _update_team(
client: httpx.AsyncClient,
team_id: str,
team_alias: str | None,
max_budget: float | None,
):
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
logger.warning('LiteLLM API configuration not found')
return
json_data: dict[str, Any] = {
'team_id': team_id,
'metadata': {
'version': ORG_SETTINGS_VERSION,
'model': get_default_litellm_model(),
},
}
if max_budget is not None:
json_data['max_budget'] = max_budget
if team_alias is not None:
json_data['team_alias'] = team_alias
response = await client.post(
f'{LITE_LLM_API_URL}/team/update',
json=json_data,
)
# Team failed to update in litellm - this is an unforseen error state...
if not response.is_success:
logger.error(
'error_updating_litellm_team',
extra={
'status_code': response.status_code,
'text': response.text,
'team_id': [team_id],
'max_budget': max_budget,
},
)
response.raise_for_status()
@staticmethod
async def _create_user(
client: httpx.AsyncClient,
email: str | None,
keycloak_user_id: str,
):
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
logger.warning('LiteLLM API configuration not found')
return
response = await client.post(
f'{LITE_LLM_API_URL}/user/new',
json={
'user_email': email,
'models': [],
'user_id': keycloak_user_id,
'teams': [LITE_LLM_TEAM_ID],
'auto_create_key': False,
'send_invite_email': False,
'metadata': {
'version': ORG_SETTINGS_VERSION,
'model': get_default_litellm_model(),
},
},
)
if not response.is_success:
logger.warning(
'duplicate_user_email',
extra={
'user_id': keycloak_user_id,
'email': email,
},
)
# Litellm insists on unique email addresses - it is possible the email address was registered with a different user.
response = await client.post(
f'{LITE_LLM_API_URL}/user/new',
json={
'user_email': None,
'models': [],
'user_id': keycloak_user_id,
'teams': [LITE_LLM_TEAM_ID],
'auto_create_key': False,
'send_invite_email': False,
'metadata': {
'version': ORG_SETTINGS_VERSION,
'model': get_default_litellm_model(),
},
},
)
# User failed to create in litellm - this is an unforseen error state...
if not response.is_success:
if (
response.status_code in (400, 409)
and 'already exists' in response.text
):
logger.warning(
'litellm_user_already_exists',
extra={
'user_id': keycloak_user_id,
},
)
return
logger.error(
'error_creating_litellm_user',
extra={
'status_code': response.status_code,
'text': response.text,
'user_id': [keycloak_user_id],
'email': None,
},
)
response.raise_for_status()
@staticmethod
async def _get_user(client: httpx.AsyncClient, user_id: str) -> dict | None:
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
logger.warning('LiteLLM API configuration not found')
return None
"""Get a user from litellm with the id matching that given."""
response = await client.get(
f'{LITE_LLM_API_URL}/user/info?user_id={user_id}',
)
response.raise_for_status()
return response.json()
@staticmethod
async def _update_user(
client: httpx.AsyncClient,
keycloak_user_id: str,
**kwargs,
):
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
logger.warning('LiteLLM API configuration not found')
return
payload = {
'user_id': keycloak_user_id,
}
payload.update(kwargs)
response = await client.post(
f'{LITE_LLM_API_URL}/user/update',
json=payload,
)
if not response.is_success:
logger.error(
'error_updating_litellm_user',
extra={
'status_code': response.status_code,
'text': response.text,
'user_id': keycloak_user_id,
},
)
response.raise_for_status()
@staticmethod
async def _update_key(
client: httpx.AsyncClient,
keycloak_user_id: str,
key: str,
**kwargs,
):
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
logger.warning('LiteLLM API configuration not found')
return
payload = {
'key': key,
}
payload.update(kwargs)
response = await client.post(
f'{LITE_LLM_API_URL}/key/update',
json=payload,
)
if not response.is_success:
if response.status_code == 401:
logger.warning(
'invalid_litellm_key_during_update',
extra={
'user_id': keycloak_user_id,
},
)
return
logger.error(
'error_updating_litellm_key',
extra={
'status_code': response.status_code,
'text': response.text,
'user_id': keycloak_user_id,
},
)
response.raise_for_status()
@staticmethod
async def _delete_user(
client: httpx.AsyncClient,
keycloak_user_id: str,
):
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
logger.warning('LiteLLM API configuration not found')
return
response = await client.post(
f'{LITE_LLM_API_URL}/user/delete', json={'user_ids': [keycloak_user_id]}
)
if not response.is_success:
logger.error(
'error_deleting_litellm_user',
extra={
'status_code': response.status_code,
'text': response.text,
'user_id': [keycloak_user_id],
},
)
response.raise_for_status()
@staticmethod
async def _delete_team(
client: httpx.AsyncClient,
team_id: str,
):
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
logger.warning('LiteLLM API configuration not found')
return
response = await client.post(
f'{LITE_LLM_API_URL}/team/delete',
json={'team_ids': [team_id]},
)
if not response.is_success:
if response.status_code == 404:
# Team doesn't exist, that's fine
logger.info(
'Team already deleted or does not exist',
extra={'team_id': team_id},
)
return
logger.error(
'error_deleting_litellm_team',
extra={
'status_code': response.status_code,
'text': response.text,
'team_id': team_id,
},
)
response.raise_for_status()
logger.info(
'LiteLlmManager:_delete_team:team_deleted',
extra={'team_id': team_id},
)
@staticmethod
async def _add_user_to_team(
client: httpx.AsyncClient,
keycloak_user_id: str,
team_id: str,
max_budget: float,
):
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
logger.warning('LiteLLM API configuration not found')
return
response = await client.post(
f'{LITE_LLM_API_URL}/team/member_add',
json={
'team_id': team_id,
'member': {'user_id': keycloak_user_id, 'role': 'user'},
'max_budget_in_team': max_budget,
},
)
# Failed to add user to team - this is an unforseen error state...
if not response.is_success:
if (
response.status_code == 400
and 'already in team' in response.text.lower()
):
logger.warning(
'user_already_in_team',
extra={
'user_id': keycloak_user_id,
'team_id': team_id,
},
)
return
logger.error(
'error_adding_litellm_user_to_team',
extra={
'status_code': response.status_code,
'text': response.text,
'user_id': [keycloak_user_id],
'team_id': [team_id],
'max_budget': max_budget,
},
)
response.raise_for_status()
@staticmethod
async def _get_user_team_info(
client: httpx.AsyncClient,
keycloak_user_id: str,
team_id: str,
) -> dict | None:
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
logger.warning('LiteLLM API configuration not found')
return None
team_info = await LiteLlmManager._get_team(client, team_id)
if not team_info:
return None
# Filter team_memberships based on team_id and keycloak_user_id
user_membership = next(
(
membership
for membership in team_info.get('team_memberships', [])
if membership.get('user_id') == keycloak_user_id
and membership.get('team_id') == team_id
),
None,
)
return user_membership
@staticmethod
async def _update_user_in_team(
client: httpx.AsyncClient,
keycloak_user_id: str,
team_id: str,
max_budget: float,
):
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
logger.warning('LiteLLM API configuration not found')
return
response = await client.post(
f'{LITE_LLM_API_URL}/team/member_update',
json={
'team_id': team_id,
'user_id': keycloak_user_id,
'max_budget_in_team': max_budget,
},
)
# Failed to update user in team - this is an unforseen error state...
if not response.is_success:
logger.error(
'error_updating_litellm_user_in_team',
extra={
'status_code': response.status_code,
'text': response.text,
'user_id': [keycloak_user_id],
'team_id': [team_id],
'max_budget': max_budget,
},
)
response.raise_for_status()
@staticmethod
async def _generate_key(
client: httpx.AsyncClient,
keycloak_user_id: str,
team_id: str | None,
key_alias: str | None,
metadata: dict | None,
) -> str | None:
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
logger.warning('LiteLLM API configuration not found')
return None
json_data: dict[str, Any] = {
'user_id': keycloak_user_id,
'models': [],
}
if team_id is not None:
json_data['team_id'] = team_id
if key_alias is not None:
json_data['key_alias'] = key_alias
if metadata is not None:
json_data['metadata'] = metadata
response = await client.post(
f'{LITE_LLM_API_URL}/key/generate',
json=json_data,
)
# Failed to generate user key for team - this is an unforseen error state...
if not response.is_success:
logger.error(
'error_generate_user_team_key',
extra={
'status_code': response.status_code,
'text': response.text,
'user_id': keycloak_user_id,
'team_id': team_id,
'key_alias': key_alias,
},
)
response.raise_for_status()
response_json = response.json()
key = response_json['key']
logger.info(
'LiteLlmManager:_lite_llm_generate_user_team_key:key_created',
extra={
'user_id': keycloak_user_id,
'team_id': team_id,
'key_alias': key_alias,
},
)
return key
@staticmethod
async def verify_key(key: str, user_id: str) -> bool:
"""Verify that a key is valid in LiteLLM by making a lightweight API call.
Args:
key: The key to verify
user_id: The user ID for logging purposes
Returns:
True if the key is verified as valid, False if verification fails or key is invalid.
Returns False on network errors/timeouts to ensure we don't return potentially invalid keys.
"""
if not (LITE_LLM_API_URL and key):
return False
try:
async with httpx.AsyncClient(
verify=httpx_verify_option(),
timeout=BYOR_KEY_VERIFICATION_TIMEOUT,
) as client:
# Make a lightweight request to verify the key
# Using /v1/models endpoint as it's lightweight and requires authentication
response = await client.get(
f'{LITE_LLM_API_URL}/v1/models',
headers={
'Authorization': f'Bearer {key}',
},
)
# Only 200 status code indicates valid key
if response.status_code == 200:
logger.debug(
'BYOR key verification successful',
extra={'user_id': user_id},
)
return True
# All other status codes (401, 403, 500, etc.) are treated as invalid
# This includes authentication errors and server errors
logger.warning(
'BYOR key verification failed - treating as invalid',
extra={
'user_id': user_id,
'status_code': response.status_code,
'key_prefix': key[:10] + '...' if len(key) > 10 else key,
},
)
return False
except (httpx.TimeoutException, Exception) as e:
# Any exception (timeout, network error, etc.) means we can't verify
# Return False to trigger regeneration rather than returning potentially invalid key
logger.warning(
'BYOR key verification error - treating as invalid to ensure key validity',
extra={
'user_id': user_id,
'error': str(e),
'error_type': type(e).__name__,
},
)
return False
@staticmethod
async def _get_key_info(
client: httpx.AsyncClient,
org_id: str,
keycloak_user_id: str,
) -> dict | None:
from storage.user_store import UserStore
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
logger.warning('LiteLLM API configuration not found')
return None
user = await UserStore.get_user_by_id_async(keycloak_user_id)
if not user:
return {}
org_member = None
for om in user.org_members:
if om.org_id == org_id:
org_member = om
break
if not org_member or not org_member.llm_api_key:
return {}
response = await client.get(
f'{LITE_LLM_API_URL}/key/info?key={org_member.llm_api_key}'
)
response.raise_for_status()
response_json = response.json()
key_info = response_json.get('info')
if not key_info:
return {}
return {
'key_max_budget': key_info.get('max_budget'),
'key_spend': key_info.get('spend'),
}
@staticmethod
async def _delete_key_by_alias(
client: httpx.AsyncClient,
key_alias: str,
):
"""Delete a key from LiteLLM by its alias.
This is a best-effort operation that logs but does not raise on failure.
"""
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
logger.warning('LiteLLM API configuration not found')
return
response = await client.post(
f'{LITE_LLM_API_URL}/key/delete',
json={
'key_aliases': [key_alias],
},
)
if response.is_success:
logger.info(
'LiteLlmManager:_delete_key_by_alias:key_deleted',
extra={'key_alias': key_alias},
)
elif response.status_code != 404:
# Log non-404 errors but don't fail
logger.warning(
'error_deleting_key_by_alias',
extra={
'key_alias': key_alias,
'status_code': response.status_code,
'text': response.text,
},
)
@staticmethod
async def _delete_key(
client: httpx.AsyncClient,
key_id: str,
key_alias: str | None = None,
):
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
logger.warning('LiteLLM API configuration not found')
return
response = await client.post(
f'{LITE_LLM_API_URL}/key/delete',
json={
'keys': [key_id],
},
)
# Failed to delete key...
if not response.is_success:
if response.status_code == 404:
# Key doesn't exist by key_id. If we have a key_alias,
# try deleting by alias to clean up any orphaned alias.
if key_alias:
await LiteLlmManager._delete_key_by_alias(client, key_alias)
return
logger.error(
'error_deleting_key',
extra={
'status_code': response.status_code,
'text': response.text,
},
)
response.raise_for_status()
logger.info(
'LiteLlmManager:_delete_key:key_deleted',
)
@staticmethod
def with_http_client(
internal_fn: Callable[..., Awaitable[Any]],
) -> Callable[..., Awaitable[Any]]:
@functools.wraps(internal_fn)
async def wrapper(*args, **kwargs):
async with httpx.AsyncClient(
headers={'x-goog-api-key': LITE_LLM_API_KEY}
) as client:
return await internal_fn(client, *args, **kwargs)
return wrapper
# Public methods with injected client
create_team = staticmethod(with_http_client(_create_team))
get_team = staticmethod(with_http_client(_get_team))
update_team = staticmethod(with_http_client(_update_team))
create_user = staticmethod(with_http_client(_create_user))
get_user = staticmethod(with_http_client(_get_user))
update_user = staticmethod(with_http_client(_update_user))
delete_user = staticmethod(with_http_client(_delete_user))
delete_team = staticmethod(with_http_client(_delete_team))
add_user_to_team = staticmethod(with_http_client(_add_user_to_team))
get_user_team_info = staticmethod(with_http_client(_get_user_team_info))
update_user_in_team = staticmethod(with_http_client(_update_user_in_team))
generate_key = staticmethod(with_http_client(_generate_key))
get_key_info = staticmethod(with_http_client(_get_key_info))
delete_key = staticmethod(with_http_client(_delete_key))

View File

@@ -1,100 +0,0 @@
"""
SQLAlchemy model for Organization.
"""
from uuid import uuid4
from pydantic import SecretStr
from server.constants import DEFAULT_BILLING_MARGIN
from sqlalchemy import JSON, UUID, Boolean, Column, Float, Integer, String
from sqlalchemy.orm import relationship
from storage.base import Base
from storage.encrypt_utils import decrypt_value, encrypt_value
class Org(Base): # type: ignore
"""Organization model."""
__tablename__ = 'org'
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid4)
name = Column(String, nullable=False, unique=True)
contact_name = Column(String, nullable=True)
contact_email = Column(String, nullable=True)
agent = Column(String, nullable=True)
default_max_iterations = Column(Integer, nullable=True)
security_analyzer = Column(String, nullable=True)
confirmation_mode = Column(Boolean, nullable=True, default=False)
default_llm_model = Column(String, nullable=True)
default_llm_base_url = Column(String, nullable=True)
remote_runtime_resource_factor = Column(Integer, nullable=True)
enable_default_condenser = Column(Boolean, nullable=False, default=True)
billing_margin = Column(Float, nullable=True, default=DEFAULT_BILLING_MARGIN)
enable_proactive_conversation_starters = Column(
Boolean, nullable=False, default=True
)
sandbox_base_container_image = Column(String, nullable=True)
sandbox_runtime_container_image = Column(String, nullable=True)
org_version = Column(Integer, nullable=False, default=0)
mcp_config = Column(JSON, nullable=True)
# encrypted column, don't set directly, set without the underscore
_search_api_key = Column(String, nullable=True)
# encrypted column, don't set directly, set without the underscore
_sandbox_api_key = Column(String, nullable=True)
max_budget_per_task = Column(Float, nullable=True)
enable_solvability_analysis = Column(Boolean, nullable=True, default=False)
v1_enabled = Column(Boolean, nullable=True)
conversation_expiration = Column(Integer, nullable=True)
condenser_max_size = Column(Integer, nullable=True)
# Relationships
org_members = relationship('OrgMember', back_populates='org')
current_users = relationship('User', back_populates='current_org')
billing_sessions = relationship('BillingSession', back_populates='org')
stored_conversation_metadata_saas = relationship(
'StoredConversationMetadataSaas', back_populates='org'
)
user_secrets = relationship('StoredCustomSecrets', back_populates='org')
api_keys = relationship('ApiKey', back_populates='org')
slack_conversations = relationship('SlackConversation', back_populates='org')
slack_users = relationship('SlackUser', back_populates='org')
stripe_customers = relationship('StripeCustomer', back_populates='org')
def __init__(self, **kwargs):
# Handle known SQLAlchemy columns directly
for key in list(kwargs):
if hasattr(self.__class__, key):
setattr(self, key, kwargs.pop(key))
# Handle custom property-style fields
if 'search_api_key' in kwargs:
self.search_api_key = kwargs.pop('search_api_key')
if 'sandbox_api_key' in kwargs:
self.sandbox_api_key = kwargs.pop('sandbox_api_key')
if kwargs:
raise TypeError(f'Unexpected keyword arguments: {list(kwargs.keys())}')
@property
def search_api_key(self) -> SecretStr | None:
if self._search_api_key:
decrypted = decrypt_value(self._search_api_key)
return SecretStr(decrypted)
return None
@search_api_key.setter
def search_api_key(self, value: str | SecretStr | None):
raw = value.get_secret_value() if isinstance(value, SecretStr) else value
self._search_api_key = encrypt_value(raw) if raw else None
@property
def sandbox_api_key(self) -> SecretStr | None:
if self._sandbox_api_key:
decrypted = decrypt_value(self._sandbox_api_key)
return SecretStr(decrypted)
return None
@sandbox_api_key.setter
def sandbox_api_key(self, value: str | SecretStr | None):
raw = value.get_secret_value() if isinstance(value, SecretStr) else value
self._sandbox_api_key = encrypt_value(raw) if raw else None

View File

@@ -1,67 +0,0 @@
"""
SQLAlchemy model for Organization-Member relationship.
"""
from pydantic import SecretStr
from sqlalchemy import UUID, Column, ForeignKey, Integer, String
from sqlalchemy.orm import relationship
from storage.base import Base
from storage.encrypt_utils import decrypt_value, encrypt_value
class OrgMember(Base): # type: ignore
"""Junction table for organization-member relationships with roles."""
__tablename__ = 'org_member'
org_id = Column(UUID(as_uuid=True), ForeignKey('org.id'), primary_key=True)
user_id = Column(UUID(as_uuid=True), ForeignKey('user.id'), primary_key=True)
role_id = Column(Integer, ForeignKey('role.id'), nullable=False)
_llm_api_key = Column(String, nullable=False)
max_iterations = Column(Integer, nullable=True)
llm_model = Column(String, nullable=True)
_llm_api_key_for_byor = Column(String, nullable=True)
llm_base_url = Column(String, nullable=True)
status = Column(String, nullable=True)
# Relationships
org = relationship('Org', back_populates='org_members')
user = relationship('User', back_populates='org_members')
role = relationship('Role', back_populates='org_members')
def __init__(self, **kwargs):
# Handle known SQLAlchemy columns directly
for key in list(kwargs):
if hasattr(self.__class__, key):
setattr(self, key, kwargs.pop(key))
# Handle custom property-style fields
if 'llm_api_key' in kwargs:
self.llm_api_key = kwargs.pop('llm_api_key')
if 'llm_api_key_for_byor' in kwargs:
self.llm_api_key_for_byor = kwargs.pop('llm_api_key_for_byor')
if kwargs:
raise TypeError(f'Unexpected keyword arguments: {list(kwargs.keys())}')
@property
def llm_api_key(self) -> SecretStr:
decrypted = decrypt_value(self._llm_api_key)
return SecretStr(decrypted)
@llm_api_key.setter
def llm_api_key(self, value: str | SecretStr):
raw = value.get_secret_value() if isinstance(value, SecretStr) else value
self._llm_api_key = encrypt_value(raw)
@property
def llm_api_key_for_byor(self) -> SecretStr | None:
if self._llm_api_key_for_byor:
decrypted = decrypt_value(self._llm_api_key_for_byor)
return SecretStr(decrypted)
return None
@llm_api_key_for_byor.setter
def llm_api_key_for_byor(self, value: str | SecretStr | None):
raw = value.get_secret_value() if isinstance(value, SecretStr) else value
self._llm_api_key_for_byor = encrypt_value(raw) if raw else None

View File

@@ -1,125 +0,0 @@
"""
Store class for managing organization-member relationships.
"""
from typing import Optional
from uuid import UUID
from storage.database import session_maker
from storage.org_member import OrgMember
from storage.user_settings import UserSettings
from openhands.storage.data_models.settings import Settings
class OrgMemberStore:
"""Store for managing organization-member relationships."""
@staticmethod
def add_user_to_org(
org_id: UUID,
user_id: UUID,
role_id: int,
llm_api_key: str,
status: Optional[str] = None,
) -> OrgMember:
"""Add a user to an organization with a specific role."""
with session_maker() as session:
org_member = OrgMember(
org_id=org_id,
user_id=user_id,
role_id=role_id,
llm_api_key=llm_api_key,
status=status,
)
session.add(org_member)
session.commit()
session.refresh(org_member)
return org_member
@staticmethod
def get_org_member(org_id: UUID, user_id: int) -> Optional[OrgMember]:
"""Get organization-user relationship."""
with session_maker() as session:
return (
session.query(OrgMember)
.filter(OrgMember.org_id == org_id, OrgMember.user_id == user_id)
.first()
)
@staticmethod
def get_user_orgs(user_id: int) -> list[OrgMember]:
"""Get all organizations for a user."""
with session_maker() as session:
return session.query(OrgMember).filter(OrgMember.user_id == user_id).all()
@staticmethod
def get_org_members(org_id: UUID) -> list[OrgMember]:
"""Get all users in an organization."""
with session_maker() as session:
return session.query(OrgMember).filter(OrgMember.org_id == org_id).all()
@staticmethod
def update_org_member(org_member: OrgMember) -> None:
"""Update an organization-member relationship."""
with session_maker() as session:
session.merge(org_member)
session.commit()
@staticmethod
def update_user_role_in_org(
org_id: UUID, user_id: int, role_id: int, status: Optional[str] = None
) -> Optional[OrgMember]:
"""Update user's role in an organization."""
with session_maker() as session:
org_member = (
session.query(OrgMember)
.filter(OrgMember.org_id == org_id, OrgMember.user_id == user_id)
.first()
)
if not org_member:
return None
org_member.role_id = role_id
if status is not None:
org_member.status = status
session.commit()
session.refresh(org_member)
return org_member
@staticmethod
def remove_user_from_org(org_id: UUID, user_id: int) -> bool:
"""Remove a user from an organization."""
with session_maker() as session:
org_member = (
session.query(OrgMember)
.filter(OrgMember.org_id == org_id, OrgMember.user_id == user_id)
.first()
)
if not org_member:
return False
session.delete(org_member)
session.commit()
return True
@staticmethod
def get_kwargs_from_settings(settings: Settings):
kwargs = {
normalized: getattr(settings, normalized)
for c in OrgMember.__table__.columns
if (normalized := c.name.lstrip('_')) and hasattr(settings, normalized)
}
return kwargs
@staticmethod
def get_kwargs_from_user_settings(user_settings: UserSettings):
kwargs = {
normalized: getattr(user_settings, normalized)
for c in OrgMember.__table__.columns
if (normalized := c.name.lstrip('_')) and hasattr(user_settings, normalized)
}
return kwargs

View File

@@ -1,844 +0,0 @@
"""
Service class for managing organization operations.
Separates business logic from route handlers.
"""
from uuid import UUID, uuid4
from uuid import UUID as parse_uuid
from server.constants import ORG_SETTINGS_VERSION, get_default_litellm_model
from server.routes.org_models import (
LiteLLMIntegrationError,
OrgAuthorizationError,
OrgDatabaseError,
OrgNameExistsError,
OrgNotFoundError,
OrgUpdate,
)
from storage.lite_llm_manager import LiteLlmManager
from storage.org import Org
from storage.org_member import OrgMember
from storage.org_member_store import OrgMemberStore
from storage.org_store import OrgStore
from storage.role_store import RoleStore
from storage.user_store import UserStore
from openhands.core.logger import openhands_logger as logger
class OrgService:
"""Service for handling organization-related operations."""
@staticmethod
def validate_name_uniqueness(name: str) -> None:
"""
Validate that organization name is unique.
Args:
name: Organization name to validate
Raises:
OrgNameExistsError: If organization name already exists
"""
existing_org = OrgStore.get_org_by_name(name)
if existing_org is not None:
raise OrgNameExistsError(name)
@staticmethod
async def create_litellm_integration(org_id: UUID, user_id: str) -> dict:
"""
Create LiteLLM team integration for the organization.
Args:
org_id: Organization ID
user_id: User ID who will own the organization
Returns:
dict: LiteLLM settings object
Raises:
LiteLLMIntegrationError: If LiteLLM integration fails
"""
try:
settings = await UserStore.create_default_settings(
org_id=str(org_id), user_id=user_id, create_user=False
)
if not settings:
logger.error(
'Failed to create LiteLLM settings',
extra={'org_id': str(org_id), 'user_id': user_id},
)
raise LiteLLMIntegrationError('Failed to create LiteLLM settings')
logger.debug(
'LiteLLM integration created',
extra={'org_id': str(org_id), 'user_id': user_id},
)
return settings
except LiteLLMIntegrationError:
raise
except Exception as e:
logger.exception(
'Error creating LiteLLM integration',
extra={'org_id': str(org_id), 'user_id': user_id, 'error': str(e)},
)
raise LiteLLMIntegrationError(f'LiteLLM integration failed: {str(e)}')
@staticmethod
def create_org_entity(
org_id: UUID,
name: str,
contact_name: str,
contact_email: str,
) -> Org:
"""
Create an organization entity with basic information.
Args:
org_id: Organization UUID
name: Organization name
contact_name: Contact person name
contact_email: Contact email address
Returns:
Org: New organization entity (not yet persisted)
"""
return Org(
id=org_id,
name=name,
contact_name=contact_name,
contact_email=contact_email,
org_version=ORG_SETTINGS_VERSION,
default_llm_model=get_default_litellm_model(),
)
@staticmethod
def apply_litellm_settings_to_org(org: Org, settings: dict) -> None:
"""
Apply LiteLLM settings to organization entity.
Args:
org: Organization entity to update
settings: LiteLLM settings object
"""
org_kwargs = OrgStore.get_kwargs_from_settings(settings)
for key, value in org_kwargs.items():
if hasattr(org, key):
setattr(org, key, value)
@staticmethod
def get_owner_role():
"""
Get the owner role from the database.
Returns:
Role: The owner role object
Raises:
Exception: If owner role not found
"""
owner_role = RoleStore.get_role_by_name('owner')
if not owner_role:
raise Exception('Owner role not found in database')
return owner_role
@staticmethod
def create_org_member_entity(
org_id: UUID,
user_id: str,
role_id: int,
settings: dict,
) -> OrgMember:
"""
Create an organization member entity.
Args:
org_id: Organization UUID
user_id: User ID (string that will be converted to UUID)
role_id: Role ID
settings: LiteLLM settings object
Returns:
OrgMember: New organization member entity (not yet persisted)
"""
org_member_kwargs = OrgMemberStore.get_kwargs_from_settings(settings)
return OrgMember(
org_id=org_id,
user_id=parse_uuid(user_id),
role_id=role_id,
status='active',
**org_member_kwargs,
)
@staticmethod
async def create_org_with_owner(
name: str,
contact_name: str,
contact_email: str,
user_id: str,
) -> Org:
"""
Create a new organization with the specified user as owner.
This method orchestrates the complete organization creation workflow:
1. Validates that the organization name doesn't already exist
2. Generates a unique organization ID
3. Creates LiteLLM team integration
4. Creates the organization entity
5. Applies LiteLLM settings
6. Creates owner membership
7. Persists everything in a transaction
If database persistence fails, LiteLLM resources are cleaned up (compensation).
Args:
name: Organization name (must be unique)
contact_name: Contact person name
contact_email: Contact email address
user_id: ID of the user who will be the owner
Returns:
Org: The created organization object
Raises:
OrgNameExistsError: If organization name already exists
LiteLLMIntegrationError: If LiteLLM integration fails
OrgDatabaseError: If database operations fail
"""
logger.info(
'Starting organization creation',
extra={'user_id': user_id, 'org_name': name},
)
# Step 1: Validate name uniqueness (fails early, no cleanup needed)
OrgService.validate_name_uniqueness(name)
# Step 2: Generate organization ID
org_id = uuid4()
# Step 3: Create LiteLLM integration (external state created)
settings = await OrgService.create_litellm_integration(org_id, user_id)
# Steps 4-7: Create entities and persist with compensation
# If any of these fail, we need to clean up LiteLLM resources
try:
# Step 4: Create organization entity
org = OrgService.create_org_entity(
org_id=org_id,
name=name,
contact_name=contact_name,
contact_email=contact_email,
)
# Step 5: Apply LiteLLM settings
OrgService.apply_litellm_settings_to_org(org, settings)
# Step 6: Get owner role and create member entity
owner_role = OrgService.get_owner_role()
org_member = OrgService.create_org_member_entity(
org_id=org_id,
user_id=user_id,
role_id=owner_role.id,
settings=settings,
)
# Step 7: Persist in transaction (critical section)
persisted_org = await OrgService._persist_with_compensation(
org, org_member, org_id, user_id
)
logger.info(
'Successfully created organization',
extra={
'org_id': str(persisted_org.id),
'org_name': persisted_org.name,
'user_id': user_id,
'role': 'owner',
},
)
return persisted_org
except OrgDatabaseError:
# Already handled by _persist_with_compensation, just re-raise
raise
except Exception as e:
# Unexpected error in steps 4-6, need to clean up LiteLLM
logger.error(
'Unexpected error during organization creation, initiating cleanup',
extra={
'org_id': str(org_id),
'user_id': user_id,
'error': str(e),
},
)
await OrgService._handle_failure_with_cleanup(
org_id, user_id, e, 'Failed to create organization'
)
@staticmethod
async def _persist_with_compensation(
org: Org,
org_member: OrgMember,
org_id: UUID,
user_id: str,
) -> Org:
"""
Persist organization with compensation on failure.
If database persistence fails, cleans up LiteLLM resources.
Args:
org: Organization entity to persist
org_member: Organization member entity to persist
org_id: Organization ID (for cleanup)
user_id: User ID (for cleanup)
Returns:
Org: The persisted organization object
Raises:
OrgDatabaseError: If database operations fail
"""
try:
persisted_org = OrgStore.persist_org_with_owner(org, org_member)
return persisted_org
except Exception as e:
logger.error(
'Database persistence failed, initiating LiteLLM cleanup',
extra={
'org_id': str(org_id),
'user_id': user_id,
'error': str(e),
},
)
await OrgService._handle_failure_with_cleanup(
org_id, user_id, e, 'Failed to create organization'
)
@staticmethod
async def _handle_failure_with_cleanup(
org_id: UUID,
user_id: str,
original_error: Exception,
error_message: str,
) -> None:
"""
Handle failure by cleaning up LiteLLM resources and raising appropriate error.
This method performs compensating transaction and raises OrgDatabaseError.
Args:
org_id: Organization ID
user_id: User ID
original_error: The original exception that caused the failure
error_message: Base error message for the exception
Raises:
OrgDatabaseError: Always raises with details about the failure
"""
cleanup_error = await OrgService._cleanup_litellm_resources(org_id, user_id)
if cleanup_error:
logger.error(
'Both operation and cleanup failed',
extra={
'org_id': str(org_id),
'user_id': user_id,
'original_error': str(original_error),
'cleanup_error': str(cleanup_error),
},
)
raise OrgDatabaseError(
f'{error_message}: {str(original_error)}. '
f'Cleanup also failed: {str(cleanup_error)}'
)
raise OrgDatabaseError(f'{error_message}: {str(original_error)}')
@staticmethod
async def _cleanup_litellm_resources(
org_id: UUID, user_id: str
) -> Exception | None:
"""
Compensating transaction: Clean up LiteLLM resources.
Deletes the team which should cascade to remove keys and memberships.
This is a best-effort operation - errors are logged but not raised.
Args:
org_id: Organization ID
user_id: User ID
Returns:
Exception | None: Exception if cleanup failed, None if successful
"""
try:
await LiteLlmManager.delete_team(str(org_id))
logger.info(
'Successfully cleaned up LiteLLM team',
extra={'org_id': str(org_id), 'user_id': user_id},
)
return None
except Exception as e:
logger.error(
'Failed to cleanup LiteLLM team (resources may be orphaned)',
extra={
'org_id': str(org_id),
'user_id': user_id,
'error': str(e),
},
)
return e
@staticmethod
def has_admin_or_owner_role(user_id: str, org_id: UUID) -> bool:
"""
Check if user has admin or owner role in the specified organization.
Args:
user_id: User ID to check
org_id: Organization ID to check membership in
Returns:
bool: True if user has admin or owner role, False otherwise
"""
try:
# Parse user_id as UUID for database query
user_uuid = parse_uuid(user_id)
# Get the user's membership in this organization
# Note: The type annotation says int but the actual column is UUID
org_member = OrgMemberStore.get_org_member(org_id, user_uuid)
if not org_member:
return False
# Get the role details
role = RoleStore.get_role_by_id(org_member.role_id)
if not role:
return False
# Admin and owner roles have elevated permissions
# Based on test files, both admin and owner have rank 1
return role.name in ['admin', 'owner']
except Exception as e:
logger.warning(
'Error checking user role in organization',
extra={
'user_id': user_id,
'org_id': str(org_id),
'error': str(e),
},
)
return False
@staticmethod
def is_org_member(user_id: str, org_id: UUID) -> bool:
"""
Check if user is a member of the specified organization.
Args:
user_id: User ID to check
org_id: Organization ID to check membership in
Returns:
bool: True if user is a member, False otherwise
"""
try:
user_uuid = parse_uuid(user_id)
org_member = OrgMemberStore.get_org_member(org_id, user_uuid)
return org_member is not None
except Exception as e:
logger.warning(
'Error checking user membership in organization',
extra={
'user_id': user_id,
'org_id': str(org_id),
'error': str(e),
},
)
return False
@staticmethod
def _get_llm_settings_fields() -> set[str]:
"""
Get the set of organization fields that are considered LLM settings
and require admin/owner role to update.
Returns:
set[str]: Set of field names that require elevated permissions
"""
return {
'default_llm_model',
'default_llm_api_key_for_byor',
'default_llm_base_url',
'search_api_key',
'security_analyzer',
'agent',
'confirmation_mode',
'enable_default_condenser',
'condenser_max_size',
}
@staticmethod
def _has_llm_settings_updates(update_data: OrgUpdate) -> set[str]:
"""
Check if the update contains any LLM settings fields.
Args:
update_data: The organization update data
Returns:
set[str]: Set of LLM fields being updated (empty if none)
"""
llm_fields = OrgService._get_llm_settings_fields()
update_dict = update_data.model_dump(exclude_none=True)
return llm_fields.intersection(update_dict.keys())
@staticmethod
async def update_org_with_permissions(
org_id: UUID,
update_data: OrgUpdate,
user_id: str,
) -> Org:
"""
Update organization with permission checks for LLM settings.
Args:
org_id: Organization UUID to update
update_data: Organization update data from request
user_id: ID of the user requesting the update
Returns:
Org: The updated organization object
Raises:
ValueError: If organization not found
PermissionError: If user is not a member, or lacks admin/owner role for LLM settings
OrgDatabaseError: If database update fails
"""
logger.info(
'Updating organization with permission checks',
extra={
'org_id': str(org_id),
'user_id': user_id,
'has_update_data': update_data is not None,
},
)
# Validate organization exists
existing_org = OrgStore.get_org_by_id(org_id)
if not existing_org:
raise ValueError(f'Organization with ID {org_id} not found')
# Check if user is a member of this organization
if not OrgService.is_org_member(user_id, org_id):
logger.warning(
'Non-member attempted to update organization',
extra={
'user_id': user_id,
'org_id': str(org_id),
},
)
raise PermissionError(
'User must be a member of the organization to update it'
)
# Check if update contains any LLM settings
llm_fields_being_updated = OrgService._has_llm_settings_updates(update_data)
if llm_fields_being_updated:
# Verify user has admin or owner role
has_permission = OrgService.has_admin_or_owner_role(user_id, org_id)
if not has_permission:
logger.warning(
'User attempted to update LLM settings without permission',
extra={
'user_id': user_id,
'org_id': str(org_id),
'attempted_fields': list(llm_fields_being_updated),
},
)
raise PermissionError(
'Admin or owner role required to update LLM settings'
)
logger.debug(
'User has permission to update LLM settings',
extra={
'user_id': user_id,
'org_id': str(org_id),
'llm_fields': list(llm_fields_being_updated),
},
)
# Convert to dict for OrgStore (excluding None values)
update_dict = update_data.model_dump(exclude_none=True)
if not update_dict:
logger.info(
'No fields to update',
extra={'org_id': str(org_id), 'user_id': user_id},
)
return existing_org
# Perform the update
try:
updated_org = OrgStore.update_org(org_id, update_dict)
if not updated_org:
raise OrgDatabaseError('Failed to update organization in database')
logger.info(
'Organization updated successfully',
extra={
'org_id': str(org_id),
'user_id': user_id,
'updated_fields': list(update_dict.keys()),
},
)
return updated_org
except Exception as e:
logger.error(
'Failed to update organization',
extra={
'org_id': str(org_id),
'user_id': user_id,
'error': str(e),
},
)
raise OrgDatabaseError(f'Failed to update organization: {str(e)}')
@staticmethod
async def get_org_credits(user_id: str, org_id: UUID) -> float | None:
"""
Get organization credits from LiteLLM team.
Args:
user_id: User ID
org_id: Organization ID
Returns:
float | None: Credits (max_budget - spend) or None if LiteLLM not configured
"""
try:
user_team_info = await LiteLlmManager.get_user_team_info(
user_id, str(org_id)
)
if not user_team_info:
logger.warning(
'No team info available from LiteLLM',
extra={'user_id': user_id, 'org_id': str(org_id)},
)
return None
max_budget = (user_team_info.get('litellm_budget_table') or {}).get(
'max_budget', 0
)
spend = user_team_info.get('spend', 0)
credits = max(max_budget - spend, 0)
logger.debug(
'Retrieved organization credits',
extra={
'user_id': user_id,
'org_id': str(org_id),
'credits': credits,
'max_budget': max_budget,
'spend': spend,
},
)
return credits
except Exception as e:
logger.warning(
'Failed to retrieve organization credits',
extra={'user_id': user_id, 'org_id': str(org_id), 'error': str(e)},
)
return None
@staticmethod
def get_user_orgs_paginated(
user_id: str, page_id: str | None = None, limit: int = 100
):
"""
Get paginated list of organizations for a user.
Args:
user_id: User ID (string that will be converted to UUID)
page_id: Optional page ID (offset as string) for pagination
limit: Maximum number of organizations to return
Returns:
Tuple of (list of Org objects, next_page_id or None)
"""
logger.debug(
'Fetching paginated organizations for user',
extra={'user_id': user_id, 'page_id': page_id, 'limit': limit},
)
# Convert user_id string to UUID
user_uuid = parse_uuid(user_id)
# Fetch organizations from store
orgs, next_page_id = OrgStore.get_user_orgs_paginated(
user_id=user_uuid, page_id=page_id, limit=limit
)
logger.debug(
'Retrieved organizations for user',
extra={
'user_id': user_id,
'org_count': len(orgs),
'has_more': next_page_id is not None,
},
)
return orgs, next_page_id
@staticmethod
async def get_org_by_id(org_id: UUID, user_id: str) -> Org:
"""
Get organization by ID with membership validation.
This method verifies that the user is a member of the organization
before returning the organization details.
Args:
org_id: Organization ID
user_id: User ID (string that will be converted to UUID)
Returns:
Org: The organization object
Raises:
OrgNotFoundError: If organization not found or user is not a member
"""
logger.info(
'Retrieving organization',
extra={'user_id': user_id, 'org_id': str(org_id)},
)
# Verify user is a member of the organization
org_member = OrgMemberStore.get_org_member(org_id, parse_uuid(user_id))
if not org_member:
logger.warning(
'User is not a member of organization or organization does not exist',
extra={'user_id': user_id, 'org_id': str(org_id)},
)
raise OrgNotFoundError(str(org_id))
# Retrieve organization
org = OrgStore.get_org_by_id(org_id)
if not org:
logger.error(
'Organization not found despite valid membership',
extra={'user_id': user_id, 'org_id': str(org_id)},
)
raise OrgNotFoundError(str(org_id))
logger.info(
'Successfully retrieved organization',
extra={
'org_id': str(org.id),
'org_name': org.name,
'user_id': user_id,
},
)
return org
@staticmethod
def verify_owner_authorization(user_id: str, org_id: UUID) -> None:
"""
Verify that the user is the owner of the organization.
Args:
user_id: User ID to check
org_id: Organization ID
Raises:
OrgNotFoundError: If organization doesn't exist
OrgAuthorizationError: If user is not authorized to delete
"""
# Check if organization exists
org = OrgStore.get_org_by_id(org_id)
if not org:
raise OrgNotFoundError(str(org_id))
# Check if user is a member of the organization
org_member = OrgMemberStore.get_org_member(org_id, parse_uuid(user_id))
if not org_member:
raise OrgAuthorizationError('User is not a member of this organization')
# Check if user has owner role
role = RoleStore.get_role_by_id(org_member.role_id)
if not role or role.name != 'owner':
raise OrgAuthorizationError(
'Only organization owners can delete organizations'
)
logger.debug(
'User authorization verified for organization deletion',
extra={'user_id': user_id, 'org_id': str(org_id), 'role': role.name},
)
@staticmethod
async def delete_org_with_cleanup(user_id: str, org_id: UUID) -> Org:
"""
Delete organization with complete cleanup of all associated data.
This method performs the complete organization deletion workflow:
1. Verifies user authorization (owner only)
2. Performs database cascade deletion and LiteLLM cleanup in single transaction
Args:
user_id: User ID requesting deletion (must be owner)
org_id: Organization ID to delete
Returns:
Org: The deleted organization details
Raises:
OrgNotFoundError: If organization doesn't exist
OrgAuthorizationError: If user is not authorized to delete
OrgDatabaseError: If database operations or LiteLLM cleanup fail
"""
logger.info(
'Starting organization deletion',
extra={'user_id': user_id, 'org_id': str(org_id)},
)
# Step 1: Verify user authorization
OrgService.verify_owner_authorization(user_id, org_id)
# Step 2: Perform database cascade deletion with LiteLLM cleanup in transaction
try:
deleted_org = await OrgStore.delete_org_cascade(org_id)
if not deleted_org:
# This shouldn't happen since we verified existence above
raise OrgDatabaseError('Organization not found during deletion')
logger.info(
'Organization deletion completed successfully',
extra={
'user_id': user_id,
'org_id': str(org_id),
'org_name': deleted_org.name,
},
)
return deleted_org
except Exception as e:
logger.error(
'Organization deletion failed',
extra={'user_id': user_id, 'org_id': str(org_id), 'error': str(e)},
)
raise OrgDatabaseError(f'Failed to delete organization: {str(e)}')

View File

@@ -1,363 +0,0 @@
"""
Store class for managing organizations.
"""
from typing import Optional
from uuid import UUID
from server.constants import (
LITE_LLM_API_URL,
ORG_SETTINGS_VERSION,
get_default_litellm_model,
)
from sqlalchemy import text
from sqlalchemy.orm import joinedload
from storage.database import session_maker
from storage.lite_llm_manager import LiteLlmManager
from storage.org import Org
from storage.org_member import OrgMember
from storage.user import User
from storage.user_settings import UserSettings
from openhands.core.logger import openhands_logger as logger
from openhands.storage.data_models.settings import Settings
class OrgStore:
"""Store for managing organizations."""
@staticmethod
def create_org(
kwargs: dict,
) -> Org:
"""Create a new organization."""
with session_maker() as session:
org = Org(**kwargs)
org.org_version = ORG_SETTINGS_VERSION
org.default_llm_model = get_default_litellm_model()
session.add(org)
session.commit()
session.refresh(org)
return org
@staticmethod
def get_org_by_id(org_id: UUID) -> Org | None:
"""Get organization by ID."""
org = None
with session_maker() as session:
org = session.query(Org).filter(Org.id == org_id).first()
return OrgStore._validate_org_version(org)
@staticmethod
def get_current_org_from_keycloak_user_id(keycloak_user_id: str) -> Org | None:
with session_maker() as session:
user = (
session.query(User)
.options(joinedload(User.org_members))
.filter(User.id == UUID(keycloak_user_id))
.first()
)
if not user:
logger.warning(f'User not found for ID {keycloak_user_id}')
return None
org_id = user.current_org_id
org = session.query(Org).filter(Org.id == org_id).first()
if not org:
logger.warning(
f'Org not found for ID {org_id} as the current org for user {keycloak_user_id}'
)
return None
return OrgStore._validate_org_version(org)
@staticmethod
def get_org_by_name(name: str) -> Org | None:
"""Get organization by name."""
org = None
with session_maker() as session:
org = session.query(Org).filter(Org.name == name).first()
return OrgStore._validate_org_version(org)
@staticmethod
def _validate_org_version(org: Org) -> Org | None:
"""Check if we need to update org version."""
if org and org.org_version < ORG_SETTINGS_VERSION:
org = OrgStore.update_org(
org.id,
{
'org_version': ORG_SETTINGS_VERSION,
'default_llm_model': get_default_litellm_model(),
'llm_base_url': LITE_LLM_API_URL,
},
)
return org
@staticmethod
def list_orgs() -> list[Org]:
"""List all organizations."""
with session_maker() as session:
orgs = session.query(Org).all()
return orgs
@staticmethod
def get_user_orgs_paginated(
user_id: UUID, page_id: str | None = None, limit: int = 100
) -> tuple[list[Org], str | None]:
"""
Get paginated list of organizations for a user.
Args:
user_id: User UUID
page_id: Optional page ID (offset as string) for pagination
limit: Maximum number of organizations to return
Returns:
Tuple of (list of Org objects, next_page_id or None)
"""
with session_maker() as session:
# Build query joining OrgMember with Org
query = (
session.query(Org)
.join(OrgMember, Org.id == OrgMember.org_id)
.filter(OrgMember.user_id == user_id)
.order_by(Org.name)
)
# Apply pagination offset
if page_id is not None:
try:
offset = int(page_id)
query = query.offset(offset)
except ValueError:
# If page_id is not a valid integer, start from beginning
offset = 0
else:
offset = 0
# Fetch limit + 1 to check if there are more results
query = query.limit(limit + 1)
orgs = query.all()
# Check if there are more results
has_more = len(orgs) > limit
if has_more:
orgs = orgs[:limit]
# Calculate next page ID
next_page_id = None
if has_more:
next_page_id = str(offset + limit)
# Validate org versions
validated_orgs = [
OrgStore._validate_org_version(org) for org in orgs if org
]
validated_orgs = [org for org in validated_orgs if org is not None]
return validated_orgs, next_page_id
@staticmethod
def update_org(
org_id: UUID,
kwargs: dict,
) -> Optional[Org]:
"""Update organization details."""
with session_maker() as session:
org = session.query(Org).filter(Org.id == org_id).first()
if not org:
return None
if 'id' in kwargs:
kwargs.pop('id')
for key, value in kwargs.items():
if hasattr(org, key):
setattr(org, key, value)
session.commit()
session.refresh(org)
return org
@staticmethod
def get_kwargs_from_settings(settings: Settings):
kwargs = {}
for c in Org.__table__.columns:
# Normalize for lookup
normalized = (
c.name.removeprefix('_default_').removeprefix('default_').lstrip('_')
)
if not hasattr(settings, normalized):
continue
# ---- FIX: Output key should drop *only* leading "_" but preserve "default" ----
key = c.name
if key.startswith('_'):
key = key[1:] # remove only the very first leading underscore
kwargs[key] = getattr(settings, normalized)
return kwargs
@staticmethod
def get_kwargs_from_user_settings(user_settings: UserSettings):
kwargs = {}
for c in Org.__table__.columns:
# Normalize for lookup
normalized = (
c.name.removeprefix('_default_').removeprefix('default_').lstrip('_')
)
if not hasattr(user_settings, normalized):
continue
# ---- FIX: Output key should drop *only* leading "_" but preserve "default" ----
key = c.name
if key.startswith('_'):
key = key[1:] # remove only the very first leading underscore
kwargs[key] = getattr(user_settings, normalized)
kwargs['org_version'] = user_settings.user_version
return kwargs
@staticmethod
def persist_org_with_owner(
org: Org,
org_member: OrgMember,
) -> Org:
"""
Persist organization and owner membership in a single transaction.
Args:
org: Organization entity to persist
org_member: Organization member entity to persist
Returns:
Org: The persisted organization object
Raises:
Exception: If database operations fail
"""
with session_maker() as session:
session.add(org)
session.add(org_member)
session.commit()
session.refresh(org)
return org
@staticmethod
async def delete_org_cascade(org_id: UUID) -> Org | None:
"""
Delete organization and all associated data in cascade, including external LiteLLM cleanup.
Args:
org_id: UUID of the organization to delete
Returns:
Org: The deleted organization object, or None if not found
Raises:
Exception: If database operations or LiteLLM cleanup fail
"""
with session_maker() as session:
# First get the organization to return it
org = session.query(Org).filter(Org.id == org_id).first()
if not org:
return None
try:
# 1. Delete conversation data for organization conversations
session.execute(
text("""
DELETE FROM conversation_metadata
WHERE conversation_id IN (
SELECT conversation_id FROM conversation_metadata_saas WHERE org_id = :org_id
)
"""),
{'org_id': str(org_id)},
)
session.execute(
text("""
DELETE FROM app_conversation_start_task
WHERE app_conversation_id::text IN (
SELECT conversation_id FROM conversation_metadata_saas WHERE org_id = :org_id
)
"""),
{'org_id': str(org_id)},
)
# 2. Delete organization-owned data tables (direct org_id foreign keys)
session.execute(
text('DELETE FROM billing_sessions WHERE org_id = :org_id'),
{'org_id': str(org_id)},
)
session.execute(
text(
'DELETE FROM conversation_metadata_saas WHERE org_id = :org_id'
),
{'org_id': str(org_id)},
)
session.execute(
text('DELETE FROM custom_secrets WHERE org_id = :org_id'),
{'org_id': str(org_id)},
)
session.execute(
text('DELETE FROM api_keys WHERE org_id = :org_id'),
{'org_id': str(org_id)},
)
session.execute(
text('DELETE FROM slack_conversation WHERE org_id = :org_id'),
{'org_id': str(org_id)},
)
session.execute(
text('DELETE FROM slack_users WHERE org_id = :org_id'),
{'org_id': str(org_id)},
)
session.execute(
text('DELETE FROM stripe_customers WHERE org_id = :org_id'),
{'org_id': str(org_id)},
)
# 3. Delete organization memberships
session.execute(
text('DELETE FROM org_member WHERE org_id = :org_id'),
{'org_id': str(org_id)},
)
# 4. Handle users with this as current_org_id
session.execute(
text(
'UPDATE "user" SET current_org_id = NULL WHERE current_org_id = :org_id'
),
{'org_id': str(org_id)},
)
# 5. Finally delete the organization
session.delete(org)
# 6. Clean up LiteLLM team before committing transaction
logger.info(
'Deleting LiteLLM team within database transaction',
extra={'org_id': str(org_id)},
)
await LiteLlmManager.delete_team(str(org_id))
# 7. Commit all changes only if everything succeeded
session.commit()
logger.info(
'Successfully deleted organization and all associated data including LiteLLM team',
extra={'org_id': str(org_id), 'org_name': org.name},
)
return org
except Exception as e:
session.rollback()
logger.error(
'Failed to delete organization - transaction rolled back',
extra={'org_id': str(org_id), 'error': str(e)},
)
raise

View File

@@ -1,21 +0,0 @@
"""
SQLAlchemy model for Role.
"""
from sqlalchemy import Column, Identity, Integer, String
from sqlalchemy.orm import relationship
from storage.base import Base
class Role(Base): # type: ignore
"""Role model for user permissions."""
__tablename__ = 'role'
id = Column(Integer, Identity(), primary_key=True)
name = Column(String, nullable=False, unique=True)
rank = Column(Integer, nullable=False)
# Relationships
users = relationship('User', back_populates='role')
org_members = relationship('OrgMember', back_populates='role')

View File

@@ -1,56 +0,0 @@
"""
Store class for managing roles.
"""
from typing import List, Optional
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from storage.database import a_session_maker, session_maker
from storage.role import Role
class RoleStore:
"""Store for managing roles."""
@staticmethod
def create_role(name: str, rank: int) -> Role:
"""Create a new role."""
with session_maker() as session:
role = Role(name=name, rank=rank)
session.add(role)
session.commit()
session.refresh(role)
return role
@staticmethod
def get_role_by_id(role_id: int) -> Optional[Role]:
"""Get role by ID."""
with session_maker() as session:
return session.query(Role).filter(Role.id == role_id).first()
@staticmethod
def get_role_by_name(name: str) -> Optional[Role]:
"""Get role by name."""
with session_maker() as session:
return session.query(Role).filter(Role.name == name).first()
@staticmethod
async def get_role_by_name_async(
name: str,
session: Optional[AsyncSession] = None,
) -> Optional[Role]:
"""Get role by name."""
if session is not None:
result = await session.execute(select(Role).where(Role.name == name))
return result.scalars().first()
async with a_session_maker() as session:
result = await session.execute(select(Role).where(Role.name == name))
return result.scalars().first()
@staticmethod
def list_roles() -> List[Role]:
"""List all roles."""
with session_maker() as session:
return session.query(Role).order_by(Role.rank).all()

View File

@@ -4,13 +4,10 @@ import dataclasses
import logging
from dataclasses import dataclass
from datetime import UTC
from uuid import UUID
from sqlalchemy.orm import sessionmaker
from storage.database import session_maker
from storage.stored_conversation_metadata import StoredConversationMetadata
from storage.stored_conversation_metadata_saas import StoredConversationMetadataSaas
from storage.user_store import UserStore
from openhands.core.config.openhands_config import OpenHandsConfig
from openhands.integrations.provider import ProviderType
@@ -32,37 +29,20 @@ logger = logging.getLogger(__name__)
class SaasConversationStore(ConversationStore):
user_id: str
session_maker: sessionmaker
org_id: UUID | None = None # will be fetched automatically
def __init__(self, user_id: str, session_maker: sessionmaker):
self.user_id = user_id
self.session_maker = session_maker
user = UserStore.get_user_by_id(user_id)
self.org_id = user.current_org_id if user else None
def _select_by_id(self, session, conversation_id: str):
# Join StoredConversationMetadata with ConversationMetadataSaas to filter by user/org
query = (
return (
session.query(StoredConversationMetadata)
.join(
StoredConversationMetadataSaas,
StoredConversationMetadata.conversation_id
== StoredConversationMetadataSaas.conversation_id,
)
.filter(StoredConversationMetadataSaas.user_id == UUID(self.user_id))
.filter(StoredConversationMetadata.user_id == self.user_id)
.filter(StoredConversationMetadata.conversation_id == conversation_id)
.filter(StoredConversationMetadata.conversation_version == 'V0')
)
if self.org_id is not None:
query = query.filter(StoredConversationMetadataSaas.org_id == self.org_id)
return query
def _to_external_model(self, conversation_metadata: StoredConversationMetadata):
kwargs = {
c.name: getattr(conversation_metadata, c.name)
for c in StoredConversationMetadata.__table__.columns
if c.name != 'github_user_id' # Skip github_user_id field
}
# TODO: I'm not sure why the timezone is not set on the dates coming back out of the db
kwargs['created_at'] = kwargs['created_at'].replace(tzinfo=UTC)
@@ -73,8 +53,6 @@ class SaasConversationStore(ConversationStore):
# Convert string to ProviderType enum
kwargs['git_provider'] = ProviderType(kwargs['git_provider'])
kwargs['user_id'] = self.user_id
# Remove V1 attributes
kwargs.pop('max_budget_per_task', None)
kwargs.pop('cache_read_tokens', None)
@@ -89,10 +67,7 @@ class SaasConversationStore(ConversationStore):
async def save_metadata(self, metadata: ConversationMetadata):
kwargs = dataclasses.asdict(metadata)
# Remove user_id and org_id from kwargs since they're no longer in StoredConversationMetadata
kwargs.pop('user_id', None)
kwargs.pop('org_id', None)
kwargs['user_id'] = self.user_id
# Convert ProviderType enum to string for storage
if kwargs.get('git_provider') is not None:
@@ -106,41 +81,7 @@ class SaasConversationStore(ConversationStore):
def _save_metadata():
with self.session_maker() as session:
# Save the main conversation metadata
session.merge(stored_metadata)
# Create or update the SaaS metadata record
saas_metadata = (
session.query(StoredConversationMetadataSaas)
.filter(
StoredConversationMetadataSaas.conversation_id
== stored_metadata.conversation_id
)
.first()
)
if not saas_metadata:
saas_metadata = StoredConversationMetadataSaas(
conversation_id=stored_metadata.conversation_id,
user_id=UUID(self.user_id),
org_id=self.org_id,
)
session.add(saas_metadata)
else:
# Validate
expected_user_id = UUID(self.user_id)
expected_org_id = self.org_id
if saas_metadata.user_id != expected_user_id:
raise ValueError(
f'Existing user_id ({saas_metadata.user_id}) does not match expected value ({expected_user_id}).'
)
if expected_org_id and saas_metadata.org_id != expected_org_id:
raise ValueError(
f'Existing org_id ({saas_metadata.org_id}) does not match expected value ({expected_org_id}).'
)
session.commit()
await call_sync_from_async(_save_metadata)
@@ -160,29 +101,8 @@ class SaasConversationStore(ConversationStore):
async def delete_metadata(self, conversation_id: str) -> None:
def _delete_metadata():
with self.session_maker() as session:
saas_record = (
session.query(StoredConversationMetadataSaas)
.filter(
StoredConversationMetadataSaas.conversation_id
== conversation_id,
StoredConversationMetadataSaas.user_id == UUID(self.user_id),
StoredConversationMetadataSaas.org_id == self.org_id,
)
.first()
)
if saas_record:
# Delete both records, but only if the SaaS one exists
session.query(StoredConversationMetadata).filter(
StoredConversationMetadata.conversation_id == conversation_id,
).delete()
session.delete(saas_record)
session.commit()
else:
# No SaaS record found → skip deleting main metadata
session.rollback()
self._select_by_id(session, conversation_id).delete()
session.commit()
await call_sync_from_async(_delete_metadata)
@@ -205,15 +125,7 @@ class SaasConversationStore(ConversationStore):
with self.session_maker() as session:
conversations = (
session.query(StoredConversationMetadata)
.join(
StoredConversationMetadataSaas,
StoredConversationMetadata.conversation_id
== StoredConversationMetadataSaas.conversation_id,
)
.filter(
StoredConversationMetadataSaas.user_id == UUID(self.user_id)
)
.filter(StoredConversationMetadataSaas.org_id == self.org_id)
.filter(StoredConversationMetadata.user_id == self.user_id)
.filter(StoredConversationMetadata.conversation_version == 'V0')
.order_by(StoredConversationMetadata.created_at.desc())
.offset(offset)

View File

@@ -8,7 +8,6 @@ from cryptography.fernet import Fernet
from sqlalchemy.orm import sessionmaker
from storage.database import session_maker
from storage.stored_custom_secrets import StoredCustomSecrets
from storage.user_store import UserStore
from openhands.core.config.openhands_config import OpenHandsConfig
from openhands.core.logger import openhands_logger as logger
@@ -25,17 +24,14 @@ class SaasSecretsStore(SecretsStore):
async def load(self) -> Secrets | None:
if not self.user_id:
return None
user = await UserStore.get_user_by_id_async(self.user_id)
org_id = user.current_org_id if user else None
with self.session_maker() as session:
# Fetch all secrets for the given user ID
query = session.query(StoredCustomSecrets).filter(
StoredCustomSecrets.keycloak_user_id == self.user_id
settings = (
session.query(StoredCustomSecrets)
.filter(StoredCustomSecrets.keycloak_user_id == self.user_id)
.all()
)
if org_id is not None:
query = query.filter(StoredCustomSecrets.org_id == org_id)
settings = query.all()
if not settings:
return Secrets()
@@ -52,8 +48,6 @@ class SaasSecretsStore(SecretsStore):
return Secrets(custom_secrets=kwargs) # type: ignore[arg-type]
async def store(self, item: Secrets):
user = await UserStore.get_user_by_id_async(self.user_id)
org_id = user.current_org_id
with self.session_maker() as session:
# Incoming secrets are always the most updated ones
# Delete all existing records and override with incoming ones
@@ -82,7 +76,6 @@ class SaasSecretsStore(SecretsStore):
for secret_name, secret_value, description in secret_tuples:
new_secret = StoredCustomSecrets(
keycloak_user_id=self.user_id,
org_id=org_id,
secret_name=secret_name,
secret_value=secret_value,
description=description,

View File

@@ -1,28 +1,46 @@
from __future__ import annotations
import asyncio
import binascii
import hashlib
import uuid
import json
import os
from base64 import b64decode, b64encode
from dataclasses import dataclass
import httpx
from cryptography.fernet import Fernet
from integrations import stripe_service
from pydantic import SecretStr
from server.auth.token_manager import TokenManager
from server.constants import (
CURRENT_USER_SETTINGS_VERSION,
DEFAULT_INITIAL_BUDGET,
LITE_LLM_API_KEY,
LITE_LLM_API_URL,
LITE_LLM_TEAM_ID,
REQUIRE_PAYMENT,
USER_SETTINGS_VERSION_TO_MODEL,
get_default_litellm_model,
)
from server.logger import logger
from sqlalchemy.orm import joinedload, sessionmaker
from sqlalchemy.orm import sessionmaker
from storage.database import session_maker
from storage.lite_llm_manager import LiteLlmManager
from storage.org import Org
from storage.org_member import OrgMember
from storage.org_store import OrgStore
from storage.user import User
from storage.user_settings import UserSettings
from storage.user_store import UserStore
from openhands.core.config.openhands_config import OpenHandsConfig
from openhands.server.settings import Settings
from openhands.storage import get_file_store
from openhands.storage.settings.settings_store import SettingsStore
from openhands.utils.async_utils import call_sync_from_async
from openhands.utils.http_session import httpx_verify_option
# The max possible time to wait for another process to finish creating a user before retrying
_REDIS_CREATE_TIMEOUT_SECONDS = 30
# The delay to wait for another process to finish creating a user before trying to load again
_RETRY_LOAD_DELAY_SECONDS = 2
# Redis key prefix for user creation locks
_REDIS_USER_CREATION_KEY_PREFIX = 'create_user:'
@dataclass
@@ -30,9 +48,8 @@ class SaasSettingsStore(SettingsStore):
user_id: str
session_maker: sessionmaker
config: OpenHandsConfig
ENCRYPT_VALUES = ['llm_api_key', 'llm_api_key_for_byor', 'search_api_key']
def _get_user_settings_by_keycloak_id(
def get_user_settings_by_keycloak_id(
self, keycloak_user_id: str, session=None
) -> UserSettings | None:
"""
@@ -68,105 +85,354 @@ class SaasSettingsStore(SettingsStore):
return _get_settings()
async def load(self) -> Settings | None:
user = await call_sync_from_async(UserStore.get_user_by_id, self.user_id)
if not user:
logger.error(f'User not found for ID {self.user_id}')
if not self.user_id:
return None
with self.session_maker() as session:
settings = self.get_user_settings_by_keycloak_id(self.user_id, session)
org_id = user.current_org_id
org_member: OrgMember = None
for om in user.org_members:
if om.org_id == org_id:
org_member = om
break
if not org_member or not org_member.llm_api_key:
return None
org = OrgStore.get_org_by_id(org_id)
if not org:
logger.error(
f'Org not found for ID {org_id} as the current org for user {self.user_id}'
)
return None
kwargs = {
**{
normalized: getattr(org, c.name)
for c in Org.__table__.columns
if (
normalized := c.name.removeprefix('_default_')
.removeprefix('default_')
.lstrip('_')
if not settings or settings.user_version != CURRENT_USER_SETTINGS_VERSION:
logger.info(
'saas_settings_store:load:triggering_migration',
extra={'user_id': self.user_id},
)
in Settings.model_fields
},
**{
normalized: getattr(user, c.name)
for c in User.__table__.columns
if (normalized := c.name.lstrip('_')) in Settings.model_fields
},
}
kwargs['llm_api_key'] = org_member.llm_api_key
if org_member.max_iterations:
kwargs['max_iterations'] = org_member.max_iterations
if org_member.llm_model:
kwargs['llm_model'] = org_member.llm_model
if org_member.llm_api_key_for_byor:
kwargs['llm_api_key_for_byor'] = org_member.llm_api_key_for_byor
if org_member.llm_base_url:
kwargs['llm_base_url'] = org_member.llm_base_url
if org.v1_enabled is None:
kwargs['v1_enabled'] = True
return await self.create_default_settings(settings)
kwargs = {
c.name: getattr(settings, c.name)
for c in UserSettings.__table__.columns
if c.name in Settings.model_fields
}
self._decrypt_kwargs(kwargs)
settings = Settings(**kwargs)
settings = Settings(**kwargs)
return settings
return settings
async def store(self, item: Settings):
with self.session_maker() as session:
if not item:
return None
kwargs = item.model_dump(context={'expose_secrets': True})
user = (
session.query(User)
.options(joinedload(User.org_members))
.filter(User.id == uuid.UUID(self.user_id))
).first()
# Check if provider is OpenHands and generate API key if needed
if item and self._is_openhands_provider(item):
await self._ensure_openhands_api_key(item)
with self.session_maker() as session:
existing = None
kwargs = {}
if item:
kwargs = item.model_dump(context={'expose_secrets': True})
self._encrypt_kwargs(kwargs)
# First check if we have an existing entry in the new table
existing = self.get_user_settings_by_keycloak_id(self.user_id, session)
kwargs = {
key: value
for key, value in kwargs.items()
if key in UserSettings.__table__.columns
}
if existing:
# Update existing entry
for key, value in kwargs.items():
setattr(existing, key, value)
existing.user_version = CURRENT_USER_SETTINGS_VERSION
session.merge(existing)
else:
kwargs['keycloak_user_id'] = self.user_id
kwargs['user_version'] = CURRENT_USER_SETTINGS_VERSION
kwargs.pop('secrets_store', None) # Don't save secrets_store to db
settings = UserSettings(**kwargs)
session.add(settings)
session.commit()
def _get_redis_client(self):
"""Get the Redis client from the Socket.IO manager."""
from openhands.server.shared import sio
return getattr(sio.manager, 'redis', None)
async def _acquire_user_creation_lock(self) -> bool:
"""Attempt to acquire a distributed lock for user creation.
Returns True if the lock was acquired or if Redis is unavailable (fallback to no locking).
Returns False if another process holds the lock.
"""
redis_client = self._get_redis_client()
if redis_client is None:
logger.warning(
'saas_settings_store:_acquire_user_creation_lock:no_redis_client',
extra={'user_id': self.user_id},
)
return True # Proceed without locking if Redis is unavailable
user_key = f'{_REDIS_USER_CREATION_KEY_PREFIX}{self.user_id}'
lock_acquired = await redis_client.set(
user_key, 1, nx=True, ex=_REDIS_CREATE_TIMEOUT_SECONDS
)
return bool(lock_acquired)
async def create_default_settings(self, user_settings: UserSettings | None):
logger.info(
'saas_settings_store:create_default_settings:start',
extra={'user_id': self.user_id},
)
# You must log in before you get default settings
if not self.user_id:
return None
# Prevent duplicate settings creation using distributed lock
if not await self._acquire_user_creation_lock():
# The user is already being created in another thread / process
logger.info(
'saas_settings_store:create_default_settings:waiting_for_lock',
extra={'user_id': self.user_id},
)
await asyncio.sleep(_RETRY_LOAD_DELAY_SECONDS)
return await self.load()
# Only users that have specified a payment method get default settings
if REQUIRE_PAYMENT and not await stripe_service.has_payment_method(
self.user_id
):
logger.info(
'saas_settings_store:create_default_settings:no_payment',
extra={'user_id': self.user_id},
)
return None
settings: Settings | None = None
if user_settings is None:
settings = Settings(
language='en',
enable_proactive_conversation_starters=True,
)
elif isinstance(user_settings, UserSettings):
# Convert UserSettings (SQLAlchemy model) to Settings (Pydantic model)
kwargs = {
c.name: getattr(user_settings, c.name)
for c in UserSettings.__table__.columns
if c.name in Settings.model_fields
}
self._decrypt_kwargs(kwargs)
settings = Settings(**kwargs)
if settings:
settings = await self.update_settings_with_litellm_default(settings)
if settings is None:
logger.info(
'saas_settings_store:create_default_settings:litellm_update_failed',
extra={'user_id': self.user_id},
)
return None
await self.store(settings)
return settings
async def load_legacy_file_store_settings(self, github_user_id: str):
if not github_user_id:
return None
file_store = get_file_store(self.config.file_store, self.config.file_store_path)
path = f'users/github/{github_user_id}/settings.json'
try:
json_str = await call_sync_from_async(file_store.read, path)
logger.info(
'saas_settings_store:load_legacy_file_store_settings:found',
extra={'github_user_id': github_user_id},
)
kwargs = json.loads(json_str)
self._decrypt_kwargs(kwargs)
settings = Settings(**kwargs)
return settings
except FileNotFoundError:
return None
except Exception as e:
logger.error(
'saas_settings_store:load_legacy_file_store_settings:error',
extra={'github_user_id': github_user_id, 'error': str(e)},
)
return None
def _has_custom_settings(
self, settings: Settings, old_user_version: int | None
) -> bool:
"""
Check if user has custom LLM settings that should be preserved.
Returns True if user customized either model or base_url.
Args:
settings: The user's current settings
old_user_version: The user's old settings version, if any
Returns:
True if user has custom settings, False if using old defaults
"""
# Normalize values
user_model = (
settings.llm_model.strip()
if settings.llm_model and settings.llm_model.strip()
else None
)
user_base_url = (
settings.llm_base_url.strip()
if settings.llm_base_url and settings.llm_base_url.strip()
else None
)
# Custom base_url = definitely custom settings (BYOK)
if user_base_url and user_base_url != LITE_LLM_API_URL:
return True
# No model set = using defaults
if not user_model:
return False
# Check if model matches old version's default
if (
old_user_version
and old_user_version < CURRENT_USER_SETTINGS_VERSION
and old_user_version in USER_SETTINGS_VERSION_TO_MODEL
):
old_default_base = USER_SETTINGS_VERSION_TO_MODEL[old_user_version]
user_model_base = user_model.split('/')[-1]
if user_model_base == old_default_base:
return False # Matches old default
return True # Custom model
async def update_settings_with_litellm_default(
self, settings: Settings
) -> Settings | None:
logger.info(
'saas_settings_store:update_settings_with_litellm_default:start',
extra={'user_id': self.user_id},
)
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
return None
local_deploy = os.environ.get('LOCAL_DEPLOYMENT', None)
key = LITE_LLM_API_KEY
# Check if user has custom settings
has_custom = self._has_custom_settings(settings, settings.user_version)
# Determine model to use (needed before LiteLLM user creation)
llm_model_to_use = (
settings.llm_model
if has_custom and settings.llm_model
else get_default_litellm_model()
)
if not local_deploy:
# Get user info to add to litellm
token_manager = TokenManager()
keycloak_user_info = (
await token_manager.get_user_info_from_user_id(self.user_id) or {}
)
async with httpx.AsyncClient(
verify=httpx_verify_option(),
headers={
'x-goog-api-key': LITE_LLM_API_KEY,
},
) as client:
# Get the previous max budget to prevent accidental loss.
#
# LiteLLM v1.80+ returns 404 for non-existent users (previously returned empty user_info)
response = await client.get(
f'{LITE_LLM_API_URL}/user/info?user_id={self.user_id}'
)
user_info: dict
if response.status_code == 404:
# New user - doesn't exist in LiteLLM yet (v1.80+ behavior)
user_info = {}
else:
# For any other status, use standard error handling
response.raise_for_status()
response_json = response.json()
user_info = response_json.get('user_info') or {}
logger.info(
f'creating_litellm_user: {self.user_id}; prev_max_budget: {user_info.get("max_budget")}; prev_metadata: {user_info.get("metadata")}'
)
max_budget = user_info.get('max_budget') or DEFAULT_INITIAL_BUDGET
spend = user_info.get('spend') or 0
if not user:
# Check if we need to migrate from user_settings
user_settings = None
with session_maker() as session:
user_settings = self._get_user_settings_by_keycloak_id(
user_settings = self.get_user_settings_by_keycloak_id(
self.user_id, session
)
if user_settings:
user = await UserStore.migrate_user(self.user_id, user_settings)
else:
logger.error(f'User not found for ID {self.user_id}')
# In upgrade to V4, we no longer use billing margin, but instead apply this directly
# in litellm. The default billing marign was 2 before this (hence the magic numbers below)
if (
user_settings
and user_settings.user_version < 4
and user_settings.billing_margin
and user_settings.billing_margin != 1.0
):
billing_margin = user_settings.billing_margin
logger.info(
'user_settings_v4_budget_upgrade',
extra={
'max_budget': max_budget,
'billing_margin': billing_margin,
'spend': spend,
},
)
max_budget *= billing_margin
spend *= billing_margin
user_settings.billing_margin = 1.0
session.commit()
email = keycloak_user_info.get('email')
# We explicitly delete here to guard against odd inherited settings on upgrade.
# We don't care if this fails with a 404
await client.post(
f'{LITE_LLM_API_URL}/user/delete', json={'user_ids': [self.user_id]}
)
# Create the new litellm user
response = await self._create_user_in_lite_llm(
client, email, max_budget, spend, llm_model_to_use
)
if not response.is_success:
logger.warning(
'duplicate_user_email',
extra={'user_id': self.user_id, 'email': email},
)
# Litellm insists on unique email addresses - it is possible the email address was registered with a different user.
response = await self._create_user_in_lite_llm(
client, None, max_budget, spend, llm_model_to_use
)
# User failed to create in litellm - this is an unforseen error state...
if not response.is_success:
logger.error(
'error_creating_litellm_user',
extra={
'status_code': response.status_code,
'text': response.text,
'user_id': [self.user_id],
'email': email,
'max_budget': max_budget,
'spend': spend,
},
)
return None
org_id = user.current_org_id
# Check if provider is OpenHands and generate API key if needed
if self._is_openhands_provider(item):
await self._ensure_openhands_api_key(item, str(org_id))
org_member = None
for om in user.org_members:
if om.org_id == org_id:
org_member = om
break
if not org_member or not org_member.llm_api_key:
return None
org = session.query(Org).filter(Org.id == org_id).first()
if not org:
logger.error(
f'Org not found for ID {org_id} as the current org for user {self.user_id}'
response_json = response.json()
key = response_json['key']
logger.info(
'saas_settings_store:update_settings_with_litellm_default:user_created',
extra={'user_id': self.user_id},
)
return None
for model in (user, org, org_member):
for key, value in kwargs.items():
if hasattr(model, key):
setattr(model, key, value)
if has_custom:
settings.llm_model = settings.llm_model or get_default_litellm_model()
settings.llm_base_url = settings.llm_base_url or LITE_LLM_API_URL
settings.llm_api_key = settings.llm_api_key or SecretStr(key)
else:
settings.llm_model = get_default_litellm_model()
settings.llm_base_url = LITE_LLM_API_URL
settings.llm_api_key = SecretStr(key)
session.commit()
settings.agent = 'CodeActAgent'
return settings
@classmethod
async def get_instance(
@@ -177,9 +443,6 @@ class SaasSettingsStore(SettingsStore):
logger.debug(f'saas_settings_store.get_instance::{user_id}')
return SaasSettingsStore(user_id, session_maker, config)
def _should_encrypt(self, key):
return key in self.ENCRYPT_VALUES
def _decrypt_kwargs(self, kwargs: dict):
fernet = self._fernet()
for key, value in kwargs.items():
@@ -223,24 +486,21 @@ class SaasSettingsStore(SettingsStore):
fernet_key = b64encode(hashlib.sha256(jwt_secret.encode()).digest())
return Fernet(fernet_key)
def _should_encrypt(self, key: str) -> bool:
return key in ('llm_api_key', 'llm_api_key_for_byor', 'search_api_key')
def _is_openhands_provider(self, item: Settings) -> bool:
"""Check if the settings use the OpenHands provider."""
return bool(item.llm_model and item.llm_model.startswith('openhands/'))
async def _ensure_openhands_api_key(self, item: Settings, org_id: str) -> None:
async def _ensure_openhands_api_key(self, item: Settings) -> None:
"""Generate and set the OpenHands API key for the given settings.
First checks if an existing key with the OpenHands alias exists,
and reuses it if found. Otherwise, generates a new key.
"""
# Generate new key if none exists
generated_key = await LiteLlmManager.generate_key(
self.user_id,
org_id,
None,
{'type': 'openhands'},
)
generated_key = await self._generate_openhands_key()
if generated_key:
item.llm_api_key = SecretStr(generated_key)
logger.info(
@@ -252,3 +512,83 @@ class SaasSettingsStore(SettingsStore):
'saas_settings_store:store:failed_to_generate_openhands_key',
extra={'user_id': self.user_id},
)
async def _create_user_in_lite_llm(
self,
client: httpx.AsyncClient,
email: str | None,
max_budget: int,
spend: int,
llm_model: str,
):
response = await client.post(
f'{LITE_LLM_API_URL}/user/new',
json={
'user_email': email,
'models': [],
'max_budget': max_budget,
'spend': spend,
'user_id': str(self.user_id),
'teams': [LITE_LLM_TEAM_ID],
'auto_create_key': True,
'send_invite_email': False,
'metadata': {
'version': CURRENT_USER_SETTINGS_VERSION,
'model': llm_model,
},
'key_alias': f'OpenHands Cloud - user {self.user_id}',
},
)
return response
async def _generate_openhands_key(self) -> str | None:
"""Generate a new OpenHands provider key for a user."""
if not (LITE_LLM_API_KEY and LITE_LLM_API_URL):
logger.warning(
'saas_settings_store:_generate_openhands_key:litellm_config_not_found',
extra={'user_id': self.user_id},
)
return None
try:
async with httpx.AsyncClient(
verify=httpx_verify_option(),
headers={
'x-goog-api-key': LITE_LLM_API_KEY,
},
) as client:
response = await client.post(
f'{LITE_LLM_API_URL}/key/generate',
json={
'user_id': self.user_id,
'metadata': {'type': 'openhands'},
},
)
response.raise_for_status()
response_json = response.json()
key = response_json.get('key')
if key:
logger.info(
'saas_settings_store:_generate_openhands_key:success',
extra={
'user_id': self.user_id,
'key_length': len(key) if key else 0,
'key_prefix': (
key[:10] + '...' if key and len(key) > 10 else key
),
},
)
return key
else:
logger.error(
'saas_settings_store:_generate_openhands_key:no_key_in_response',
extra={'user_id': self.user_id, 'response_json': response_json},
)
return None
except Exception as e:
logger.exception(
'saas_settings_store:_generate_openhands_key:error',
extra={'user_id': self.user_id, 'error': str(e)},
)
return None

View File

@@ -1,6 +1,4 @@
from sqlalchemy import Boolean, Column, ForeignKey, Identity, Integer, String
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import relationship
from sqlalchemy import Boolean, Column, Identity, Integer, String
from storage.base import Base
@@ -10,9 +8,5 @@ class SlackConversation(Base): # type: ignore
conversation_id = Column(String, nullable=False, index=True)
channel_id = Column(String, nullable=False)
keycloak_user_id = Column(String, nullable=False)
org_id = Column(UUID(as_uuid=True), ForeignKey('org.id'), nullable=True)
parent_id = Column(String, nullable=True, index=True)
v1_enabled = Column(Boolean, nullable=True)
# Relationships
org = relationship('Org', back_populates='slack_conversations')

View File

@@ -1,6 +1,4 @@
from sqlalchemy import Column, DateTime, ForeignKey, Identity, Integer, String, text
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import relationship
from sqlalchemy import Column, DateTime, Identity, Integer, String, text
from storage.base import Base
@@ -8,7 +6,6 @@ class SlackUser(Base): # type: ignore
__tablename__ = 'slack_users'
id = Column(Integer, Identity(), primary_key=True)
keycloak_user_id = Column(String, nullable=False, index=True)
org_id = Column(UUID(as_uuid=True), ForeignKey('org.id'), nullable=True)
slack_user_id = Column(String, nullable=False, index=True)
slack_display_name = Column(String, nullable=False)
created_at = Column(
@@ -16,6 +13,3 @@ class SlackUser(Base): # type: ignore
server_default=text('CURRENT_TIMESTAMP'),
nullable=False,
)
# Relationships
org = relationship('Org', back_populates='slack_users')

View File

@@ -4,4 +4,5 @@ from openhands.app_server.app_conversation.sql_app_conversation_info_service imp
StoredConversationMetadata = _StoredConversationMetadata
__all__ = ['StoredConversationMetadata']

View File

@@ -1,28 +0,0 @@
"""
SQLAlchemy model for ConversationMetadataSaas.
This model stores the SaaS-specific metadata for conversations,
containing only the conversation_id, user_id, and org_id.
"""
from sqlalchemy import UUID as SQL_UUID
from sqlalchemy import Column, ForeignKey, String
from sqlalchemy.orm import relationship
from storage.base import Base
class StoredConversationMetadataSaas(Base): # type: ignore
"""SaaS conversation metadata model containing user and org associations."""
__tablename__ = 'conversation_metadata_saas'
conversation_id = Column(String, primary_key=True)
user_id = Column(SQL_UUID(as_uuid=True), ForeignKey('user.id'), nullable=False)
org_id = Column(SQL_UUID(as_uuid=True), ForeignKey('org.id'), nullable=False)
# Relationships
user = relationship('User', back_populates='stored_conversation_metadata_saas')
org = relationship('Org', back_populates='stored_conversation_metadata_saas')
__all__ = ['StoredConversationMetadataSaas']

View File

@@ -1,6 +1,4 @@
from sqlalchemy import Column, ForeignKey, Identity, Integer, String
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import relationship
from sqlalchemy import Column, Identity, Integer, String
from storage.base import Base
@@ -8,10 +6,6 @@ class StoredCustomSecrets(Base): # type: ignore
__tablename__ = 'custom_secrets'
id = Column(Integer, Identity(), primary_key=True)
keycloak_user_id = Column(String, nullable=True, index=True)
org_id = Column(UUID(as_uuid=True), ForeignKey('org.id'), nullable=True)
secret_name = Column(String, nullable=False)
secret_value = Column(String, nullable=False)
description = Column(String, nullable=True)
# Relationships
org = relationship('Org', back_populates='user_secrets')

View File

@@ -1,6 +1,4 @@
from sqlalchemy import Column, DateTime, ForeignKey, Integer, String, text
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import relationship
from sqlalchemy import Column, DateTime, Integer, String, text
from storage.base import Base
@@ -15,7 +13,6 @@ class StripeCustomer(Base): # type: ignore
__tablename__ = 'stripe_customers'
id = Column(Integer, primary_key=True, autoincrement=True)
keycloak_user_id = Column(String, nullable=False)
org_id = Column(UUID(as_uuid=True), ForeignKey('org.id'), nullable=True)
stripe_customer_id = Column(String, nullable=False)
created_at = Column(
DateTime, server_default=text('CURRENT_TIMESTAMP'), nullable=False
@@ -26,6 +23,3 @@ class StripeCustomer(Base): # type: ignore
onupdate=text('CURRENT_TIMESTAMP'),
nullable=False,
)
# Relationships
org = relationship('Org', back_populates='stripe_customers')

View File

@@ -1,43 +0,0 @@
"""
SQLAlchemy model for User.
"""
from uuid import uuid4
from sqlalchemy import (
UUID,
Boolean,
Column,
DateTime,
ForeignKey,
Integer,
String,
)
from sqlalchemy.orm import relationship
from storage.base import Base
class User(Base): # type: ignore
"""User model with organizational relationships."""
__tablename__ = 'user'
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid4)
current_org_id = Column(UUID(as_uuid=True), ForeignKey('org.id'), nullable=False)
role_id = Column(Integer, ForeignKey('role.id'), nullable=True)
accepted_tos = Column(DateTime, nullable=True)
enable_sound_notifications = Column(Boolean, nullable=True)
language = Column(String, nullable=True)
user_consents_to_analytics = Column(Boolean, nullable=True)
email = Column(String, nullable=True)
email_verified = Column(Boolean, nullable=True)
git_user_name = Column(String, nullable=True)
git_user_email = Column(String, nullable=True)
# Relationships
role = relationship('Role', back_populates='users')
org_members = relationship('OrgMember', back_populates='user')
current_org = relationship('Org', back_populates='current_users')
stored_conversation_metadata_saas = relationship(
'StoredConversationMetadataSaas', back_populates='user'
)

View File

@@ -39,6 +39,3 @@ class UserSettings(Base): # type: ignore
git_user_name = Column(String, nullable=True)
git_user_email = Column(String, nullable=True)
v1_enabled = Column(Boolean, nullable=True)
already_migrated = Column(
Boolean, nullable=True, default=False
) # False = not migrated, True = migrated

View File

@@ -1,567 +0,0 @@
"""
Store class for managing users.
"""
import asyncio
import uuid
from typing import Optional
from server.auth.token_manager import TokenManager
from server.constants import (
LITE_LLM_API_URL,
ORG_SETTINGS_VERSION,
PERSONAL_WORKSPACE_VERSION_TO_MODEL,
get_default_litellm_model,
)
from server.logger import logger
from sqlalchemy import select, text
from sqlalchemy.orm import joinedload
from storage.database import a_session_maker, session_maker
from storage.encrypt_utils import decrypt_legacy_model
from storage.org import Org
from storage.org_member import OrgMember
from storage.role_store import RoleStore
from storage.user import User
from storage.user_settings import UserSettings
from openhands.utils.async_utils import GENERAL_TIMEOUT, call_async_from_sync
# The max possible time to wait for another process to finish creating a user before retrying
_REDIS_CREATE_TIMEOUT_SECONDS = 30
# The delay to wait for another process to finish creating a user before trying to load again
_RETRY_LOAD_DELAY_SECONDS = 2
# Redis key prefix for user creation locks
_REDIS_USER_CREATION_KEY_PREFIX = 'create_user:'
class UserStore:
"""Store for managing users."""
@staticmethod
async def create_user(
user_id: str,
user_info: dict,
role_id: Optional[int] = None,
) -> User | None:
"""Create a new user."""
with session_maker() as session:
# create personal org
org = Org(
id=uuid.UUID(user_id),
name=f'user_{user_id}_org',
contact_name=user_info['preferred_username'],
contact_email=user_info['email'],
v1_enabled=True,
)
session.add(org)
settings = await UserStore.create_default_settings(
org_id=str(org.id), user_id=user_id
)
if not settings:
return None
from storage.org_store import OrgStore
org_kwargs = OrgStore.get_kwargs_from_settings(settings)
for key, value in org_kwargs.items():
if hasattr(org, key):
setattr(org, key, value)
user_kwargs = UserStore.get_kwargs_from_settings(settings)
user = User(
id=uuid.UUID(user_id),
current_org_id=org.id,
role_id=role_id,
**user_kwargs,
)
session.add(user)
role = RoleStore.get_role_by_name('owner')
from storage.org_member_store import OrgMemberStore
org_member_kwargs = OrgMemberStore.get_kwargs_from_settings(settings)
# avoid setting org member llm fields to use org defaults on user creation
del org_member_kwargs['llm_model']
del org_member_kwargs['llm_base_url']
org_member = OrgMember(
org_id=org.id,
user_id=user.id,
role_id=role.id, # owner of your own org.
status='active',
**org_member_kwargs,
)
session.add(org_member)
session.commit()
session.refresh(user)
user.org_members # load org_members
return user
@staticmethod
def _get_redis_client():
"""Get the Redis client from the Socket.IO manager."""
from openhands.server.shared import sio
return getattr(sio.manager, 'redis', None)
@staticmethod
async def _acquire_user_creation_lock(user_id: str) -> bool:
"""Attempt to acquire a distributed lock for user creation.
Returns True if the lock was acquired or if Redis is unavailable (fallback to no locking).
Returns False if another process holds the lock.
"""
redis_client = UserStore._get_redis_client()
if redis_client is None:
logger.warning(
'user_store:_acquire_user_creation_lock:no_redis_client',
extra={'user_id': user_id},
)
return True # Proceed without locking if Redis is unavailable
user_key = f'{_REDIS_USER_CREATION_KEY_PREFIX}{user_id}'
lock_acquired = await redis_client.set(
user_key, 1, nx=True, ex=_REDIS_CREATE_TIMEOUT_SECONDS
)
return bool(lock_acquired)
@staticmethod
async def migrate_user(
user_id: str,
user_settings: UserSettings,
user_info: dict,
) -> User:
if not user_id or not user_settings:
return None
kwargs = decrypt_legacy_model(
[
'llm_api_key',
'llm_api_key_for_byor',
'search_api_key',
'sandbox_api_key',
],
user_settings,
)
decrypted_user_settings = UserSettings(**kwargs)
with session_maker() as session:
# create personal org
org = Org(
id=uuid.UUID(user_id),
name=f'user_{user_id}_org',
org_version=user_settings.user_version,
contact_name=user_info['username'],
contact_email=user_info['email'],
)
session.add(org)
from storage.lite_llm_manager import LiteLlmManager
logger.info(
'user_store:migrate_user:calling_litellm_migrate_entries',
extra={'user_id': user_id},
)
await LiteLlmManager.migrate_entries(
str(org.id),
user_id,
decrypted_user_settings,
)
logger.info(
'user_store:migrate_user:done_litellm_migrate_entries',
extra={'user_id': user_id},
)
custom_settings = UserStore._has_custom_settings(
decrypted_user_settings, user_settings.user_version
)
# avoids circular reference. This migrate method is temprorary until all users are migrated.
from integrations.stripe_service import migrate_customer
logger.info(
'user_store:migrate_user:calling_stripe_migrate_customer',
extra={'user_id': user_id},
)
await migrate_customer(session, user_id, org)
logger.info(
'user_store:migrate_user:done_stripe_migrate_customer',
extra={'user_id': user_id},
)
from storage.org_store import OrgStore
org_kwargs = OrgStore.get_kwargs_from_user_settings(decrypted_user_settings)
org_kwargs.pop('id', None)
# if user has custom settings, set org defaults to current version
if custom_settings:
org_kwargs['default_llm_model'] = get_default_litellm_model()
org_kwargs['llm_base_url'] = LITE_LLM_API_URL
org_kwargs['org_version'] = ORG_SETTINGS_VERSION
for key, value in org_kwargs.items():
if hasattr(org, key):
setattr(org, key, value)
user_kwargs = UserStore.get_kwargs_from_user_settings(
decrypted_user_settings
)
user_kwargs.pop('id', None)
user = User(
id=uuid.UUID(user_id),
current_org_id=org.id,
role_id=None,
**user_kwargs,
)
session.add(user)
logger.info(
'user_store:migrate_user:calling_get_role_by_name',
extra={'user_id': user_id},
)
role = await RoleStore.get_role_by_name_async('owner')
logger.info(
'user_store:migrate_user:done_get_role_by_name',
extra={'user_id': user_id},
)
from storage.org_member_store import OrgMemberStore
org_member_kwargs = OrgMemberStore.get_kwargs_from_user_settings(
decrypted_user_settings
)
# if the user did not have custom settings in the old model,
# then use the org defaults by not setting org_member fields
if not custom_settings:
del org_member_kwargs['llm_model']
del org_member_kwargs['llm_base_url']
del org_member_kwargs['llm_api_key_for_byor']
org_member = OrgMember(
org_id=org.id,
user_id=user.id,
role_id=role.id, # owner of your own org.
status='active',
**org_member_kwargs,
)
session.add(org_member)
# Mark the old user_settings as migrated instead of deleting
user_settings.already_migrated = True
session.merge(user_settings)
session.flush()
logger.info(
'user_store:migrate_user:session_flush_complete',
extra={'user_id': user_id},
)
# need to migrate conversation metadata
session.execute(
text("""
INSERT INTO conversation_metadata_saas (conversation_id, user_id, org_id)
SELECT
conversation_id,
:user_id,
:user_id
FROM conversation_metadata
WHERE user_id = :user_id
"""),
{'user_id': user_id},
)
# Update org_id for tables that had org_id added
user_uuid = uuid.UUID(user_id)
# Update stripe_customers
session.execute(
text(
'UPDATE stripe_customers SET org_id = :org_id WHERE keycloak_user_id = :user_id'
),
{'org_id': user_uuid, 'user_id': user_uuid},
)
# Update slack_users
session.execute(
text(
'UPDATE slack_users SET org_id = :org_id WHERE keycloak_user_id = :user_id'
),
{'org_id': user_uuid, 'user_id': user_uuid},
)
# Update slack_conversation
session.execute(
text(
'UPDATE slack_conversation SET org_id = :org_id WHERE keycloak_user_id = :user_id'
),
{'org_id': user_uuid, 'user_id': user_uuid},
)
# Update api_keys
session.execute(
text('UPDATE api_keys SET org_id = :org_id WHERE user_id = :user_id'),
{'org_id': user_uuid, 'user_id': user_uuid},
)
# Update custom_secrets
session.execute(
text(
'UPDATE custom_secrets SET org_id = :org_id WHERE keycloak_user_id = :user_id'
),
{'org_id': user_uuid, 'user_id': user_uuid},
)
# Update billing_sessions
session.execute(
text(
'UPDATE billing_sessions SET org_id = :org_id WHERE user_id = :user_id'
),
{'org_id': user_uuid, 'user_id': user_uuid},
)
session.commit()
session.refresh(user)
user.org_members # load org_members
logger.info(
'user_store:migrate_user:session_committed',
extra={'user_id': user_id},
)
return user
@staticmethod
def get_user_by_id(user_id: str) -> Optional[User]:
"""Get user by Keycloak user ID (sync version).
Note: This method uses call_async_from_sync internally which creates a new
event loop. If you're already in an async context, use get_user_by_id_async
instead to avoid event loop conflicts.
"""
with session_maker() as session:
user = (
session.query(User)
.options(joinedload(User.org_members))
.filter(User.id == uuid.UUID(user_id))
.first()
)
if user:
return user
# Check if we need to migrate from user_settings
while not call_async_from_sync(
UserStore._acquire_user_creation_lock, GENERAL_TIMEOUT, user_id
):
# The user is already being created in another thread / process
logger.info(
'user_store:create_default_settings:waiting_for_lock',
extra={'user_id': user_id},
)
call_async_from_sync(
asyncio.sleep, GENERAL_TIMEOUT, _RETRY_LOAD_DELAY_SECONDS
)
# Check for user again as migration could have happened while trying to get the lock.
user = (
session.query(User)
.options(joinedload(User.org_members))
.filter(User.id == uuid.UUID(user_id))
.first()
)
if user:
return user
user_settings = (
session.query(UserSettings)
.filter(
UserSettings.keycloak_user_id == user_id,
UserSettings.already_migrated.is_(False),
)
.first()
)
if user_settings:
token_manager = TokenManager()
user_info = call_async_from_sync(
token_manager.get_user_info_from_user_id,
GENERAL_TIMEOUT,
user_id,
)
user = call_async_from_sync(
UserStore.migrate_user,
GENERAL_TIMEOUT,
user_id,
user_settings,
user_info,
)
return user
else:
return None
@staticmethod
async def get_user_by_id_async(user_id: str) -> Optional[User]:
"""Get user by Keycloak user ID (async version).
This is the preferred method when calling from an async context as it
avoids event loop conflicts that can occur with the sync version.
"""
async with a_session_maker() as session:
result = await session.execute(
select(User)
.options(joinedload(User.org_members))
.filter(User.id == uuid.UUID(user_id))
)
user = result.scalars().first()
if user:
return user
# Check if we need to migrate from user_settings
while not await UserStore._acquire_user_creation_lock(user_id):
# The user is already being created in another thread / process
logger.info(
'user_store:get_user_by_id_async:waiting_for_lock',
extra={'user_id': user_id},
)
await asyncio.sleep(_RETRY_LOAD_DELAY_SECONDS)
# Check for user again as migration could have happened while trying to get the lock.
result = await session.execute(
select(User)
.options(joinedload(User.org_members))
.filter(User.id == uuid.UUID(user_id))
)
user = result.scalars().first()
if user:
return user
logger.info(
'user_store:get_user_by_id_async:start_migration',
extra={'user_id': user_id},
)
result = await session.execute(
select(UserSettings).filter(
UserSettings.keycloak_user_id == user_id,
UserSettings.already_migrated.is_(False),
)
)
user_settings = result.scalars().first()
if user_settings:
token_manager = TokenManager()
user_info = await token_manager.get_user_info_from_user_id(user_id)
logger.info(
'user_store:get_user_by_id_async:calling_migrate_user',
extra={'user_id': user_id},
)
user = await UserStore.migrate_user(
user_id,
user_settings,
user_info,
)
return user
else:
return None
@staticmethod
def list_users() -> list[User]:
"""List all users."""
with session_maker() as session:
return session.query(User).all()
# Prevent circular imports
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from openhands.storage.data_models.settings import Settings
@staticmethod
async def create_default_settings(
org_id: str, user_id: str, create_user: bool = True
) -> Optional['Settings']:
logger.info(
'UserStore:create_default_settings:start',
extra={'org_id': org_id, 'user_id': user_id},
)
# You must log in before you get default settings
if not org_id:
return None
from openhands.storage.data_models.settings import Settings
settings = Settings(language='en', enable_proactive_conversation_starters=True)
from storage.lite_llm_manager import LiteLlmManager
settings = await LiteLlmManager.create_entries(
org_id, user_id, settings, create_user
)
if not settings:
logger.info(
'UserStore:create_default_settings:litellm_create_failed',
extra={'org_id': org_id},
)
return None
return settings
@staticmethod
def get_kwargs_from_settings(settings: 'Settings'):
kwargs = {
normalized: getattr(settings, normalized)
for c in User.__table__.columns
if (normalized := c.name.lstrip('_')) and hasattr(settings, normalized)
}
return kwargs
@staticmethod
def get_kwargs_from_user_settings(user_settings: UserSettings):
kwargs = {
normalized: getattr(user_settings, normalized)
for c in User.__table__.columns
if (normalized := c.name.lstrip('_')) and hasattr(user_settings, normalized)
}
return kwargs
@staticmethod
def _has_custom_settings(
user_settings: UserSettings, old_user_version: int | None
) -> bool:
"""
Check if user has custom LLM settings that should be preserved.
Returns True if user customized either model or base_url.
Args:
settings: The user's current settings
old_user_version: The user's old settings version, if any
Returns:
True if user has custom settings, False if using old defaults
"""
# Normalize values
user_model = (
user_settings.llm_model.strip() or None if user_settings.llm_model else None
)
user_base_url = (
user_settings.llm_base_url.strip() or None
if user_settings.llm_base_url
else None
)
# Custom base_url = definitely custom settings (BYOK)
if user_base_url and user_base_url != LITE_LLM_API_URL:
return True
# No model set = using defaults
if not user_model:
return False
# Check if model matches old version's default
if (
old_user_version
and old_user_version <= ORG_SETTINGS_VERSION
and old_user_version in PERSONAL_WORKSPACE_VERSION_TO_MODEL
):
old_default_base = PERSONAL_WORKSPACE_VERSION_TO_MODEL[old_user_version]
user_model_base = user_model.split('/')[-1]
if user_model_base == old_default_base:
return False # Matches old default
return True # Custom model

View File

@@ -21,7 +21,7 @@ from sqlalchemy import text
# Add the parent directory to the path so we can import from storage
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from server.auth.token_manager import get_keycloak_admin
from storage.database import get_engine
from storage.database import engine
# Configure logging
logging.basicConfig(
@@ -85,7 +85,7 @@ def get_recent_conversations(minutes: int = 60) -> List[Dict[str, Any]]:
created_at DESC
""")
with get_engine().connect() as connection:
with engine.connect() as connection:
result = connection.execute(query, {'minutes': minutes})
conversations = [
{

View File

@@ -13,7 +13,7 @@ from sqlalchemy import text
# Add the parent directory to the path so we can import from storage
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from storage.database import get_engine
from storage.database import engine
def test_conversation_count_query():
@@ -29,8 +29,6 @@ def test_conversation_count_query():
user_id
""")
engine = get_engine()
with engine.connect() as connection:
count_result = connection.execute(count_query)
user_counts = [

View File

@@ -1,9 +1,10 @@
import uuid
from datetime import datetime
from uuid import UUID
import pytest
from server.constants import ORG_SETTINGS_VERSION
from server.constants import CURRENT_USER_SETTINGS_VERSION
from server.maintenance_task_processor.user_version_upgrade_processor import (
UserVersionUpgradeProcessor,
)
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from storage.base import Base
@@ -14,16 +15,11 @@ from storage.conversation_work import ConversationWork
from storage.device_code import DeviceCode # noqa: F401
from storage.feedback import Feedback
from storage.github_app_installation import GithubAppInstallation
from storage.org import Org
from storage.org_member import OrgMember
from storage.role import Role
from storage.maintenance_task import MaintenanceTask, MaintenanceTaskStatus
from storage.stored_conversation_metadata import StoredConversationMetadata
from storage.stored_conversation_metadata_saas import (
StoredConversationMetadataSaas,
)
from storage.stored_offline_token import StoredOfflineToken
from storage.stripe_customer import StripeCustomer
from storage.user import User
from storage.user_settings import UserSettings
@pytest.fixture
@@ -72,6 +68,7 @@ def add_minimal_fixtures(session_maker):
session.add(
StoredConversationMetadata(
conversation_id='mock-conversation-id',
user_id='mock-user-id',
created_at=datetime.fromisoformat('2025-03-07'),
last_updated_at=datetime.fromisoformat('2025-03-08'),
accumulated_cost=5.25,
@@ -80,13 +77,6 @@ def add_minimal_fixtures(session_maker):
total_tokens=750,
)
)
session.add(
StoredConversationMetadataSaas(
conversation_id='mock-conversation-id',
user_id=UUID('5594c7b6-f959-4b81-92e9-b09c206f5081'),
org_id=UUID('5594c7b6-f959-4b81-92e9-b09c206f5081'),
)
)
session.add(
StoredOfflineToken(
user_id='mock-user-id',
@@ -95,38 +85,7 @@ def add_minimal_fixtures(session_maker):
updated_at=datetime.fromisoformat('2025-03-08'),
)
)
session.add(
Org(
id=uuid.UUID('5594c7b6-f959-4b81-92e9-b09c206f5081'),
name='mock-org',
org_version=ORG_SETTINGS_VERSION,
enable_default_condenser=True,
enable_proactive_conversation_starters=True,
)
)
session.add(
Role(
id=1,
name='admin',
rank=1,
)
)
session.add(
User(
id=uuid.UUID('5594c7b6-f959-4b81-92e9-b09c206f5081'),
current_org_id=uuid.UUID('5594c7b6-f959-4b81-92e9-b09c206f5081'),
user_consents_to_analytics=True,
)
)
session.add(
OrgMember(
org_id=uuid.UUID('5594c7b6-f959-4b81-92e9-b09c206f5081'),
user_id=uuid.UUID('5594c7b6-f959-4b81-92e9-b09c206f5081'),
role_id=1,
llm_api_key='mock-api-key',
status='active',
)
)
session.add(
StripeCustomer(
keycloak_user_id='mock-user-id',
@@ -135,6 +94,13 @@ def add_minimal_fixtures(session_maker):
updated_at=datetime.fromisoformat('2025-03-10'),
)
)
session.add(
UserSettings(
keycloak_user_id='mock-user-id',
user_consents_to_analytics=True,
user_version=CURRENT_USER_SETTINGS_VERSION,
)
)
session.add(
ConversationWork(
conversation_id='mock-conversation-id',
@@ -143,6 +109,17 @@ def add_minimal_fixtures(session_maker):
updated_at=datetime.fromisoformat('2025-03-08'),
)
)
maintenance_task = MaintenanceTask(
status=MaintenanceTaskStatus.PENDING,
)
maintenance_task.set_processor(
UserVersionUpgradeProcessor(
user_ids=['mock-user-id'],
created_at=datetime.fromisoformat('2025-03-07'),
updated_at=datetime.fromisoformat('2025-03-08'),
)
)
session.add(maintenance_task)
session.commit()

View File

@@ -6,13 +6,11 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from integrations.jira.jira_manager import JiraManager
from integrations.jira.jira_payload import (
JiraEventType,
JiraWebhookPayload,
)
from integrations.jira.jira_view import (
JiraExistingConversationView,
JiraNewConversationView,
)
from integrations.models import JobContext
from jinja2 import DictLoader, Environment
from storage.jira_conversation import JiraConversation
from storage.jira_user import JiraUser
@@ -27,7 +25,7 @@ def mock_token_manager():
"""Create a mock TokenManager for testing."""
token_manager = MagicMock()
token_manager.get_user_id_from_user_email = AsyncMock()
token_manager.decrypt_text = MagicMock(return_value='decrypted_key')
token_manager.decrypt_text = MagicMock()
return token_manager
@@ -62,7 +60,6 @@ def sample_jira_workspace():
workspace = MagicMock(spec=JiraWorkspace)
workspace.id = 1
workspace.name = 'test.atlassian.net'
workspace.jira_cloud_id = 'cloud-123'
workspace.admin_user_id = 'admin_id'
workspace.webhook_secret = 'encrypted_secret'
workspace.svc_acc_email = 'service@example.com'
@@ -78,41 +75,22 @@ def sample_user_auth():
user_auth.get_provider_tokens = AsyncMock(return_value={})
user_auth.get_access_token = AsyncMock(return_value='test_token')
user_auth.get_user_id = AsyncMock(return_value='test_user_id')
user_auth.get_secrets = AsyncMock(return_value=None)
return user_auth
@pytest.fixture
def sample_webhook_payload():
"""Create a sample JiraWebhookPayload for testing."""
return JiraWebhookPayload(
event_type=JiraEventType.COMMENT_MENTION,
raw_event='comment_created',
def sample_job_context():
"""Create a sample JobContext for testing."""
return JobContext(
issue_id='12345',
issue_key='TEST-123',
user_msg='Fix this bug @openhands',
user_email='user@test.com',
display_name='Test User',
account_id='user123',
workspace_name='test.atlassian.net',
base_api_url='https://test.atlassian.net',
comment_body='Fix this bug @openhands',
)
@pytest.fixture
def sample_label_webhook_payload():
"""Create a sample labeled ticket JiraWebhookPayload for testing."""
return JiraWebhookPayload(
event_type=JiraEventType.LABELED_TICKET,
raw_event='jira:issue_updated',
issue_id='12345',
issue_key='PROJ-123',
user_email='user@company.com',
display_name='Test User',
account_id='user456',
workspace_name='jira.company.com',
base_api_url='https://jira.company.com',
comment_body='',
issue_title='Test Issue',
issue_description='This is a test issue',
)
@@ -203,17 +181,31 @@ def jira_conversation():
@pytest.fixture
def new_conversation_view(
sample_webhook_payload, sample_user_auth, sample_jira_user, sample_jira_workspace
sample_job_context, sample_user_auth, sample_jira_user, sample_jira_workspace
):
"""JiraNewConversationView instance for testing"""
return JiraNewConversationView(
payload=sample_webhook_payload,
job_context=sample_job_context,
saas_user_auth=sample_user_auth,
jira_user=sample_jira_user,
jira_workspace=sample_jira_workspace,
selected_repo='test/repo1',
conversation_id='conv-123',
)
@pytest.fixture
def existing_conversation_view(
sample_job_context, sample_user_auth, sample_jira_user, sample_jira_workspace
):
"""JiraExistingConversationView instance for testing"""
return JiraExistingConversationView(
job_context=sample_job_context,
saas_user_auth=sample_user_auth,
jira_user=sample_jira_user,
jira_workspace=sample_jira_workspace,
selected_repo='test/repo1',
conversation_id='conv-123',
_decrypted_api_key='decrypted_key',
)

File diff suppressed because it is too large Load Diff

View File

@@ -5,87 +5,28 @@ Tests for Jira view classes and factory.
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from integrations.jira.jira_payload import (
JiraEventType,
JiraPayloadError,
JiraPayloadParser,
JiraPayloadSkipped,
JiraPayloadSuccess,
)
from integrations.jira.jira_types import RepositoryNotFoundError, StartingConvoException
from integrations.jira.jira_types import StartingConvoException
from integrations.jira.jira_view import (
JiraExistingConversationView,
JiraFactory,
JiraNewConversationView,
)
from openhands.core.schema.agent import AgentState
class TestJiraNewConversationView:
"""Tests for JiraNewConversationView"""
@pytest.mark.asyncio
async def test_get_issue_details_success(
self, new_conversation_view, sample_jira_workspace
):
"""Test successful issue details retrieval."""
mock_response = MagicMock()
mock_response.json.return_value = {
'fields': {'summary': 'Test Issue', 'description': 'Test description'}
}
mock_response.raise_for_status = MagicMock()
with patch('httpx.AsyncClient') as mock_client:
mock_client.return_value.__aenter__.return_value.get = AsyncMock(
return_value=mock_response
)
title, description = await new_conversation_view.get_issue_details()
assert title == 'Test Issue'
assert description == 'Test description'
@pytest.mark.asyncio
async def test_get_issue_details_cached(self, new_conversation_view):
"""Test issue details are cached after first call."""
new_conversation_view._issue_title = 'Cached Title'
new_conversation_view._issue_description = 'Cached Description'
title, description = await new_conversation_view.get_issue_details()
assert title == 'Cached Title'
assert description == 'Cached Description'
@pytest.mark.asyncio
async def test_get_issue_details_no_title(self, new_conversation_view):
"""Test issue details with no title raises error."""
mock_response = MagicMock()
mock_response.json.return_value = {
'fields': {'summary': '', 'description': 'Test description'}
}
mock_response.raise_for_status = MagicMock()
with patch('httpx.AsyncClient') as mock_client:
mock_client.return_value.__aenter__.return_value.get = AsyncMock(
return_value=mock_response
)
with pytest.raises(StartingConvoException, match='does not have a title'):
await new_conversation_view.get_issue_details()
@pytest.mark.asyncio
async def test_get_instructions(self, new_conversation_view, mock_jinja_env):
"""Test _get_instructions method fetches issue details."""
new_conversation_view._issue_title = 'Test Issue'
new_conversation_view._issue_description = 'This is a test issue'
instructions, user_msg = await new_conversation_view._get_instructions(
mock_jinja_env
)
def test_get_instructions(self, new_conversation_view, mock_jinja_env):
"""Test _get_instructions method"""
instructions, user_msg = new_conversation_view._get_instructions(mock_jinja_env)
assert instructions == 'Test Jira instructions template'
assert 'TEST-123' in user_msg
assert 'Test Issue' in user_msg
assert 'Fix this bug @openhands' in user_msg
@pytest.mark.asyncio
@patch('integrations.jira.jira_view.create_new_conversation')
@patch('integrations.jira.jira_view.integration_store')
async def test_create_or_update_conversation_success(
@@ -97,8 +38,6 @@ class TestJiraNewConversationView:
mock_agent_loop_info,
):
"""Test successful conversation creation"""
new_conversation_view._issue_title = 'Test Issue'
new_conversation_view._issue_description = 'Test description'
mock_create_conversation.return_value = mock_agent_loop_info
mock_store.create_conversation = AsyncMock()
@@ -110,7 +49,6 @@ class TestJiraNewConversationView:
mock_create_conversation.assert_called_once()
mock_store.create_conversation.assert_called_once()
@pytest.mark.asyncio
async def test_create_or_update_conversation_no_repo(
self, new_conversation_view, mock_jinja_env
):
@@ -120,6 +58,18 @@ class TestJiraNewConversationView:
with pytest.raises(StartingConvoException, match='No repository selected'):
await new_conversation_view.create_or_update_conversation(mock_jinja_env)
@patch('integrations.jira.jira_view.create_new_conversation')
async def test_create_or_update_conversation_failure(
self, mock_create_conversation, new_conversation_view, mock_jinja_env
):
"""Test conversation creation failure"""
mock_create_conversation.side_effect = Exception('Creation failed')
with pytest.raises(
StartingConvoException, match='Failed to create conversation'
):
await new_conversation_view.create_or_update_conversation(mock_jinja_env)
def test_get_response_msg(self, new_conversation_view):
"""Test get_response_msg method"""
response = new_conversation_view.get_response_msg()
@@ -130,336 +80,344 @@ class TestJiraNewConversationView:
assert 'conv-123' in response
class TestJiraExistingConversationView:
"""Tests for JiraExistingConversationView"""
def test_get_instructions(self, existing_conversation_view, mock_jinja_env):
"""Test _get_instructions method"""
instructions, user_msg = existing_conversation_view._get_instructions(
mock_jinja_env
)
assert instructions == ''
assert 'TEST-123' in user_msg
assert 'Test Issue' in user_msg
assert 'Fix this bug @openhands' in user_msg
@patch('integrations.jira.jira_view.ConversationStoreImpl.get_instance')
@patch('integrations.jira.jira_view.setup_init_conversation_settings')
@patch('integrations.jira.jira_view.conversation_manager')
@patch('integrations.jira.jira_view.get_final_agent_observation')
async def test_create_or_update_conversation_success(
self,
mock_get_observation,
mock_conversation_manager,
mock_setup_init,
mock_store_impl,
existing_conversation_view,
mock_jinja_env,
mock_conversation_store,
mock_conversation_init_data,
mock_agent_loop_info,
):
"""Test successful existing conversation update"""
# Setup mocks
mock_store_impl.return_value = mock_conversation_store
mock_setup_init.return_value = mock_conversation_init_data
mock_conversation_manager.maybe_start_agent_loop = AsyncMock(
return_value=mock_agent_loop_info
)
mock_conversation_manager.send_event_to_conversation = AsyncMock()
# Mock agent observation with RUNNING state
mock_observation = MagicMock()
mock_observation.agent_state = AgentState.RUNNING
mock_get_observation.return_value = [mock_observation]
result = await existing_conversation_view.create_or_update_conversation(
mock_jinja_env
)
assert result == 'conv-123'
mock_conversation_manager.send_event_to_conversation.assert_called_once()
@patch('integrations.jira.jira_view.ConversationStoreImpl.get_instance')
async def test_create_or_update_conversation_no_metadata(
self, mock_store_impl, existing_conversation_view, mock_jinja_env
):
"""Test conversation update with no metadata"""
mock_store = AsyncMock()
mock_store.get_metadata.side_effect = FileNotFoundError(
'No such file or directory'
)
mock_store_impl.return_value = mock_store
with pytest.raises(
StartingConvoException, match='Conversation no longer exists'
):
await existing_conversation_view.create_or_update_conversation(
mock_jinja_env
)
@patch('integrations.jira.jira_view.ConversationStoreImpl.get_instance')
@patch('integrations.jira.jira_view.setup_init_conversation_settings')
@patch('integrations.jira.jira_view.conversation_manager')
@patch('integrations.jira.jira_view.get_final_agent_observation')
async def test_create_or_update_conversation_loading_state(
self,
mock_get_observation,
mock_conversation_manager,
mock_setup_init,
mock_store_impl,
existing_conversation_view,
mock_jinja_env,
mock_conversation_store,
mock_conversation_init_data,
mock_agent_loop_info,
):
"""Test conversation update with loading state"""
mock_store_impl.return_value = mock_conversation_store
mock_setup_init.return_value = mock_conversation_init_data
mock_conversation_manager.maybe_start_agent_loop = AsyncMock(
return_value=mock_agent_loop_info
)
# Mock agent observation with LOADING state
mock_observation = MagicMock()
mock_observation.agent_state = AgentState.LOADING
mock_get_observation.return_value = [mock_observation]
with pytest.raises(
StartingConvoException, match='Conversation is still starting'
):
await existing_conversation_view.create_or_update_conversation(
mock_jinja_env
)
@patch('integrations.jira.jira_view.ConversationStoreImpl.get_instance')
async def test_create_or_update_conversation_failure(
self, mock_store_impl, existing_conversation_view, mock_jinja_env
):
"""Test conversation update failure"""
mock_store_impl.side_effect = Exception('Store error')
with pytest.raises(
StartingConvoException, match='Failed to create conversation'
):
await existing_conversation_view.create_or_update_conversation(
mock_jinja_env
)
def test_get_response_msg(self, existing_conversation_view):
"""Test get_response_msg method"""
response = existing_conversation_view.get_response_msg()
assert "I'm on it!" in response
assert 'Test User' in response
assert 'continue tracking my progress here' in response
assert 'conv-123' in response
class TestJiraFactory:
"""Tests for JiraFactory"""
@pytest.mark.asyncio
@patch('integrations.jira.jira_view.JiraFactory._create_provider_handler')
@patch('integrations.jira.jira_view.infer_repo_from_message')
async def test_create_view_success(
@patch('integrations.jira.jira_view.integration_store')
async def test_create_jira_view_from_payload_existing_conversation(
self,
mock_infer_repos,
mock_create_handler,
sample_webhook_payload,
mock_store,
sample_job_context,
sample_user_auth,
sample_jira_user,
sample_jira_workspace,
sample_repositories,
jira_conversation,
):
"""Test factory creating view with repo selection."""
# Setup mock provider handler
mock_handler = MagicMock()
mock_handler.verify_repo_provider = AsyncMock(
return_value=sample_repositories[0]
)
mock_create_handler.return_value = mock_handler
# Mock repo inference to return a repo name
mock_infer_repos.return_value = ['test/repo1']
mock_response = MagicMock()
mock_response.json.return_value = {
'fields': {'summary': 'Test Issue', 'description': 'Test description'}
}
mock_response.raise_for_status = MagicMock()
with patch('httpx.AsyncClient') as mock_client:
mock_client.return_value.__aenter__.return_value.get = AsyncMock(
return_value=mock_response
)
view = await JiraFactory.create_view(
payload=sample_webhook_payload,
workspace=sample_jira_workspace,
user=sample_jira_user,
user_auth=sample_user_auth,
decrypted_api_key='test_api_key',
)
assert isinstance(view, JiraNewConversationView)
assert view.selected_repo == 'test/repo1'
mock_handler.verify_repo_provider.assert_called_once_with('test/repo1')
@pytest.mark.asyncio
@patch('integrations.jira.jira_view.JiraFactory._create_provider_handler')
@patch('integrations.jira.jira_view.infer_repo_from_message')
async def test_create_view_no_repo_in_text(
self,
mock_infer_repos,
mock_create_handler,
sample_webhook_payload,
sample_user_auth,
sample_jira_user,
sample_jira_workspace,
):
"""Test factory raises error when no repo mentioned in text."""
mock_handler = MagicMock()
mock_create_handler.return_value = mock_handler
# No repos found in text
mock_infer_repos.return_value = []
mock_response = MagicMock()
mock_response.json.return_value = {
'fields': {'summary': 'Test Issue', 'description': 'Test description'}
}
mock_response.raise_for_status = MagicMock()
with patch('httpx.AsyncClient') as mock_client:
mock_client.return_value.__aenter__.return_value.get = AsyncMock(
return_value=mock_response
)
with pytest.raises(
RepositoryNotFoundError, match='Could not determine which repository'
):
await JiraFactory.create_view(
payload=sample_webhook_payload,
workspace=sample_jira_workspace,
user=sample_jira_user,
user_auth=sample_user_auth,
decrypted_api_key='test_api_key',
)
@pytest.mark.asyncio
@patch('integrations.jira.jira_view.JiraFactory._create_provider_handler')
@patch('integrations.jira.jira_view.infer_repo_from_message')
async def test_create_view_repo_verification_fails(
self,
mock_infer_repos,
mock_create_handler,
sample_webhook_payload,
sample_user_auth,
sample_jira_user,
sample_jira_workspace,
):
"""Test factory raises error when repo verification fails."""
mock_handler = MagicMock()
mock_handler.verify_repo_provider = AsyncMock(
side_effect=Exception('Repository not found')
)
mock_create_handler.return_value = mock_handler
# Repos found in text but verification fails
mock_infer_repos.return_value = ['test/repo1', 'test/repo2']
mock_response = MagicMock()
mock_response.json.return_value = {
'fields': {'summary': 'Test Issue', 'description': 'Test description'}
}
mock_response.raise_for_status = MagicMock()
with patch('httpx.AsyncClient') as mock_client:
mock_client.return_value.__aenter__.return_value.get = AsyncMock(
return_value=mock_response
)
with pytest.raises(
RepositoryNotFoundError,
match='Could not access any of the mentioned repositories',
):
await JiraFactory.create_view(
payload=sample_webhook_payload,
workspace=sample_jira_workspace,
user=sample_jira_user,
user_auth=sample_user_auth,
decrypted_api_key='test_api_key',
)
@pytest.mark.asyncio
@patch('integrations.jira.jira_view.JiraFactory._create_provider_handler')
@patch('integrations.jira.jira_view.infer_repo_from_message')
async def test_create_view_multiple_repos_verified(
self,
mock_infer_repos,
mock_create_handler,
sample_webhook_payload,
sample_user_auth,
sample_jira_user,
sample_jira_workspace,
sample_repositories,
):
"""Test factory raises error when multiple repos are verified."""
mock_handler = MagicMock()
# Both repos verify successfully
mock_handler.verify_repo_provider = AsyncMock(
side_effect=[sample_repositories[0], sample_repositories[1]]
)
mock_create_handler.return_value = mock_handler
# Multiple repos found in text
mock_infer_repos.return_value = ['test/repo1', 'test/repo2']
mock_response = MagicMock()
mock_response.json.return_value = {
'fields': {'summary': 'Test Issue', 'description': 'Test description'}
}
mock_response.raise_for_status = MagicMock()
with patch('httpx.AsyncClient') as mock_client:
mock_client.return_value.__aenter__.return_value.get = AsyncMock(
return_value=mock_response
)
with pytest.raises(
RepositoryNotFoundError, match='Multiple repositories found'
):
await JiraFactory.create_view(
payload=sample_webhook_payload,
workspace=sample_jira_workspace,
user=sample_jira_user,
user_auth=sample_user_auth,
decrypted_api_key='test_api_key',
)
@pytest.mark.asyncio
@patch('integrations.jira.jira_view.JiraFactory._create_provider_handler')
async def test_create_view_no_provider(
self,
mock_create_handler,
sample_webhook_payload,
sample_user_auth,
sample_jira_user,
sample_jira_workspace,
):
"""Test factory raises error when no provider is connected."""
mock_create_handler.return_value = None
mock_response = MagicMock()
mock_response.json.return_value = {
'fields': {'summary': 'Test Issue', 'description': 'Test description'}
}
mock_response.raise_for_status = MagicMock()
with patch('httpx.AsyncClient') as mock_client:
mock_client.return_value.__aenter__.return_value.get = AsyncMock(
return_value=mock_response
)
with pytest.raises(
RepositoryNotFoundError, match='No Git provider connected'
):
await JiraFactory.create_view(
payload=sample_webhook_payload,
workspace=sample_jira_workspace,
user=sample_jira_user,
user_auth=sample_user_auth,
decrypted_api_key='test_api_key',
)
class TestJiraPayloadParser:
"""Tests for JiraPayloadParser"""
@pytest.fixture
def parser(self):
"""Create a parser for testing."""
return JiraPayloadParser(oh_label='openhands', inline_oh_label='@openhands')
def test_parse_label_event_success(
self, parser, sample_issue_update_webhook_payload
):
"""Test parsing label event."""
result = parser.parse(sample_issue_update_webhook_payload)
assert isinstance(result, JiraPayloadSuccess)
assert result.payload.event_type == JiraEventType.LABELED_TICKET
assert result.payload.issue_key == 'PROJ-123'
def test_parse_comment_event_success(self, parser, sample_comment_webhook_payload):
"""Test parsing comment event."""
result = parser.parse(sample_comment_webhook_payload)
assert isinstance(result, JiraPayloadSuccess)
assert result.payload.event_type == JiraEventType.COMMENT_MENTION
assert result.payload.issue_key == 'TEST-123'
assert '@openhands' in result.payload.comment_body
def test_parse_unknown_event_skipped(self, parser):
"""Test unknown event is skipped."""
payload = {'webhookEvent': 'unknown_event'}
result = parser.parse(payload)
assert isinstance(result, JiraPayloadSkipped)
assert 'Unhandled webhook event type' in result.skip_reason
def test_parse_label_event_wrong_label_skipped(self, parser):
"""Test label event without OH label is skipped."""
payload = {
'webhookEvent': 'jira:issue_updated',
'changelog': {'items': [{'field': 'labels', 'toString': 'other-label'}]},
}
result = parser.parse(payload)
assert isinstance(result, JiraPayloadSkipped)
assert 'does not contain' in result.skip_reason
def test_parse_comment_event_no_mention_skipped(self, parser):
"""Test comment without mention is skipped."""
payload = {
'webhookEvent': 'comment_created',
'comment': {
'body': 'Regular comment',
'author': {'emailAddress': 'test@test.com'},
},
}
result = parser.parse(payload)
assert isinstance(result, JiraPayloadSkipped)
assert 'does not mention' in result.skip_reason
def test_parse_missing_fields_error(self, parser):
"""Test missing required fields returns error."""
payload = {
'webhookEvent': 'jira:issue_updated',
'changelog': {'items': [{'field': 'labels', 'toString': 'openhands'}]},
'issue': {'id': '123'}, # Missing key
'user': {'emailAddress': 'test@test.com'}, # Missing other fields
}
result = parser.parse(payload)
assert isinstance(result, JiraPayloadError)
assert 'Missing required fields' in result.error
class TestJiraPayloadParserStagingLabels:
"""Tests for JiraPayloadParser with staging labels."""
@pytest.fixture
def staging_parser(self):
"""Create a parser with staging labels."""
return JiraPayloadParser(
oh_label='openhands-exp', inline_oh_label='@openhands-exp'
"""Test factory creating existing conversation view"""
mock_store.get_user_conversations_by_issue_id = AsyncMock(
return_value=jira_conversation
)
def test_parse_staging_label(self, staging_parser):
"""Test parsing with staging label."""
payload = {
'webhookEvent': 'jira:issue_updated',
'changelog': {'items': [{'field': 'labels', 'toString': 'openhands-exp'}]},
'issue': {
'id': '123',
'key': 'TEST-1',
'self': 'https://test.atlassian.net/rest/api/2/issue/123',
},
'user': {
'emailAddress': 'test@test.com',
'displayName': 'Test',
'accountId': 'acc123',
'self': 'https://test.atlassian.net/rest/api/2/user',
},
}
result = staging_parser.parse(payload)
view = await JiraFactory.create_jira_view_from_payload(
sample_job_context,
sample_user_auth,
sample_jira_user,
sample_jira_workspace,
)
assert isinstance(result, JiraPayloadSuccess)
assert result.payload.event_type == JiraEventType.LABELED_TICKET
assert isinstance(view, JiraExistingConversationView)
assert view.conversation_id == 'conv-123'
def test_parse_prod_label_in_staging_skipped(self, staging_parser):
"""Test prod label is skipped in staging environment."""
payload = {
'webhookEvent': 'jira:issue_updated',
'changelog': {'items': [{'field': 'labels', 'toString': 'openhands'}]},
}
result = staging_parser.parse(payload)
@patch('integrations.jira.jira_view.integration_store')
async def test_create_jira_view_from_payload_new_conversation(
self,
mock_store,
sample_job_context,
sample_user_auth,
sample_jira_user,
sample_jira_workspace,
):
"""Test factory creating new conversation view"""
mock_store.get_user_conversations_by_issue_id = AsyncMock(return_value=None)
assert isinstance(result, JiraPayloadSkipped)
view = await JiraFactory.create_jira_view_from_payload(
sample_job_context,
sample_user_auth,
sample_jira_user,
sample_jira_workspace,
)
assert isinstance(view, JiraNewConversationView)
assert view.conversation_id == ''
async def test_create_jira_view_from_payload_no_user(
self, sample_job_context, sample_user_auth, sample_jira_workspace
):
"""Test factory with no Jira user"""
with pytest.raises(StartingConvoException, match='User not authenticated'):
await JiraFactory.create_jira_view_from_payload(
sample_job_context,
sample_user_auth,
None,
sample_jira_workspace, # type: ignore
)
async def test_create_jira_view_from_payload_no_auth(
self, sample_job_context, sample_jira_user, sample_jira_workspace
):
"""Test factory with no SaaS auth"""
with pytest.raises(StartingConvoException, match='User not authenticated'):
await JiraFactory.create_jira_view_from_payload(
sample_job_context,
None,
sample_jira_user,
sample_jira_workspace, # type: ignore
)
async def test_create_jira_view_from_payload_no_workspace(
self, sample_job_context, sample_user_auth, sample_jira_user
):
"""Test factory with no workspace"""
with pytest.raises(StartingConvoException, match='User not authenticated'):
await JiraFactory.create_jira_view_from_payload(
sample_job_context,
sample_user_auth,
sample_jira_user,
None, # type: ignore
)
class TestJiraViewEdgeCases:
"""Tests for edge cases and error scenarios"""
@patch('integrations.jira.jira_view.create_new_conversation')
@patch('integrations.jira.jira_view.integration_store')
async def test_conversation_creation_with_no_user_secrets(
self,
mock_store,
mock_create_conversation,
new_conversation_view,
mock_jinja_env,
mock_agent_loop_info,
):
"""Test conversation creation when user has no secrets"""
new_conversation_view.saas_user_auth.get_secrets.return_value = None
mock_create_conversation.return_value = mock_agent_loop_info
mock_store.create_conversation = AsyncMock()
result = await new_conversation_view.create_or_update_conversation(
mock_jinja_env
)
assert result == 'conv-123'
# Verify create_new_conversation was called with custom_secrets=None
call_kwargs = mock_create_conversation.call_args[1]
assert call_kwargs['custom_secrets'] is None
@patch('integrations.jira.jira_view.create_new_conversation')
@patch('integrations.jira.jira_view.integration_store')
async def test_conversation_creation_store_failure(
self,
mock_store,
mock_create_conversation,
new_conversation_view,
mock_jinja_env,
mock_agent_loop_info,
):
"""Test conversation creation when store creation fails"""
mock_create_conversation.return_value = mock_agent_loop_info
mock_store.create_conversation = AsyncMock(side_effect=Exception('Store error'))
with pytest.raises(
StartingConvoException, match='Failed to create conversation'
):
await new_conversation_view.create_or_update_conversation(mock_jinja_env)
@patch('integrations.jira.jira_view.ConversationStoreImpl.get_instance')
@patch('integrations.jira.jira_view.setup_init_conversation_settings')
@patch('integrations.jira.jira_view.conversation_manager')
@patch('integrations.jira.jira_view.get_final_agent_observation')
async def test_existing_conversation_empty_observations(
self,
mock_get_observation,
mock_conversation_manager,
mock_setup_init,
mock_store_impl,
existing_conversation_view,
mock_jinja_env,
mock_conversation_store,
mock_conversation_init_data,
mock_agent_loop_info,
):
"""Test existing conversation with empty observations"""
mock_store_impl.return_value = mock_conversation_store
mock_setup_init.return_value = mock_conversation_init_data
mock_conversation_manager.maybe_start_agent_loop = AsyncMock(
return_value=mock_agent_loop_info
)
mock_get_observation.return_value = [] # Empty observations
with pytest.raises(
StartingConvoException, match='Conversation is still starting'
):
await existing_conversation_view.create_or_update_conversation(
mock_jinja_env
)
def test_new_conversation_view_attributes(self, new_conversation_view):
"""Test new conversation view attribute access"""
assert new_conversation_view.job_context.issue_key == 'TEST-123'
assert new_conversation_view.selected_repo == 'test/repo1'
assert new_conversation_view.conversation_id == 'conv-123'
def test_existing_conversation_view_attributes(self, existing_conversation_view):
"""Test existing conversation view attribute access"""
assert existing_conversation_view.job_context.issue_key == 'TEST-123'
assert existing_conversation_view.selected_repo == 'test/repo1'
assert existing_conversation_view.conversation_id == 'conv-123'
@patch('integrations.jira.jira_view.ConversationStoreImpl.get_instance')
@patch('integrations.jira.jira_view.setup_init_conversation_settings')
@patch('integrations.jira.jira_view.conversation_manager')
@patch('integrations.jira.jira_view.get_final_agent_observation')
async def test_existing_conversation_message_send_failure(
self,
mock_get_observation,
mock_conversation_manager,
mock_setup_init,
mock_store_impl,
existing_conversation_view,
mock_jinja_env,
mock_conversation_store,
mock_conversation_init_data,
mock_agent_loop_info,
):
"""Test existing conversation when message sending fails"""
mock_store_impl.return_value = mock_conversation_store
mock_setup_init.return_value = mock_conversation_init_data
mock_conversation_manager.maybe_start_agent_loop = AsyncMock(
return_value=mock_agent_loop_info
)
mock_conversation_manager.send_event_to_conversation = AsyncMock(
side_effect=Exception('Send error')
)
# Mock agent observation with RUNNING state
mock_observation = MagicMock()
mock_observation.agent_state = AgentState.RUNNING
mock_get_observation.return_value = [mock_observation]
with pytest.raises(
StartingConvoException, match='Failed to create conversation'
):
await existing_conversation_view.create_or_update_conversation(
mock_jinja_env
)

View File

@@ -1,4 +1,4 @@
"""Test for ResolverUserContext get_secrets and get_latest_token logic.
"""Test for ResolverUserContext get_secrets conversion logic.
This test focuses on testing the actual ResolverUserContext implementation.
"""
@@ -12,8 +12,7 @@ from pydantic import SecretStr
from enterprise.integrations.resolver_context import ResolverUserContext
# Import the real classes we want to test
from openhands.integrations.provider import CustomSecret, ProviderToken
from openhands.integrations.service_types import ProviderType
from openhands.integrations.provider import CustomSecret
# Import the SDK types we need for testing
from openhands.sdk.secret import SecretSource, StaticSecret
@@ -132,135 +131,3 @@ def test_custom_to_static_conversion():
assert isinstance(static_secret, StaticSecret)
assert isinstance(static_secret, SecretSource)
assert static_secret.value.get_secret_value() == secret_value
# ---------------------------------------------------------------------------
# Tests for get_latest_token - ensuring string values are returned
# ---------------------------------------------------------------------------
def create_provider_tokens(
tokens_dict: dict[ProviderType, str],
) -> dict[ProviderType, ProviderToken]:
"""Helper to create provider tokens dictionary."""
return {
provider_type: ProviderToken(token=SecretStr(token_value))
for provider_type, token_value in tokens_dict.items()
}
@pytest.mark.asyncio
async def test_get_latest_token_returns_string(resolver_context, mock_saas_user_auth):
"""Test that get_latest_token returns a string, not a ProviderToken object."""
# Arrange
token_value = 'ghp_test_github_token_123'
provider_tokens = create_provider_tokens({ProviderType.GITHUB: token_value})
mock_saas_user_auth.get_provider_tokens = AsyncMock(return_value=provider_tokens)
# Act
result = await resolver_context.get_latest_token(ProviderType.GITHUB)
# Assert
assert result is not None
assert isinstance(result, str), (
f'Expected str, got {type(result).__name__}. '
'get_latest_token must return a string for StaticSecret compatibility.'
)
assert result == token_value
@pytest.mark.asyncio
async def test_get_latest_token_returns_string_for_multiple_providers(
resolver_context, mock_saas_user_auth
):
"""Test that get_latest_token returns strings for all provider types."""
# Arrange
provider_tokens = create_provider_tokens(
{
ProviderType.GITHUB: 'ghp_github_token',
ProviderType.GITLAB: 'glpat_gitlab_token',
ProviderType.BITBUCKET: 'bitbucket_token',
}
)
mock_saas_user_auth.get_provider_tokens = AsyncMock(return_value=provider_tokens)
# Act & Assert - verify each provider returns a string
for provider_type, expected_token in [
(ProviderType.GITHUB, 'ghp_github_token'),
(ProviderType.GITLAB, 'glpat_gitlab_token'),
(ProviderType.BITBUCKET, 'bitbucket_token'),
]:
result = await resolver_context.get_latest_token(provider_type)
assert isinstance(
result, str
), f'Expected str for {provider_type.name}, got {type(result).__name__}'
assert result == expected_token
@pytest.mark.asyncio
async def test_get_latest_token_returns_none_for_missing_provider(
resolver_context, mock_saas_user_auth
):
"""Test that get_latest_token returns None when provider is not in tokens."""
# Arrange - only GitHub token available
provider_tokens = create_provider_tokens({ProviderType.GITHUB: 'ghp_token'})
mock_saas_user_auth.get_provider_tokens = AsyncMock(return_value=provider_tokens)
# Act - request GitLab token which doesn't exist
result = await resolver_context.get_latest_token(ProviderType.GITLAB)
# Assert
assert result is None
@pytest.mark.asyncio
async def test_get_latest_token_returns_none_when_no_provider_tokens(
resolver_context, mock_saas_user_auth
):
"""Test that get_latest_token returns None when no provider tokens exist."""
# Arrange
mock_saas_user_auth.get_provider_tokens = AsyncMock(return_value=None)
# Act
result = await resolver_context.get_latest_token(ProviderType.GITHUB)
# Assert
assert result is None
@pytest.mark.asyncio
async def test_get_latest_token_returns_none_for_empty_token(
resolver_context, mock_saas_user_auth
):
"""Test that get_latest_token returns None when provider token has no value."""
# Arrange - provider exists but token is None
provider_tokens = {ProviderType.GITHUB: ProviderToken(token=None)}
mock_saas_user_auth.get_provider_tokens = AsyncMock(return_value=provider_tokens)
# Act
result = await resolver_context.get_latest_token(ProviderType.GITHUB)
# Assert
assert result is None
@pytest.mark.asyncio
async def test_get_latest_token_can_be_used_with_static_secret(
resolver_context, mock_saas_user_auth
):
"""Test that get_latest_token result can be used directly with StaticSecret.
This is a critical integration test to ensure the return value is compatible
with how it's used in _setup_secrets_for_git_providers.
"""
# Arrange
token_value = 'ghp_integration_test_token'
provider_tokens = create_provider_tokens({ProviderType.GITHUB: token_value})
mock_saas_user_auth.get_provider_tokens = AsyncMock(return_value=provider_tokens)
# Act
token = await resolver_context.get_latest_token(ProviderType.GITHUB)
# Assert - this should NOT raise a ValidationError
static_secret = StaticSecret(value=token, description='GITHUB authentication token')
assert static_secret.get_value() == token_value

View File

@@ -6,18 +6,17 @@ import httpx
import pytest
from fastapi import HTTPException
from server.routes.api_keys import (
delete_byor_key_from_litellm,
get_llm_api_key_for_byor,
verify_byor_key_in_litellm,
)
from storage.lite_llm_manager import LiteLlmManager
class TestVerifyByorKeyInLitellm:
"""Test the verify_byor_key_in_litellm function."""
@pytest.mark.asyncio
@patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'https://litellm.example.com')
@patch('storage.lite_llm_manager.httpx.AsyncClient')
@patch('server.routes.api_keys.LITE_LLM_API_URL', 'https://litellm.example.com')
@patch('server.routes.api_keys.httpx.AsyncClient')
async def test_verify_valid_key_returns_true(self, mock_client_class):
"""Test that a valid key (200 response) returns True."""
# Arrange
@@ -33,7 +32,7 @@ class TestVerifyByorKeyInLitellm:
mock_client_class.return_value = mock_client
# Act
result = await LiteLlmManager.verify_key(byor_key, user_id)
result = await verify_byor_key_in_litellm(byor_key, user_id)
# Assert
assert result is True
@@ -43,8 +42,8 @@ class TestVerifyByorKeyInLitellm:
)
@pytest.mark.asyncio
@patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'https://litellm.example.com')
@patch('storage.lite_llm_manager.httpx.AsyncClient')
@patch('server.routes.api_keys.LITE_LLM_API_URL', 'https://litellm.example.com')
@patch('server.routes.api_keys.httpx.AsyncClient')
async def test_verify_invalid_key_401_returns_false(self, mock_client_class):
"""Test that an invalid key (401 response) returns False."""
# Arrange
@@ -59,14 +58,14 @@ class TestVerifyByorKeyInLitellm:
mock_client_class.return_value = mock_client
# Act
result = await LiteLlmManager.verify_key(byor_key, user_id)
result = await verify_byor_key_in_litellm(byor_key, user_id)
# Assert
assert result is False
@pytest.mark.asyncio
@patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'https://litellm.example.com')
@patch('storage.lite_llm_manager.httpx.AsyncClient')
@patch('server.routes.api_keys.LITE_LLM_API_URL', 'https://litellm.example.com')
@patch('server.routes.api_keys.httpx.AsyncClient')
async def test_verify_invalid_key_403_returns_false(self, mock_client_class):
"""Test that an invalid key (403 response) returns False."""
# Arrange
@@ -81,14 +80,14 @@ class TestVerifyByorKeyInLitellm:
mock_client_class.return_value = mock_client
# Act
result = await LiteLlmManager.verify_key(byor_key, user_id)
result = await verify_byor_key_in_litellm(byor_key, user_id)
# Assert
assert result is False
@pytest.mark.asyncio
@patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'https://litellm.example.com')
@patch('storage.lite_llm_manager.httpx.AsyncClient')
@patch('server.routes.api_keys.LITE_LLM_API_URL', 'https://litellm.example.com')
@patch('server.routes.api_keys.httpx.AsyncClient')
async def test_verify_server_error_returns_false(self, mock_client_class):
"""Test that a server error (500) returns False to ensure key validity."""
# Arrange
@@ -104,14 +103,14 @@ class TestVerifyByorKeyInLitellm:
mock_client_class.return_value = mock_client
# Act
result = await LiteLlmManager.verify_key(byor_key, user_id)
result = await verify_byor_key_in_litellm(byor_key, user_id)
# Assert
assert result is False
@pytest.mark.asyncio
@patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'https://litellm.example.com')
@patch('storage.lite_llm_manager.httpx.AsyncClient')
@patch('server.routes.api_keys.LITE_LLM_API_URL', 'https://litellm.example.com')
@patch('server.routes.api_keys.httpx.AsyncClient')
async def test_verify_timeout_returns_false(self, mock_client_class):
"""Test that a timeout returns False to ensure key validity."""
# Arrange
@@ -124,14 +123,14 @@ class TestVerifyByorKeyInLitellm:
mock_client_class.return_value = mock_client
# Act
result = await LiteLlmManager.verify_key(byor_key, user_id)
result = await verify_byor_key_in_litellm(byor_key, user_id)
# Assert
assert result is False
@pytest.mark.asyncio
@patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'https://litellm.example.com')
@patch('storage.lite_llm_manager.httpx.AsyncClient')
@patch('server.routes.api_keys.LITE_LLM_API_URL', 'https://litellm.example.com')
@patch('server.routes.api_keys.httpx.AsyncClient')
async def test_verify_network_error_returns_false(self, mock_client_class):
"""Test that a network error returns False to ensure key validity."""
# Arrange
@@ -144,13 +143,13 @@ class TestVerifyByorKeyInLitellm:
mock_client_class.return_value = mock_client
# Act
result = await LiteLlmManager.verify_key(byor_key, user_id)
result = await verify_byor_key_in_litellm(byor_key, user_id)
# Assert
assert result is False
@pytest.mark.asyncio
@patch('storage.lite_llm_manager.LITE_LLM_API_URL', None)
@patch('server.routes.api_keys.LITE_LLM_API_URL', None)
async def test_verify_missing_api_url_returns_false(self):
"""Test that missing LITE_LLM_API_URL returns False."""
# Arrange
@@ -158,13 +157,13 @@ class TestVerifyByorKeyInLitellm:
user_id = 'user-123'
# Act
result = await LiteLlmManager.verify_key(byor_key, user_id)
result = await verify_byor_key_in_litellm(byor_key, user_id)
# Assert
assert result is False
@pytest.mark.asyncio
@patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'https://litellm.example.com')
@patch('server.routes.api_keys.LITE_LLM_API_URL', 'https://litellm.example.com')
async def test_verify_empty_key_returns_false(self):
"""Test that empty key returns False."""
# Arrange
@@ -172,7 +171,7 @@ class TestVerifyByorKeyInLitellm:
user_id = 'user-123'
# Act
result = await LiteLlmManager.verify_key(byor_key, user_id)
result = await verify_byor_key_in_litellm(byor_key, user_id)
# Assert
assert result is False
@@ -206,7 +205,7 @@ class TestGetLlmApiKeyForByor:
mock_store_key.assert_called_once_with(user_id, new_key)
@pytest.mark.asyncio
@patch('storage.lite_llm_manager.LiteLlmManager.verify_key')
@patch('server.routes.api_keys.verify_byor_key_in_litellm')
@patch('server.routes.api_keys.get_byor_key_from_db')
async def test_valid_key_in_database_returns_key(
self, mock_get_key, mock_verify_key
@@ -230,7 +229,7 @@ class TestGetLlmApiKeyForByor:
@patch('server.routes.api_keys.store_byor_key_in_db')
@patch('server.routes.api_keys.generate_byor_key')
@patch('server.routes.api_keys.delete_byor_key_from_litellm')
@patch('storage.lite_llm_manager.LiteLlmManager.verify_key')
@patch('server.routes.api_keys.verify_byor_key_in_litellm')
@patch('server.routes.api_keys.get_byor_key_from_db')
async def test_invalid_key_in_database_regenerates(
self,
@@ -266,7 +265,7 @@ class TestGetLlmApiKeyForByor:
@patch('server.routes.api_keys.store_byor_key_in_db')
@patch('server.routes.api_keys.generate_byor_key')
@patch('server.routes.api_keys.delete_byor_key_from_litellm')
@patch('storage.lite_llm_manager.LiteLlmManager.verify_key')
@patch('server.routes.api_keys.verify_byor_key_in_litellm')
@patch('server.routes.api_keys.get_byor_key_from_db')
async def test_invalid_key_deletion_failure_still_regenerates(
self,
@@ -329,99 +328,3 @@ class TestGetLlmApiKeyForByor:
assert exc_info.value.status_code == 500
assert 'Failed to retrieve BYOR LLM API key' in exc_info.value.detail
class TestDeleteByorKeyFromLitellm:
"""Test the delete_byor_key_from_litellm function with alias cleanup."""
@pytest.mark.asyncio
@patch('storage.lite_llm_manager.LiteLlmManager.delete_key')
@patch('storage.user_store.UserStore.get_user_by_id_async')
async def test_delete_constructs_alias_from_user(
self, mock_get_user, mock_delete_key
):
"""Test that delete_byor_key_from_litellm constructs key alias from user."""
# Arrange
user_id = 'user-123'
org_id = 'org-456'
byor_key = 'sk-byor-key-to-delete'
expected_alias = f'BYOR Key - user {user_id}, org {org_id}'
mock_user = MagicMock()
mock_user.current_org_id = org_id
mock_get_user.return_value = mock_user
mock_delete_key.return_value = None
# Act
result = await delete_byor_key_from_litellm(user_id, byor_key)
# Assert
assert result is True
mock_get_user.assert_called_once_with(user_id)
mock_delete_key.assert_called_once_with(byor_key, key_alias=expected_alias)
@pytest.mark.asyncio
@patch('storage.lite_llm_manager.LiteLlmManager.delete_key')
@patch('storage.user_store.UserStore.get_user_by_id_async')
async def test_delete_without_user_passes_no_alias(
self, mock_get_user, mock_delete_key
):
"""Test that when user is not found, no alias is passed."""
# Arrange
user_id = 'user-123'
byor_key = 'sk-byor-key-to-delete'
mock_get_user.return_value = None
mock_delete_key.return_value = None
# Act
result = await delete_byor_key_from_litellm(user_id, byor_key)
# Assert
assert result is True
mock_delete_key.assert_called_once_with(byor_key, key_alias=None)
@pytest.mark.asyncio
@patch('storage.lite_llm_manager.LiteLlmManager.delete_key')
@patch('storage.user_store.UserStore.get_user_by_id_async')
async def test_delete_without_org_id_passes_no_alias(
self, mock_get_user, mock_delete_key
):
"""Test that when user has no current_org_id, no alias is passed."""
# Arrange
user_id = 'user-123'
byor_key = 'sk-byor-key-to-delete'
mock_user = MagicMock()
mock_user.current_org_id = None
mock_get_user.return_value = mock_user
mock_delete_key.return_value = None
# Act
result = await delete_byor_key_from_litellm(user_id, byor_key)
# Assert
assert result is True
mock_delete_key.assert_called_once_with(byor_key, key_alias=None)
@pytest.mark.asyncio
@patch('storage.lite_llm_manager.LiteLlmManager.delete_key')
@patch('storage.user_store.UserStore.get_user_by_id_async')
async def test_delete_returns_false_on_exception(
self, mock_get_user, mock_delete_key
):
"""Test that exceptions during deletion return False."""
# Arrange
user_id = 'user-123'
byor_key = 'sk-byor-key-to-delete'
mock_user = MagicMock()
mock_user.current_org_id = 'org-456'
mock_get_user.return_value = mock_user
mock_delete_key.side_effect = Exception('LiteLLM API error')
# Act
result = await delete_byor_key_from_litellm(user_id, byor_key)
# Assert
assert result is False

View File

@@ -1,96 +0,0 @@
from unittest.mock import patch
from urllib.parse import parse_qs, urlparse
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from pydantic import SecretStr
from server.routes.github_proxy import add_github_proxy_routes
@pytest.fixture
def app_with_github_proxy(monkeypatch):
"""Create a FastAPI app with github proxy routes enabled."""
# Enable the github proxy endpoints
monkeypatch.setenv('GITHUB_PROXY_ENDPOINTS', '1')
# Mock the config to have a jwt_secret
mock_config = type(
'MockConfig', (), {'jwt_secret': SecretStr('test-secret-key-for-testing')}
)()
app = FastAPI()
with patch('server.routes.github_proxy.GITHUB_PROXY_ENDPOINTS', True):
with patch('server.routes.github_proxy.config', mock_config):
add_github_proxy_routes(app)
# Return app and mock_config so we can use the same config in tests
return app, mock_config
def test_state_compress_encrypt_and_decrypt_decompress_roundtrip(
app_with_github_proxy, monkeypatch
):
"""
Verify the code path used by github_proxy_start -> github_proxy_callback:
- compress payload, encrypt, base64-encode (what the start code does)
- base64-decode, decrypt, decompress (what the callback code does)
This test exercises the actual endpoints to verify the roundtrip works correctly.
"""
app, mock_config = app_with_github_proxy
client = TestClient(app)
original_state = 'some-state-value'
original_redirect_uri = 'https://example.com/redirect'
# Call github_proxy_start endpoint - it should redirect to GitHub with encrypted state
with patch('server.routes.github_proxy.config', mock_config):
response = client.get(
'/github-proxy/test-subdomain/login/oauth/authorize',
params={
'state': original_state,
'redirect_uri': original_redirect_uri,
'client_id': 'test-client-id',
},
follow_redirects=False,
)
assert response.status_code == 307
redirect_url = response.headers['location']
# Verify it redirects to GitHub
assert redirect_url.startswith('https://github.com/login/oauth/authorize')
# Parse the redirect URL to get the encrypted state
parsed = urlparse(redirect_url)
query_params = parse_qs(parsed.query)
encrypted_state = query_params['state'][0]
# The redirect_uri should now point to our callback
assert 'github-proxy/callback' in query_params['redirect_uri'][0]
# Now simulate GitHub calling back with this encrypted state
with patch('server.routes.github_proxy.config', mock_config):
callback_response = client.get(
'/github-proxy/callback',
params={
'state': encrypted_state,
'code': 'test-auth-code',
},
follow_redirects=False,
)
assert callback_response.status_code == 307
final_redirect = callback_response.headers['location']
# Verify the callback redirects to the original redirect_uri
assert final_redirect.startswith(original_redirect_uri)
# Parse the final redirect to verify the state was decrypted correctly
final_parsed = urlparse(final_redirect)
final_params = parse_qs(final_parsed.query)
assert final_params['state'][0] == original_state
assert final_params['code'][0] == 'test-auth-code'

View File

@@ -1,5 +1,3 @@
import hashlib
import hmac
import json
from datetime import datetime
from unittest.mock import AsyncMock, MagicMock, patch
@@ -21,7 +19,6 @@ from server.routes.integration.jira import (
jira_events,
unlink_workspace,
validate_workspace_integration,
verify_jira_signature,
)
@@ -64,35 +61,25 @@ def mock_user_auth():
@pytest.mark.asyncio
@patch('server.routes.integration.jira.verify_jira_signature', new_callable=AsyncMock)
@patch('server.routes.integration.jira.jira_manager', new_callable=AsyncMock)
@patch('server.routes.integration.jira.redis_client', new_callable=MagicMock)
async def test_jira_events_invalid_signature(mock_redis, mock_verify, mock_request):
async def test_jira_events_invalid_signature(mock_redis, mock_manager, mock_request):
with patch('server.routes.integration.jira.JIRA_WEBHOOKS_ENABLED', True):
mock_request.body = AsyncMock(return_value=b'{}')
mock_request.json = AsyncMock(return_value={})
mock_verify.side_effect = HTTPException(
status_code=403, detail="Request signatures didn't match!"
)
mock_manager.validate_request.return_value = (False, None, None)
with pytest.raises(HTTPException) as exc_info:
await jira_events(
mock_request, MagicMock(), x_hub_signature='sha256=invalid'
)
await jira_events(mock_request, MagicMock())
assert exc_info.value.status_code == 403
assert exc_info.value.detail == "Request signatures didn't match!"
assert exc_info.value.detail == 'Invalid webhook signature!'
@pytest.mark.asyncio
@patch('server.routes.integration.jira.verify_jira_signature', new_callable=AsyncMock)
@patch('server.routes.integration.jira.jira_manager', new_callable=AsyncMock)
@patch('server.routes.integration.jira.redis_client')
async def test_jira_events_duplicate_request(mock_redis, mock_verify, mock_request):
async def test_jira_events_duplicate_request(mock_redis, mock_manager, mock_request):
with patch('server.routes.integration.jira.JIRA_WEBHOOKS_ENABLED', True):
mock_request.body = AsyncMock(return_value=b'{}')
mock_request.json = AsyncMock(return_value={})
mock_verify.return_value = None
mock_manager.validate_request.return_value = (True, 'sig123', 'payload')
mock_redis.exists.return_value = True
response = await jira_events(
mock_request, MagicMock(), x_hub_signature='sha256=sig123'
)
response = await jira_events(mock_request, MagicMock())
assert response.status_code == 200
body = json.loads(response.body)
assert body['success'] is True
@@ -361,21 +348,18 @@ class TestJiraLinkCreateValidation:
# Test jira_events error scenarios
@pytest.mark.asyncio
@patch('server.routes.integration.jira.jira_manager', new_callable=AsyncMock)
@patch('server.routes.integration.jira.verify_jira_signature', new_callable=AsyncMock)
@patch('server.routes.integration.jira.redis_client', new_callable=MagicMock)
async def test_jira_events_processing_success(
mock_redis, mock_verify, mock_manager, mock_request
):
async def test_jira_events_processing_success(mock_redis, mock_manager, mock_request):
with patch('server.routes.integration.jira.JIRA_WEBHOOKS_ENABLED', True):
mock_request.body = AsyncMock(return_value=b'{"test": "payload"}')
mock_request.json = AsyncMock(return_value={'test': 'payload'})
mock_verify.return_value = None
mock_manager.validate_request.return_value = (
True,
'sig123',
{'test': 'payload'},
)
mock_redis.exists.return_value = False
background_tasks = MagicMock()
response = await jira_events(
mock_request, background_tasks, x_hub_signature='sha256=sig123'
)
response = await jira_events(mock_request, background_tasks)
assert response.status_code == 200
body = json.loads(response.body)
@@ -385,241 +369,19 @@ async def test_jira_events_processing_success(
@pytest.mark.asyncio
@patch('server.routes.integration.jira.verify_jira_signature', new_callable=AsyncMock)
@patch('server.routes.integration.jira.jira_manager', new_callable=AsyncMock)
@patch('server.routes.integration.jira.redis_client', new_callable=MagicMock)
async def test_jira_events_general_exception(mock_redis, mock_verify, mock_request):
async def test_jira_events_general_exception(mock_redis, mock_manager, mock_request):
with patch('server.routes.integration.jira.JIRA_WEBHOOKS_ENABLED', True):
mock_request.body = AsyncMock(side_effect=Exception('Unexpected error'))
mock_request.json = AsyncMock(return_value={})
mock_manager.validate_request.side_effect = Exception('Unexpected error')
response = await jira_events(
mock_request, MagicMock(), x_hub_signature='sha256=sig123'
)
response = await jira_events(mock_request, MagicMock())
assert response.status_code == 500
body = json.loads(response.body)
assert 'Internal server error processing webhook' in body['error']
# Test verify_jira_signature
class TestVerifyJiraSignature:
"""Test Jira webhook signature verification."""
@pytest.fixture
def sample_payload(self):
"""Sample webhook payload with comment_created event."""
return {
'webhookEvent': 'comment_created',
'comment': {
'body': 'Test comment @openhands',
'author': {
'emailAddress': 'user@test.com',
'displayName': 'Test User',
'self': 'https://test.atlassian.net/rest/api/2/user?accountId=123',
},
},
'issue': {
'id': '12345',
'key': 'TEST-123',
'self': 'https://test.atlassian.net/rest/api/2/issue/12345',
},
}
@pytest.fixture
def mock_workspace(self):
"""Create a mock workspace."""
workspace = MagicMock()
workspace.id = 1
workspace.name = 'test.atlassian.net'
workspace.status = 'active'
workspace.webhook_secret = 'encrypted_secret'
return workspace
@pytest.mark.asyncio
@pytest.mark.parametrize(
'signature,expected_detail',
[
(None, 'x-hub-signature header is missing!'),
('', 'x-hub-signature header is missing!'),
],
ids=['signature_none', 'signature_empty'],
)
async def test_missing_signature(self, signature, expected_detail, sample_payload):
"""Test that missing or empty signature raises HTTPException."""
body = json.dumps(sample_payload).encode()
with pytest.raises(HTTPException) as exc_info:
await verify_jira_signature(body, signature, sample_payload)
assert exc_info.value.status_code == 403
assert exc_info.value.detail == expected_detail
@pytest.mark.asyncio
@pytest.mark.parametrize(
'payload',
[
{'webhookEvent': 'unknown_event'},
{'webhookEvent': 'comment_created', 'comment': {}},
{'webhookEvent': 'comment_created', 'comment': {'author': {}}},
{'webhookEvent': 'jira:issue_updated', 'user': {}},
{},
],
ids=[
'unknown_event',
'missing_author',
'missing_self_url',
'issue_updated_missing_self',
'empty_payload',
],
)
@patch('server.routes.integration.jira.jira_manager')
async def test_workspace_name_not_found(self, mock_manager, payload):
"""Test that missing workspace name in payload raises HTTPException."""
mock_manager.get_workspace_name_from_payload.return_value = None
body = json.dumps(payload).encode()
with pytest.raises(HTTPException) as exc_info:
await verify_jira_signature(body, 'valid_signature', payload)
assert exc_info.value.status_code == 403
assert exc_info.value.detail == 'Workspace name not found in payload'
@pytest.mark.asyncio
@patch('server.routes.integration.jira.jira_manager')
async def test_workspace_not_found_in_database(self, mock_manager, sample_payload):
"""Test that workspace not found in database raises HTTPException."""
mock_manager.get_workspace_name_from_payload.return_value = 'test.atlassian.net'
mock_manager.integration_store.get_workspace_by_name = AsyncMock(
return_value=None
)
body = json.dumps(sample_payload).encode()
with pytest.raises(HTTPException) as exc_info:
await verify_jira_signature(body, 'valid_signature', sample_payload)
assert exc_info.value.status_code == 403
assert exc_info.value.detail == 'Unidentified workspace'
@pytest.mark.asyncio
@pytest.mark.parametrize(
'workspace_status',
['inactive', 'disabled', 'pending'],
ids=['inactive', 'disabled', 'pending'],
)
@patch('server.routes.integration.jira.jira_manager')
async def test_workspace_not_active(
self, mock_manager, workspace_status, sample_payload, mock_workspace
):
"""Test that inactive workspace raises HTTPException."""
mock_workspace.status = workspace_status
mock_manager.get_workspace_name_from_payload.return_value = 'test.atlassian.net'
mock_manager.integration_store.get_workspace_by_name = AsyncMock(
return_value=mock_workspace
)
body = json.dumps(sample_payload).encode()
with pytest.raises(HTTPException) as exc_info:
await verify_jira_signature(body, 'valid_signature', sample_payload)
assert exc_info.value.status_code == 403
assert exc_info.value.detail == 'Workspace is inactive'
@pytest.mark.asyncio
@patch('server.routes.integration.jira.token_manager')
@patch('server.routes.integration.jira.jira_manager')
async def test_signature_mismatch(
self, mock_manager, mock_token_mgr, sample_payload, mock_workspace
):
"""Test that signature mismatch raises HTTPException."""
mock_manager.get_workspace_name_from_payload.return_value = 'test.atlassian.net'
mock_manager.integration_store.get_workspace_by_name = AsyncMock(
return_value=mock_workspace
)
mock_token_mgr.decrypt_text.return_value = 'webhook_secret'
body = json.dumps(sample_payload).encode()
with pytest.raises(HTTPException) as exc_info:
await verify_jira_signature(body, 'invalid_signature', sample_payload)
assert exc_info.value.status_code == 403
assert exc_info.value.detail == "Request signatures didn't match!"
@pytest.mark.asyncio
@patch('server.routes.integration.jira.token_manager')
@patch('server.routes.integration.jira.jira_manager')
async def test_valid_signature(
self, mock_manager, mock_token_mgr, sample_payload, mock_workspace
):
"""Test that valid signature passes verification."""
webhook_secret = 'webhook_secret'
mock_manager.get_workspace_name_from_payload.return_value = 'test.atlassian.net'
mock_manager.integration_store.get_workspace_by_name = AsyncMock(
return_value=mock_workspace
)
mock_token_mgr.decrypt_text.return_value = webhook_secret
body = json.dumps(sample_payload).encode()
valid_signature = hmac.new(
webhook_secret.encode(), body, hashlib.sha256
).hexdigest()
# Should not raise any exception
result = await verify_jira_signature(body, valid_signature, sample_payload)
assert result is None
@pytest.mark.asyncio
@pytest.mark.parametrize(
'event_type,payload_key,author_key',
[
('comment_created', 'comment', 'author'),
('jira:issue_updated', 'user', None),
],
ids=['comment_created', 'issue_updated'],
)
@patch('server.routes.integration.jira.token_manager')
@patch('server.routes.integration.jira.jira_manager')
async def test_valid_signature_different_events(
self,
mock_manager,
mock_token_mgr,
event_type,
payload_key,
author_key,
mock_workspace,
):
"""Test valid signature verification for different webhook events."""
webhook_secret = 'webhook_secret'
mock_manager.get_workspace_name_from_payload.return_value = 'test.atlassian.net'
mock_manager.integration_store.get_workspace_by_name = AsyncMock(
return_value=mock_workspace
)
mock_token_mgr.decrypt_text.return_value = webhook_secret
if event_type == 'comment_created':
payload = {
'webhookEvent': event_type,
'comment': {
'body': 'Test',
'author': {
'self': 'https://test.atlassian.net/rest/api/2/user?id=1'
},
},
}
else:
payload = {
'webhookEvent': event_type,
'user': {'self': 'https://test.atlassian.net/rest/api/2/user?id=1'},
}
body = json.dumps(payload).encode()
valid_signature = hmac.new(
webhook_secret.encode(), body, hashlib.sha256
).hexdigest()
result = await verify_jira_signature(body, valid_signature, payload)
assert result is None
# Test create_jira_workspace error scenarios
@pytest.mark.asyncio
@patch('server.routes.integration.jira.get_user_auth')

File diff suppressed because it is too large Load Diff

View File

@@ -82,7 +82,7 @@ class TestGetUserId:
session_maker_with_minimal_fixtures,
):
user_id = _get_user_id('mock-conversation-id')
assert user_id == '5594c7b6-f959-4b81-92e9-b09c206f5081'
assert user_id == 'mock-user-id'
def test_get_user_id_conversation_not_found(self, session_maker):
"""Test getting user ID when conversation doesn't exist."""
@@ -105,12 +105,10 @@ class TestGetSessionApiKey:
return_value=[mock_agent_loop_info]
)
api_key = await _get_session_api_key(
'5594c7b6-f959-4b81-92e9-b09c206f5081', 'conv-456'
)
api_key = await _get_session_api_key('user-123', 'conv-456')
assert api_key == 'test-api-key'
mock_manager.get_agent_loop_info.assert_called_once_with(
'5594c7b6-f959-4b81-92e9-b09c206f5081', filter_to_sids={'conv-456'}
'user-123', filter_to_sids={'conv-456'}
)
@pytest.mark.asyncio
@@ -120,9 +118,7 @@ class TestGetSessionApiKey:
mock_manager.get_agent_loop_info = AsyncMock(return_value=[])
with pytest.raises(IndexError):
await _get_session_api_key(
'5594c7b6-f959-4b81-92e9-b09c206f5081', 'conv-456'
)
await _get_session_api_key('user-123', 'conv-456')
class TestProcessEvent:
@@ -146,15 +142,10 @@ class TestProcessEvent:
mock_event = MagicMock()
mock_event_from_dict.return_value = mock_event
await process_event(
'5594c7b6-f959-4b81-92e9-b09c206f5081',
'conv-456',
'events/event-1.json',
content,
)
await process_event('user-123', 'conv-456', 'events/event-1.json', content)
mock_file_store.write.assert_called_once_with(
'users/5594c7b6-f959-4b81-92e9-b09c206f5081/conversations/conv-456/events/event-1.json',
'users/user-123/conversations/conv-456/events/event-1.json',
json.dumps(content),
)
mock_event_from_dict.assert_called_once_with(content)
@@ -186,19 +177,14 @@ class TestProcessEvent:
)
mock_event_from_dict.return_value = mock_event
await process_event(
'5594c7b6-f959-4b81-92e9-b09c206f5081',
'conv-456',
'events/event-1.json',
content,
)
await process_event('user-123', 'conv-456', 'events/event-1.json', content)
mock_file_store.write.assert_called_once()
mock_event_from_dict.assert_called_once_with(content)
mock_invoke_callbacks.assert_called_once_with('conv-456', mock_event)
mock_update_working_seconds.assert_called_once()
mock_event_store_class.assert_called_once_with(
'conv-456', mock_file_store, '5594c7b6-f959-4b81-92e9-b09c206f5081'
'conv-456', mock_file_store, 'user-123'
)
@pytest.mark.asyncio
@@ -226,12 +212,7 @@ class TestProcessEvent:
mock_event.agent_state = 'running' # Set RUNNING state to skip the update
mock_event_from_dict.return_value = mock_event
await process_event(
'5594c7b6-f959-4b81-92e9-b09c206f5081',
'conv-456',
'events/event-1.json',
content,
)
await process_event('user-123', 'conv-456', 'events/event-1.json', content)
mock_file_store.write.assert_called_once()
mock_event_from_dict.assert_called_once_with(content)

View File

@@ -1,368 +0,0 @@
"""Tests for SaasSQLAppConversationInfoService.
This module tests the SAAS implementation of SQLAppConversationInfoService,
focusing on user isolation, SAAS metadata handling, and multi-tenant functionality.
"""
from datetime import datetime, timezone
from typing import AsyncGenerator
from unittest.mock import AsyncMock, MagicMock
from uuid import UUID, uuid4
import pytest
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.pool import StaticPool
from enterprise.server.utils.saas_app_conversation_info_injector import (
SaasSQLAppConversationInfoService,
)
from openhands.app_server.app_conversation.app_conversation_models import (
AppConversationInfo,
)
from openhands.app_server.user.specifiy_user_context import SpecifyUserContext
from openhands.app_server.utils.sql_utils import Base
from openhands.integrations.service_types import ProviderType
from openhands.storage.data_models.conversation_metadata import ConversationTrigger
@pytest.fixture
async def async_engine():
"""Create an async SQLite engine for testing."""
engine = create_async_engine(
'sqlite+aiosqlite:///:memory:',
poolclass=StaticPool,
connect_args={'check_same_thread': False},
echo=False,
)
# Create all tables
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield engine
await engine.dispose()
@pytest.fixture
async def async_session(async_engine) -> AsyncGenerator[AsyncSession, None]:
"""Create an async session for testing."""
async_session_maker = async_sessionmaker(
async_engine, class_=AsyncSession, expire_on_commit=False
)
async with async_session_maker() as db_session:
yield db_session
@pytest.fixture
def service(async_session) -> SaasSQLAppConversationInfoService:
"""Create a SQLAppConversationInfoService instance for testing."""
return SaasSQLAppConversationInfoService(
db_session=async_session, user_context=SpecifyUserContext(user_id=None)
)
@pytest.fixture
def service_with_user(async_session) -> SaasSQLAppConversationInfoService:
"""Create a SQLAppConversationInfoService instance with a user_id for testing."""
return SaasSQLAppConversationInfoService(
db_session=async_session,
user_context=SpecifyUserContext(user_id='a1111111-1111-1111-1111-111111111111'),
)
@pytest.fixture
def sample_conversation_info() -> AppConversationInfo:
"""Create a sample AppConversationInfo for testing."""
return AppConversationInfo(
id=uuid4(),
created_by_user_id='a1111111-1111-1111-1111-111111111111',
sandbox_id='sandbox_123',
selected_repository='https://github.com/test/repo',
selected_branch='main',
git_provider=ProviderType.GITHUB,
title='Test Conversation',
trigger=ConversationTrigger.GUI,
pr_number=[123, 456],
llm_model='gpt-4',
metrics=None,
created_at=datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc),
updated_at=datetime(2024, 1, 1, 12, 30, 0, tzinfo=timezone.utc),
)
@pytest.fixture
def multiple_conversation_infos() -> list[AppConversationInfo]:
"""Create multiple AppConversationInfo instances for testing."""
base_time = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc)
return [
AppConversationInfo(
id=uuid4(),
created_by_user_id=None,
sandbox_id=f'sandbox_{i}',
selected_repository=f'https://github.com/test/repo{i}',
selected_branch='main',
git_provider=ProviderType.GITHUB,
title=f'Test Conversation {i}',
trigger=ConversationTrigger.GUI,
pr_number=[i * 100],
llm_model='gpt-4',
metrics=None,
created_at=base_time.replace(hour=12 + i),
updated_at=base_time.replace(hour=12 + i, minute=30),
)
for i in range(1, 6) # Create 5 conversations
]
@pytest.fixture
def mock_db_session():
"""Create a mock database session."""
return AsyncMock()
@pytest.fixture
def user1_context():
"""Create user context for user1."""
return SpecifyUserContext(user_id='a1111111-1111-1111-1111-111111111111')
@pytest.fixture
def user2_context():
"""Create user context for user2."""
return SpecifyUserContext(user_id='b2222222-2222-2222-2222-222222222222')
@pytest.fixture
def saas_service_user1(mock_db_session, user1_context):
"""Create a SaasSQLAppConversationInfoService instance for user1."""
return SaasSQLAppConversationInfoService(
db_session=mock_db_session, user_context=user1_context
)
@pytest.fixture
def saas_service_user2(mock_db_session, user2_context):
"""Create a SaasSQLAppConversationInfoService instance for user2."""
return SaasSQLAppConversationInfoService(
db_session=mock_db_session, user_context=user2_context
)
class TestSaasSQLAppConversationInfoService:
"""Test suite for SaasSQLAppConversationInfoService."""
def test_service_initialization(
self,
saas_service_user1: SaasSQLAppConversationInfoService,
user1_context: SpecifyUserContext,
):
"""Test that the SAAS service is properly initialized."""
assert saas_service_user1.user_context == user1_context
assert saas_service_user1.db_session is not None
@pytest.mark.asyncio
async def test_user_context_isolation(
self,
saas_service_user1: SaasSQLAppConversationInfoService,
saas_service_user2: SaasSQLAppConversationInfoService,
):
"""Test that different service instances have different user contexts."""
user1_id = await saas_service_user1.user_context.get_user_id()
user2_id = await saas_service_user2.user_context.get_user_id()
assert user1_id == 'a1111111-1111-1111-1111-111111111111'
assert user2_id == 'b2222222-2222-2222-2222-222222222222'
assert user1_id != user2_id
@pytest.mark.asyncio
async def test_secure_select_includes_user_filtering(
self,
saas_service_user1: SaasSQLAppConversationInfoService,
):
"""Test that _secure_select method includes user filtering."""
# This test verifies that the _secure_select method exists and can be called
# The actual SQL generation is tested implicitly through integration
query = await saas_service_user1._secure_select()
assert query is not None
@pytest.mark.asyncio
async def test_to_info_with_user_id_functionality(
self,
saas_service_user1: SaasSQLAppConversationInfoService,
):
"""Test that _to_info_with_user_id properly sets user_id from SAAS metadata."""
from storage.stored_conversation_metadata import StoredConversationMetadata
from storage.stored_conversation_metadata_saas import (
StoredConversationMetadataSaas,
)
# Create mock metadata objects
stored_metadata = MagicMock(spec=StoredConversationMetadata)
stored_metadata.conversation_id = '12345678-1234-5678-1234-567812345678'
stored_metadata.parent_conversation_id = None
stored_metadata.title = 'Test Conversation'
stored_metadata.sandbox_id = 'test-sandbox'
stored_metadata.selected_repository = None
stored_metadata.selected_branch = None
stored_metadata.git_provider = None
stored_metadata.trigger = None
stored_metadata.pr_number = []
stored_metadata.llm_model = None
from datetime import datetime, timezone
stored_metadata.created_at = datetime.now(timezone.utc)
stored_metadata.last_updated_at = datetime.now(timezone.utc)
stored_metadata.accumulated_cost = 0.0
stored_metadata.prompt_tokens = 0
stored_metadata.completion_tokens = 0
stored_metadata.total_tokens = 0
stored_metadata.max_budget_per_task = None
stored_metadata.cache_read_tokens = 0
stored_metadata.cache_write_tokens = 0
stored_metadata.reasoning_tokens = 0
stored_metadata.context_window = 0
stored_metadata.per_turn_token = 0
saas_metadata = MagicMock(spec=StoredConversationMetadataSaas)
saas_metadata.user_id = UUID('a1111111-1111-1111-1111-111111111111')
saas_metadata.org_id = UUID('a1111111-1111-1111-1111-111111111111')
# Test the _to_info_with_user_id method
result = saas_service_user1._to_info_with_user_id(
stored_metadata, saas_metadata
)
# Verify that the user_id from SAAS metadata is used
assert result.created_by_user_id == 'a1111111-1111-1111-1111-111111111111'
assert result.title == 'Test Conversation'
assert result.sandbox_id == 'test-sandbox'
@pytest.mark.asyncio
async def test_user_isolation(
self,
async_session: AsyncSession,
multiple_conversation_infos: list[AppConversationInfo],
):
"""Test that user isolation works correctly."""
from unittest.mock import MagicMock
from storage.user import User
# Mock the database session execute method to return mock users
# This mock intercepts User queries and returns a mock user object
# with user_id and org_id the same as the user_id_uuid from the query
original_execute = async_session.execute
async def mock_execute(query):
query_str = str(query)
# Check if this is a User query
if '"user"' in query_str.lower() and '"user".id' in query_str.lower():
# Extract the UUID from the query parameters
# The query will have bound parameters, we need to get the UUID value
if hasattr(query, 'compile'):
try:
compiled = query.compile(compile_kwargs={'literal_binds': True})
query_with_params = str(compiled)
# Extract UUID from the query string
import re
# Try both formats: with dashes and without dashes
uuid_pattern_with_dashes = r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}'
uuid_pattern_without_dashes = r'[a-f0-9]{32}'
uuid_match = re.search(
uuid_pattern_with_dashes, query_with_params
)
if not uuid_match:
uuid_match = re.search(
uuid_pattern_without_dashes, query_with_params
)
if uuid_match:
user_id_str = uuid_match.group(0)
# If the UUID doesn't have dashes, add them
if len(user_id_str) == 32 and '-' not in user_id_str:
# Convert from 'a1111111111111111111111111111111' to 'a1111111-1111-1111-1111-111111111111'
user_id_str = f'{user_id_str[:8]}-{user_id_str[8:12]}-{user_id_str[12:16]}-{user_id_str[16:20]}-{user_id_str[20:]}'
user_id_uuid = UUID(user_id_str)
# Create a mock user with user_id and org_id the same as user_id_uuid
mock_user = MagicMock(spec=User)
mock_user.id = user_id_uuid
mock_user.current_org_id = user_id_uuid
# Create a mock result
mock_result = MagicMock()
mock_result.scalar_one_or_none.return_value = mock_user
return mock_result
except Exception:
# If there's any error in parsing, fall back to original execute
pass
# For all other queries, use the original execute method
return await original_execute(query)
# Apply the mock
async_session.execute = mock_execute
# Create services for different users
user1_service = SaasSQLAppConversationInfoService(
db_session=async_session,
user_context=SpecifyUserContext(
user_id='a1111111-1111-1111-1111-111111111111'
),
)
user2_service = SaasSQLAppConversationInfoService(
db_session=async_session,
user_context=SpecifyUserContext(
user_id='b2222222-2222-2222-2222-222222222222'
),
)
# Create conversations for different users
user1_info = AppConversationInfo(
id=uuid4(),
created_by_user_id='a1111111-1111-1111-1111-111111111111',
sandbox_id='sandbox_user1',
title='User 1 Conversation',
)
user2_info = AppConversationInfo(
id=uuid4(),
created_by_user_id='b2222222-2222-2222-2222-222222222222',
sandbox_id='sandbox_user2',
title='User 2 Conversation',
)
# Save conversations
await user1_service.save_app_conversation_info(user1_info)
await user2_service.save_app_conversation_info(user2_info)
# User 1 should only see their conversation
user1_page = await user1_service.search_app_conversation_info()
assert len(user1_page.items) == 1
assert (
user1_page.items[0].created_by_user_id
== 'a1111111-1111-1111-1111-111111111111'
)
# User 2 should only see their conversation
user2_page = await user2_service.search_app_conversation_info()
assert len(user2_page.items) == 1
assert (
user2_page.items[0].created_by_user_id
== 'b2222222-2222-2222-2222-222222222222'
)
# User 1 should not be able to get user 2's conversation
user2_from_user1 = await user1_service.get_app_conversation_info(user2_info.id)
assert user2_from_user1 is None
# User 2 should not be able to get user 1's conversation
user1_from_user2 = await user2_service.get_app_conversation_info(user1_info.id)
assert user1_from_user2 is None

View File

@@ -1,5 +1,5 @@
from datetime import UTC, datetime, timedelta
from unittest.mock import MagicMock, patch
from unittest.mock import MagicMock
import pytest
from storage.api_key_store import ApiKeyStore
@@ -19,14 +19,6 @@ def mock_session_maker(mock_session):
return session_maker
@pytest.fixture
def mock_user():
"""Mock user with org_id."""
user = MagicMock()
user.current_org_id = 'test-org-123'
return user
@pytest.fixture
def api_key_store(mock_session_maker):
return ApiKeyStore(mock_session_maker)
@@ -41,13 +33,11 @@ def test_generate_api_key(api_key_store):
assert len(key) == len('sk-oh-') + 32
@patch('storage.api_key_store.UserStore.get_user_by_id')
def test_create_api_key(mock_get_user, api_key_store, mock_session, mock_user):
def test_create_api_key(api_key_store, mock_session):
"""Test creating an API key."""
# Setup
user_id = 'test-user-123'
name = 'Test Key'
mock_get_user.return_value = mock_user
api_key_store.generate_api_key = MagicMock(return_value='test-api-key')
# Execute
@@ -55,15 +45,10 @@ def test_create_api_key(mock_get_user, api_key_store, mock_session, mock_user):
# Verify
assert result == 'test-api-key'
mock_get_user.assert_called_once_with(user_id)
mock_session.add.assert_called_once()
mock_session.commit.assert_called_once()
api_key_store.generate_api_key.assert_called_once()
# Verify the ApiKey was created with the correct org_id
added_api_key = mock_session.add.call_args[0][0]
assert added_api_key.org_id == mock_user.current_org_id
def test_validate_api_key_valid(api_key_store, mock_session):
"""Test validating a valid API key."""
@@ -219,12 +204,10 @@ def test_delete_api_key_by_id(api_key_store, mock_session):
mock_session.commit.assert_called_once()
@patch('storage.api_key_store.UserStore.get_user_by_id')
def test_list_api_keys(mock_get_user, api_key_store, mock_session, mock_user):
def test_list_api_keys(api_key_store, mock_session):
"""Test listing API keys for a user."""
# Setup
user_id = 'test-user-123'
mock_get_user.return_value = mock_user
now = datetime.now(UTC)
mock_key1 = MagicMock()
mock_key1.id = 1
@@ -240,17 +223,15 @@ def test_list_api_keys(mock_get_user, api_key_store, mock_session, mock_user):
mock_key2.last_used_at = None
mock_key2.expires_at = None
# Mock the chained query calls for filtering by user_id and org_id
mock_query = mock_session.query.return_value
mock_filter_user = mock_query.filter.return_value
mock_filter_org = mock_filter_user.filter.return_value
mock_filter_org.all.return_value = [mock_key1, mock_key2]
mock_session.query.return_value.filter.return_value.all.return_value = [
mock_key1,
mock_key2,
]
# Execute
result = api_key_store.list_api_keys(user_id)
# Verify
mock_get_user.assert_called_once_with(user_id)
assert len(result) == 2
assert result[0]['id'] == 1
assert result[0]['name'] == 'Key 1'
@@ -263,59 +244,3 @@ def test_list_api_keys(mock_get_user, api_key_store, mock_session, mock_user):
assert result[1]['created_at'] == now
assert result[1]['last_used_at'] is None
assert result[1]['expires_at'] is None
@patch('storage.api_key_store.UserStore.get_user_by_id')
def test_retrieve_mcp_api_key(mock_get_user, api_key_store, mock_session, mock_user):
"""Test retrieving MCP API key for a user."""
# Setup
user_id = 'test-user-123'
mock_get_user.return_value = mock_user
mock_mcp_key = MagicMock()
mock_mcp_key.name = 'MCP_API_KEY'
mock_mcp_key.key = 'mcp-test-key'
mock_other_key = MagicMock()
mock_other_key.name = 'Other Key'
mock_other_key.key = 'other-test-key'
# Mock the chained query calls for filtering by user_id and org_id
mock_query = mock_session.query.return_value
mock_filter_user = mock_query.filter.return_value
mock_filter_org = mock_filter_user.filter.return_value
mock_filter_org.all.return_value = [mock_other_key, mock_mcp_key]
# Execute
result = api_key_store.retrieve_mcp_api_key(user_id)
# Verify
mock_get_user.assert_called_once_with(user_id)
assert result == 'mcp-test-key'
@patch('storage.api_key_store.UserStore.get_user_by_id')
def test_retrieve_mcp_api_key_not_found(
mock_get_user, api_key_store, mock_session, mock_user
):
"""Test retrieving MCP API key when none exists."""
# Setup
user_id = 'test-user-123'
mock_get_user.return_value = mock_user
mock_other_key = MagicMock()
mock_other_key.name = 'Other Key'
mock_other_key.key = 'other-test-key'
# Mock the chained query calls for filtering by user_id and org_id
mock_query = mock_session.query.return_value
mock_filter_user = mock_query.filter.return_value
mock_filter_org = mock_filter_user.filter.return_value
mock_filter_org.all.return_value = [mock_other_key]
# Execute
result = api_key_store.retrieve_mcp_api_key(user_id)
# Verify
mock_get_user.assert_called_once_with(user_id)
assert result is None

View File

@@ -130,7 +130,6 @@ async def test_keycloak_callback_user_not_allowed(mock_request):
with (
patch('server.routes.auth.token_manager') as mock_token_manager,
patch('server.routes.auth.user_verifier') as mock_verifier,
patch('server.routes.auth.UserStore') as mock_user_store,
):
mock_token_manager.get_keycloak_tokens = AsyncMock(
return_value=('test_access_token', 'test_refresh_token')
@@ -145,15 +144,6 @@ async def test_keycloak_callback_user_not_allowed(mock_request):
)
mock_token_manager.store_idp_tokens = AsyncMock()
# Mock the user creation
mock_user = MagicMock()
mock_user.id = 'test_user_id'
mock_user.current_org_id = 'test_org_id'
mock_user.accepted_tos = None
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
mock_user_store.create_user = AsyncMock(return_value=mock_user)
mock_user_store.migrate_user = AsyncMock(return_value=mock_user)
mock_verifier.is_active.return_value = True
mock_verifier.is_user_allowed.return_value = False
@@ -175,19 +165,20 @@ async def test_keycloak_callback_success_with_valid_offline_token(mock_request):
patch('server.routes.auth.token_manager') as mock_token_manager,
patch('server.routes.auth.user_verifier') as mock_verifier,
patch('server.routes.auth.set_response_cookie') as mock_set_cookie,
patch('server.routes.auth.UserStore') as mock_user_store,
patch('server.routes.auth.session_maker') as mock_session_maker,
patch('server.routes.auth.posthog') as mock_posthog,
):
# Mock user with accepted_tos
mock_user = MagicMock()
mock_user.id = 'test_user_id'
mock_user.current_org_id = 'test_org_id'
mock_user.accepted_tos = '2025-01-01'
# Mock the session and query results
mock_session = MagicMock()
mock_session_maker.return_value.__enter__.return_value = mock_session
mock_query = MagicMock()
mock_session.query.return_value = mock_query
mock_query.filter.return_value = mock_query
# Setup UserStore mocks
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
mock_user_store.create_user = AsyncMock(return_value=mock_user)
mock_user_store.migrate_user = AsyncMock(return_value=mock_user)
# Mock user settings with accepted_tos
mock_user_settings = MagicMock()
mock_user_settings.accepted_tos = '2025-01-01'
mock_query.first.return_value = mock_user_settings
mock_token_manager.get_keycloak_tokens = AsyncMock(
return_value=('test_access_token', 'test_refresh_token')
@@ -237,7 +228,6 @@ async def test_keycloak_callback_email_not_verified(mock_request):
patch('server.routes.auth.token_manager') as mock_token_manager,
patch('server.routes.auth.user_verifier') as mock_verifier,
patch('server.routes.email.verify_email', mock_verify_email),
patch('server.routes.auth.UserStore') as mock_user_store,
):
mock_token_manager.get_keycloak_tokens = AsyncMock(
return_value=('test_access_token', 'test_refresh_token')
@@ -253,13 +243,6 @@ async def test_keycloak_callback_email_not_verified(mock_request):
mock_token_manager.store_idp_tokens = AsyncMock()
mock_verifier.is_active.return_value = False
# Mock the user creation
mock_user = MagicMock()
mock_user.id = 'test_user_id'
mock_user.current_org_id = 'test_org_id'
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
mock_user_store.create_user = AsyncMock(return_value=mock_user)
# Act
result = await keycloak_callback(
code='test_code', state='test_state', request=mock_request
@@ -284,7 +267,6 @@ async def test_keycloak_callback_email_not_verified_missing_field(mock_request):
patch('server.routes.auth.token_manager') as mock_token_manager,
patch('server.routes.auth.user_verifier') as mock_verifier,
patch('server.routes.email.verify_email', mock_verify_email),
patch('server.routes.auth.UserStore') as mock_user_store,
):
mock_token_manager.get_keycloak_tokens = AsyncMock(
return_value=('test_access_token', 'test_refresh_token')
@@ -300,13 +282,6 @@ async def test_keycloak_callback_email_not_verified_missing_field(mock_request):
mock_token_manager.store_idp_tokens = AsyncMock()
mock_verifier.is_active.return_value = False
# Mock the user creation
mock_user = MagicMock()
mock_user.id = 'test_user_id'
mock_user.current_org_id = 'test_org_id'
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
mock_user_store.create_user = AsyncMock(return_value=mock_user)
# Act
result = await keycloak_callback(
code='test_code', state='test_state', request=mock_request
@@ -334,20 +309,20 @@ async def test_keycloak_callback_success_without_offline_token(mock_request):
),
patch('server.routes.auth.KEYCLOAK_REALM_NAME', 'test-realm'),
patch('server.routes.auth.KEYCLOAK_CLIENT_ID', 'test-client'),
patch('server.routes.auth.UserStore') as mock_user_store,
patch('server.routes.auth.session_maker') as mock_session_maker,
patch('server.routes.auth.posthog') as mock_posthog,
):
# Mock user with accepted_tos
mock_user = MagicMock()
mock_user.id = 'test_user_id'
mock_user.current_org_id = 'test_org_id'
mock_user.accepted_tos = '2025-01-01'
# Setup UserStore mocks
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
mock_user_store.create_user = AsyncMock(return_value=mock_user)
mock_user_store.migrate_user = AsyncMock(return_value=mock_user)
# Mock the session and query results
mock_session = MagicMock()
mock_session_maker.return_value.__enter__.return_value = mock_session
mock_query = MagicMock()
mock_session.query.return_value = mock_query
mock_query.filter.return_value = mock_query
# Mock user settings with accepted_tos
mock_user_settings = MagicMock()
mock_user_settings.accepted_tos = '2025-01-01'
mock_query.first.return_value = mock_user_settings
mock_token_manager.get_keycloak_tokens = AsyncMock(
return_value=('test_access_token', 'test_refresh_token')
)
@@ -560,7 +535,6 @@ async def test_keycloak_callback_blocked_email_domain(mock_request):
with (
patch('server.routes.auth.token_manager') as mock_token_manager,
patch('server.routes.auth.domain_blocker') as mock_domain_blocker,
patch('server.routes.auth.UserStore') as mock_user_store,
):
mock_token_manager.get_keycloak_tokens = AsyncMock(
return_value=('test_access_token', 'test_refresh_token')
@@ -575,14 +549,6 @@ async def test_keycloak_callback_blocked_email_domain(mock_request):
)
mock_token_manager.disable_keycloak_user = AsyncMock()
# Mock the user creation
mock_user = MagicMock()
mock_user.id = 'test_user_id'
mock_user.current_org_id = 'test_org_id'
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
mock_user_store.create_user = AsyncMock(return_value=mock_user)
mock_domain_blocker.is_active.return_value = True
mock_domain_blocker.is_domain_blocked.return_value = True
# Act
@@ -610,7 +576,6 @@ async def test_keycloak_callback_allowed_email_domain(mock_request):
patch('server.routes.auth.domain_blocker') as mock_domain_blocker,
patch('server.routes.auth.user_verifier') as mock_verifier,
patch('server.routes.auth.session_maker') as mock_session_maker,
patch('server.routes.auth.UserStore') as mock_user_store,
):
mock_session = MagicMock()
mock_session_maker.return_value.__enter__.return_value = mock_session
@@ -637,15 +602,6 @@ async def test_keycloak_callback_allowed_email_domain(mock_request):
mock_token_manager.store_idp_tokens = AsyncMock()
mock_token_manager.validate_offline_token = AsyncMock(return_value=True)
# Mock the user creation
mock_user = MagicMock()
mock_user.id = 'test_user_id'
mock_user.current_org_id = 'test_org_id'
mock_user.accepted_tos = '2025-01-01'
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
mock_user_store.create_user = AsyncMock(return_value=mock_user)
mock_domain_blocker.is_active.return_value = True
mock_domain_blocker.is_domain_blocked.return_value = False
mock_verifier.is_active.return_value = True
@@ -673,7 +629,6 @@ async def test_keycloak_callback_domain_blocking_inactive(mock_request):
patch('server.routes.auth.domain_blocker') as mock_domain_blocker,
patch('server.routes.auth.user_verifier') as mock_verifier,
patch('server.routes.auth.session_maker') as mock_session_maker,
patch('server.routes.auth.UserStore') as mock_user_store,
):
mock_session = MagicMock()
mock_session_maker.return_value.__enter__.return_value = mock_session
@@ -700,15 +655,6 @@ async def test_keycloak_callback_domain_blocking_inactive(mock_request):
mock_token_manager.store_idp_tokens = AsyncMock()
mock_token_manager.validate_offline_token = AsyncMock(return_value=True)
# Mock the user creation
mock_user = MagicMock()
mock_user.id = 'test_user_id'
mock_user.current_org_id = 'test_org_id'
mock_user.accepted_tos = '2025-01-01'
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
mock_user_store.create_user = AsyncMock(return_value=mock_user)
mock_domain_blocker.is_active.return_value = False
mock_domain_blocker.is_domain_blocked.return_value = False
mock_verifier.is_active.return_value = True
@@ -734,7 +680,6 @@ async def test_keycloak_callback_missing_email(mock_request):
patch('server.routes.auth.domain_blocker') as mock_domain_blocker,
patch('server.routes.auth.user_verifier') as mock_verifier,
patch('server.routes.auth.session_maker') as mock_session_maker,
patch('server.routes.auth.UserStore') as mock_user_store,
):
mock_session = MagicMock()
mock_session_maker.return_value.__enter__.return_value = mock_session
@@ -761,16 +706,6 @@ async def test_keycloak_callback_missing_email(mock_request):
mock_token_manager.store_idp_tokens = AsyncMock()
mock_token_manager.validate_offline_token = AsyncMock(return_value=True)
# Mock the user creation
mock_user = MagicMock()
mock_user.id = 'test_user_id'
mock_user.current_org_id = 'test_org_id'
mock_user.accepted_tos = '2025-01-01'
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
mock_user_store.create_user = AsyncMock(return_value=mock_user)
mock_domain_blocker.is_active.return_value = True
mock_verifier.is_active.return_value = True
mock_verifier.is_user_allowed.return_value = True
@@ -790,7 +725,6 @@ async def test_keycloak_callback_duplicate_email_detected(mock_request):
"""Test keycloak_callback when duplicate email is detected."""
with (
patch('server.routes.auth.token_manager') as mock_token_manager,
patch('server.routes.auth.UserStore') as mock_user_store,
):
# Arrange
mock_token_manager.get_keycloak_tokens = AsyncMock(
@@ -807,13 +741,6 @@ async def test_keycloak_callback_duplicate_email_detected(mock_request):
mock_token_manager.check_duplicate_base_email = AsyncMock(return_value=True)
mock_token_manager.delete_keycloak_user = AsyncMock(return_value=True)
# Mock the user creation
mock_user = MagicMock()
mock_user.id = 'test_user_id'
mock_user.current_org_id = 'test_org_id'
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
mock_user_store.create_user = AsyncMock(return_value=mock_user)
# Act
result = await keycloak_callback(
code='test_code', state='test_state', request=mock_request
@@ -834,7 +761,6 @@ async def test_keycloak_callback_duplicate_email_deletion_fails(mock_request):
"""Test keycloak_callback when duplicate is detected but deletion fails."""
with (
patch('server.routes.auth.token_manager') as mock_token_manager,
patch('server.routes.auth.UserStore') as mock_user_store,
):
# Arrange
mock_token_manager.get_keycloak_tokens = AsyncMock(
@@ -851,13 +777,6 @@ async def test_keycloak_callback_duplicate_email_deletion_fails(mock_request):
mock_token_manager.check_duplicate_base_email = AsyncMock(return_value=True)
mock_token_manager.delete_keycloak_user = AsyncMock(return_value=False)
# Mock the user creation
mock_user = MagicMock()
mock_user.id = 'test_user_id'
mock_user.current_org_id = 'test_org_id'
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
mock_user_store.create_user = AsyncMock(return_value=mock_user)
# Act
result = await keycloak_callback(
code='test_code', state='test_state', request=mock_request
@@ -877,7 +796,6 @@ async def test_keycloak_callback_duplicate_check_exception(mock_request):
patch('server.routes.auth.token_manager') as mock_token_manager,
patch('server.routes.auth.user_verifier') as mock_verifier,
patch('server.routes.auth.session_maker') as mock_session_maker,
patch('server.routes.auth.UserStore') as mock_user_store,
):
# Arrange
mock_session = MagicMock()
@@ -907,14 +825,6 @@ async def test_keycloak_callback_duplicate_check_exception(mock_request):
mock_token_manager.store_idp_tokens = AsyncMock()
mock_token_manager.validate_offline_token = AsyncMock(return_value=True)
# Mock the user creation
mock_user = MagicMock()
mock_user.id = 'test_user_id'
mock_user.current_org_id = 'test_org_id'
mock_user.accepted_tos = '2025-01-01'
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
mock_user_store.create_user = AsyncMock(return_value=mock_user)
mock_verifier.is_active.return_value = True
mock_verifier.is_user_allowed.return_value = True
@@ -936,7 +846,6 @@ async def test_keycloak_callback_no_duplicate_email(mock_request):
patch('server.routes.auth.token_manager') as mock_token_manager,
patch('server.routes.auth.user_verifier') as mock_verifier,
patch('server.routes.auth.session_maker') as mock_session_maker,
patch('server.routes.auth.UserStore') as mock_user_store,
):
# Arrange
mock_session = MagicMock()
@@ -964,14 +873,6 @@ async def test_keycloak_callback_no_duplicate_email(mock_request):
mock_token_manager.store_idp_tokens = AsyncMock()
mock_token_manager.validate_offline_token = AsyncMock(return_value=True)
# Mock the user creation
mock_user = MagicMock()
mock_user.id = 'test_user_id'
mock_user.current_org_id = 'test_org_id'
mock_user.accepted_tos = '2025-01-01'
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
mock_user_store.create_user = AsyncMock(return_value=mock_user)
mock_verifier.is_active.return_value = True
mock_verifier.is_user_allowed.return_value = True
@@ -997,7 +898,6 @@ async def test_keycloak_callback_no_email_in_user_info(mock_request):
patch('server.routes.auth.token_manager') as mock_token_manager,
patch('server.routes.auth.user_verifier') as mock_verifier,
patch('server.routes.auth.session_maker') as mock_session_maker,
patch('server.routes.auth.UserStore') as mock_user_store,
):
# Arrange
mock_session = MagicMock()
@@ -1024,14 +924,6 @@ async def test_keycloak_callback_no_email_in_user_info(mock_request):
mock_token_manager.store_idp_tokens = AsyncMock()
mock_token_manager.validate_offline_token = AsyncMock(return_value=True)
# Mock the user creation
mock_user = MagicMock()
mock_user.id = 'test_user_id'
mock_user.current_org_id = 'test_org_id'
mock_user.accepted_tos = '2025-01-01'
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
mock_user_store.create_user = AsyncMock(return_value=mock_user)
mock_verifier.is_active.return_value = True
mock_verifier.is_user_allowed.return_value = True
@@ -1151,7 +1043,6 @@ class TestKeycloakCallbackRecaptcha:
patch('server.routes.auth.set_response_cookie'),
patch('server.routes.auth.posthog'),
patch('server.routes.email.verify_email', new_callable=AsyncMock),
patch('server.routes.auth.UserStore') as mock_user_store,
):
mock_session = MagicMock()
mock_session_maker.return_value.__enter__.return_value = mock_session
@@ -1180,14 +1071,6 @@ class TestKeycloakCallbackRecaptcha:
return_value=False
)
# Setup UserStore mocks
mock_user = MagicMock()
mock_user.id = 'test_user_id'
mock_user.current_org_id = 'test_org_id'
mock_user.accepted_tos = '2025-01-01'
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
mock_user_store.create_user = AsyncMock(return_value=mock_user)
mock_verifier.is_active.return_value = True
mock_verifier.is_user_allowed.return_value = True
@@ -1229,7 +1112,6 @@ class TestKeycloakCallbackRecaptcha:
patch('server.routes.auth.recaptcha_service') as mock_recaptcha_service,
patch('server.routes.auth.RECAPTCHA_SITE_KEY', 'test-site-key'),
patch('server.routes.auth.domain_blocker') as mock_domain_blocker,
patch('server.routes.auth.UserStore') as mock_user_store,
):
mock_token_manager.get_keycloak_tokens = AsyncMock(
return_value=('test_access_token', 'test_refresh_token')
@@ -1245,13 +1127,6 @@ class TestKeycloakCallbackRecaptcha:
return_value=False
)
# Setup UserStore mocks
mock_user = MagicMock()
mock_user.id = 'test_user_id'
mock_user.current_org_id = 'test_org_id'
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
mock_user_store.create_user = AsyncMock(return_value=mock_user)
mock_domain_blocker.is_domain_blocked.return_value = False
# Patch the module-level recaptcha_service instance
@@ -1297,7 +1172,6 @@ class TestKeycloakCallbackRecaptcha:
patch('server.routes.auth.set_response_cookie'),
patch('server.routes.auth.posthog'),
patch('server.routes.email.verify_email', new_callable=AsyncMock),
patch('server.routes.auth.UserStore') as mock_user_store,
):
mock_session = MagicMock()
mock_session_maker.return_value.__enter__.return_value = mock_session
@@ -1326,14 +1200,6 @@ class TestKeycloakCallbackRecaptcha:
return_value=False
)
# Setup UserStore mocks
mock_user = MagicMock()
mock_user.id = 'test_user_id'
mock_user.current_org_id = 'test_org_id'
mock_user.accepted_tos = '2025-01-01'
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
mock_user_store.create_user = AsyncMock(return_value=mock_user)
mock_verifier.is_active.return_value = True
mock_verifier.is_user_allowed.return_value = True
@@ -1384,7 +1250,6 @@ class TestKeycloakCallbackRecaptcha:
patch('server.routes.auth.set_response_cookie'),
patch('server.routes.auth.posthog'),
patch('server.routes.email.verify_email', new_callable=AsyncMock),
patch('server.routes.auth.UserStore') as mock_user_store,
):
mock_session = MagicMock()
mock_session_maker.return_value.__enter__.return_value = mock_session
@@ -1413,14 +1278,6 @@ class TestKeycloakCallbackRecaptcha:
return_value=False
)
# Setup UserStore mocks
mock_user = MagicMock()
mock_user.id = 'test_user_id'
mock_user.current_org_id = 'test_org_id'
mock_user.accepted_tos = '2025-01-01'
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
mock_user_store.create_user = AsyncMock(return_value=mock_user)
mock_verifier.is_active.return_value = True
mock_verifier.is_user_allowed.return_value = True
@@ -1468,7 +1325,6 @@ class TestKeycloakCallbackRecaptcha:
patch('server.routes.auth.set_response_cookie'),
patch('server.routes.auth.posthog'),
patch('server.routes.email.verify_email', new_callable=AsyncMock),
patch('server.routes.auth.UserStore') as mock_user_store,
):
mock_session = MagicMock()
mock_session_maker.return_value.__enter__.return_value = mock_session
@@ -1497,14 +1353,6 @@ class TestKeycloakCallbackRecaptcha:
return_value=False
)
# Setup UserStore mocks
mock_user = MagicMock()
mock_user.id = 'test_user_id'
mock_user.current_org_id = 'test_org_id'
mock_user.accepted_tos = '2025-01-01'
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
mock_user_store.create_user = AsyncMock(return_value=mock_user)
mock_verifier.is_active.return_value = True
mock_verifier.is_user_allowed.return_value = True
@@ -1551,7 +1399,6 @@ class TestKeycloakCallbackRecaptcha:
patch('server.routes.auth.set_response_cookie'),
patch('server.routes.auth.posthog'),
patch('server.routes.email.verify_email', new_callable=AsyncMock),
patch('server.routes.auth.UserStore') as mock_user_store,
):
mock_session = MagicMock()
mock_session_maker.return_value.__enter__.return_value = mock_session
@@ -1580,14 +1427,6 @@ class TestKeycloakCallbackRecaptcha:
return_value=False
)
# Setup UserStore mocks
mock_user = MagicMock()
mock_user.id = 'test_user_id'
mock_user.current_org_id = 'test_org_id'
mock_user.accepted_tos = '2025-01-01'
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
mock_user_store.create_user = AsyncMock(return_value=mock_user)
mock_verifier.is_active.return_value = True
mock_verifier.is_user_allowed.return_value = True
@@ -1631,7 +1470,6 @@ class TestKeycloakCallbackRecaptcha:
patch('server.routes.auth.set_response_cookie'),
patch('server.routes.auth.posthog'),
patch('server.routes.email.verify_email', new_callable=AsyncMock),
patch('server.routes.auth.UserStore') as mock_user_store,
):
mock_session = MagicMock()
mock_session_maker.return_value.__enter__.return_value = mock_session
@@ -1660,14 +1498,6 @@ class TestKeycloakCallbackRecaptcha:
return_value=False
)
# Setup UserStore mocks
mock_user = MagicMock()
mock_user.id = 'test_user_id'
mock_user.current_org_id = 'test_org_id'
mock_user.accepted_tos = '2025-01-01'
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
mock_user_store.create_user = AsyncMock(return_value=mock_user)
mock_verifier.is_active.return_value = True
mock_verifier.is_user_allowed.return_value = True
@@ -1697,7 +1527,6 @@ class TestKeycloakCallbackRecaptcha:
patch('server.routes.auth.set_response_cookie'),
patch('server.routes.auth.posthog'),
patch('server.routes.email.verify_email', new_callable=AsyncMock),
patch('server.routes.auth.UserStore') as mock_user_store,
):
mock_session = MagicMock()
mock_session_maker.return_value.__enter__.return_value = mock_session
@@ -1726,14 +1555,6 @@ class TestKeycloakCallbackRecaptcha:
return_value=False
)
# Setup UserStore mocks
mock_user = MagicMock()
mock_user.id = 'test_user_id'
mock_user.current_org_id = 'test_org_id'
mock_user.accepted_tos = '2025-01-01'
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
mock_user_store.create_user = AsyncMock(return_value=mock_user)
mock_verifier.is_active.return_value = True
mock_verifier.is_user_allowed.return_value = True
@@ -1769,7 +1590,6 @@ class TestKeycloakCallbackRecaptcha:
patch('server.routes.auth.set_response_cookie'),
patch('server.routes.auth.posthog'),
patch('server.routes.auth.logger') as mock_logger,
patch('server.routes.auth.UserStore') as mock_user_store,
):
mock_session = MagicMock()
mock_session_maker.return_value.__enter__.return_value = mock_session
@@ -1798,14 +1618,6 @@ class TestKeycloakCallbackRecaptcha:
return_value=False
)
# Setup UserStore mocks
mock_user = MagicMock()
mock_user.id = 'test_user_id'
mock_user.current_org_id = 'test_org_id'
mock_user.accepted_tos = '2025-01-01'
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
mock_user_store.create_user = AsyncMock(return_value=mock_user)
mock_verifier.is_active.return_value = True
mock_verifier.is_user_allowed.return_value = True
@@ -1853,7 +1665,6 @@ class TestKeycloakCallbackRecaptcha:
patch('server.routes.auth.domain_blocker') as mock_domain_blocker,
patch('server.routes.auth.logger') as mock_logger,
patch('server.routes.email.verify_email', new_callable=AsyncMock),
patch('server.routes.auth.UserStore') as mock_user_store,
):
mock_token_manager.get_keycloak_tokens = AsyncMock(
return_value=('test_access_token', 'test_refresh_token')
@@ -1869,13 +1680,6 @@ class TestKeycloakCallbackRecaptcha:
return_value=False
)
# Setup UserStore mocks
mock_user = MagicMock()
mock_user.id = 'test_user_id'
mock_user.current_org_id = 'test_org_id'
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
mock_user_store.create_user = AsyncMock(return_value=mock_user)
mock_domain_blocker.is_domain_blocked.return_value = False
# Patch the module-level recaptcha_service instance

View File

@@ -1,4 +1,3 @@
import uuid
from decimal import Decimal
from unittest.mock import AsyncMock, MagicMock, patch
@@ -6,21 +5,22 @@ import pytest
import stripe
from fastapi import HTTPException, Request, status
from httpx import Response
from server.routes import billing
from integrations.stripe_service import has_payment_method
from server.routes.billing import (
CreateBillingSessionResponse,
CreateCheckoutSessionRequest,
GetCreditsResponse,
cancel_callback,
cancel_subscription,
create_checkout_session,
create_customer_setup_session,
create_subscription_checkout_session,
get_credits,
has_payment_method,
success_callback,
)
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from starlette.datastructures import URL
from storage.billing_session_type import BillingSessionType
from storage.stripe_customer import Base as StripeCustomerBase
@@ -78,32 +78,28 @@ def mock_subscription_request():
@pytest.mark.asyncio
async def test_get_credits_lite_llm_error():
with (
patch('integrations.stripe_service.STRIPE_API_KEY', 'mock_key'),
patch(
'storage.user_store.UserStore.get_user_by_id_async',
new_callable=AsyncMock,
return_value=MagicMock(current_org_id='mock_org_id'),
),
patch(
'storage.lite_llm_manager.LiteLlmManager.get_user_team_info',
side_effect=Exception('LiteLLM API Error'),
),
):
with pytest.raises(Exception, match='LiteLLM API Error'):
await get_credits('mock_user')
mock_response = Response(
status_code=500, json={'error': 'Internal Server Error'}, request=MagicMock()
)
mock_client = AsyncMock()
mock_client.__aenter__.return_value.get.return_value = mock_response
with patch('integrations.stripe_service.STRIPE_API_KEY', 'mock_key'):
with patch('httpx.AsyncClient', return_value=mock_client):
with pytest.raises(HTTPException) as exc_info:
await get_credits('mock_user')
assert exc_info.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
assert (
exc_info.value.detail
== 'Failed to retrieve credit balance from billing service'
)
@pytest.mark.asyncio
async def test_get_credits_success():
mock_response = Response(
status_code=200,
json={
'user_info': {
'spend': 25.50,
'litellm_budget_table': {'max_budget': 100.00},
}
},
json={'user_info': {'max_budget': 100.00, 'spend': 25.50}},
request=MagicMock(),
)
mock_client = AsyncMock()
@@ -112,23 +108,24 @@ async def test_get_credits_success():
with (
patch('integrations.stripe_service.STRIPE_API_KEY', 'mock_key'),
patch('httpx.AsyncClient', return_value=mock_client),
patch(
'storage.user_store.UserStore.get_user_by_id_async',
new_callable=AsyncMock,
return_value=MagicMock(current_org_id='mock_org_id'),
),
patch(
'storage.lite_llm_manager.LiteLlmManager.get_user_team_info',
return_value={
'spend': 25.50,
'litellm_budget_table': {'max_budget': 100.00},
},
),
):
result = await get_credits('mock_user')
with patch('server.routes.billing.session_maker') as mock_session_maker:
mock_db_session = MagicMock()
mock_db_session.query.return_value.filter.return_value.first.return_value = MagicMock(
billing_margin=4
)
mock_session_maker.return_value.__enter__.return_value = mock_db_session
assert isinstance(result, GetCreditsResponse)
assert result.credits == Decimal('74.50') # 100.00 - 25.50 = 74.50
result = await get_credits('mock_user')
assert isinstance(result, GetCreditsResponse)
assert result.credits == Decimal(
'74.50'
) # 100.00 - 25.50 = 74.50 (no billing margin applied)
mock_client.__aenter__.return_value.get.assert_called_once_with(
'https://llm-proxy.app.all-hands.dev/user/info?user_id=mock_user',
headers={'x-goog-api-key': None},
)
@pytest.mark.asyncio
@@ -141,9 +138,6 @@ async def test_create_checkout_session_stripe_error(
id='mock-customer', metadata={'user_id': 'mock-user'}
)
mock_customer_create = AsyncMock(return_value=mock_customer)
mock_org = MagicMock()
mock_org.id = uuid.uuid4()
mock_org.contact_email = 'testy@tester.com'
with (
pytest.raises(Exception, match='Stripe API Error'),
patch('stripe.Customer.create_async', mock_customer_create),
@@ -155,15 +149,11 @@ async def test_create_checkout_session_stripe_error(
AsyncMock(side_effect=Exception('Stripe API Error')),
),
patch('integrations.stripe_service.session_maker', session_maker),
patch(
'storage.org_store.OrgStore.get_current_org_from_keycloak_user_id',
return_value=mock_org,
),
patch(
'server.auth.token_manager.TokenManager.get_user_info_from_user_id',
AsyncMock(return_value={'email': 'testy@tester.com'}),
),
patch('server.routes.billing.validate_billing_enabled'),
patch('server.routes.billing.validate_saas_environment'),
):
await create_checkout_session(
CreateCheckoutSessionRequest(amount=25), mock_checkout_request, 'mock_user'
@@ -184,10 +174,6 @@ async def test_create_checkout_session_success(session_maker, mock_checkout_requ
id='mock-customer', metadata={'user_id': 'mock-user'}
)
mock_customer_create = AsyncMock(return_value=mock_customer)
mock_org = MagicMock()
mock_org_id = uuid.uuid4()
mock_org.id = mock_org_id
mock_org.contact_email = 'testy@tester.com'
with (
patch('stripe.Customer.create_async', mock_customer_create),
patch(
@@ -196,15 +182,11 @@ async def test_create_checkout_session_success(session_maker, mock_checkout_requ
patch('stripe.checkout.Session.create_async', mock_create),
patch('server.routes.billing.session_maker') as mock_session_maker,
patch('integrations.stripe_service.session_maker', session_maker),
patch(
'storage.org_store.OrgStore.get_current_org_from_keycloak_user_id',
return_value=mock_org,
),
patch(
'server.auth.token_manager.TokenManager.get_user_info_from_user_id',
AsyncMock(return_value={'email': 'testy@tester.com'}),
),
patch('server.routes.billing.validate_billing_enabled'),
patch('server.routes.billing.validate_saas_environment'),
):
mock_db_session = MagicMock()
mock_session_maker.return_value.__enter__.return_value = mock_db_session
@@ -236,8 +218,8 @@ async def test_create_checkout_session_success(session_maker, mock_checkout_requ
mode='payment',
payment_method_types=['card'],
saved_payment_method_options={'payment_method_save': 'enabled'},
success_url='https://test.com/api/billing/success?session_id={CHECKOUT_SESSION_ID}',
cancel_url='https://test.com/api/billing/cancel?session_id={CHECKOUT_SESSION_ID}',
success_url='http://test.com/api/billing/success?session_id={CHECKOUT_SESSION_ID}',
cancel_url='http://test.com/api/billing/cancel?session_id={CHECKOUT_SESSION_ID}',
)
# Verify database session creation
@@ -271,6 +253,7 @@ async def test_success_callback_stripe_incomplete():
mock_billing_session = MagicMock()
mock_billing_session.status = 'in_progress'
mock_billing_session.user_id = 'mock_user'
mock_billing_session.billing_session_type = BillingSessionType.DIRECT_PAYMENT.value
with (
patch('server.routes.billing.session_maker') as mock_session_maker,
@@ -298,51 +281,65 @@ async def test_success_callback_success():
mock_billing_session = MagicMock()
mock_billing_session.status = 'in_progress'
mock_billing_session.user_id = 'mock_user'
mock_billing_session.billing_session_type = BillingSessionType.DIRECT_PAYMENT.value
mock_lite_llm_response = Response(
status_code=200,
json={'user_info': {'max_budget': 100.00, 'spend': 25.50}},
request=MagicMock(),
)
mock_lite_llm_update_response = Response(
status_code=200, json={}, request=MagicMock()
)
with (
patch('server.routes.billing.session_maker') as mock_session_maker,
patch('stripe.checkout.Session.retrieve') as mock_stripe_retrieve,
patch(
'storage.user_store.UserStore.get_user_by_id_async',
new_callable=AsyncMock,
return_value=MagicMock(current_org_id='mock_org_id'),
),
patch(
'storage.lite_llm_manager.LiteLlmManager.get_user_team_info',
return_value={
'spend': 25.50,
'litellm_budget_table': {'max_budget': 100.00},
},
),
patch(
'storage.lite_llm_manager.LiteLlmManager.update_team_and_users_budget'
) as mock_update_budget,
patch('httpx.AsyncClient') as mock_client,
):
mock_db_session = MagicMock()
mock_db_session.query.return_value.filter.return_value.filter.return_value.first.return_value = mock_billing_session
mock_user_settings = MagicMock(billing_margin=None)
mock_db_session.query.return_value.filter.return_value.first.return_value = (
mock_user_settings
)
mock_session_maker.return_value.__enter__.return_value = mock_db_session
mock_stripe_retrieve.return_value = MagicMock(
status='complete', amount_subtotal=2500, customer='mock_customer_id'
status='complete',
amount_subtotal=2500,
) # $25.00 in cents
mock_client_instance = AsyncMock()
mock_client_instance.__aenter__.return_value.get.return_value = (
mock_lite_llm_response
)
mock_client_instance.__aenter__.return_value.post.return_value = (
mock_lite_llm_update_response
)
mock_client.return_value = mock_client_instance
response = await success_callback('test_session_id', mock_request)
assert response.status_code == 302
assert (
response.headers['location']
== 'https://test.com/settings/billing?checkout=success'
== 'http://test.com/settings/billing?checkout=success'
)
# Verify LiteLLM API calls
mock_update_budget.assert_called_once_with(
'mock_org_id',
125.0, # 100 + (25.00 from Stripe)
mock_client_instance.__aenter__.return_value.get.assert_called_once()
mock_client_instance.__aenter__.return_value.post.assert_called_once_with(
'https://llm-proxy.app.all-hands.dev/user/update',
headers={'x-goog-api-key': None},
json={
'user_id': 'mock_user',
'max_budget': 125,
}, # 100 + (25.00 from Stripe)
)
# Verify database updates
assert mock_billing_session.status == 'completed'
assert mock_billing_session.price == 25.0
mock_db_session.merge.assert_called_once()
mock_db_session.commit.assert_called_once()
@@ -356,28 +353,27 @@ async def test_success_callback_lite_llm_error():
mock_billing_session = MagicMock()
mock_billing_session.status = 'in_progress'
mock_billing_session.user_id = 'mock_user'
mock_billing_session.billing_session_type = BillingSessionType.DIRECT_PAYMENT.value
with (
patch('server.routes.billing.session_maker') as mock_session_maker,
patch('stripe.checkout.Session.retrieve') as mock_stripe_retrieve,
patch(
'storage.user_store.UserStore.get_user_by_id_async',
new_callable=AsyncMock,
return_value=MagicMock(current_org_id='mock_org_id'),
),
patch(
'storage.lite_llm_manager.LiteLlmManager.get_user_team_info',
side_effect=Exception('LiteLLM API Error'),
),
patch('httpx.AsyncClient') as mock_client,
):
mock_db_session = MagicMock()
mock_db_session.query.return_value.filter.return_value.filter.return_value.first.return_value = mock_billing_session
mock_session_maker.return_value.__enter__.return_value = mock_db_session
mock_stripe_retrieve.return_value = MagicMock(
status='complete', amount_subtotal=2500
status='complete', amount_total=2500
)
mock_client_instance = AsyncMock()
mock_client_instance.__aenter__.return_value.get.side_effect = Exception(
'LiteLLM API Error'
)
mock_client.return_value = mock_client_instance
with pytest.raises(Exception, match='LiteLLM API Error'):
await success_callback('test_session_id', mock_request)
@@ -401,8 +397,7 @@ async def test_cancel_callback_session_not_found():
response = await cancel_callback('test_session_id', mock_request)
assert response.status_code == 302
assert (
response.headers['location']
== 'https://test.com/settings/billing?checkout=cancel'
response.headers['location'] == 'http://test.com/settings?checkout=cancel'
)
# Verify no database updates occurred
@@ -428,8 +423,7 @@ async def test_cancel_callback_success():
assert response.status_code == 302
assert (
response.headers['location']
== 'https://test.com/settings/billing?checkout=cancel'
response.headers['location'] == 'http://test.com/settings?checkout=cancel'
)
# Verify database updates
@@ -441,67 +435,314 @@ async def test_cancel_callback_success():
@pytest.mark.asyncio
async def test_has_payment_method_with_payment_method():
"""Test has_payment_method returns True when user has a payment method."""
mock_has_payment_method = AsyncMock(return_value=True)
with patch(
'server.routes.billing.stripe_service.has_payment_method_by_user_id',
mock_has_payment_method,
with (
patch('integrations.stripe_service.session_maker') as mock_session_maker,
patch(
'stripe.Customer.list_payment_methods_async',
AsyncMock(return_value=MagicMock(data=[MagicMock()])),
) as mock_list_payment_methods,
):
# Setup mock session
mock_session = MagicMock()
mock_session_maker.return_value.__enter__.return_value = mock_session
mock_session.query.return_value.filter.return_value.first.return_value = (
MagicMock(stripe_customer_id='cus_test123')
)
result = await has_payment_method('mock_user')
assert result is True
mock_has_payment_method.assert_called_once_with('mock_user')
mock_list_payment_methods.assert_called_once_with('cus_test123')
@pytest.mark.asyncio
async def test_has_payment_method_without_payment_method():
"""Test has_payment_method returns False when user has no payment method."""
mock_has_payment_method = AsyncMock(return_value=False)
with patch(
'server.routes.billing.stripe_service.has_payment_method_by_user_id',
mock_has_payment_method,
with (
patch('integrations.stripe_service.session_maker') as mock_session_maker,
patch(
'stripe.Customer.list_payment_methods_async',
AsyncMock(return_value=MagicMock(data=[])),
) as mock_list_payment_methods,
):
mock_has_payment_method.return_value = False
# Setup mock session
mock_session = MagicMock()
mock_session_maker.return_value.__enter__.return_value = mock_session
mock_session.query.return_value.filter.return_value.first.return_value = (
MagicMock(stripe_customer_id='cus_test123')
)
result = await has_payment_method('mock_user')
assert result is False
mock_has_payment_method.assert_called_once_with('mock_user')
mock_list_payment_methods.assert_called_once_with('cus_test123')
@pytest.mark.asyncio
async def test_create_customer_setup_session_success():
"""Test successful creation of customer setup session."""
mock_request = Request(
scope={
'type': 'http',
'path': '/api/billing/create-customer-setup-session',
'server': ('test.com', 80),
'headers': [],
}
)
mock_request._base_url = URL('http://test.com/')
async def test_cancel_subscription_success():
"""Test successful subscription cancellation."""
from datetime import UTC, datetime
mock_customer_info = {'customer_id': 'mock-customer-id', 'org_id': 'mock-org-id'}
mock_session = MagicMock()
mock_session.url = 'https://checkout.stripe.com/test-session'
mock_create = AsyncMock(return_value=mock_session)
from storage.subscription_access import SubscriptionAccess
# Mock active subscription
mock_subscription_access = SubscriptionAccess(
id=1,
status='ACTIVE',
user_id='test_user',
start_at=datetime.now(UTC),
end_at=datetime.now(UTC),
amount_paid=2000,
stripe_invoice_payment_id='pi_test',
stripe_subscription_id='sub_test123',
cancelled_at=None,
)
# Mock Stripe subscription response
mock_stripe_subscription = MagicMock()
mock_stripe_subscription.cancel_at_period_end = True
with (
patch('server.routes.billing.session_maker') as mock_session_maker,
patch(
'integrations.stripe_service.find_or_create_customer_by_user_id',
AsyncMock(return_value=mock_customer_info),
),
patch('stripe.checkout.Session.create_async', mock_create),
patch('server.routes.billing.validate_billing_enabled'),
'stripe.Subscription.modify_async',
AsyncMock(return_value=mock_stripe_subscription),
) as mock_stripe_modify,
):
result = await create_customer_setup_session(mock_request, 'mock_user')
# Setup mock session
mock_session = MagicMock()
mock_session_maker.return_value.__enter__.return_value = mock_session
mock_session.query.return_value.filter.return_value.filter.return_value.filter.return_value.filter.return_value.filter.return_value.first.return_value = mock_subscription_access
assert isinstance(result, billing.CreateBillingSessionResponse)
# Call the function
result = await cancel_subscription('test_user')
# Verify Stripe API was called
mock_stripe_modify.assert_called_once_with(
'sub_test123', cancel_at_period_end=True
)
# Verify database was updated
assert mock_subscription_access.cancelled_at is not None
mock_session.merge.assert_called_once_with(mock_subscription_access)
mock_session.commit.assert_called_once()
# Verify response
assert result.status_code == 200
@pytest.mark.asyncio
async def test_cancel_subscription_no_active_subscription():
"""Test cancellation when no active subscription exists."""
with (
patch('server.routes.billing.session_maker') as mock_session_maker,
):
# Setup mock session with no subscription found
mock_session = MagicMock()
mock_session_maker.return_value.__enter__.return_value = mock_session
mock_session.query.return_value.filter.return_value.filter.return_value.filter.return_value.filter.return_value.filter.return_value.first.return_value = None
# Call the function and expect HTTPException
with pytest.raises(HTTPException) as exc_info:
await cancel_subscription('test_user')
assert exc_info.value.status_code == 404
assert 'No active subscription found' in str(exc_info.value.detail)
@pytest.mark.asyncio
async def test_cancel_subscription_missing_stripe_id():
"""Test cancellation when subscription has no Stripe ID."""
from datetime import UTC, datetime
from storage.subscription_access import SubscriptionAccess
# Mock subscription without Stripe ID
mock_subscription_access = SubscriptionAccess(
id=1,
status='ACTIVE',
user_id='test_user',
start_at=datetime.now(UTC),
end_at=datetime.now(UTC),
amount_paid=2000,
stripe_invoice_payment_id='pi_test',
stripe_subscription_id=None, # Missing Stripe ID
cancelled_at=None,
)
with (
patch('server.routes.billing.session_maker') as mock_session_maker,
):
# Setup mock session
mock_session = MagicMock()
mock_session_maker.return_value.__enter__.return_value = mock_session
mock_session.query.return_value.filter.return_value.filter.return_value.filter.return_value.filter.return_value.filter.return_value.first.return_value = mock_subscription_access
# Call the function and expect HTTPException
with pytest.raises(HTTPException) as exc_info:
await cancel_subscription('test_user')
assert exc_info.value.status_code == 400
assert 'missing Stripe subscription ID' in str(exc_info.value.detail)
@pytest.mark.asyncio
async def test_cancel_subscription_stripe_error():
"""Test cancellation when Stripe API fails."""
from datetime import UTC, datetime
from storage.subscription_access import SubscriptionAccess
# Mock active subscription
mock_subscription_access = SubscriptionAccess(
id=1,
status='ACTIVE',
user_id='test_user',
start_at=datetime.now(UTC),
end_at=datetime.now(UTC),
amount_paid=2000,
stripe_invoice_payment_id='pi_test',
stripe_subscription_id='sub_test123',
cancelled_at=None,
)
with (
patch('server.routes.billing.session_maker') as mock_session_maker,
patch(
'stripe.Subscription.modify_async',
AsyncMock(side_effect=stripe.StripeError('API Error')),
),
):
# Setup mock session
mock_session = MagicMock()
mock_session_maker.return_value.__enter__.return_value = mock_session
mock_session.query.return_value.filter.return_value.filter.return_value.filter.return_value.filter.return_value.filter.return_value.first.return_value = mock_subscription_access
# Call the function and expect HTTPException
with pytest.raises(HTTPException) as exc_info:
await cancel_subscription('test_user')
assert exc_info.value.status_code == 500
assert 'Failed to cancel subscription' in str(exc_info.value.detail)
@pytest.mark.asyncio
async def test_create_subscription_checkout_session_duplicate_prevention(
mock_subscription_request,
):
"""Test that creating a subscription when user already has active subscription raises error."""
from datetime import UTC, datetime
from storage.subscription_access import SubscriptionAccess
# Mock active subscription
mock_subscription_access = SubscriptionAccess(
id=1,
status='ACTIVE',
user_id='test_user',
start_at=datetime.now(UTC),
end_at=datetime.now(UTC),
amount_paid=2000,
stripe_invoice_payment_id='pi_test',
stripe_subscription_id='sub_test123',
cancelled_at=None,
)
with (
patch('server.routes.billing.session_maker') as mock_session_maker,
patch('server.routes.billing.validate_saas_environment'),
):
# Setup mock session to return existing active subscription
mock_session = MagicMock()
mock_session_maker.return_value.__enter__.return_value = mock_session
mock_session.query.return_value.filter.return_value.filter.return_value.filter.return_value.filter.return_value.filter.return_value.first.return_value = mock_subscription_access
# Call the function and expect HTTPException
with pytest.raises(HTTPException) as exc_info:
await create_subscription_checkout_session(
mock_subscription_request, user_id='test_user'
)
assert exc_info.value.status_code == 400
assert (
'user already has an active subscription'
in str(exc_info.value.detail).lower()
)
@pytest.mark.asyncio
async def test_create_subscription_checkout_session_allows_after_cancellation(
mock_subscription_request,
):
"""Test that creating a subscription is allowed when previous subscription was cancelled."""
mock_session_obj = MagicMock()
mock_session_obj.url = 'https://checkout.stripe.com/test-session'
mock_session_obj.id = 'test_session_id'
with (
patch('server.routes.billing.session_maker') as mock_session_maker,
patch(
'integrations.stripe_service.find_or_create_customer',
AsyncMock(return_value='cus_test123'),
),
patch(
'stripe.checkout.Session.create_async',
AsyncMock(return_value=mock_session_obj),
),
patch(
'server.routes.billing.SUBSCRIPTION_PRICE_DATA',
{'MONTHLY_SUBSCRIPTION': {'unit_amount': 2000}},
),
patch('server.routes.billing.validate_saas_environment'),
):
# Setup mock session - the query should return None because cancelled subscriptions are filtered out
mock_session = MagicMock()
mock_session_maker.return_value.__enter__.return_value = mock_session
mock_session.query.return_value.filter.return_value.filter.return_value.filter.return_value.filter.return_value.filter.return_value.first.return_value = None
# Should succeed
result = await create_subscription_checkout_session(
mock_subscription_request, user_id='test_user'
)
assert isinstance(result, CreateBillingSessionResponse)
assert result.redirect_url == 'https://checkout.stripe.com/test-session'
# Verify Stripe session creation parameters
mock_create.assert_called_once_with(
customer='mock-customer-id',
mode='setup',
payment_method_types=['card'],
success_url='https://test.com/?free_credits=success',
cancel_url='https://test.com/',
@pytest.mark.asyncio
async def test_create_subscription_checkout_session_success_no_existing(
mock_subscription_request,
):
"""Test successful subscription creation when no existing subscription."""
mock_session_obj = MagicMock()
mock_session_obj.url = 'https://checkout.stripe.com/test-session'
mock_session_obj.id = 'test_session_id'
with (
patch('server.routes.billing.session_maker') as mock_session_maker,
patch(
'integrations.stripe_service.find_or_create_customer',
AsyncMock(return_value='cus_test123'),
),
patch(
'stripe.checkout.Session.create_async',
AsyncMock(return_value=mock_session_obj),
),
patch(
'server.routes.billing.SUBSCRIPTION_PRICE_DATA',
{'MONTHLY_SUBSCRIPTION': {'unit_amount': 2000}},
),
patch('server.routes.billing.validate_saas_environment'),
):
# Setup mock session to return no existing subscription
mock_session = MagicMock()
mock_session_maker.return_value.__enter__.return_value = mock_session
mock_session.query.return_value.filter.return_value.filter.return_value.filter.return_value.filter.return_value.filter.return_value.first.return_value = None
# Should succeed
result = await create_subscription_checkout_session(
mock_subscription_request, user_id='test_user'
)
assert isinstance(result, CreateBillingSessionResponse)
assert result.redirect_url == 'https://checkout.stripe.com/test-session'

View File

@@ -3,7 +3,6 @@ Tests for ConversationCallbackProcessor and ConversationCallback models.
"""
import json
from uuid import UUID
import pytest
from storage.conversation_callback import (
@@ -12,9 +11,6 @@ from storage.conversation_callback import (
ConversationCallbackProcessor,
)
from storage.stored_conversation_metadata import StoredConversationMetadata
from storage.stored_conversation_metadata_saas import (
StoredConversationMetadataSaas,
)
from openhands.events.observation.agent import AgentStateChangedObservation
@@ -84,22 +80,15 @@ class TestConversationCallback:
"""Create a test conversation metadata record."""
with session_maker() as session:
metadata = StoredConversationMetadata(
conversation_id='test_conversation_123'
)
metadata_saas = StoredConversationMetadataSaas(
conversation_id='test_conversation_123',
user_id=UUID('5594c7b6-f959-4b81-92e9-b09c206f5081'),
org_id=UUID('5594c7b6-f959-4b81-92e9-b09c206f5081'),
conversation_id='test_conversation_123', user_id='test_user_456'
)
session.add(metadata)
session.add(metadata_saas)
session.commit()
session.refresh(metadata)
yield metadata
# Cleanup
session.delete(metadata)
session.delete(metadata_saas)
session.commit()
def test_callback_creation(self, conversation_metadata, session_maker):

View File

@@ -1,272 +0,0 @@
"""
Unit tests for email validation dependency (get_admin_user_id).
Tests the FastAPI dependency that validates @openhands.dev email domain.
"""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from fastapi import HTTPException, Request
from server.email_validation import get_admin_user_id
@pytest.fixture
def mock_request():
"""Create a mock FastAPI request."""
return MagicMock(spec=Request)
@pytest.fixture
def mock_user_auth():
"""Create a mock user auth object."""
mock_auth = AsyncMock()
mock_auth.get_user_email = AsyncMock()
return mock_auth
@pytest.mark.asyncio
async def test_get_openhands_user_id_success(mock_request, mock_user_auth):
"""
GIVEN: Valid user ID and @openhands.dev email
WHEN: get_admin_user_id is called
THEN: User ID is returned successfully
"""
# Arrange
user_id = 'test-user-123'
mock_user_auth.get_user_email.return_value = 'test@openhands.dev'
with patch('server.email_validation.get_user_auth', return_value=mock_user_auth):
# Act
result = await get_admin_user_id(mock_request, user_id)
# Assert
assert result == user_id
mock_user_auth.get_user_email.assert_called_once()
@pytest.mark.asyncio
async def test_get_openhands_user_id_no_user_id(mock_request):
"""
GIVEN: No user ID provided (None)
WHEN: get_admin_user_id is called
THEN: 401 Unauthorized is raised
"""
# Arrange
user_id = None
# Act & Assert
with pytest.raises(HTTPException) as exc_info:
await get_admin_user_id(mock_request, user_id)
assert exc_info.value.status_code == 401
assert 'not authenticated' in exc_info.value.detail.lower()
@pytest.mark.asyncio
async def test_get_openhands_user_id_no_email(mock_request, mock_user_auth):
"""
GIVEN: User ID provided but email is None
WHEN: get_admin_user_id is called
THEN: 401 Unauthorized is raised
"""
# Arrange
user_id = 'test-user-123'
mock_user_auth.get_user_email.return_value = None
with patch('server.email_validation.get_user_auth', return_value=mock_user_auth):
# Act & Assert
with pytest.raises(HTTPException) as exc_info:
await get_admin_user_id(mock_request, user_id)
assert exc_info.value.status_code == 401
assert 'email not available' in exc_info.value.detail.lower()
@pytest.mark.asyncio
async def test_get_openhands_user_id_invalid_domain(mock_request, mock_user_auth):
"""
GIVEN: User ID and email with non-@openhands.dev domain
WHEN: get_admin_user_id is called
THEN: 403 Forbidden is raised
"""
# Arrange
user_id = 'test-user-123'
mock_user_auth.get_user_email.return_value = 'test@external.com'
with patch('server.email_validation.get_user_auth', return_value=mock_user_auth):
# Act & Assert
with pytest.raises(HTTPException) as exc_info:
await get_admin_user_id(mock_request, user_id)
assert exc_info.value.status_code == 403
assert 'openhands.dev' in exc_info.value.detail.lower()
@pytest.mark.asyncio
async def test_get_openhands_user_id_empty_string_user_id(mock_request):
"""
GIVEN: Empty string user ID
WHEN: get_admin_user_id is called
THEN: 401 Unauthorized is raised
"""
# Arrange
user_id = ''
# Act & Assert
with pytest.raises(HTTPException) as exc_info:
await get_admin_user_id(mock_request, user_id)
assert exc_info.value.status_code == 401
assert 'not authenticated' in exc_info.value.detail.lower()
@pytest.mark.asyncio
async def test_get_openhands_user_id_case_sensitivity(mock_request, mock_user_auth):
"""
GIVEN: Email with uppercase @OPENHANDS.DEV domain
WHEN: get_admin_user_id is called
THEN: 403 Forbidden is raised (case-sensitive check)
"""
# Arrange
user_id = 'test-user-123'
mock_user_auth.get_user_email.return_value = 'test@OPENHANDS.DEV'
with patch('server.email_validation.get_user_auth', return_value=mock_user_auth):
# Act & Assert
with pytest.raises(HTTPException) as exc_info:
await get_admin_user_id(mock_request, user_id)
assert exc_info.value.status_code == 403
@pytest.mark.asyncio
async def test_get_openhands_user_id_subdomain_not_allowed(
mock_request, mock_user_auth
):
"""
GIVEN: Email with subdomain like @test.openhands.dev
WHEN: get_admin_user_id is called
THEN: 403 Forbidden is raised
"""
# Arrange
user_id = 'test-user-123'
mock_user_auth.get_user_email.return_value = 'test@test.openhands.dev'
with patch('server.email_validation.get_user_auth', return_value=mock_user_auth):
# Act & Assert
with pytest.raises(HTTPException) as exc_info:
await get_admin_user_id(mock_request, user_id)
assert exc_info.value.status_code == 403
@pytest.mark.asyncio
async def test_get_openhands_user_id_similar_domain_not_allowed(
mock_request, mock_user_auth
):
"""
GIVEN: Email with similar but different domain like @openhands.dev.fake.com
WHEN: get_admin_user_id is called
THEN: 403 Forbidden is raised
"""
# Arrange
user_id = 'test-user-123'
mock_user_auth.get_user_email.return_value = 'test@openhands.dev.fake.com'
with patch('server.email_validation.get_user_auth', return_value=mock_user_auth):
# Act & Assert
with pytest.raises(HTTPException) as exc_info:
await get_admin_user_id(mock_request, user_id)
assert exc_info.value.status_code == 403
@pytest.mark.asyncio
async def test_get_openhands_user_id_logs_warning_on_invalid_domain(
mock_request, mock_user_auth
):
"""
GIVEN: User with invalid email domain
WHEN: get_admin_user_id is called
THEN: Warning is logged with user_id and email_domain
"""
# Arrange
user_id = 'test-user-123'
invalid_email = 'test@external.com'
mock_user_auth.get_user_email.return_value = invalid_email
with (
patch('server.email_validation.get_user_auth', return_value=mock_user_auth),
patch('server.email_validation.logger') as mock_logger,
):
# Act & Assert
with pytest.raises(HTTPException):
await get_admin_user_id(mock_request, user_id)
# Verify warning was logged
mock_logger.warning.assert_called_once()
call_args = mock_logger.warning.call_args
assert 'Access denied' in call_args[0][0]
assert call_args[1]['extra']['user_id'] == user_id
assert call_args[1]['extra']['email_domain'] == 'external.com'
@pytest.mark.asyncio
async def test_get_openhands_user_id_with_plus_addressing(mock_request, mock_user_auth):
"""
GIVEN: Email with plus addressing (test+tag@openhands.dev)
WHEN: get_admin_user_id is called
THEN: User ID is returned successfully
"""
# Arrange
user_id = 'test-user-123'
mock_user_auth.get_user_email.return_value = 'test+tag@openhands.dev'
with patch('server.email_validation.get_user_auth', return_value=mock_user_auth):
# Act
result = await get_admin_user_id(mock_request, user_id)
# Assert
assert result == user_id
@pytest.mark.asyncio
async def test_get_openhands_user_id_with_dots_in_local_part(
mock_request, mock_user_auth
):
"""
GIVEN: Email with dots in local part (first.last@openhands.dev)
WHEN: get_admin_user_id is called
THEN: User ID is returned successfully
"""
# Arrange
user_id = 'test-user-123'
mock_user_auth.get_user_email.return_value = 'first.last@openhands.dev'
with patch('server.email_validation.get_user_auth', return_value=mock_user_auth):
# Act
result = await get_admin_user_id(mock_request, user_id)
# Assert
assert result == user_id
@pytest.mark.asyncio
async def test_get_openhands_user_id_empty_email(mock_request, mock_user_auth):
"""
GIVEN: Empty string email
WHEN: get_admin_user_id is called
THEN: 401 Unauthorized is raised
"""
# Arrange
user_id = 'test-user-123'
mock_user_auth.get_user_email.return_value = ''
with patch('server.email_validation.get_user_auth', return_value=mock_user_auth):
# Act & Assert
with pytest.raises(HTTPException) as exc_info:
await get_admin_user_id(mock_request, user_id)
assert exc_info.value.status_code == 401
assert 'email not available' in exc_info.value.detail.lower()

View File

@@ -1,100 +1,114 @@
"""Unit tests for get_user_v1_enabled_setting and is_v1_enabled_for_github_resolver functions."""
"""Unit tests for get_user_v1_enabled_setting function."""
import os
from unittest.mock import MagicMock, patch
import pytest
from integrations.github.github_view import (
get_user_v1_enabled_setting,
is_v1_enabled_for_github_resolver,
)
from integrations.utils import get_user_v1_enabled_setting
@pytest.fixture
def mock_org():
"""Create a mock org object."""
org = MagicMock()
org.v1_enabled = True # Default to True, can be overridden in tests
return org
def mock_user_settings():
"""Create a mock user settings object."""
settings = MagicMock()
settings.v1_enabled = True # Default to True, can be overridden in tests
return settings
@pytest.fixture
def mock_dependencies(mock_org):
def mock_settings_store():
"""Create a mock settings store."""
store = MagicMock()
return store
@pytest.fixture
def mock_config():
"""Create a mock config object."""
return MagicMock()
@pytest.fixture
def mock_session_maker():
"""Create a mock session maker."""
return MagicMock()
@pytest.fixture
def mock_dependencies(
mock_settings_store, mock_config, mock_session_maker, mock_user_settings
):
"""Fixture that patches all the common dependencies."""
# Patch at the source module since SaasSettingsStore is imported inside the function
with patch(
'storage.saas_settings_store.SaasSettingsStore',
return_value=mock_settings_store,
) as mock_store_class, patch(
'integrations.utils.get_config', return_value=mock_config
) as mock_get_config, patch(
'integrations.utils.session_maker', mock_session_maker
), patch(
'integrations.utils.call_sync_from_async',
return_value=mock_org,
) as mock_call_sync, patch('integrations.utils.OrgStore') as mock_org_store:
return_value=mock_user_settings,
) as mock_call_sync:
yield {
'store_class': mock_store_class,
'get_config': mock_get_config,
'session_maker': mock_session_maker,
'call_sync': mock_call_sync,
'org_store': mock_org_store,
'org': mock_org,
'settings_store': mock_settings_store,
'user_settings': mock_user_settings,
}
class TestIsV1EnabledForGithubResolver:
"""Test cases for is_v1_enabled_for_github_resolver function.
This function returns True only if BOTH the environment variable
ENABLE_V1_GITHUB_RESOLVER is true AND the user's org has v1_enabled=True.
"""
@pytest.mark.asyncio
@pytest.mark.parametrize(
'env_var_enabled,user_setting_enabled,expected_result',
[
(False, True, False), # Env var disabled, user enabled -> False
(True, False, False), # Env var enabled, user disabled -> False
(True, True, True), # Both enabled -> True
(False, False, False), # Both disabled -> False
],
)
async def test_v1_enabled_combinations(
self, mock_dependencies, env_var_enabled, user_setting_enabled, expected_result
):
"""Test all combinations of environment variable and user setting values."""
mock_dependencies['org'].v1_enabled = user_setting_enabled
with patch(
'integrations.github.github_view.ENABLE_V1_GITHUB_RESOLVER', env_var_enabled
):
result = await is_v1_enabled_for_github_resolver('test_user_id')
assert result is expected_result
@pytest.mark.asyncio
@pytest.mark.parametrize(
'env_var_value,env_var_bool,expected_result',
[
('false', False, False), # Environment variable 'false' -> False
('true', True, True), # Environment variable 'true' -> True
],
)
async def test_environment_variable_integration(
self, mock_dependencies, env_var_value, env_var_bool, expected_result
):
"""Test that the function properly reads the ENABLE_V1_GITHUB_RESOLVER environment variable."""
mock_dependencies['org'].v1_enabled = True
with patch.dict(
os.environ, {'ENABLE_V1_GITHUB_RESOLVER': env_var_value}
), patch('integrations.utils.os.getenv', return_value=env_var_value), patch(
'integrations.github.github_view.ENABLE_V1_GITHUB_RESOLVER', env_var_bool
):
result = await is_v1_enabled_for_github_resolver('test_user_id')
assert result is expected_result
class TestGetUserV1EnabledSetting:
"""Test cases for get_user_v1_enabled_setting function.
"""Test cases for get_user_v1_enabled_setting function."""
This function only returns the user's org v1_enabled setting.
It does NOT check the ENABLE_V1_GITHUB_RESOLVER environment variable.
"""
@pytest.mark.asyncio
@pytest.mark.parametrize(
'user_setting_enabled,expected_result',
[
(True, True), # User enabled -> True
(False, False), # User disabled -> False
],
)
async def test_v1_enabled_user_setting(
self, mock_dependencies, user_setting_enabled, expected_result
):
"""Test that the function returns the user's v1_enabled setting."""
mock_dependencies['user_settings'].v1_enabled = user_setting_enabled
result = await get_user_v1_enabled_setting('test_user_id')
assert result is expected_result
@pytest.mark.asyncio
async def test_returns_false_when_no_user_id(self):
"""Test that the function returns False when no user_id is provided."""
result = await get_user_v1_enabled_setting(None)
assert result is False
result = await get_user_v1_enabled_setting('')
assert result is False
@pytest.mark.asyncio
async def test_returns_false_when_settings_is_none(self, mock_dependencies):
"""Test that the function returns False when settings is None."""
mock_dependencies['call_sync'].return_value = None
result = await get_user_v1_enabled_setting('test_user_id')
assert result is False
@pytest.mark.asyncio
async def test_returns_false_when_v1_enabled_is_none(self, mock_dependencies):
"""Test that the function returns False when v1_enabled is None."""
mock_dependencies['user_settings'].v1_enabled = None
result = await get_user_v1_enabled_setting('test_user_id')
assert result is False
@pytest.mark.asyncio
async def test_function_calls_correct_methods(self, mock_dependencies):
"""Test that the function calls the correct methods with correct parameters."""
mock_dependencies['org'].v1_enabled = True
mock_dependencies['user_settings'].v1_enabled = True
result = await get_user_v1_enabled_setting('test_user_123')
@@ -102,38 +116,13 @@ class TestGetUserV1EnabledSetting:
assert result is True
# Verify correct methods were called with correct parameters
mock_dependencies['get_config'].assert_called_once()
mock_dependencies['store_class'].assert_called_once_with(
user_id='test_user_123',
session_maker=mock_dependencies['session_maker'],
config=mock_dependencies['get_config'].return_value,
)
mock_dependencies['call_sync'].assert_called_once_with(
mock_dependencies['org_store'].get_current_org_from_keycloak_user_id,
mock_dependencies['settings_store'].get_user_settings_by_keycloak_id,
'test_user_123',
)
@pytest.mark.asyncio
async def test_returns_user_setting_true(self, mock_dependencies):
"""Test that the function returns True when org.v1_enabled is True."""
mock_dependencies['org'].v1_enabled = True
result = await get_user_v1_enabled_setting('test_user_123')
assert result is True
@pytest.mark.asyncio
async def test_returns_user_setting_false(self, mock_dependencies):
"""Test that the function returns False when org.v1_enabled is False."""
mock_dependencies['org'].v1_enabled = False
result = await get_user_v1_enabled_setting('test_user_123')
assert result is False
@pytest.mark.asyncio
async def test_no_org_returns_false(self, mock_dependencies):
"""Test that the function returns False when no org is found."""
# Mock call_sync_from_async to return None (no org found)
mock_dependencies['call_sync'].return_value = None
result = await get_user_v1_enabled_setting('test_user_123')
assert result is False
@pytest.mark.asyncio
async def test_org_v1_enabled_none_returns_false(self, mock_dependencies):
"""Test that the function returns False when org.v1_enabled is None."""
mock_dependencies['org'].v1_enabled = None
result = await get_user_v1_enabled_setting('test_user_123')
assert result is False

File diff suppressed because it is too large Load Diff

View File

@@ -1,151 +0,0 @@
"""
Test that the models are correctly defined.
"""
from uuid import uuid4
import pytest
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from storage.base import Base
from storage.org import Org
from storage.org_member import OrgMember
from storage.user import User
@pytest.fixture
def engine():
engine = create_engine('sqlite:///:memory:')
Base.metadata.create_all(engine)
return engine
@pytest.fixture
def session_maker(engine):
return sessionmaker(bind=engine)
def test_user_model(session_maker):
"""Test that the User model works correctly."""
with session_maker() as session:
# Create a test org
org = Org(name='test_org')
session.add(org)
session.flush()
# Create a test user
test_user_id = uuid4()
user = User(id=test_user_id, current_org_id=org.id, language='en')
session.add(user)
session.flush()
# Create org_member relationship
org_member = OrgMember(
org_id=org.id,
user_id=user.id,
role_id=1,
llm_api_key='test-api-key',
status='active',
)
session.add(org_member)
session.commit()
# Query the user
queried_user = session.query(User).filter(User.id == test_user_id).first()
assert queried_user is not None
assert queried_user.language == 'en'
# Query the org
queried_org = session.query(Org).filter(Org.id == org.id).first()
assert queried_org is not None
assert queried_org.name == 'test_org'
# Query the org_member relationship
queried_org_member = (
session.query(OrgMember)
.filter(OrgMember.org_id == org.id, OrgMember.user_id == user.id)
.first()
)
assert queried_org_member is not None
assert queried_org_member.llm_api_key.get_secret_value() == 'test-api-key'
def test_user_model_git_user_fields(session_maker):
"""Test that git_user_name and git_user_email columns exist and work correctly."""
with session_maker() as session:
# Arrange
org = Org(name='test_org_git')
session.add(org)
session.flush()
test_user_id = uuid4()
# Act
user = User(
id=test_user_id,
current_org_id=org.id,
git_user_name='Test Git Author',
git_user_email='git@example.com',
)
session.add(user)
session.commit()
# Assert
queried_user = session.query(User).filter(User.id == test_user_id).first()
assert queried_user.git_user_name == 'Test Git Author'
assert queried_user.git_user_email == 'git@example.com'
def test_user_model_git_user_fields_nullable(session_maker):
"""Test that git_user_name and git_user_email can be null."""
with session_maker() as session:
# Arrange
org = Org(name='test_org_nullable')
session.add(org)
session.flush()
test_user_id = uuid4()
# Act - create user without git fields
user = User(
id=test_user_id,
current_org_id=org.id,
)
session.add(user)
session.commit()
# Assert
queried_user = session.query(User).filter(User.id == test_user_id).first()
assert queried_user.git_user_name is None
assert queried_user.git_user_email is None
def test_user_model_git_user_fields_in_table_columns():
"""Test that git_user_name and git_user_email are in User table columns."""
# Arrange & Act
column_names = [c.name for c in User.__table__.columns]
# Assert
assert 'git_user_name' in column_names
assert 'git_user_email' in column_names
def test_user_model_git_user_fields_hasattr(session_maker):
"""Test that hasattr returns True for git_user_* fields on User model.
This verifies the fix for SaasSettingsStore.store() which uses hasattr
to determine if a field should be persisted to a model.
"""
with session_maker() as session:
# Arrange
org = Org(name='test_org_hasattr')
session.add(org)
session.flush()
user = User(id=uuid4(), current_org_id=org.id)
session.add(user)
session.flush()
# Assert - hasattr must return True for store() to work
assert hasattr(user, 'git_user_name')
assert hasattr(user, 'git_user_email')

View File

@@ -1,253 +0,0 @@
import uuid
from unittest.mock import patch
# Mock the database module before importing OrgMemberStore
with patch('storage.database.engine'), patch('storage.database.a_engine'):
from storage.org import Org
from storage.org_member import OrgMember
from storage.org_member_store import OrgMemberStore
from storage.role import Role
from storage.user import User
def test_get_org_members(session_maker):
# Test getting org_members by org ID
with session_maker() as session:
# Create test data
org = Org(name='test-org')
session.add(org)
session.flush()
user1 = User(id=uuid.uuid4(), current_org_id=org.id)
user2 = User(id=uuid.uuid4(), current_org_id=org.id)
role = Role(name='admin', rank=1)
session.add_all([user1, user2, role])
session.flush()
org_member1 = OrgMember(
org_id=org.id,
user_id=user1.id,
role_id=role.id,
llm_api_key='test-key-1',
status='active',
)
org_member2 = OrgMember(
org_id=org.id,
user_id=user2.id,
role_id=role.id,
llm_api_key='test-key-2',
status='active',
)
session.add_all([org_member1, org_member2])
session.commit()
org_id = org.id
# Test retrieval
with patch('storage.org_member_store.session_maker', session_maker):
org_members = OrgMemberStore.get_org_members(org_id)
assert len(org_members) == 2
api_keys = [om.llm_api_key.get_secret_value() for om in org_members]
assert 'test-key-1' in api_keys
assert 'test-key-2' in api_keys
def test_get_user_orgs(session_maker):
# Test getting org_members by user ID
with session_maker() as session:
# Create test data
org1 = Org(name='test-org-1')
org2 = Org(name='test-org-2')
session.add_all([org1, org2])
session.flush()
user = User(id=uuid.uuid4(), current_org_id=org1.id)
role = Role(name='admin', rank=1)
session.add_all([user, role])
session.flush()
org_member1 = OrgMember(
org_id=org1.id,
user_id=user.id,
role_id=role.id,
llm_api_key='test-key-1',
status='active',
)
org_member2 = OrgMember(
org_id=org2.id,
user_id=user.id,
role_id=role.id,
llm_api_key='test-key-2',
status='active',
)
session.add_all([org_member1, org_member2])
session.commit()
user_id = user.id
# Test retrieval
with patch('storage.org_member_store.session_maker', session_maker):
org_members = OrgMemberStore.get_user_orgs(user_id)
assert len(org_members) == 2
api_keys = [ou.llm_api_key.get_secret_value() for ou in org_members]
assert 'test-key-1' in api_keys
assert 'test-key-2' in api_keys
def test_get_org_member(session_maker):
# Test getting org_member by org and user ID
with session_maker() as session:
# Create test data
org = Org(name='test-org')
session.add(org)
session.flush()
user = User(id=uuid.uuid4(), current_org_id=org.id)
role = Role(name='admin', rank=1)
session.add_all([user, role])
session.flush()
org_member = OrgMember(
org_id=org.id,
user_id=user.id,
role_id=role.id,
llm_api_key='test-key',
status='active',
)
session.add(org_member)
session.commit()
org_id = org.id
user_id = user.id
# Test retrieval
with patch('storage.org_member_store.session_maker', session_maker):
retrieved_org_member = OrgMemberStore.get_org_member(org_id, user_id)
assert retrieved_org_member is not None
assert retrieved_org_member.org_id == org_id
assert retrieved_org_member.user_id == user_id
assert retrieved_org_member.llm_api_key.get_secret_value() == 'test-key'
def test_add_user_to_org(session_maker):
# Test adding a user to an org
with session_maker() as session:
# Create test data
org = Org(name='test-org')
session.add(org)
session.flush()
user = User(id=uuid.uuid4(), current_org_id=org.id)
role = Role(name='admin', rank=1)
session.add_all([user, role])
session.commit()
org_id = org.id
user_id = user.id
role_id = role.id
# Test creation
with patch('storage.org_member_store.session_maker', session_maker):
org_member = OrgMemberStore.add_user_to_org(
org_id=org_id,
user_id=user_id,
role_id=role_id,
llm_api_key='new-test-key',
status='active',
)
assert org_member is not None
assert org_member.org_id == org_id
assert org_member.user_id == user_id
assert org_member.role_id == role_id
assert org_member.llm_api_key.get_secret_value() == 'new-test-key'
assert org_member.status == 'active'
def test_update_user_role_in_org(session_maker):
# Test updating user role in org
with session_maker() as session:
# Create test data
org = Org(name='test-org')
session.add(org)
session.flush()
user = User(id=uuid.uuid4(), current_org_id=org.id)
role1 = Role(name='admin', rank=1)
role2 = Role(name='user', rank=2)
session.add_all([user, role1, role2])
session.flush()
org_member = OrgMember(
org_id=org.id,
user_id=user.id,
role_id=role1.id,
llm_api_key='test-key',
status='active',
)
session.add(org_member)
session.commit()
org_id = org.id
user_id = user.id
role2_id = role2.id
# Test update
with patch('storage.org_member_store.session_maker', session_maker):
updated_org_member = OrgMemberStore.update_user_role_in_org(
org_id=org_id, user_id=user_id, role_id=role2_id, status='inactive'
)
assert updated_org_member is not None
assert updated_org_member.role_id == role2_id
assert updated_org_member.status == 'inactive'
def test_update_user_role_in_org_not_found(session_maker):
# Test updating org_member that doesn't exist
from uuid import uuid4
with patch('storage.org_member_store.session_maker', session_maker):
updated_org_member = OrgMemberStore.update_user_role_in_org(
org_id=uuid4(), user_id=99999, role_id=1
)
assert updated_org_member is None
def test_remove_user_from_org(session_maker):
# Test removing a user from an org
with session_maker() as session:
# Create test data
org = Org(name='test-org')
session.add(org)
session.flush()
user = User(id=uuid.uuid4(), current_org_id=org.id)
role = Role(name='admin', rank=1)
session.add_all([user, role])
session.flush()
org_member = OrgMember(
org_id=org.id,
user_id=user.id,
role_id=role.id,
llm_api_key='test-key',
status='active',
)
session.add(org_member)
session.commit()
org_id = org.id
user_id = user.id
# Test removal
with patch('storage.org_member_store.session_maker', session_maker):
result = OrgMemberStore.remove_user_from_org(org_id, user_id)
assert result is True
# Verify it's removed
retrieved_org_member = OrgMemberStore.get_org_member(org_id, user_id)
assert retrieved_org_member is None
def test_remove_user_from_org_not_found(session_maker):
# Test removing user from org that doesn't exist
from uuid import uuid4
with patch('storage.org_member_store.session_maker', session_maker):
result = OrgMemberStore.remove_user_from_org(uuid4(), 99999)
assert result is False

File diff suppressed because it is too large Load Diff

Some files were not shown because too many files have changed in this diff Show More