mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-04-29 03:00:45 -04:00
Compare commits
1 Commits
fix-async-
...
pre-org-ch
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c6501ff35d |
6
.github/workflows/ghcr-build.yml
vendored
6
.github/workflows/ghcr-build.yml
vendored
@@ -39,7 +39,8 @@ jobs:
|
||||
run: |
|
||||
if [[ "$GITHUB_EVENT_NAME" == "pull_request" ]]; then
|
||||
json=$(jq -n -c '[
|
||||
{ image: "nikolaik/python-nodejs:python3.12-nodejs22", tag: "nikolaik" }
|
||||
{ image: "nikolaik/python-nodejs:python3.12-nodejs22", tag: "nikolaik" },
|
||||
{ image: "ubuntu:24.04", tag: "ubuntu" }
|
||||
]')
|
||||
else
|
||||
json=$(jq -n -c '[
|
||||
@@ -148,9 +149,6 @@ jobs:
|
||||
push: true
|
||||
tags: ${{ env.DOCKER_TAGS }}
|
||||
platforms: ${{ env.DOCKER_PLATFORM }}
|
||||
# Caching directives to boost performance
|
||||
cache-from: type=registry,ref=ghcr.io/${{ env.REPO_OWNER }}/runtime:buildcache-${{ matrix.base_image.tag }}
|
||||
cache-to: type=registry,ref=ghcr.io/${{ env.REPO_OWNER }}/runtime:buildcache-${{ matrix.base_image.tag }},mode=max
|
||||
build-args: ${{ env.DOCKER_BUILD_ARGS }}
|
||||
context: containers/runtime
|
||||
provenance: false
|
||||
|
||||
@@ -161,7 +161,7 @@ poetry run pytest ./tests/unit/test_*.py
|
||||
To reduce build time (e.g., if no changes were made to the client-runtime component), you can use an existing Docker
|
||||
container image by setting the SANDBOX_RUNTIME_CONTAINER_IMAGE environment variable to the desired Docker image.
|
||||
|
||||
Example: `export SANDBOX_RUNTIME_CONTAINER_IMAGE=ghcr.io/openhands/runtime:1.2-nikolaik`
|
||||
Example: `export SANDBOX_RUNTIME_CONTAINER_IMAGE=ghcr.io/openhands/runtime:1.1-nikolaik`
|
||||
|
||||
## Develop inside Docker container
|
||||
|
||||
|
||||
@@ -12,8 +12,7 @@ services:
|
||||
- SANDBOX_API_HOSTNAME=host.docker.internal
|
||||
- DOCKER_HOST_ADDR=host.docker.internal
|
||||
#
|
||||
- AGENT_SERVER_IMAGE_REPOSITORY=${AGENT_SERVER_IMAGE_REPOSITORY:-ghcr.io/openhands/runtime}
|
||||
- AGENT_SERVER_IMAGE_TAG=${AGENT_SERVER_IMAGE_TAG:-1.2-nikolaik}
|
||||
- SANDBOX_RUNTIME_CONTAINER_IMAGE=${SANDBOX_RUNTIME_CONTAINER_IMAGE:-ghcr.io/openhands/runtime:1.1-nikolaik}
|
||||
- SANDBOX_USER_ID=${SANDBOX_USER_ID:-1234}
|
||||
- WORKSPACE_MOUNT_PATH=${WORKSPACE_BASE:-$PWD/workspace}
|
||||
ports:
|
||||
|
||||
@@ -7,8 +7,7 @@ services:
|
||||
image: openhands:latest
|
||||
container_name: openhands-app-${DATE:-}
|
||||
environment:
|
||||
- AGENT_SERVER_IMAGE_REPOSITORY=${AGENT_SERVER_IMAGE_REPOSITORY:-ghcr.io/openhands/agent-server}
|
||||
- AGENT_SERVER_IMAGE_TAG=${AGENT_SERVER_IMAGE_TAG:-31536c8-python}
|
||||
- SANDBOX_RUNTIME_CONTAINER_IMAGE=${SANDBOX_RUNTIME_CONTAINER_IMAGE:-docker.openhands.dev/openhands/runtime:1.1-nikolaik}
|
||||
#- SANDBOX_USER_ID=${SANDBOX_USER_ID:-1234} # enable this only if you want a specific non-root sandbox user but you will have to manually adjust permissions of ~/.openhands for this user
|
||||
- WORKSPACE_MOUNT_PATH=${WORKSPACE_BASE:-$PWD/workspace}
|
||||
ports:
|
||||
|
||||
@@ -2,7 +2,7 @@ BACKEND_HOST ?= "127.0.0.1"
|
||||
BACKEND_PORT = 3000
|
||||
BACKEND_HOST_PORT = "$(BACKEND_HOST):$(BACKEND_PORT)"
|
||||
FRONTEND_PORT = 3001
|
||||
OPENHANDS_PATH ?= ".."
|
||||
OPENHANDS_PATH ?= "../../OpenHands"
|
||||
OPENHANDS := $(OPENHANDS_PATH)
|
||||
OPENHANDS_FRONTEND_PATH = $(OPENHANDS)/frontend/build
|
||||
|
||||
|
||||
@@ -1,205 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
Downgrade script for migrated users.
|
||||
|
||||
This script identifies users who have been migrated (already_migrated=True)
|
||||
and reverts them back to the pre-migration state.
|
||||
|
||||
Usage:
|
||||
# Dry run - just list the users that would be downgraded
|
||||
python downgrade_migrated_users.py --dry-run
|
||||
|
||||
# Downgrade a specific user by their keycloak_user_id
|
||||
python downgrade_migrated_users.py --user-id <user_id>
|
||||
|
||||
# Downgrade all migrated users (with confirmation)
|
||||
python downgrade_migrated_users.py --all
|
||||
|
||||
# Downgrade all migrated users without confirmation (dangerous!)
|
||||
python downgrade_migrated_users.py --all --no-confirm
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import sys
|
||||
|
||||
# Add the enterprise directory to the path
|
||||
sys.path.insert(0, '/workspace/project/OpenHands/enterprise')
|
||||
|
||||
from server.logger import logger
|
||||
from sqlalchemy import select, text
|
||||
from storage.database import session_maker
|
||||
from storage.user_settings import UserSettings
|
||||
from storage.user_store import UserStore
|
||||
|
||||
|
||||
def get_migrated_users() -> list[str]:
|
||||
"""Get list of keycloak_user_ids for users who have been migrated.
|
||||
|
||||
This includes:
|
||||
1. Users with already_migrated=True in user_settings (migrated users)
|
||||
2. Users in the 'user' table who don't have a user_settings entry (new sign-ups)
|
||||
"""
|
||||
with session_maker() as session:
|
||||
# Get users from user_settings with already_migrated=True
|
||||
migrated_result = session.execute(
|
||||
select(UserSettings.keycloak_user_id).where(
|
||||
UserSettings.already_migrated.is_(True)
|
||||
)
|
||||
)
|
||||
migrated_users = {row[0] for row in migrated_result.fetchall() if row[0]}
|
||||
|
||||
# Get users from the 'user' table (new sign-ups won't have user_settings)
|
||||
# These are users who signed up after the migration was deployed
|
||||
new_signup_result = session.execute(
|
||||
text("""
|
||||
SELECT CAST(u.id AS VARCHAR)
|
||||
FROM "user" u
|
||||
WHERE NOT EXISTS (
|
||||
SELECT 1 FROM user_settings us
|
||||
WHERE us.keycloak_user_id = CAST(u.id AS VARCHAR)
|
||||
)
|
||||
""")
|
||||
)
|
||||
new_signups = {row[0] for row in new_signup_result.fetchall() if row[0]}
|
||||
|
||||
# Combine both sets
|
||||
all_users = migrated_users | new_signups
|
||||
return list(all_users)
|
||||
|
||||
|
||||
async def downgrade_user(user_id: str) -> bool:
|
||||
"""Downgrade a single user.
|
||||
|
||||
Args:
|
||||
user_id: The keycloak_user_id to downgrade
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
result = await UserStore.downgrade_user(user_id)
|
||||
if result:
|
||||
print(f'✓ Successfully downgraded user: {user_id}')
|
||||
return True
|
||||
else:
|
||||
print(f'✗ Failed to downgrade user: {user_id}')
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f'✗ Error downgrading user {user_id}: {e}')
|
||||
logger.exception(
|
||||
'downgrade_script:error',
|
||||
extra={'user_id': user_id, 'error': str(e)},
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
async def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Downgrade migrated users back to pre-migration state'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--dry-run',
|
||||
action='store_true',
|
||||
help='Just list users that would be downgraded, without making changes',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--user-id',
|
||||
type=str,
|
||||
help='Downgrade a specific user by keycloak_user_id',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--all',
|
||||
action='store_true',
|
||||
help='Downgrade all migrated users',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--no-confirm',
|
||||
action='store_true',
|
||||
help='Skip confirmation prompt (use with caution!)',
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Get list of migrated users
|
||||
migrated_users = get_migrated_users()
|
||||
print(f'\nFound {len(migrated_users)} migrated user(s).')
|
||||
|
||||
if args.dry_run:
|
||||
print('\n--- DRY RUN MODE ---')
|
||||
print('The following users would be downgraded:')
|
||||
for user_id in migrated_users:
|
||||
print(f' - {user_id}')
|
||||
print('\nNo changes were made.')
|
||||
return
|
||||
|
||||
if args.user_id:
|
||||
# Downgrade a specific user
|
||||
if args.user_id not in migrated_users:
|
||||
print(f'\nUser {args.user_id} is not in the migrated users list.')
|
||||
print('Either the user was not migrated, or the user_id is incorrect.')
|
||||
return
|
||||
|
||||
print(f'\nDowngrading user: {args.user_id}')
|
||||
if not args.no_confirm:
|
||||
confirm = input('Are you sure? (yes/no): ')
|
||||
if confirm.lower() != 'yes':
|
||||
print('Cancelled.')
|
||||
return
|
||||
|
||||
success = await downgrade_user(args.user_id)
|
||||
if success:
|
||||
print('\nDowngrade completed successfully.')
|
||||
else:
|
||||
print('\nDowngrade failed. Check logs for details.')
|
||||
sys.exit(1)
|
||||
|
||||
elif args.all:
|
||||
# Downgrade all migrated users
|
||||
if not migrated_users:
|
||||
print('\nNo migrated users to downgrade.')
|
||||
return
|
||||
|
||||
print(f'\n⚠️ About to downgrade {len(migrated_users)} user(s).')
|
||||
if not args.no_confirm:
|
||||
print('\nThis will:')
|
||||
print(' - Revert LiteLLM team/user budget settings')
|
||||
print(' - Delete organization entries')
|
||||
print(' - Delete user entries in the new schema')
|
||||
print(' - Reset the already_migrated flag')
|
||||
print('\nUsers to downgrade:')
|
||||
for user_id in migrated_users[:10]: # Show first 10
|
||||
print(f' - {user_id}')
|
||||
if len(migrated_users) > 10:
|
||||
print(f' ... and {len(migrated_users) - 10} more')
|
||||
|
||||
confirm = input('\nType "yes" to proceed: ')
|
||||
if confirm.lower() != 'yes':
|
||||
print('Cancelled.')
|
||||
return
|
||||
|
||||
print('\nStarting downgrade...\n')
|
||||
success_count = 0
|
||||
fail_count = 0
|
||||
|
||||
for user_id in migrated_users:
|
||||
success = await downgrade_user(user_id)
|
||||
if success:
|
||||
success_count += 1
|
||||
else:
|
||||
fail_count += 1
|
||||
|
||||
print('\n--- Summary ---')
|
||||
print(f'Successful: {success_count}')
|
||||
print(f'Failed: {fail_count}')
|
||||
|
||||
if fail_count > 0:
|
||||
sys.exit(1)
|
||||
|
||||
else:
|
||||
parser.print_help()
|
||||
print('\nPlease specify --dry-run, --user-id, or --all')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
asyncio.run(main())
|
||||
@@ -26,14 +26,12 @@ from integrations.utils import (
|
||||
from integrations.v1_utils import get_saas_user_auth
|
||||
from jinja2 import Environment, FileSystemLoader
|
||||
from pydantic import SecretStr
|
||||
from server.auth.auth_error import ExpiredError
|
||||
from server.auth.constants import GITHUB_APP_CLIENT_ID, GITHUB_APP_PRIVATE_KEY
|
||||
from server.auth.token_manager import TokenManager
|
||||
from server.utils.conversation_callback_utils import register_callback_processor
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.integrations.provider import ProviderToken, ProviderType
|
||||
from openhands.integrations.service_types import AuthenticationError
|
||||
from openhands.server.types import (
|
||||
LLMAuthenticationError,
|
||||
MissingSettingsError,
|
||||
@@ -349,7 +347,7 @@ class GithubManager(Manager):
|
||||
|
||||
msg_info = f'@{user_info.username} please set a valid LLM API key in [OpenHands Cloud]({HOST_URL}) before starting a job.'
|
||||
|
||||
except (AuthenticationError, ExpiredError, SessionExpiredError) as e:
|
||||
except SessionExpiredError as e:
|
||||
logger.warning(
|
||||
f'[GitHub] Session expired for user {user_info.username}: {str(e)}'
|
||||
)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import asyncio
|
||||
|
||||
from integrations.store_repo_utils import store_repositories_in_db
|
||||
from integrations.utils import store_repositories_in_db
|
||||
from pydantic import SecretStr
|
||||
from server.auth.token_manager import TokenManager
|
||||
|
||||
|
||||
@@ -25,9 +25,9 @@ from server.auth.constants import GITHUB_APP_CLIENT_ID, GITHUB_APP_PRIVATE_KEY
|
||||
from server.auth.token_manager import TokenManager
|
||||
from server.config import get_config
|
||||
from storage.database import session_maker
|
||||
from storage.org_store import OrgStore
|
||||
from storage.proactive_conversation_store import ProactiveConversationStore
|
||||
from storage.saas_secrets_store import SaasSecretsStore
|
||||
from storage.saas_settings_store import SaasSettingsStore
|
||||
|
||||
from openhands.agent_server.models import SendMessageRequest
|
||||
from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
@@ -78,17 +78,19 @@ async def get_user_proactive_conversation_setting(user_id: str | None) -> bool:
|
||||
if not user_id:
|
||||
return False
|
||||
|
||||
# Check global setting first - if disabled globally, return False
|
||||
if not ENABLE_PROACTIVE_CONVERSATION_STARTERS:
|
||||
config = get_config()
|
||||
settings_store = SaasSettingsStore(
|
||||
user_id=user_id, session_maker=session_maker, config=config
|
||||
)
|
||||
|
||||
settings = await call_sync_from_async(
|
||||
settings_store.get_user_settings_by_keycloak_id, user_id
|
||||
)
|
||||
|
||||
if not settings or settings.enable_proactive_conversation_starters is None:
|
||||
return False
|
||||
|
||||
def _get_setting():
|
||||
org = OrgStore.get_current_org_from_keycloak_user_id(user_id)
|
||||
if not org:
|
||||
return False
|
||||
return bool(org.enable_proactive_conversation_starters)
|
||||
|
||||
return await call_sync_from_async(_get_setting)
|
||||
return settings.enable_proactive_conversation_starters
|
||||
|
||||
|
||||
# =================================================
|
||||
@@ -149,7 +151,6 @@ class GithubIssue(ResolverViewInterface):
|
||||
issue_body=self.description,
|
||||
previous_comments=self.previous_comments,
|
||||
)
|
||||
|
||||
return user_instructions, conversation_instructions
|
||||
|
||||
async def _get_user_secrets(self):
|
||||
@@ -188,7 +189,6 @@ class GithubIssue(ResolverViewInterface):
|
||||
conversation_trigger=ConversationTrigger.RESOLVER,
|
||||
git_provider=ProviderType.GITHUB,
|
||||
)
|
||||
|
||||
self.conversation_id = conversation_metadata.conversation_id
|
||||
return conversation_metadata
|
||||
|
||||
@@ -327,6 +327,7 @@ class GithubIssueComment(GithubIssue):
|
||||
conversation_instructions_template = jinja_env.get_template(
|
||||
'issue_conversation_instructions.j2'
|
||||
)
|
||||
|
||||
conversation_instructions = conversation_instructions_template.render(
|
||||
issue_number=self.issue_number,
|
||||
issue_title=self.title,
|
||||
@@ -396,6 +397,7 @@ class GithubInlinePRComment(GithubPRComment):
|
||||
conversation_instructions_template = jinja_env.get_template(
|
||||
'pr_update_conversation_instructions.j2'
|
||||
)
|
||||
|
||||
conversation_instructions = conversation_instructions_template.render(
|
||||
pr_number=self.issue_number,
|
||||
pr_title=self.title,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import asyncio
|
||||
|
||||
from integrations.store_repo_utils import store_repositories_in_db
|
||||
from integrations.types import GitLabResourceType
|
||||
from integrations.utils import store_repositories_in_db
|
||||
from pydantic import SecretStr
|
||||
from server.auth.token_manager import TokenManager
|
||||
from storage.gitlab_webhook import GitlabWebhook, WebhookStatus
|
||||
|
||||
@@ -1,37 +1,22 @@
|
||||
"""Jira integration manager.
|
||||
|
||||
This module orchestrates the processing of Jira webhook events:
|
||||
1. Parse webhook payload (via JiraPayloadParser)
|
||||
2. Validate workspace
|
||||
3. Authenticate user
|
||||
4. Create view with repository selection (via JiraFactory)
|
||||
5. Start conversation job
|
||||
|
||||
The manager delegates payload parsing to JiraPayloadParser and view creation
|
||||
to JiraFactory, keeping the orchestration logic clean and traceable.
|
||||
"""
|
||||
import hashlib
|
||||
import hmac
|
||||
from typing import Dict, Optional, Tuple
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import httpx
|
||||
from integrations.jira.jira_payload import (
|
||||
JiraPayloadError,
|
||||
JiraPayloadParser,
|
||||
JiraPayloadSkipped,
|
||||
JiraPayloadSuccess,
|
||||
JiraWebhookPayload,
|
||||
from fastapi import Request
|
||||
from integrations.jira.jira_types import JiraViewInterface
|
||||
from integrations.jira.jira_view import (
|
||||
JiraExistingConversationView,
|
||||
JiraFactory,
|
||||
JiraNewConversationView,
|
||||
)
|
||||
from integrations.jira.jira_types import (
|
||||
JiraViewInterface,
|
||||
RepositoryNotFoundError,
|
||||
StartingConvoException,
|
||||
)
|
||||
from integrations.jira.jira_view import JiraFactory, JiraNewConversationView
|
||||
from integrations.manager import Manager
|
||||
from integrations.models import Message
|
||||
from integrations.models import JobContext, Message
|
||||
from integrations.utils import (
|
||||
HOST,
|
||||
HOST_URL,
|
||||
OPENHANDS_RESOLVER_TEMPLATES_DIR,
|
||||
get_oh_labels,
|
||||
filter_potential_repos_by_user_msg,
|
||||
get_session_expired_message,
|
||||
)
|
||||
from jinja2 import Environment, FileSystemLoader
|
||||
@@ -43,6 +28,9 @@ from storage.jira_user import JiraUser
|
||||
from storage.jira_workspace import JiraWorkspace
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.integrations.provider import ProviderHandler
|
||||
from openhands.integrations.service_types import Repository
|
||||
from openhands.server.shared import server_config
|
||||
from openhands.server.types import (
|
||||
LLMAuthenticationError,
|
||||
MissingSettingsError,
|
||||
@@ -53,211 +41,303 @@ from openhands.utils.http_session import httpx_verify_option
|
||||
|
||||
JIRA_CLOUD_API_URL = 'https://api.atlassian.com/ex/jira'
|
||||
|
||||
# Get OH labels for this environment
|
||||
OH_LABEL, INLINE_OH_LABEL = get_oh_labels(HOST)
|
||||
|
||||
|
||||
class JiraManager(Manager):
|
||||
"""Manager for processing Jira webhook events.
|
||||
|
||||
This class orchestrates the flow from webhook receipt to conversation creation,
|
||||
delegating parsing to JiraPayloadParser and view creation to JiraFactory.
|
||||
"""
|
||||
|
||||
def __init__(self, token_manager: TokenManager):
|
||||
self.token_manager = token_manager
|
||||
self.integration_store = JiraIntegrationStore.get_instance()
|
||||
self.jinja_env = Environment(
|
||||
loader=FileSystemLoader(OPENHANDS_RESOLVER_TEMPLATES_DIR + 'jira')
|
||||
)
|
||||
self.payload_parser = JiraPayloadParser(
|
||||
oh_label=OH_LABEL,
|
||||
inline_oh_label=INLINE_OH_LABEL,
|
||||
)
|
||||
|
||||
async def receive_message(self, message: Message):
|
||||
"""Process incoming Jira webhook message.
|
||||
|
||||
Flow:
|
||||
1. Parse webhook payload
|
||||
2. Validate workspace exists and is active
|
||||
3. Authenticate user
|
||||
4. Create view (includes fetching issue details and selecting repository)
|
||||
5. Start job
|
||||
|
||||
Each step has clear logging for traceability.
|
||||
"""
|
||||
raw_payload = message.message.get('payload', {})
|
||||
|
||||
# Step 1: Parse webhook payload
|
||||
logger.info(
|
||||
'[Jira] Received webhook',
|
||||
extra={'raw_payload': raw_payload},
|
||||
)
|
||||
|
||||
parse_result = self.payload_parser.parse(raw_payload)
|
||||
|
||||
if isinstance(parse_result, JiraPayloadSkipped):
|
||||
logger.info(
|
||||
'[Jira] Webhook skipped', extra={'reason': parse_result.skip_reason}
|
||||
)
|
||||
return
|
||||
|
||||
if isinstance(parse_result, JiraPayloadError):
|
||||
logger.warning(
|
||||
'[Jira] Webhook parse failed', extra={'error': parse_result.error}
|
||||
)
|
||||
return
|
||||
|
||||
payload = parse_result.payload
|
||||
logger.info(
|
||||
'[Jira] Processing webhook',
|
||||
extra={
|
||||
'event_type': payload.event_type.value,
|
||||
'issue_key': payload.issue_key,
|
||||
'user_email': payload.user_email,
|
||||
},
|
||||
)
|
||||
|
||||
# Step 2: Validate workspace
|
||||
workspace = await self._get_active_workspace(payload)
|
||||
if not workspace:
|
||||
return
|
||||
|
||||
# Step 3: Authenticate user
|
||||
jira_user, saas_user_auth = await self._authenticate_user(payload, workspace)
|
||||
if not jira_user or not saas_user_auth:
|
||||
return
|
||||
|
||||
# Step 4: Create view (includes issue details fetch and repo selection)
|
||||
decrypted_api_key = self.token_manager.decrypt_text(workspace.svc_acc_api_key)
|
||||
|
||||
try:
|
||||
view = await JiraFactory.create_view(
|
||||
payload=payload,
|
||||
workspace=workspace,
|
||||
user=jira_user,
|
||||
user_auth=saas_user_auth,
|
||||
decrypted_api_key=decrypted_api_key,
|
||||
)
|
||||
except RepositoryNotFoundError as e:
|
||||
logger.warning(
|
||||
'[Jira] Repository not found',
|
||||
extra={'issue_key': payload.issue_key, 'error': str(e)},
|
||||
)
|
||||
await self._send_error_from_payload(payload, workspace, str(e))
|
||||
return
|
||||
except StartingConvoException as e:
|
||||
logger.warning(
|
||||
'[Jira] View creation failed',
|
||||
extra={'issue_key': payload.issue_key, 'error': str(e)},
|
||||
)
|
||||
await self._send_error_from_payload(payload, workspace, str(e))
|
||||
return
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
'[Jira] Unexpected error creating view',
|
||||
extra={'issue_key': payload.issue_key, 'error': str(e)},
|
||||
exc_info=True,
|
||||
)
|
||||
await self._send_error_from_payload(
|
||||
payload,
|
||||
workspace,
|
||||
'Failed to initialize conversation. Please try again.',
|
||||
)
|
||||
return
|
||||
|
||||
# Step 5: Start job
|
||||
await self.start_job(view)
|
||||
|
||||
async def _get_active_workspace(
|
||||
self, payload: JiraWebhookPayload
|
||||
) -> JiraWorkspace | None:
|
||||
"""Validate and return the workspace for the webhook.
|
||||
|
||||
Returns None if:
|
||||
- Workspace not found
|
||||
- Workspace is inactive
|
||||
- Request is from service account (to prevent recursion)
|
||||
"""
|
||||
workspace = await self.integration_store.get_workspace_by_name(
|
||||
payload.workspace_name
|
||||
)
|
||||
|
||||
if not workspace:
|
||||
logger.warning(
|
||||
'[Jira] Workspace not found',
|
||||
extra={'workspace_name': payload.workspace_name},
|
||||
)
|
||||
# Can't send error without workspace credentials
|
||||
return None
|
||||
|
||||
# Prevent recursive triggers from service account
|
||||
if payload.user_email == workspace.svc_acc_email:
|
||||
logger.debug(
|
||||
'[Jira] Ignoring service account trigger',
|
||||
extra={'workspace_name': payload.workspace_name},
|
||||
)
|
||||
return None
|
||||
|
||||
if workspace.status != 'active':
|
||||
logger.warning(
|
||||
'[Jira] Workspace inactive',
|
||||
extra={'workspace_id': workspace.id, 'status': workspace.status},
|
||||
)
|
||||
await self._send_error_from_payload(
|
||||
payload, workspace, 'Jira integration is not active for your workspace.'
|
||||
)
|
||||
return None
|
||||
|
||||
return workspace
|
||||
|
||||
async def _authenticate_user(
|
||||
self, payload: JiraWebhookPayload, workspace: JiraWorkspace
|
||||
async def authenticate_user(
|
||||
self, jira_user_id: str, workspace_id: int
|
||||
) -> tuple[JiraUser | None, UserAuth | None]:
|
||||
"""Authenticate the Jira user and get OpenHands auth."""
|
||||
"""Authenticate Jira user and get their OpenHands user auth."""
|
||||
|
||||
# Find active Jira user by Keycloak user ID and workspace ID
|
||||
jira_user = await self.integration_store.get_active_user(
|
||||
payload.account_id, workspace.id
|
||||
jira_user_id, workspace_id
|
||||
)
|
||||
|
||||
if not jira_user:
|
||||
logger.warning(
|
||||
'[Jira] User not found or inactive',
|
||||
extra={
|
||||
'account_id': payload.account_id,
|
||||
'user_email': payload.user_email,
|
||||
'workspace_id': workspace.id,
|
||||
},
|
||||
)
|
||||
await self._send_error_from_payload(
|
||||
payload,
|
||||
workspace,
|
||||
f'User {payload.user_email} is not authenticated or active in the Jira integration.',
|
||||
f'[Jira] No active Jira user found for {jira_user_id} in workspace {workspace_id}'
|
||||
)
|
||||
return None, None
|
||||
|
||||
saas_user_auth = await get_user_auth_from_keycloak_id(
|
||||
jira_user.keycloak_user_id
|
||||
)
|
||||
|
||||
if not saas_user_auth:
|
||||
logger.warning(
|
||||
'[Jira] Failed to get OpenHands auth',
|
||||
extra={
|
||||
'keycloak_user_id': jira_user.keycloak_user_id,
|
||||
'user_email': payload.user_email,
|
||||
},
|
||||
)
|
||||
await self._send_error_from_payload(
|
||||
payload,
|
||||
workspace,
|
||||
f'User {payload.user_email} is not authenticated with OpenHands.',
|
||||
)
|
||||
return None, None
|
||||
|
||||
return jira_user, saas_user_auth
|
||||
|
||||
async def start_job(self, view: JiraViewInterface):
|
||||
async def _get_repositories(self, user_auth: UserAuth) -> list[Repository]:
|
||||
"""Get repositories that the user has access to."""
|
||||
provider_tokens = await user_auth.get_provider_tokens()
|
||||
if provider_tokens is None:
|
||||
return []
|
||||
access_token = await user_auth.get_access_token()
|
||||
user_id = await user_auth.get_user_id()
|
||||
client = ProviderHandler(
|
||||
provider_tokens=provider_tokens,
|
||||
external_auth_token=access_token,
|
||||
external_auth_id=user_id,
|
||||
)
|
||||
repos: list[Repository] = await client.get_repositories(
|
||||
'pushed', server_config.app_mode, None, None, None, None
|
||||
)
|
||||
return repos
|
||||
|
||||
async def validate_request(
|
||||
self, request: Request
|
||||
) -> Tuple[bool, Optional[str], Optional[Dict]]:
|
||||
"""Verify Jira webhook signature."""
|
||||
signature_header = request.headers.get('x-hub-signature')
|
||||
signature = signature_header.split('=')[1] if signature_header else None
|
||||
body = await request.body()
|
||||
payload = await request.json()
|
||||
workspace_name = ''
|
||||
|
||||
if payload.get('webhookEvent') == 'comment_created':
|
||||
selfUrl = payload.get('comment', {}).get('author', {}).get('self')
|
||||
elif payload.get('webhookEvent') == 'jira:issue_updated':
|
||||
selfUrl = payload.get('user', {}).get('self')
|
||||
else:
|
||||
workspace_name = ''
|
||||
|
||||
parsedUrl = urlparse(selfUrl)
|
||||
if parsedUrl.hostname:
|
||||
workspace_name = parsedUrl.hostname
|
||||
|
||||
if not workspace_name:
|
||||
logger.warning('[Jira] No workspace name found in webhook payload')
|
||||
return False, None, None
|
||||
|
||||
if not signature:
|
||||
logger.warning('[Jira] No signature found in webhook headers')
|
||||
return False, None, None
|
||||
|
||||
workspace = await self.integration_store.get_workspace_by_name(workspace_name)
|
||||
|
||||
if not workspace:
|
||||
logger.warning('[Jira] Could not identify workspace for webhook')
|
||||
return False, None, None
|
||||
|
||||
if workspace.status != 'active':
|
||||
logger.warning(f'[Jira] Workspace {workspace.id} is not active')
|
||||
return False, None, None
|
||||
|
||||
webhook_secret = self.token_manager.decrypt_text(workspace.webhook_secret)
|
||||
digest = hmac.new(webhook_secret.encode(), body, hashlib.sha256).hexdigest()
|
||||
|
||||
if hmac.compare_digest(signature, digest):
|
||||
logger.info('[Jira] Webhook signature verified successfully')
|
||||
return True, signature, payload
|
||||
|
||||
return False, None, None
|
||||
|
||||
def parse_webhook(self, payload: Dict) -> JobContext | None:
|
||||
event_type = payload.get('webhookEvent')
|
||||
|
||||
if event_type == 'comment_created':
|
||||
comment_data = payload.get('comment', {})
|
||||
comment = comment_data.get('body', '')
|
||||
|
||||
if '@openhands' not in comment:
|
||||
return None
|
||||
|
||||
issue_data = payload.get('issue', {})
|
||||
issue_id = issue_data.get('id')
|
||||
issue_key = issue_data.get('key')
|
||||
base_api_url = issue_data.get('self', '').split('/rest/')[0]
|
||||
|
||||
user_data = comment_data.get('author', {})
|
||||
user_email = user_data.get('emailAddress')
|
||||
display_name = user_data.get('displayName')
|
||||
account_id = user_data.get('accountId')
|
||||
elif event_type == 'jira:issue_updated':
|
||||
changelog = payload.get('changelog', {})
|
||||
items = changelog.get('items', [])
|
||||
labels = [
|
||||
item.get('toString', '')
|
||||
for item in items
|
||||
if item.get('field') == 'labels' and 'toString' in item
|
||||
]
|
||||
|
||||
if 'openhands' not in labels:
|
||||
return None
|
||||
|
||||
issue_data = payload.get('issue', {})
|
||||
issue_id = issue_data.get('id')
|
||||
issue_key = issue_data.get('key')
|
||||
base_api_url = issue_data.get('self', '').split('/rest/')[0]
|
||||
|
||||
user_data = payload.get('user', {})
|
||||
user_email = user_data.get('emailAddress')
|
||||
display_name = user_data.get('displayName')
|
||||
account_id = user_data.get('accountId')
|
||||
comment = ''
|
||||
else:
|
||||
return None
|
||||
|
||||
workspace_name = ''
|
||||
|
||||
parsedUrl = urlparse(base_api_url)
|
||||
if parsedUrl.hostname:
|
||||
workspace_name = parsedUrl.hostname
|
||||
|
||||
if not all(
|
||||
[
|
||||
issue_id,
|
||||
issue_key,
|
||||
user_email,
|
||||
display_name,
|
||||
account_id,
|
||||
workspace_name,
|
||||
base_api_url,
|
||||
]
|
||||
):
|
||||
return None
|
||||
|
||||
return JobContext(
|
||||
issue_id=issue_id,
|
||||
issue_key=issue_key,
|
||||
user_msg=comment,
|
||||
user_email=user_email,
|
||||
display_name=display_name,
|
||||
platform_user_id=account_id,
|
||||
workspace_name=workspace_name,
|
||||
base_api_url=base_api_url,
|
||||
)
|
||||
|
||||
async def receive_message(self, message: Message):
|
||||
"""Process incoming Jira webhook message."""
|
||||
|
||||
payload = message.message.get('payload', {})
|
||||
job_context = self.parse_webhook(payload)
|
||||
|
||||
if not job_context:
|
||||
logger.info('[Jira] Webhook does not match trigger conditions')
|
||||
return
|
||||
|
||||
# Get workspace by user email domain
|
||||
workspace = await self.integration_store.get_workspace_by_name(
|
||||
job_context.workspace_name
|
||||
)
|
||||
if not workspace:
|
||||
logger.warning(
|
||||
f'[Jira] No workspace found for email domain: {job_context.user_email}'
|
||||
)
|
||||
await self._send_error_comment(
|
||||
job_context,
|
||||
'Your workspace is not configured with Jira integration.',
|
||||
None,
|
||||
)
|
||||
return
|
||||
|
||||
# Prevent any recursive triggers from the service account
|
||||
if job_context.user_email == workspace.svc_acc_email:
|
||||
return
|
||||
|
||||
if workspace.status != 'active':
|
||||
logger.warning(f'[Jira] Workspace {workspace.id} is not active')
|
||||
await self._send_error_comment(
|
||||
job_context,
|
||||
'Jira integration is not active for your workspace.',
|
||||
workspace,
|
||||
)
|
||||
return
|
||||
|
||||
# Authenticate user
|
||||
jira_user, saas_user_auth = await self.authenticate_user(
|
||||
job_context.platform_user_id, workspace.id
|
||||
)
|
||||
if not jira_user or not saas_user_auth:
|
||||
logger.warning(
|
||||
f'[Jira] User authentication failed for {job_context.user_email}'
|
||||
)
|
||||
await self._send_error_comment(
|
||||
job_context,
|
||||
f'User {job_context.user_email} is not authenticated or active in the Jira integration.',
|
||||
workspace,
|
||||
)
|
||||
return
|
||||
|
||||
# Get issue details
|
||||
try:
|
||||
api_key = self.token_manager.decrypt_text(workspace.svc_acc_api_key)
|
||||
issue_title, issue_description = await self.get_issue_details(
|
||||
job_context, workspace.jira_cloud_id, workspace.svc_acc_email, api_key
|
||||
)
|
||||
job_context.issue_title = issue_title
|
||||
job_context.issue_description = issue_description
|
||||
except Exception as e:
|
||||
logger.error(f'[Jira] Failed to get issue context: {str(e)}')
|
||||
await self._send_error_comment(
|
||||
job_context,
|
||||
'Failed to retrieve issue details. Please check the issue key and try again.',
|
||||
workspace,
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
# Create Jira view
|
||||
jira_view = await JiraFactory.create_jira_view_from_payload(
|
||||
job_context,
|
||||
saas_user_auth,
|
||||
jira_user,
|
||||
workspace,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f'[Jira] Failed to create jira view: {str(e)}', exc_info=True)
|
||||
await self._send_error_comment(
|
||||
job_context,
|
||||
'Failed to initialize conversation. Please try again.',
|
||||
workspace,
|
||||
)
|
||||
return
|
||||
|
||||
if not await self.is_job_requested(message, jira_view):
|
||||
return
|
||||
|
||||
await self.start_job(jira_view)
|
||||
|
||||
async def is_job_requested(
|
||||
self, message: Message, jira_view: JiraViewInterface
|
||||
) -> bool:
|
||||
"""
|
||||
Check if a job is requested and handle repository selection.
|
||||
"""
|
||||
|
||||
if isinstance(jira_view, JiraExistingConversationView):
|
||||
return True
|
||||
|
||||
try:
|
||||
# Get user repositories
|
||||
user_repos: list[Repository] = await self._get_repositories(
|
||||
jira_view.saas_user_auth
|
||||
)
|
||||
|
||||
target_str = f'{jira_view.job_context.issue_description}\n{jira_view.job_context.user_msg}'
|
||||
|
||||
# Try to infer repository from issue description
|
||||
match, repos = filter_potential_repos_by_user_msg(target_str, user_repos)
|
||||
|
||||
if match:
|
||||
# Found exact repository match
|
||||
jira_view.selected_repo = repos[0].full_name
|
||||
logger.info(f'[Jira] Inferred repository: {repos[0].full_name}')
|
||||
return True
|
||||
else:
|
||||
# No clear match - send repository selection comment
|
||||
await self._send_repo_selection_comment(jira_view)
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f'[Jira] Error in is_job_requested: {str(e)}')
|
||||
return False
|
||||
|
||||
async def start_job(self, jira_view: JiraViewInterface):
|
||||
"""Start a Jira job/conversation."""
|
||||
# Import here to prevent circular import
|
||||
from server.conversation_callback_processor.jira_callback_processor import (
|
||||
@@ -265,79 +345,101 @@ class JiraManager(Manager):
|
||||
)
|
||||
|
||||
try:
|
||||
user_info: JiraUser = jira_view.jira_user
|
||||
logger.info(
|
||||
'[Jira] Starting job',
|
||||
extra={
|
||||
'issue_key': view.payload.issue_key,
|
||||
'user_id': view.jira_user.keycloak_user_id,
|
||||
'selected_repo': view.selected_repo,
|
||||
},
|
||||
f'[Jira] Starting job for user {user_info.keycloak_user_id} '
|
||||
f'issue {jira_view.job_context.issue_key}',
|
||||
)
|
||||
|
||||
# Create conversation
|
||||
conversation_id = await view.create_or_update_conversation(self.jinja_env)
|
||||
conversation_id = await jira_view.create_or_update_conversation(
|
||||
self.jinja_env
|
||||
)
|
||||
|
||||
logger.info(
|
||||
'[Jira] Conversation created',
|
||||
extra={
|
||||
'conversation_id': conversation_id,
|
||||
'issue_key': view.payload.issue_key,
|
||||
},
|
||||
f'[Jira] Created/Updated conversation {conversation_id} for issue {jira_view.job_context.issue_key}'
|
||||
)
|
||||
|
||||
# Register callback processor for updates
|
||||
if isinstance(view, JiraNewConversationView):
|
||||
if isinstance(jira_view, JiraNewConversationView):
|
||||
processor = JiraCallbackProcessor(
|
||||
issue_key=view.payload.issue_key,
|
||||
workspace_name=view.jira_workspace.name,
|
||||
)
|
||||
register_callback_processor(conversation_id, processor)
|
||||
logger.info(
|
||||
'[Jira] Callback processor registered',
|
||||
extra={'conversation_id': conversation_id},
|
||||
issue_key=jira_view.job_context.issue_key,
|
||||
workspace_name=jira_view.jira_workspace.name,
|
||||
)
|
||||
|
||||
# Send success response
|
||||
msg_info = view.get_response_msg()
|
||||
# Register the callback processor
|
||||
register_callback_processor(conversation_id, processor)
|
||||
|
||||
logger.info(
|
||||
f'[Jira] Created callback processor for conversation {conversation_id}'
|
||||
)
|
||||
|
||||
# Send initial response
|
||||
msg_info = jira_view.get_response_msg()
|
||||
|
||||
except MissingSettingsError as e:
|
||||
logger.warning(
|
||||
'[Jira] Missing settings error',
|
||||
extra={'issue_key': view.payload.issue_key, 'error': str(e)},
|
||||
)
|
||||
logger.warning(f'[Jira] Missing settings error: {str(e)}')
|
||||
msg_info = f'Please re-login into [OpenHands Cloud]({HOST_URL}) before starting a job.'
|
||||
|
||||
except LLMAuthenticationError as e:
|
||||
logger.warning(
|
||||
'[Jira] LLM authentication error',
|
||||
extra={'issue_key': view.payload.issue_key, 'error': str(e)},
|
||||
)
|
||||
logger.warning(f'[Jira] LLM authentication error: {str(e)}')
|
||||
msg_info = f'Please set a valid LLM API key in [OpenHands Cloud]({HOST_URL}) before starting a job.'
|
||||
|
||||
except SessionExpiredError as e:
|
||||
logger.warning(
|
||||
'[Jira] Session expired',
|
||||
extra={'issue_key': view.payload.issue_key, 'error': str(e)},
|
||||
)
|
||||
logger.warning(f'[Jira] Session expired: {str(e)}')
|
||||
msg_info = get_session_expired_message()
|
||||
|
||||
except StartingConvoException as e:
|
||||
logger.warning(
|
||||
'[Jira] Conversation start failed',
|
||||
extra={'issue_key': view.payload.issue_key, 'error': str(e)},
|
||||
)
|
||||
msg_info = str(e)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
'[Jira] Unexpected error starting job',
|
||||
extra={'issue_key': view.payload.issue_key, 'error': str(e)},
|
||||
exc_info=True,
|
||||
f'[Jira] Unexpected error starting job: {str(e)}', exc_info=True
|
||||
)
|
||||
msg_info = 'Sorry, there was an unexpected error starting the job. Please try again.'
|
||||
|
||||
# Send response comment
|
||||
await self._send_comment(view, msg_info)
|
||||
try:
|
||||
api_key = self.token_manager.decrypt_text(
|
||||
jira_view.jira_workspace.svc_acc_api_key
|
||||
)
|
||||
await self.send_message(
|
||||
self.create_outgoing_message(msg=msg_info),
|
||||
issue_key=jira_view.job_context.issue_key,
|
||||
jira_cloud_id=jira_view.jira_workspace.jira_cloud_id,
|
||||
svc_acc_email=jira_view.jira_workspace.svc_acc_email,
|
||||
svc_acc_api_key=api_key,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f'[Jira] Failed to send response message: {str(e)}')
|
||||
|
||||
async def get_issue_details(
|
||||
self,
|
||||
job_context: JobContext,
|
||||
jira_cloud_id: str,
|
||||
svc_acc_email: str,
|
||||
svc_acc_api_key: str,
|
||||
) -> Tuple[str, str]:
|
||||
url = f'{JIRA_CLOUD_API_URL}/{jira_cloud_id}/rest/api/2/issue/{job_context.issue_key}'
|
||||
async with httpx.AsyncClient(verify=httpx_verify_option()) as client:
|
||||
response = await client.get(url, auth=(svc_acc_email, svc_acc_api_key))
|
||||
response.raise_for_status()
|
||||
issue_payload = response.json()
|
||||
|
||||
if not issue_payload:
|
||||
raise ValueError(f'Issue with key {job_context.issue_key} not found.')
|
||||
|
||||
title = issue_payload.get('fields', {}).get('summary', '')
|
||||
description = issue_payload.get('fields', {}).get('description', '')
|
||||
|
||||
if not title:
|
||||
raise ValueError(
|
||||
f'Issue with key {job_context.issue_key} does not have a title.'
|
||||
)
|
||||
|
||||
if not description:
|
||||
raise ValueError(
|
||||
f'Issue with key {job_context.issue_key} does not have a description.'
|
||||
)
|
||||
|
||||
return title, description
|
||||
|
||||
async def send_message(
|
||||
self,
|
||||
@@ -347,7 +449,6 @@ class JiraManager(Manager):
|
||||
svc_acc_email: str,
|
||||
svc_acc_api_key: str,
|
||||
):
|
||||
"""Send a comment to a Jira issue."""
|
||||
url = (
|
||||
f'{JIRA_CLOUD_API_URL}/{jira_cloud_id}/rest/api/2/issue/{issue_key}/comment'
|
||||
)
|
||||
@@ -359,53 +460,54 @@ class JiraManager(Manager):
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
async def _send_comment(self, view: JiraViewInterface, msg: str):
|
||||
"""Send a comment using credentials from the view."""
|
||||
try:
|
||||
api_key = self.token_manager.decrypt_text(
|
||||
view.jira_workspace.svc_acc_api_key
|
||||
)
|
||||
await self.send_message(
|
||||
self.create_outgoing_message(msg=msg),
|
||||
issue_key=view.payload.issue_key,
|
||||
jira_cloud_id=view.jira_workspace.jira_cloud_id,
|
||||
svc_acc_email=view.jira_workspace.svc_acc_email,
|
||||
svc_acc_api_key=api_key,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
'[Jira] Failed to send comment',
|
||||
extra={'issue_key': view.payload.issue_key, 'error': str(e)},
|
||||
)
|
||||
|
||||
async def _send_error_from_payload(
|
||||
async def _send_error_comment(
|
||||
self,
|
||||
payload: JiraWebhookPayload,
|
||||
workspace: JiraWorkspace,
|
||||
job_context: JobContext,
|
||||
error_msg: str,
|
||||
workspace: JiraWorkspace | None,
|
||||
):
|
||||
"""Send error comment before view is created (using payload directly)."""
|
||||
"""Send error comment to Jira issue."""
|
||||
if not workspace:
|
||||
logger.error('[Jira] Cannot send error comment - no workspace available')
|
||||
return
|
||||
|
||||
try:
|
||||
api_key = self.token_manager.decrypt_text(workspace.svc_acc_api_key)
|
||||
await self.send_message(
|
||||
self.create_outgoing_message(msg=error_msg),
|
||||
issue_key=payload.issue_key,
|
||||
issue_key=job_context.issue_key,
|
||||
jira_cloud_id=workspace.jira_cloud_id,
|
||||
svc_acc_email=workspace.svc_acc_email,
|
||||
svc_acc_api_key=api_key,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
'[Jira] Failed to send error comment',
|
||||
extra={'issue_key': payload.issue_key, 'error': str(e)},
|
||||
logger.error(f'[Jira] Failed to send error comment: {str(e)}')
|
||||
|
||||
async def _send_repo_selection_comment(self, jira_view: JiraViewInterface):
|
||||
"""Send a comment with repository options for the user to choose."""
|
||||
try:
|
||||
comment_msg = (
|
||||
'I need to know which repository to work with. '
|
||||
'Please add it to your issue description or send a followup comment.'
|
||||
)
|
||||
|
||||
def get_workspace_name_from_payload(self, payload: dict) -> str | None:
|
||||
"""Extract workspace name from Jira webhook payload.
|
||||
api_key = self.token_manager.decrypt_text(
|
||||
jira_view.jira_workspace.svc_acc_api_key
|
||||
)
|
||||
|
||||
This method is used by the route for signature verification.
|
||||
"""
|
||||
parse_result = self.payload_parser.parse(payload)
|
||||
if isinstance(parse_result, JiraPayloadSuccess):
|
||||
return parse_result.payload.workspace_name
|
||||
return None
|
||||
await self.send_message(
|
||||
self.create_outgoing_message(msg=comment_msg),
|
||||
issue_key=jira_view.job_context.issue_key,
|
||||
jira_cloud_id=jira_view.jira_workspace.jira_cloud_id,
|
||||
svc_acc_email=jira_view.jira_workspace.svc_acc_email,
|
||||
svc_acc_api_key=api_key,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f'[Jira] Sent repository selection comment for issue {jira_view.job_context.issue_key}'
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f'[Jira] Failed to send repository selection comment: {str(e)}'
|
||||
)
|
||||
|
||||
@@ -1,267 +0,0 @@
|
||||
"""Centralized payload parsing for Jira webhooks.
|
||||
|
||||
This module provides a single source of truth for parsing and validating
|
||||
Jira webhook payloads, replacing scattered parsing logic throughout the codebase.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
|
||||
|
||||
class JiraEventType(Enum):
|
||||
"""Types of Jira events we handle."""
|
||||
|
||||
LABELED_TICKET = 'labeled_ticket'
|
||||
COMMENT_MENTION = 'comment_mention'
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class JiraWebhookPayload:
|
||||
"""Normalized, validated representation of a Jira webhook payload.
|
||||
|
||||
This immutable dataclass replaces JobContext and provides a single
|
||||
source of truth for all webhook data. All parsing happens in
|
||||
JiraPayloadParser, ensuring consistent validation.
|
||||
"""
|
||||
|
||||
event_type: JiraEventType
|
||||
raw_event: str # Original webhookEvent value
|
||||
|
||||
# Issue data
|
||||
issue_id: str
|
||||
issue_key: str
|
||||
|
||||
# User data
|
||||
user_email: str
|
||||
display_name: str
|
||||
account_id: str
|
||||
|
||||
# Workspace data (derived from issue self URL)
|
||||
workspace_name: str
|
||||
base_api_url: str
|
||||
|
||||
# Event-specific data
|
||||
comment_body: str = '' # For comment events
|
||||
|
||||
@property
|
||||
def user_msg(self) -> str:
|
||||
"""Alias for comment_body for backward compatibility."""
|
||||
return self.comment_body
|
||||
|
||||
|
||||
class JiraPayloadParseError(Exception):
|
||||
"""Raised when payload parsing fails."""
|
||||
|
||||
def __init__(self, reason: str, event_type: str | None = None):
|
||||
self.reason = reason
|
||||
self.event_type = event_type
|
||||
super().__init__(reason)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class JiraPayloadSuccess:
|
||||
"""Result when parsing succeeds."""
|
||||
|
||||
payload: JiraWebhookPayload
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class JiraPayloadSkipped:
|
||||
"""Result when event is intentionally skipped."""
|
||||
|
||||
skip_reason: str
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class JiraPayloadError:
|
||||
"""Result when parsing fails due to invalid data."""
|
||||
|
||||
error: str
|
||||
|
||||
|
||||
JiraPayloadParseResult = JiraPayloadSuccess | JiraPayloadSkipped | JiraPayloadError
|
||||
|
||||
|
||||
class JiraPayloadParser:
|
||||
"""Centralized parser for Jira webhook payloads.
|
||||
|
||||
This class provides a single entry point for parsing webhooks,
|
||||
determining event types, and extracting all necessary fields.
|
||||
Replaces scattered parsing in JiraFactory and JiraManager.
|
||||
"""
|
||||
|
||||
def __init__(self, oh_label: str, inline_oh_label: str):
|
||||
"""Initialize parser with OpenHands label configuration.
|
||||
|
||||
Args:
|
||||
oh_label: Label that triggers OpenHands (e.g., 'openhands')
|
||||
inline_oh_label: Mention that triggers OpenHands (e.g., '@openhands')
|
||||
"""
|
||||
self.oh_label = oh_label
|
||||
self.inline_oh_label = inline_oh_label
|
||||
|
||||
def parse(self, raw_payload: dict) -> JiraPayloadParseResult:
|
||||
"""Parse a raw webhook payload into a normalized JiraWebhookPayload.
|
||||
|
||||
Args:
|
||||
raw_payload: The raw webhook payload dict from Jira
|
||||
|
||||
Returns:
|
||||
One of:
|
||||
- JiraPayloadSuccess: Valid, actionable event with payload
|
||||
- JiraPayloadSkipped: Event we intentionally don't process
|
||||
- JiraPayloadError: Malformed payload we expected to process
|
||||
"""
|
||||
webhook_event = raw_payload.get('webhookEvent', '')
|
||||
|
||||
logger.debug(
|
||||
'[Jira] Parsing webhook payload', extra={'webhook_event': webhook_event}
|
||||
)
|
||||
|
||||
if webhook_event == 'jira:issue_updated':
|
||||
return self._parse_label_event(raw_payload, webhook_event)
|
||||
elif webhook_event == 'comment_created':
|
||||
return self._parse_comment_event(raw_payload, webhook_event)
|
||||
else:
|
||||
return JiraPayloadSkipped(f'Unhandled webhook event type: {webhook_event}')
|
||||
|
||||
def _parse_label_event(
|
||||
self, payload: dict, webhook_event: str
|
||||
) -> JiraPayloadParseResult:
|
||||
"""Parse an issue_updated event for label changes."""
|
||||
changelog = payload.get('changelog', {})
|
||||
items = changelog.get('items', [])
|
||||
|
||||
# Extract labels that were added
|
||||
labels = [
|
||||
item.get('toString', '')
|
||||
for item in items
|
||||
if item.get('field') == 'labels' and 'toString' in item
|
||||
]
|
||||
|
||||
if self.oh_label not in labels:
|
||||
return JiraPayloadSkipped(
|
||||
f"Label event does not contain '{self.oh_label}' label"
|
||||
)
|
||||
|
||||
# For label events, user data comes from 'user' field
|
||||
user_data = payload.get('user', {})
|
||||
return self._extract_and_validate(
|
||||
payload=payload,
|
||||
user_data=user_data,
|
||||
event_type=JiraEventType.LABELED_TICKET,
|
||||
webhook_event=webhook_event,
|
||||
comment_body='',
|
||||
)
|
||||
|
||||
def _parse_comment_event(
|
||||
self, payload: dict, webhook_event: str
|
||||
) -> JiraPayloadParseResult:
|
||||
"""Parse a comment_created event."""
|
||||
comment_data = payload.get('comment', {})
|
||||
comment_body = comment_data.get('body', '')
|
||||
|
||||
if not self._has_mention(comment_body):
|
||||
return JiraPayloadSkipped(
|
||||
f"Comment does not mention '{self.inline_oh_label}'"
|
||||
)
|
||||
|
||||
# For comment events, user data comes from 'comment.author'
|
||||
user_data = comment_data.get('author', {})
|
||||
return self._extract_and_validate(
|
||||
payload=payload,
|
||||
user_data=user_data,
|
||||
event_type=JiraEventType.COMMENT_MENTION,
|
||||
webhook_event=webhook_event,
|
||||
comment_body=comment_body,
|
||||
)
|
||||
|
||||
def _has_mention(self, text: str) -> bool:
|
||||
"""Check if text contains an exact mention of OpenHands."""
|
||||
from integrations.utils import has_exact_mention
|
||||
|
||||
return has_exact_mention(text, self.inline_oh_label)
|
||||
|
||||
def _extract_and_validate(
|
||||
self,
|
||||
payload: dict,
|
||||
user_data: dict,
|
||||
event_type: JiraEventType,
|
||||
webhook_event: str,
|
||||
comment_body: str,
|
||||
) -> JiraPayloadParseResult:
|
||||
"""Extract common fields and validate required data is present."""
|
||||
issue_data = payload.get('issue', {})
|
||||
|
||||
# Extract all fields with empty string defaults (makes them str type)
|
||||
issue_id = issue_data.get('id', '')
|
||||
issue_key = issue_data.get('key', '')
|
||||
user_email = user_data.get('emailAddress', '')
|
||||
display_name = user_data.get('displayName', '')
|
||||
account_id = user_data.get('accountId', '')
|
||||
base_api_url, workspace_name = self._extract_workspace_from_url(
|
||||
issue_data.get('self', '')
|
||||
)
|
||||
|
||||
# Validate required fields
|
||||
missing: list[str] = []
|
||||
if not issue_id:
|
||||
missing.append('issue.id')
|
||||
if not issue_key:
|
||||
missing.append('issue.key')
|
||||
if not user_email:
|
||||
missing.append('user.emailAddress')
|
||||
if not display_name:
|
||||
missing.append('user.displayName')
|
||||
if not account_id:
|
||||
missing.append('user.accountId')
|
||||
if not workspace_name:
|
||||
missing.append('workspace_name (derived from issue.self)')
|
||||
if not base_api_url:
|
||||
missing.append('base_api_url (derived from issue.self)')
|
||||
|
||||
if missing:
|
||||
return JiraPayloadError(f"Missing required fields: {', '.join(missing)}")
|
||||
|
||||
return JiraPayloadSuccess(
|
||||
JiraWebhookPayload(
|
||||
event_type=event_type,
|
||||
raw_event=webhook_event,
|
||||
issue_id=issue_id,
|
||||
issue_key=issue_key,
|
||||
user_email=user_email,
|
||||
display_name=display_name,
|
||||
account_id=account_id,
|
||||
workspace_name=workspace_name,
|
||||
base_api_url=base_api_url,
|
||||
comment_body=comment_body,
|
||||
)
|
||||
)
|
||||
|
||||
def _extract_workspace_from_url(self, self_url: str) -> tuple[str, str]:
|
||||
"""Extract base API URL and workspace name from issue self URL.
|
||||
|
||||
Args:
|
||||
self_url: The 'self' URL from the issue data
|
||||
|
||||
Returns:
|
||||
Tuple of (base_api_url, workspace_name)
|
||||
"""
|
||||
if not self_url:
|
||||
return '', ''
|
||||
|
||||
# Extract base URL (everything before /rest/)
|
||||
if '/rest/' in self_url:
|
||||
base_api_url = self_url.split('/rest/')[0]
|
||||
else:
|
||||
parsed = urlparse(self_url)
|
||||
base_api_url = f'{parsed.scheme}://{parsed.netloc}'
|
||||
|
||||
# Extract workspace name (hostname)
|
||||
parsed = urlparse(base_api_url)
|
||||
workspace_name = parsed.hostname or ''
|
||||
|
||||
return base_api_url, workspace_name
|
||||
@@ -1,42 +1,26 @@
|
||||
"""Type definitions and interfaces for Jira integration."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from integrations.models import JobContext
|
||||
from jinja2 import Environment
|
||||
from storage.jira_user import JiraUser
|
||||
from storage.jira_workspace import JiraWorkspace
|
||||
|
||||
from openhands.server.user_auth.user_auth import UserAuth
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from integrations.jira.jira_payload import JiraWebhookPayload
|
||||
|
||||
|
||||
class JiraViewInterface(ABC):
|
||||
"""Interface for Jira views that handle different types of Jira interactions.
|
||||
"""Interface for Jira views that handle different types of Jira interactions."""
|
||||
|
||||
Views hold the webhook payload directly rather than duplicating fields,
|
||||
and fetch issue details lazily when needed.
|
||||
"""
|
||||
|
||||
# Core data - view holds these references
|
||||
payload: 'JiraWebhookPayload'
|
||||
job_context: JobContext
|
||||
saas_user_auth: UserAuth
|
||||
jira_user: JiraUser
|
||||
jira_workspace: JiraWorkspace
|
||||
|
||||
# Mutable state set during processing
|
||||
selected_repo: str | None
|
||||
conversation_id: str
|
||||
|
||||
@abstractmethod
|
||||
async def get_issue_details(self) -> tuple[str, str]:
|
||||
"""Fetch and cache issue title and description from Jira API.
|
||||
|
||||
Returns:
|
||||
Tuple of (issue_title, issue_description)
|
||||
"""
|
||||
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
"""Get initial instructions for the conversation."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@@ -51,21 +35,6 @@ class JiraViewInterface(ABC):
|
||||
|
||||
|
||||
class StartingConvoException(Exception):
|
||||
"""Exception raised when starting a conversation fails.
|
||||
|
||||
This provides user-friendly error messages that can be sent back to Jira.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class RepositoryNotFoundError(Exception):
|
||||
"""Raised when a repository cannot be determined from the issue.
|
||||
|
||||
This is a separate error domain from StartingConvoException - it represents
|
||||
a precondition failure (no repo configured/found) rather than a conversation
|
||||
creation failure. The manager catches this and converts it to a user-friendly
|
||||
message.
|
||||
"""
|
||||
"""Exception raised when starting a conversation fails."""
|
||||
|
||||
pass
|
||||
|
||||
@@ -1,21 +1,8 @@
|
||||
"""Jira view implementations and factory.
|
||||
from dataclasses import dataclass
|
||||
|
||||
Views are responsible for:
|
||||
- Holding the webhook payload and auth context
|
||||
- Lazy-loading issue details from Jira API when needed
|
||||
- Creating conversations with the selected repository
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import httpx
|
||||
from integrations.jira.jira_payload import JiraWebhookPayload
|
||||
from integrations.jira.jira_types import (
|
||||
JiraViewInterface,
|
||||
RepositoryNotFoundError,
|
||||
StartingConvoException,
|
||||
)
|
||||
from integrations.utils import CONVERSATION_URL, infer_repo_from_message
|
||||
from integrations.jira.jira_types import JiraViewInterface, StartingConvoException
|
||||
from integrations.models import JobContext
|
||||
from integrations.utils import CONVERSATION_URL, get_final_agent_observation
|
||||
from jinja2 import Environment
|
||||
from storage.jira_conversation import JiraConversation
|
||||
from storage.jira_integration_store import JiraIntegrationStore
|
||||
@@ -23,147 +10,55 @@ from storage.jira_user import JiraUser
|
||||
from storage.jira_workspace import JiraWorkspace
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.integrations.provider import ProviderHandler
|
||||
from openhands.server.services.conversation_service import create_new_conversation
|
||||
from openhands.core.schema.agent import AgentState
|
||||
from openhands.events.action import MessageAction
|
||||
from openhands.events.serialization.event import event_to_dict
|
||||
from openhands.server.services.conversation_service import (
|
||||
create_new_conversation,
|
||||
setup_init_conversation_settings,
|
||||
)
|
||||
from openhands.server.shared import ConversationStoreImpl, config, conversation_manager
|
||||
from openhands.server.user_auth.user_auth import UserAuth
|
||||
from openhands.storage.data_models.conversation_metadata import ConversationTrigger
|
||||
from openhands.utils.http_session import httpx_verify_option
|
||||
|
||||
JIRA_CLOUD_API_URL = 'https://api.atlassian.com/ex/jira'
|
||||
|
||||
integration_store = JiraIntegrationStore.get_instance()
|
||||
|
||||
|
||||
@dataclass
|
||||
class JiraNewConversationView(JiraViewInterface):
|
||||
"""View for creating a new Jira conversation.
|
||||
|
||||
This view holds the webhook payload directly and lazily fetches
|
||||
issue details when needed for rendering templates.
|
||||
"""
|
||||
|
||||
payload: JiraWebhookPayload
|
||||
job_context: JobContext
|
||||
saas_user_auth: UserAuth
|
||||
jira_user: JiraUser
|
||||
jira_workspace: JiraWorkspace
|
||||
selected_repo: str | None = None
|
||||
conversation_id: str = ''
|
||||
selected_repo: str | None
|
||||
conversation_id: str
|
||||
|
||||
# Lazy-loaded issue details (cached after first fetch)
|
||||
_issue_title: str | None = field(default=None, repr=False)
|
||||
_issue_description: str | None = field(default=None, repr=False)
|
||||
|
||||
# Decrypted API key (set by factory)
|
||||
_decrypted_api_key: str = field(default='', repr=False)
|
||||
|
||||
async def get_issue_details(self) -> tuple[str, str]:
|
||||
"""Fetch issue details from Jira API (cached after first call).
|
||||
|
||||
Returns:
|
||||
Tuple of (issue_title, issue_description)
|
||||
|
||||
Raises:
|
||||
StartingConvoException: If issue details cannot be fetched
|
||||
"""
|
||||
if self._issue_title is not None and self._issue_description is not None:
|
||||
return self._issue_title, self._issue_description
|
||||
|
||||
try:
|
||||
url = f'{JIRA_CLOUD_API_URL}/{self.jira_workspace.jira_cloud_id}/rest/api/2/issue/{self.payload.issue_key}'
|
||||
async with httpx.AsyncClient(verify=httpx_verify_option()) as client:
|
||||
response = await client.get(
|
||||
url,
|
||||
auth=(
|
||||
self.jira_workspace.svc_acc_email,
|
||||
self._decrypted_api_key,
|
||||
),
|
||||
)
|
||||
response.raise_for_status()
|
||||
issue_payload = response.json()
|
||||
|
||||
if not issue_payload:
|
||||
raise StartingConvoException(
|
||||
f'Issue {self.payload.issue_key} not found.'
|
||||
)
|
||||
|
||||
self._issue_title = issue_payload.get('fields', {}).get('summary', '')
|
||||
self._issue_description = (
|
||||
issue_payload.get('fields', {}).get('description', '') or ''
|
||||
)
|
||||
|
||||
if not self._issue_title:
|
||||
raise StartingConvoException(
|
||||
f'Issue {self.payload.issue_key} does not have a title.'
|
||||
)
|
||||
|
||||
logger.info(
|
||||
'[Jira] Fetched issue details',
|
||||
extra={
|
||||
'issue_key': self.payload.issue_key,
|
||||
'has_description': bool(self._issue_description),
|
||||
},
|
||||
)
|
||||
|
||||
return self._issue_title, self._issue_description
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(
|
||||
'[Jira] Failed to fetch issue details',
|
||||
extra={
|
||||
'issue_key': self.payload.issue_key,
|
||||
'status': e.response.status_code,
|
||||
},
|
||||
)
|
||||
raise StartingConvoException(
|
||||
f'Failed to fetch issue details: HTTP {e.response.status_code}'
|
||||
)
|
||||
except Exception as e:
|
||||
if isinstance(e, StartingConvoException):
|
||||
raise
|
||||
logger.error(
|
||||
'[Jira] Failed to fetch issue details',
|
||||
extra={'issue_key': self.payload.issue_key, 'error': str(e)},
|
||||
)
|
||||
raise StartingConvoException(f'Failed to fetch issue details: {str(e)}')
|
||||
|
||||
async def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
"""Get instructions for the conversation.
|
||||
|
||||
This fetches issue details if not already cached.
|
||||
|
||||
Returns:
|
||||
Tuple of (system_instructions, user_message)
|
||||
"""
|
||||
issue_title, issue_description = await self.get_issue_details()
|
||||
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
"""Instructions passed when conversation is first initialized"""
|
||||
|
||||
instructions_template = jinja_env.get_template('jira_instructions.j2')
|
||||
instructions = instructions_template.render()
|
||||
|
||||
user_msg_template = jinja_env.get_template('jira_new_conversation.j2')
|
||||
|
||||
user_msg = user_msg_template.render(
|
||||
issue_key=self.payload.issue_key,
|
||||
issue_title=issue_title,
|
||||
issue_description=issue_description,
|
||||
user_message=self.payload.user_msg,
|
||||
issue_key=self.job_context.issue_key,
|
||||
issue_title=self.job_context.issue_title,
|
||||
issue_description=self.job_context.issue_description,
|
||||
user_message=self.job_context.user_msg or '',
|
||||
)
|
||||
|
||||
return instructions, user_msg
|
||||
|
||||
async def create_or_update_conversation(self, jinja_env: Environment) -> str:
|
||||
"""Create a new Jira conversation.
|
||||
"""Create a new Jira conversation"""
|
||||
|
||||
Returns:
|
||||
The conversation ID
|
||||
|
||||
Raises:
|
||||
StartingConvoException: If conversation creation fails
|
||||
"""
|
||||
if not self.selected_repo:
|
||||
raise StartingConvoException('No repository selected for this conversation')
|
||||
|
||||
provider_tokens = await self.saas_user_auth.get_provider_tokens()
|
||||
user_secrets = await self.saas_user_auth.get_secrets()
|
||||
instructions, user_msg = await self._get_instructions(jinja_env)
|
||||
instructions, user_msg = self._get_instructions(jinja_env)
|
||||
|
||||
try:
|
||||
agent_loop_info = await create_new_conversation(
|
||||
@@ -181,259 +76,149 @@ class JiraNewConversationView(JiraViewInterface):
|
||||
|
||||
self.conversation_id = agent_loop_info.conversation_id
|
||||
|
||||
logger.info(
|
||||
'[Jira] Created conversation',
|
||||
extra={
|
||||
'conversation_id': self.conversation_id,
|
||||
'issue_key': self.payload.issue_key,
|
||||
'selected_repo': self.selected_repo,
|
||||
},
|
||||
)
|
||||
logger.info(f'[Jira] Created conversation {self.conversation_id}')
|
||||
|
||||
# Store Jira conversation mapping
|
||||
jira_conversation = JiraConversation(
|
||||
conversation_id=self.conversation_id,
|
||||
issue_id=self.payload.issue_id,
|
||||
issue_key=self.payload.issue_key,
|
||||
issue_id=self.job_context.issue_id,
|
||||
issue_key=self.job_context.issue_key,
|
||||
jira_user_id=self.jira_user.id,
|
||||
)
|
||||
|
||||
await integration_store.create_conversation(jira_conversation)
|
||||
|
||||
return self.conversation_id
|
||||
|
||||
except Exception as e:
|
||||
if isinstance(e, StartingConvoException):
|
||||
raise
|
||||
logger.error(
|
||||
'[Jira] Failed to create conversation',
|
||||
extra={'issue_key': self.payload.issue_key, 'error': str(e)},
|
||||
exc_info=True,
|
||||
f'[Jira] Failed to create conversation: {str(e)}', exc_info=True
|
||||
)
|
||||
raise StartingConvoException(f'Failed to create conversation: {str(e)}')
|
||||
|
||||
def get_response_msg(self) -> str:
|
||||
"""Get the response message to send back to Jira."""
|
||||
"""Get the response message to send back to Jira"""
|
||||
conversation_link = CONVERSATION_URL.format(self.conversation_id)
|
||||
return f"I'm on it! {self.payload.display_name} can [track my progress here|{conversation_link}]."
|
||||
return f"I'm on it! {self.job_context.display_name} can [track my progress here|{conversation_link}]."
|
||||
|
||||
|
||||
@dataclass
|
||||
class JiraExistingConversationView(JiraViewInterface):
|
||||
job_context: JobContext
|
||||
saas_user_auth: UserAuth
|
||||
jira_user: JiraUser
|
||||
jira_workspace: JiraWorkspace
|
||||
selected_repo: str | None
|
||||
conversation_id: str
|
||||
|
||||
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
"""Instructions passed when conversation is first initialized"""
|
||||
|
||||
user_msg_template = jinja_env.get_template('jira_existing_conversation.j2')
|
||||
user_msg = user_msg_template.render(
|
||||
issue_key=self.job_context.issue_key,
|
||||
user_message=self.job_context.user_msg or '',
|
||||
issue_title=self.job_context.issue_title,
|
||||
issue_description=self.job_context.issue_description,
|
||||
)
|
||||
|
||||
return '', user_msg
|
||||
|
||||
async def create_or_update_conversation(self, jinja_env: Environment) -> str:
|
||||
"""Update an existing Jira conversation"""
|
||||
|
||||
user_id = self.jira_user.keycloak_user_id
|
||||
|
||||
try:
|
||||
conversation_store = await ConversationStoreImpl.get_instance(
|
||||
config, user_id
|
||||
)
|
||||
|
||||
try:
|
||||
await conversation_store.get_metadata(self.conversation_id)
|
||||
except FileNotFoundError:
|
||||
raise StartingConvoException('Conversation no longer exists.')
|
||||
|
||||
provider_tokens = await self.saas_user_auth.get_provider_tokens()
|
||||
# Should we raise here if there are no providers?
|
||||
providers_set = list(provider_tokens.keys()) if provider_tokens else []
|
||||
|
||||
conversation_init_data = await setup_init_conversation_settings(
|
||||
user_id, self.conversation_id, providers_set
|
||||
)
|
||||
|
||||
# Either join ongoing conversation, or restart the conversation
|
||||
agent_loop_info = await conversation_manager.maybe_start_agent_loop(
|
||||
self.conversation_id, conversation_init_data, user_id
|
||||
)
|
||||
|
||||
final_agent_observation = get_final_agent_observation(
|
||||
agent_loop_info.event_store
|
||||
)
|
||||
agent_state = (
|
||||
None
|
||||
if len(final_agent_observation) == 0
|
||||
else final_agent_observation[0].agent_state
|
||||
)
|
||||
|
||||
if not agent_state or agent_state == AgentState.LOADING:
|
||||
raise StartingConvoException('Conversation is still starting')
|
||||
|
||||
_, user_msg = self._get_instructions(jinja_env)
|
||||
user_message_event = MessageAction(content=user_msg)
|
||||
await conversation_manager.send_event_to_conversation(
|
||||
self.conversation_id, event_to_dict(user_message_event)
|
||||
)
|
||||
|
||||
return self.conversation_id
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f'[Jira] Failed to create conversation: {str(e)}', exc_info=True
|
||||
)
|
||||
raise StartingConvoException(f'Failed to create conversation: {str(e)}')
|
||||
|
||||
def get_response_msg(self) -> str:
|
||||
"""Get the response message to send back to Jira"""
|
||||
conversation_link = CONVERSATION_URL.format(self.conversation_id)
|
||||
return f"I'm on it! {self.job_context.display_name} can [continue tracking my progress here|{conversation_link}]."
|
||||
|
||||
|
||||
class JiraFactory:
|
||||
"""Factory for creating Jira views.
|
||||
|
||||
The factory is responsible for:
|
||||
- Creating the appropriate view type
|
||||
- Inferring and selecting the repository
|
||||
- Validating all required data is available
|
||||
|
||||
Repository selection happens here so that view creation either
|
||||
succeeds with a valid repo or fails with a clear error.
|
||||
"""
|
||||
"""Factory for creating Jira views based on message content"""
|
||||
|
||||
@staticmethod
|
||||
async def _create_provider_handler(user_auth: UserAuth) -> ProviderHandler | None:
|
||||
"""Create a ProviderHandler for the user."""
|
||||
provider_tokens = await user_auth.get_provider_tokens()
|
||||
if provider_tokens is None:
|
||||
return None
|
||||
|
||||
access_token = await user_auth.get_access_token()
|
||||
user_id = await user_auth.get_user_id()
|
||||
|
||||
return ProviderHandler(
|
||||
provider_tokens=provider_tokens,
|
||||
external_auth_token=access_token,
|
||||
external_auth_id=user_id,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _extract_potential_repos(
|
||||
issue_key: str,
|
||||
issue_title: str,
|
||||
issue_description: str,
|
||||
user_msg: str,
|
||||
) -> list[str]:
|
||||
"""Extract potential repository names from issue content.
|
||||
|
||||
Raises:
|
||||
RepositoryNotFoundError: If no potential repos found in text.
|
||||
"""
|
||||
search_text = f'{issue_title}\n{issue_description}\n{user_msg}'
|
||||
potential_repos = infer_repo_from_message(search_text)
|
||||
|
||||
if not potential_repos:
|
||||
raise RepositoryNotFoundError(
|
||||
'Could not determine which repository to use. '
|
||||
'Please mention the repository (e.g., owner/repo) in the issue description or comment.'
|
||||
)
|
||||
|
||||
logger.info(
|
||||
'[Jira] Found potential repositories in issue content',
|
||||
extra={'issue_key': issue_key, 'potential_repos': potential_repos},
|
||||
)
|
||||
return potential_repos
|
||||
|
||||
@staticmethod
|
||||
async def _verify_repos(
|
||||
issue_key: str,
|
||||
potential_repos: list[str],
|
||||
provider_handler: ProviderHandler,
|
||||
) -> list[str]:
|
||||
"""Verify which repos the user has access to."""
|
||||
verified_repos: list[str] = []
|
||||
|
||||
for repo_name in potential_repos:
|
||||
try:
|
||||
repository = await provider_handler.verify_repo_provider(repo_name)
|
||||
verified_repos.append(repository.full_name)
|
||||
logger.debug(
|
||||
'[Jira] Repository verification succeeded',
|
||||
extra={'issue_key': issue_key, 'repository': repository.full_name},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug(
|
||||
'[Jira] Repository verification failed',
|
||||
extra={
|
||||
'issue_key': issue_key,
|
||||
'repo_name': repo_name,
|
||||
'error': str(e),
|
||||
},
|
||||
)
|
||||
|
||||
return verified_repos
|
||||
|
||||
@staticmethod
|
||||
def _select_single_repo(
|
||||
issue_key: str,
|
||||
potential_repos: list[str],
|
||||
verified_repos: list[str],
|
||||
) -> str:
|
||||
"""Select exactly one repo from verified repos.
|
||||
|
||||
Raises:
|
||||
RepositoryNotFoundError: If zero or multiple repos verified.
|
||||
"""
|
||||
if len(verified_repos) == 0:
|
||||
raise RepositoryNotFoundError(
|
||||
f'Could not access any of the mentioned repositories: {", ".join(potential_repos)}. '
|
||||
'Please ensure you have access to the repository and it exists.'
|
||||
)
|
||||
|
||||
if len(verified_repos) > 1:
|
||||
raise RepositoryNotFoundError(
|
||||
f'Multiple repositories found: {", ".join(verified_repos)}. '
|
||||
'Please specify exactly one repository in the issue description or comment.'
|
||||
)
|
||||
|
||||
logger.info(
|
||||
'[Jira] Verified repository access',
|
||||
extra={'issue_key': issue_key, 'repository': verified_repos[0]},
|
||||
)
|
||||
return verified_repos[0]
|
||||
|
||||
@staticmethod
|
||||
async def _infer_repository(
|
||||
payload: JiraWebhookPayload,
|
||||
user_auth: UserAuth,
|
||||
issue_title: str,
|
||||
issue_description: str,
|
||||
) -> str:
|
||||
"""Infer and verify the repository from issue content.
|
||||
|
||||
Raises:
|
||||
RepositoryNotFoundError: If no valid repository can be determined.
|
||||
"""
|
||||
provider_handler = await JiraFactory._create_provider_handler(user_auth)
|
||||
if not provider_handler:
|
||||
raise RepositoryNotFoundError(
|
||||
'No Git provider connected. Please connect a Git provider in OpenHands settings.'
|
||||
)
|
||||
|
||||
potential_repos = JiraFactory._extract_potential_repos(
|
||||
payload.issue_key, issue_title, issue_description, payload.user_msg
|
||||
)
|
||||
|
||||
verified_repos = await JiraFactory._verify_repos(
|
||||
payload.issue_key, potential_repos, provider_handler
|
||||
)
|
||||
|
||||
return JiraFactory._select_single_repo(
|
||||
payload.issue_key, potential_repos, verified_repos
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def create_view(
|
||||
payload: JiraWebhookPayload,
|
||||
workspace: JiraWorkspace,
|
||||
user: JiraUser,
|
||||
user_auth: UserAuth,
|
||||
decrypted_api_key: str,
|
||||
async def create_jira_view_from_payload(
|
||||
job_context: JobContext,
|
||||
saas_user_auth: UserAuth,
|
||||
jira_user: JiraUser,
|
||||
jira_workspace: JiraWorkspace,
|
||||
) -> JiraViewInterface:
|
||||
"""Create a Jira view with repository already selected.
|
||||
"""Create appropriate Jira view based on the message and user state"""
|
||||
|
||||
This factory method:
|
||||
1. Creates the view with payload and auth context
|
||||
2. Fetches issue details (needed for repo inference)
|
||||
3. Infers and selects the repository
|
||||
if not jira_user or not saas_user_auth or not jira_workspace:
|
||||
raise StartingConvoException('User not authenticated with Jira integration')
|
||||
|
||||
If any step fails, an appropriate exception is raised with
|
||||
a user-friendly message.
|
||||
|
||||
Args:
|
||||
payload: Parsed webhook payload
|
||||
workspace: The Jira workspace
|
||||
user: The Jira user
|
||||
user_auth: OpenHands user authentication
|
||||
decrypted_api_key: Decrypted service account API key
|
||||
|
||||
Returns:
|
||||
A JiraViewInterface with selected_repo populated
|
||||
|
||||
Raises:
|
||||
StartingConvoException: If view creation fails
|
||||
RepositoryNotFoundError: If repository cannot be determined
|
||||
"""
|
||||
logger.info(
|
||||
'[Jira] Creating view',
|
||||
extra={
|
||||
'issue_key': payload.issue_key,
|
||||
'event_type': payload.event_type.value,
|
||||
},
|
||||
conversation = await integration_store.get_user_conversations_by_issue_id(
|
||||
job_context.issue_id, jira_user.id
|
||||
)
|
||||
|
||||
# Create the view
|
||||
view = JiraNewConversationView(
|
||||
payload=payload,
|
||||
saas_user_auth=user_auth,
|
||||
jira_user=user,
|
||||
jira_workspace=workspace,
|
||||
_decrypted_api_key=decrypted_api_key,
|
||||
if conversation:
|
||||
logger.info(
|
||||
f'[Jira] Found existing conversation for issue {job_context.issue_id}'
|
||||
)
|
||||
return JiraExistingConversationView(
|
||||
job_context=job_context,
|
||||
saas_user_auth=saas_user_auth,
|
||||
jira_user=jira_user,
|
||||
jira_workspace=jira_workspace,
|
||||
selected_repo=None,
|
||||
conversation_id=conversation.conversation_id,
|
||||
)
|
||||
|
||||
return JiraNewConversationView(
|
||||
job_context=job_context,
|
||||
saas_user_auth=saas_user_auth,
|
||||
jira_user=jira_user,
|
||||
jira_workspace=jira_workspace,
|
||||
selected_repo=None, # Will be set later after repo inference
|
||||
conversation_id='', # Will be set when conversation is created
|
||||
)
|
||||
|
||||
# Fetch issue details (needed for repo inference)
|
||||
try:
|
||||
issue_title, issue_description = await view.get_issue_details()
|
||||
except StartingConvoException:
|
||||
raise # Re-raise with original message
|
||||
except Exception as e:
|
||||
raise StartingConvoException(f'Failed to fetch issue details: {str(e)}')
|
||||
|
||||
# Infer and select repository
|
||||
selected_repo = await JiraFactory._infer_repository(
|
||||
payload=payload,
|
||||
user_auth=user_auth,
|
||||
issue_title=issue_title,
|
||||
issue_description=issue_description,
|
||||
)
|
||||
|
||||
view.selected_repo = selected_repo
|
||||
|
||||
logger.info(
|
||||
'[Jira] View created successfully',
|
||||
extra={
|
||||
'issue_key': payload.issue_key,
|
||||
'selected_repo': selected_repo,
|
||||
},
|
||||
)
|
||||
|
||||
return view
|
||||
|
||||
@@ -16,6 +16,11 @@ class Manager(ABC):
|
||||
"Send message to integration from Openhands server"
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def is_job_requested(self, message: Message) -> bool:
|
||||
"Confirm that a job is being requested"
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def start_job(self):
|
||||
"Kick off a job with openhands agent"
|
||||
|
||||
@@ -37,12 +37,11 @@ class ResolverUserContext(UserContext):
|
||||
return f'https://github.com/{repository}.git'
|
||||
|
||||
async def get_latest_token(self, provider_type: ProviderType) -> str | None:
|
||||
# Return the appropriate token string from git_provider_tokens
|
||||
# Return the appropriate token from git_provider_tokens
|
||||
|
||||
provider_tokens = await self.saas_user_auth.get_provider_tokens()
|
||||
if provider_tokens:
|
||||
provider_token = provider_tokens.get(provider_type)
|
||||
if provider_token and provider_token.token:
|
||||
return provider_token.token.get_secret_value()
|
||||
return provider_tokens.get(provider_type)
|
||||
return None
|
||||
|
||||
async def get_provider_tokens(self) -> PROVIDER_TOKEN_TYPE | None:
|
||||
|
||||
@@ -189,7 +189,6 @@ class SlackNewConversationView(SlackViewInterface):
|
||||
'channel_id': self.channel_id,
|
||||
'conversation_id': self.conversation_id,
|
||||
'keycloak_user_id': user_info.keycloak_user_id,
|
||||
'org_id': user_info.org_id,
|
||||
'parent_id': self.thread_ts or self.message_ts,
|
||||
'v1_enabled': v1_enabled,
|
||||
},
|
||||
@@ -198,7 +197,6 @@ class SlackNewConversationView(SlackViewInterface):
|
||||
conversation_id=self.conversation_id,
|
||||
channel_id=self.channel_id,
|
||||
keycloak_user_id=user_info.keycloak_user_id,
|
||||
org_id=user_info.org_id,
|
||||
parent_id=self.thread_ts
|
||||
or self.message_ts, # conversations can start in a thread reply as well; we should always references the parent's (root level msg's) message ID
|
||||
v1_enabled=v1_enabled,
|
||||
@@ -401,10 +399,10 @@ class SlackUpdateExistingConversationView(SlackNewConversationView):
|
||||
if not agent_state or agent_state == AgentState.LOADING:
|
||||
raise StartingConvoException('Conversation is still starting')
|
||||
|
||||
instructions, _ = self._get_instructions(jinja)
|
||||
user_msg = MessageAction(content=instructions)
|
||||
user_msg, _ = self._get_instructions(jinja)
|
||||
user_msg_action = MessageAction(content=user_msg)
|
||||
await conversation_manager.send_event_to_conversation(
|
||||
self.conversation_id, event_to_dict(user_msg)
|
||||
self.conversation_id, event_to_dict(user_msg_action)
|
||||
)
|
||||
|
||||
async def send_message_to_v1_conversation(self, jinja: Environment):
|
||||
|
||||
@@ -1,53 +0,0 @@
|
||||
from storage.repository_store import RepositoryStore
|
||||
from storage.stored_repository import StoredRepository
|
||||
from storage.user_repo_map import UserRepositoryMap
|
||||
from storage.user_repo_map_store import UserRepositoryMapStore
|
||||
|
||||
from openhands.core.config.openhands_config import OpenHandsConfig
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.integrations.service_types import Repository
|
||||
|
||||
|
||||
async def store_repositories_in_db(repos: list[Repository], user_id: str) -> None:
|
||||
"""
|
||||
Store repositories in DB and create user-repository mappings
|
||||
|
||||
Args:
|
||||
repos: List of Repository objects to store
|
||||
user_id: User ID associated with these repositories
|
||||
"""
|
||||
|
||||
# Convert Repository objects to StoredRepository objects
|
||||
# Convert Repository objects to UserRepositoryMap objects
|
||||
stored_repos = []
|
||||
user_repos = []
|
||||
for repo in repos:
|
||||
repo_id = f'{repo.git_provider.value}##{str(repo.id)}'
|
||||
stored_repo = StoredRepository(
|
||||
repo_name=repo.full_name,
|
||||
repo_id=repo_id,
|
||||
is_public=repo.is_public,
|
||||
# Optional fields set to None by default
|
||||
has_microagent=None,
|
||||
has_setup_script=None,
|
||||
)
|
||||
stored_repos.append(stored_repo)
|
||||
user_repo_map = UserRepositoryMap(user_id=user_id, repo_id=repo_id, admin=None)
|
||||
|
||||
user_repos.append(user_repo_map)
|
||||
|
||||
# Get config instance
|
||||
config = OpenHandsConfig()
|
||||
|
||||
try:
|
||||
# Store repositories in the repos table
|
||||
repo_store = RepositoryStore.get_instance(config)
|
||||
repo_store.store_projects(stored_repos)
|
||||
|
||||
# Store user-repository mappings in the user-repos table
|
||||
user_repo_store = UserRepositoryMapStore.get_instance(config)
|
||||
user_repo_store.store_user_repo_mappings(user_repos)
|
||||
|
||||
logger.info(f'Saved repos for user {user_id}')
|
||||
except Exception:
|
||||
logger.warning('Failed to save repos', exc_info=True)
|
||||
@@ -1,24 +1,19 @@
|
||||
from uuid import UUID
|
||||
|
||||
import stripe
|
||||
from server.auth.token_manager import TokenManager
|
||||
from server.constants import STRIPE_API_KEY
|
||||
from server.logger import logger
|
||||
from sqlalchemy.orm import Session
|
||||
from storage.database import session_maker
|
||||
from storage.org import Org
|
||||
from storage.org_store import OrgStore
|
||||
from storage.stripe_customer import StripeCustomer
|
||||
|
||||
from openhands.utils.async_utils import call_sync_from_async
|
||||
|
||||
stripe.api_key = STRIPE_API_KEY
|
||||
|
||||
|
||||
async def find_customer_id_by_org_id(org_id: UUID) -> str | None:
|
||||
async def find_customer_id_by_user_id(user_id: str) -> str | None:
|
||||
# First search our own DB...
|
||||
with session_maker() as session:
|
||||
stripe_customer = (
|
||||
session.query(StripeCustomer)
|
||||
.filter(StripeCustomer.org_id == org_id)
|
||||
.filter(StripeCustomer.keycloak_user_id == user_id)
|
||||
.first()
|
||||
)
|
||||
if stripe_customer:
|
||||
@@ -26,76 +21,46 @@ async def find_customer_id_by_org_id(org_id: UUID) -> str | None:
|
||||
|
||||
# If that fails, fallback to stripe
|
||||
search_result = await stripe.Customer.search_async(
|
||||
query=f"metadata['org_id']:'{str(org_id)}'",
|
||||
query=f"metadata['user_id']:'{user_id}'",
|
||||
)
|
||||
data = search_result.data
|
||||
if not data:
|
||||
logger.info(
|
||||
'no_customer_for_org_id',
|
||||
extra={'org_id': str(org_id)},
|
||||
)
|
||||
logger.info('no_customer_for_user_id', extra={'user_id': user_id})
|
||||
return None
|
||||
return data[0].id # type: ignore [attr-defined]
|
||||
|
||||
|
||||
async def find_customer_id_by_user_id(user_id: str) -> str | None:
|
||||
# First search our own DB...
|
||||
org = await call_sync_from_async(
|
||||
OrgStore.get_current_org_from_keycloak_user_id, user_id
|
||||
)
|
||||
if not org:
|
||||
logger.warning(f'Org not found for user {user_id}')
|
||||
return None
|
||||
customer_id = await find_customer_id_by_org_id(org.id)
|
||||
return customer_id
|
||||
|
||||
|
||||
async def find_or_create_customer_by_user_id(user_id: str) -> dict | None:
|
||||
# Get the current org for the user
|
||||
org = await call_sync_from_async(
|
||||
OrgStore.get_current_org_from_keycloak_user_id, user_id
|
||||
)
|
||||
if not org:
|
||||
logger.warning(f'Org not found for user {user_id}')
|
||||
return None
|
||||
|
||||
customer_id = await find_customer_id_by_org_id(org.id)
|
||||
async def find_or_create_customer(user_id: str) -> str:
|
||||
customer_id = await find_customer_id_by_user_id(user_id)
|
||||
if customer_id:
|
||||
return {'customer_id': customer_id, 'org_id': str(org.id)}
|
||||
logger.info(
|
||||
'creating_customer',
|
||||
extra={'user_id': user_id, 'org_id': str(org.id)},
|
||||
)
|
||||
return customer_id
|
||||
logger.info('creating_customer', extra={'user_id': user_id})
|
||||
|
||||
# Get the user info from keycloak
|
||||
token_manager = TokenManager()
|
||||
user_info = await token_manager.get_user_info_from_user_id(user_id) or {}
|
||||
|
||||
# Create the customer in stripe
|
||||
customer = await stripe.Customer.create_async(
|
||||
email=org.contact_email,
|
||||
metadata={'org_id': str(org.id)},
|
||||
email=str(user_info.get('email', '')),
|
||||
metadata={'user_id': user_id},
|
||||
)
|
||||
|
||||
# Save the stripe customer in the local db
|
||||
with session_maker() as session:
|
||||
session.add(
|
||||
StripeCustomer(
|
||||
keycloak_user_id=user_id,
|
||||
org_id=org.id,
|
||||
stripe_customer_id=customer.id,
|
||||
)
|
||||
StripeCustomer(keycloak_user_id=user_id, stripe_customer_id=customer.id)
|
||||
)
|
||||
session.commit()
|
||||
|
||||
logger.info(
|
||||
'created_customer',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'org_id': str(org.id),
|
||||
'stripe_customer_id': customer.id,
|
||||
},
|
||||
extra={'user_id': user_id, 'stripe_customer_id': customer.id},
|
||||
)
|
||||
return {'customer_id': customer.id, 'org_id': str(org.id)}
|
||||
return customer.id
|
||||
|
||||
|
||||
async def has_payment_method_by_user_id(user_id: str) -> bool:
|
||||
async def has_payment_method(user_id: str) -> bool:
|
||||
customer_id = await find_customer_id_by_user_id(user_id)
|
||||
if customer_id is None:
|
||||
return False
|
||||
@@ -106,28 +71,3 @@ async def has_payment_method_by_user_id(user_id: str) -> bool:
|
||||
f'has_payment_method:{user_id}:{customer_id}:{bool(payment_methods.data)}'
|
||||
)
|
||||
return bool(payment_methods.data)
|
||||
|
||||
|
||||
async def migrate_customer(session: Session, user_id: str, org: Org):
|
||||
stripe_customer = (
|
||||
session.query(StripeCustomer)
|
||||
.filter(StripeCustomer.keycloak_user_id == user_id)
|
||||
.first()
|
||||
)
|
||||
if stripe_customer is None:
|
||||
return
|
||||
stripe_customer.org_id = org.id
|
||||
customer = await stripe.Customer.modify_async(
|
||||
id=stripe_customer.stripe_customer_id,
|
||||
email=org.contact_email,
|
||||
metadata={'user_id': '', 'org_id': str(org.id)},
|
||||
)
|
||||
|
||||
logger.info(
|
||||
'migrated_customer',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'org_id': str(org.id),
|
||||
'stripe_customer_id': customer.id,
|
||||
},
|
||||
)
|
||||
|
||||
@@ -6,9 +6,15 @@ import re
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from jinja2 import Environment, FileSystemLoader
|
||||
from server.config import get_config
|
||||
from server.constants import WEB_HOST
|
||||
from storage.org_store import OrgStore
|
||||
from storage.database import session_maker
|
||||
from storage.repository_store import RepositoryStore
|
||||
from storage.stored_repository import StoredRepository
|
||||
from storage.user_repo_map import UserRepositoryMap
|
||||
from storage.user_repo_map_store import UserRepositoryMapStore
|
||||
|
||||
from openhands.core.config.openhands_config import OpenHandsConfig
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.core.schema.agent import AgentState
|
||||
from openhands.events import Event, EventSource
|
||||
@@ -119,17 +125,26 @@ async def get_user_v1_enabled_setting(user_id: str | None) -> bool:
|
||||
Returns:
|
||||
True if V1 conversations are enabled for this user, False otherwise
|
||||
"""
|
||||
|
||||
# If no user ID is provided, we can't check user settings
|
||||
if not user_id:
|
||||
return False
|
||||
|
||||
org = await call_sync_from_async(
|
||||
OrgStore.get_current_org_from_keycloak_user_id, user_id
|
||||
from storage.saas_settings_store import SaasSettingsStore
|
||||
|
||||
config = get_config()
|
||||
settings_store = SaasSettingsStore(
|
||||
user_id=user_id, session_maker=session_maker, config=config
|
||||
)
|
||||
|
||||
if not org or org.v1_enabled is None:
|
||||
settings = await call_sync_from_async(
|
||||
settings_store.get_user_settings_by_keycloak_id, user_id
|
||||
)
|
||||
|
||||
if not settings or settings.v1_enabled is None:
|
||||
return False
|
||||
|
||||
return org.v1_enabled
|
||||
return settings.v1_enabled
|
||||
|
||||
|
||||
def has_exact_mention(text: str, mention: str) -> bool:
|
||||
@@ -394,46 +409,102 @@ def append_conversation_footer(message: str, conversation_id: str) -> str:
|
||||
return message + footer
|
||||
|
||||
|
||||
async def store_repositories_in_db(repos: list[Repository], user_id: str) -> None:
|
||||
"""
|
||||
Store repositories in DB and create user-repository mappings
|
||||
|
||||
Args:
|
||||
repos: List of Repository objects to store
|
||||
user_id: User ID associated with these repositories
|
||||
"""
|
||||
|
||||
# Convert Repository objects to StoredRepository objects
|
||||
# Convert Repository objects to UserRepositoryMap objects
|
||||
stored_repos = []
|
||||
user_repos = []
|
||||
for repo in repos:
|
||||
repo_id = f'{repo.git_provider.value}##{str(repo.id)}'
|
||||
stored_repo = StoredRepository(
|
||||
repo_name=repo.full_name,
|
||||
repo_id=repo_id,
|
||||
is_public=repo.is_public,
|
||||
# Optional fields set to None by default
|
||||
has_microagent=None,
|
||||
has_setup_script=None,
|
||||
)
|
||||
stored_repos.append(stored_repo)
|
||||
user_repo_map = UserRepositoryMap(user_id=user_id, repo_id=repo_id, admin=None)
|
||||
|
||||
user_repos.append(user_repo_map)
|
||||
|
||||
# Get config instance
|
||||
config = OpenHandsConfig()
|
||||
|
||||
try:
|
||||
# Store repositories in the repos table
|
||||
repo_store = RepositoryStore.get_instance(config)
|
||||
repo_store.store_projects(stored_repos)
|
||||
|
||||
# Store user-repository mappings in the user-repos table
|
||||
user_repo_store = UserRepositoryMapStore.get_instance(config)
|
||||
user_repo_store.store_user_repo_mappings(user_repos)
|
||||
|
||||
logger.info(f'Saved repos for user {user_id}')
|
||||
except Exception:
|
||||
logger.warning('Failed to save repos', exc_info=True)
|
||||
|
||||
|
||||
def infer_repo_from_message(user_msg: str) -> list[str]:
|
||||
"""
|
||||
Extract all repository names in the format 'owner/repo' from various Git provider URLs
|
||||
and direct mentions in text. Supports GitHub, GitLab, and BitBucket.
|
||||
Args:
|
||||
user_msg: Input message that may contain repository references
|
||||
Returns:
|
||||
List of repository names in 'owner/repo' format, empty list if none found
|
||||
"""
|
||||
# Normalize the message by removing extra whitespace and newlines
|
||||
normalized_msg = re.sub(r'\s+', ' ', user_msg.strip())
|
||||
|
||||
git_url_pattern = (
|
||||
r'https?://(?:github\.com|gitlab\.com|bitbucket\.org)/'
|
||||
r'([a-zA-Z0-9_.-]+)/([a-zA-Z0-9_.-]+?)(?:\.git)?'
|
||||
r'(?:[/?#].*?)?(?=\s|$|[^\w.-])'
|
||||
)
|
||||
# Pattern to match Git URLs from GitHub, GitLab, and BitBucket
|
||||
# Captures: protocol, domain, owner, repo (with optional .git extension)
|
||||
git_url_pattern = r'https?://(?:github\.com|gitlab\.com|bitbucket\.org)/([a-zA-Z0-9_.-]+)/([a-zA-Z0-9_.-]+?)(?:\.git)?(?:[/?#].*?)?(?=\s|$|[^\w.-])'
|
||||
|
||||
# UPDATED: allow {{ owner/repo }} in addition to existing boundaries
|
||||
# Pattern to match direct owner/repo mentions (e.g., "OpenHands/OpenHands")
|
||||
# Must be surrounded by word boundaries or specific characters to avoid false positives
|
||||
direct_pattern = (
|
||||
r'(?:^|\s|{{|[\[\(\'":`])' # left boundary
|
||||
r'([a-zA-Z0-9_.-]+)/([a-zA-Z0-9_.-]+)'
|
||||
r'(?=\s|$|}}|[\]\)\'",.:`])' # right boundary
|
||||
r'(?:^|\s|[\[\(\'"])([a-zA-Z0-9_.-]+)/([a-zA-Z0-9_.-]+)(?=\s|$|[\]\)\'",.])'
|
||||
)
|
||||
|
||||
matches: list[str] = []
|
||||
matches = []
|
||||
|
||||
# Git URLs first (highest priority)
|
||||
for owner, repo in re.findall(git_url_pattern, normalized_msg):
|
||||
# First, find all Git URLs (highest priority)
|
||||
git_matches = re.findall(git_url_pattern, normalized_msg)
|
||||
for owner, repo in git_matches:
|
||||
# Remove .git extension if present
|
||||
repo = re.sub(r'\.git$', '', repo)
|
||||
matches.append(f'{owner}/{repo}')
|
||||
|
||||
# Direct mentions
|
||||
for owner, repo in re.findall(direct_pattern, normalized_msg):
|
||||
# Second, find all direct owner/repo mentions
|
||||
direct_matches = re.findall(direct_pattern, normalized_msg)
|
||||
for owner, repo in direct_matches:
|
||||
full_match = f'{owner}/{repo}'
|
||||
|
||||
# Skip if it looks like a version number, date, or file path
|
||||
if (
|
||||
re.match(r'^\d+\.\d+/\d+\.\d+$', full_match)
|
||||
or re.match(r'^\d{1,2}/\d{1,2}$', full_match)
|
||||
or re.match(r'^[A-Z]/[A-Z]$', full_match)
|
||||
or repo.endswith(('.txt', '.md', '.py', '.js'))
|
||||
or ('.' in repo and len(repo.split('.')) > 2)
|
||||
):
|
||||
re.match(r'^\d+\.\d+/\d+\.\d+$', full_match) # version numbers
|
||||
or re.match(r'^\d{1,2}/\d{1,2}$', full_match) # dates
|
||||
or re.match(r'^[A-Z]/[A-Z]$', full_match) # single letters
|
||||
or repo.endswith('.txt')
|
||||
or repo.endswith('.md') # file extensions
|
||||
or repo.endswith('.py')
|
||||
or repo.endswith('.js')
|
||||
or '.' in repo
|
||||
and len(repo.split('.')) > 2
|
||||
): # complex file paths
|
||||
continue
|
||||
|
||||
# Avoid duplicates from Git URLs already found
|
||||
if full_match not in matches:
|
||||
matches.append(full_match)
|
||||
|
||||
|
||||
@@ -20,8 +20,6 @@ down_revision = '059'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
# TODO: decide whether to modify this for orgs or users
|
||||
|
||||
|
||||
def upgrade():
|
||||
"""
|
||||
@@ -30,10 +28,8 @@ def upgrade():
|
||||
|
||||
This replaces the functionality of the removed admin maintenance endpoint.
|
||||
"""
|
||||
|
||||
# Hardcoded value to prevent migration failures when constant is removed from codebase
|
||||
# This migration has already run in production, so we use the value that was current at the time
|
||||
CURRENT_USER_SETTINGS_VERSION = 4
|
||||
# Import here to avoid circular imports
|
||||
from server.constants import CURRENT_USER_SETTINGS_VERSION
|
||||
|
||||
# Create a connection and bind it to a session
|
||||
connection = op.get_bind()
|
||||
|
||||
@@ -1,272 +0,0 @@
|
||||
"""create org tables from pgerd schema
|
||||
|
||||
Revision ID: 089
|
||||
Revises: 088
|
||||
Create Date: 2025-01-07 00:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '089'
|
||||
down_revision: Union[str, None] = '088'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Remove current settings table
|
||||
op.execute('DROP TABLE IF EXISTS settings')
|
||||
|
||||
# Add already_migrated column to user_settings table
|
||||
op.add_column(
|
||||
'user_settings',
|
||||
sa.Column(
|
||||
'already_migrated',
|
||||
sa.Boolean,
|
||||
nullable=True,
|
||||
server_default=sa.text('false'),
|
||||
),
|
||||
)
|
||||
|
||||
# Create role table
|
||||
op.create_table(
|
||||
'role',
|
||||
sa.Column('id', sa.Integer, sa.Identity(), primary_key=True),
|
||||
sa.Column('name', sa.String, nullable=False),
|
||||
sa.Column('rank', sa.Integer, nullable=False),
|
||||
sa.UniqueConstraint('name', name='role_name_unique'),
|
||||
)
|
||||
|
||||
# 1. Create default roles
|
||||
op.execute(
|
||||
sa.text("""
|
||||
INSERT INTO role (name, rank) VALUES ('owner', 10), ('admin', 20), ('user', 1000)
|
||||
ON CONFLICT (name) DO NOTHING;
|
||||
""")
|
||||
)
|
||||
|
||||
# Create org table with settings fields
|
||||
op.create_table(
|
||||
'org',
|
||||
sa.Column(
|
||||
'id',
|
||||
postgresql.UUID(as_uuid=True),
|
||||
primary_key=True,
|
||||
),
|
||||
sa.Column('name', sa.String, nullable=False),
|
||||
sa.Column('contact_name', sa.String, nullable=True),
|
||||
sa.Column('contact_email', sa.String, nullable=True),
|
||||
sa.Column('conversation_expiration', sa.Integer, nullable=True),
|
||||
# Settings fields moved to org table
|
||||
sa.Column('agent', sa.String, nullable=True),
|
||||
sa.Column('default_max_iterations', sa.Integer, nullable=True),
|
||||
sa.Column('security_analyzer', sa.String, nullable=True),
|
||||
sa.Column(
|
||||
'confirmation_mode',
|
||||
sa.Boolean,
|
||||
nullable=True,
|
||||
server_default=sa.text('false'),
|
||||
),
|
||||
sa.Column('default_llm_model', sa.String, nullable=True),
|
||||
sa.Column('default_llm_base_url', sa.String, nullable=True),
|
||||
sa.Column('remote_runtime_resource_factor', sa.Integer, nullable=True),
|
||||
sa.Column(
|
||||
'enable_default_condenser',
|
||||
sa.Boolean,
|
||||
nullable=False,
|
||||
server_default=sa.text('true'),
|
||||
),
|
||||
sa.Column('billing_margin', sa.Float, nullable=True),
|
||||
sa.Column(
|
||||
'enable_proactive_conversation_starters',
|
||||
sa.Boolean,
|
||||
nullable=False,
|
||||
server_default=sa.text('true'),
|
||||
),
|
||||
sa.Column('sandbox_base_container_image', sa.String, nullable=True),
|
||||
sa.Column('sandbox_runtime_container_image', sa.String, nullable=True),
|
||||
sa.Column(
|
||||
'org_version', sa.Integer, nullable=False, server_default=sa.text('0')
|
||||
),
|
||||
sa.Column('mcp_config', sa.JSON, nullable=True),
|
||||
sa.Column('_search_api_key', sa.String, nullable=True),
|
||||
sa.Column('_sandbox_api_key', sa.String, nullable=True),
|
||||
sa.Column('max_budget_per_task', sa.Float, nullable=True),
|
||||
sa.Column(
|
||||
'enable_solvability_analysis',
|
||||
sa.Boolean,
|
||||
nullable=True,
|
||||
server_default=sa.text('false'),
|
||||
),
|
||||
sa.Column('v1_enabled', sa.Boolean, nullable=True),
|
||||
sa.Column('condenser_max_size', sa.Integer, nullable=True),
|
||||
sa.UniqueConstraint('name', name='org_name_unique'),
|
||||
)
|
||||
|
||||
# Create user table with user-specific settings fields
|
||||
op.create_table(
|
||||
'user',
|
||||
sa.Column(
|
||||
'id',
|
||||
postgresql.UUID(as_uuid=True),
|
||||
primary_key=True,
|
||||
),
|
||||
sa.Column('current_org_id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column('role_id', sa.Integer, nullable=True),
|
||||
sa.Column('accepted_tos', sa.DateTime, nullable=True),
|
||||
sa.Column(
|
||||
'enable_sound_notifications',
|
||||
sa.Boolean,
|
||||
nullable=True,
|
||||
server_default=sa.text('false'),
|
||||
),
|
||||
sa.Column('language', sa.String, nullable=True),
|
||||
sa.Column('user_consents_to_analytics', sa.Boolean, nullable=True),
|
||||
sa.Column('email', sa.String, nullable=True),
|
||||
sa.Column('email_verified', sa.Boolean, nullable=True),
|
||||
sa.ForeignKeyConstraint(
|
||||
['current_org_id'], ['org.id'], name='current_org_fkey'
|
||||
),
|
||||
sa.ForeignKeyConstraint(['role_id'], ['role.id'], name='user_role_fkey'),
|
||||
)
|
||||
|
||||
# Create org_member table (junction table for many-to-many relationship)
|
||||
op.create_table(
|
||||
'org_member',
|
||||
sa.Column('org_id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column('user_id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column('role_id', sa.Integer, nullable=False),
|
||||
sa.Column('_llm_api_key', sa.String, nullable=False),
|
||||
sa.Column('max_iterations', sa.Integer, nullable=True),
|
||||
sa.Column('llm_model', sa.String, nullable=True),
|
||||
sa.Column('_llm_api_key_for_byor', sa.String, nullable=True),
|
||||
sa.Column('llm_base_url', sa.String, nullable=True),
|
||||
sa.Column('status', sa.String, nullable=True),
|
||||
sa.ForeignKeyConstraint(['org_id'], ['org.id'], name='om_org_fkey'),
|
||||
sa.ForeignKeyConstraint(['user_id'], ['user.id'], name='om_user_fkey'),
|
||||
sa.ForeignKeyConstraint(['role_id'], ['role.id'], name='om_role_fkey'),
|
||||
sa.PrimaryKeyConstraint('org_id', 'user_id'),
|
||||
)
|
||||
|
||||
# Add org_id column to existing tables
|
||||
# billing_sessions
|
||||
op.add_column(
|
||||
'billing_sessions',
|
||||
sa.Column('org_id', postgresql.UUID(as_uuid=True), nullable=True),
|
||||
)
|
||||
op.create_foreign_key(
|
||||
'billing_sessions_org_fkey', 'billing_sessions', 'org', ['org_id'], ['id']
|
||||
)
|
||||
|
||||
# Create conversation_metadata_saas table
|
||||
op.create_table(
|
||||
'conversation_metadata_saas',
|
||||
sa.Column('conversation_id', sa.String(), nullable=False),
|
||||
sa.Column('user_id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column('org_id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
['user_id'], ['user.id'], name='conversation_metadata_saas_user_fkey'
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
['org_id'], ['org.id'], name='conversation_metadata_saas_org_fkey'
|
||||
),
|
||||
sa.PrimaryKeyConstraint('conversation_id'),
|
||||
)
|
||||
|
||||
# custom_secrets
|
||||
op.add_column(
|
||||
'custom_secrets',
|
||||
sa.Column('org_id', postgresql.UUID(as_uuid=True), nullable=True),
|
||||
)
|
||||
op.create_foreign_key(
|
||||
'custom_secrets_org_fkey', 'custom_secrets', 'org', ['org_id'], ['id']
|
||||
)
|
||||
|
||||
# api_keys
|
||||
op.add_column(
|
||||
'api_keys', sa.Column('org_id', postgresql.UUID(as_uuid=True), nullable=True)
|
||||
)
|
||||
op.create_foreign_key('api_keys_org_fkey', 'api_keys', 'org', ['org_id'], ['id'])
|
||||
|
||||
# slack_conversation
|
||||
op.add_column(
|
||||
'slack_conversation',
|
||||
sa.Column('org_id', postgresql.UUID(as_uuid=True), nullable=True),
|
||||
)
|
||||
op.create_foreign_key(
|
||||
'slack_conversation_org_fkey', 'slack_conversation', 'org', ['org_id'], ['id']
|
||||
)
|
||||
|
||||
# slack_users
|
||||
op.add_column(
|
||||
'slack_users', sa.Column('org_id', postgresql.UUID(as_uuid=True), nullable=True)
|
||||
)
|
||||
op.create_foreign_key(
|
||||
'slack_users_org_fkey', 'slack_users', 'org', ['org_id'], ['id']
|
||||
)
|
||||
|
||||
# stripe_customers
|
||||
op.alter_column(
|
||||
'stripe_customers',
|
||||
'keycloak_user_id',
|
||||
existing_type=sa.String(),
|
||||
nullable=True,
|
||||
)
|
||||
op.add_column(
|
||||
'stripe_customers',
|
||||
sa.Column('org_id', postgresql.UUID(as_uuid=True), nullable=True),
|
||||
)
|
||||
op.create_foreign_key(
|
||||
'stripe_customers_org_fkey', 'stripe_customers', 'org', ['org_id'], ['id']
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop already_migrated column from user_settings table
|
||||
op.drop_column('user_settings', 'already_migrated')
|
||||
|
||||
# Drop foreign keys and columns added to existing tables
|
||||
op.drop_constraint(
|
||||
'stripe_customers_org_fkey', 'stripe_customers', type_='foreignkey'
|
||||
)
|
||||
op.drop_column('stripe_customers', 'org_id')
|
||||
op.alter_column(
|
||||
'stripe_customers',
|
||||
'keycloak_user_id',
|
||||
existing_type=sa.String(),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
op.drop_constraint('slack_users_org_fkey', 'slack_users', type_='foreignkey')
|
||||
op.drop_column('slack_users', 'org_id')
|
||||
|
||||
op.drop_constraint(
|
||||
'slack_conversation_org_fkey', 'slack_conversation', type_='foreignkey'
|
||||
)
|
||||
op.drop_column('slack_conversation', 'org_id')
|
||||
|
||||
op.drop_constraint('api_keys_org_fkey', 'api_keys', type_='foreignkey')
|
||||
op.drop_column('api_keys', 'org_id')
|
||||
|
||||
op.drop_constraint('custom_secrets_org_fkey', 'custom_secrets', type_='foreignkey')
|
||||
op.drop_column('custom_secrets', 'org_id')
|
||||
|
||||
# Drop conversation_metadata_saas table
|
||||
op.drop_table('conversation_metadata_saas')
|
||||
|
||||
op.drop_constraint(
|
||||
'billing_sessions_org_fkey', 'billing_sessions', type_='foreignkey'
|
||||
)
|
||||
op.drop_column('billing_sessions', 'org_id')
|
||||
|
||||
# Drop tables in reverse order due to foreign key constraints
|
||||
op.drop_table('org_member')
|
||||
op.drop_table('user')
|
||||
op.drop_table('org')
|
||||
op.drop_table('role')
|
||||
@@ -1,28 +0,0 @@
|
||||
"""Add git_user_name and git_user_email columns to user table.
|
||||
|
||||
Revision ID: 090
|
||||
Revises: 089
|
||||
Create Date: 2025-01-22
|
||||
"""
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
revision = '090'
|
||||
down_revision = '089'
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
'user',
|
||||
sa.Column('git_user_name', sa.String, nullable=True),
|
||||
)
|
||||
op.add_column(
|
||||
'user',
|
||||
sa.Column('git_user_email', sa.String, nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column('user', 'git_user_email')
|
||||
op.drop_column('user', 'git_user_name')
|
||||
11562
enterprise/poetry.lock
generated
11562
enterprise/poetry.lock
generated
File diff suppressed because one or more lines are too long
@@ -4,10 +4,6 @@ from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
# Ensure SAAS configuration is used
|
||||
if not os.getenv('OPENHANDS_CONFIG_CLS'):
|
||||
os.environ['OPENHANDS_CONFIG_CLS'] = 'server.config.SaaSServerConfig'
|
||||
|
||||
import socketio # noqa: E402
|
||||
from fastapi import Request, status # noqa: E402
|
||||
from fastapi.middleware.cors import CORSMiddleware # noqa: E402
|
||||
@@ -38,7 +34,6 @@ from server.routes.integration.linear import linear_integration_router # noqa:
|
||||
from server.routes.integration.slack import slack_router # noqa: E402
|
||||
from server.routes.mcp_patch import patch_mcp_server # noqa: E402
|
||||
from server.routes.oauth_device import oauth_device_router # noqa: E402
|
||||
from server.routes.orgs import org_router # noqa: E402
|
||||
from server.routes.readiness import readiness_router # noqa: E402
|
||||
from server.routes.user import saas_user_router # noqa: E402
|
||||
from server.sharing.shared_conversation_router import ( # noqa: E402
|
||||
@@ -91,7 +86,6 @@ if GITLAB_APP_CLIENT_ID:
|
||||
base_app.include_router(gitlab_integration_router)
|
||||
|
||||
base_app.include_router(api_keys_router) # Add routes for API key management
|
||||
base_app.include_router(org_router) # Add routes for organization management
|
||||
add_github_proxy_routes(base_app)
|
||||
add_debugging_routes(
|
||||
base_app
|
||||
|
||||
@@ -77,15 +77,6 @@ class SaasUserAuth(UserAuth):
|
||||
self.access_token = SecretStr(tokens['access_token'])
|
||||
self.refresh_token = SecretStr(tokens['refresh_token'])
|
||||
self.refreshed = True
|
||||
if not self.email or not self.email_verified or not self.user_id:
|
||||
# We don't need to verify the signature here because we just refreshed
|
||||
# this token from the IDP via token_manager.refresh()
|
||||
access_token_payload = jwt.decode(
|
||||
tokens['access_token'], options={'verify_signature': False}
|
||||
)
|
||||
self.user_id = access_token_payload['sub']
|
||||
self.email = access_token_payload['email']
|
||||
self.email_verified = access_token_payload['email_verified']
|
||||
|
||||
def _is_token_expired(self, token: SecretStr):
|
||||
logger.debug('saas_user_auth_is_token_expired')
|
||||
@@ -112,6 +103,7 @@ class SaasUserAuth(UserAuth):
|
||||
return settings
|
||||
settings_store = await self.get_user_settings_store()
|
||||
settings = await settings_store.load()
|
||||
# If load() returned None, should settings be created?
|
||||
if settings:
|
||||
settings.email = self.email
|
||||
settings.email_verified = self.email_verified
|
||||
@@ -282,13 +274,11 @@ async def saas_user_auth_from_bearer(request: Request) -> SaasUserAuth | None:
|
||||
if not user_id:
|
||||
return None
|
||||
offline_token = await token_manager.load_offline_token(user_id)
|
||||
saas_user_auth = SaasUserAuth(
|
||||
return SaasUserAuth(
|
||||
user_id=user_id,
|
||||
refresh_token=SecretStr(offline_token),
|
||||
auth_type=AuthType.BEARER,
|
||||
)
|
||||
await saas_user_auth.refresh()
|
||||
return saas_user_auth
|
||||
except Exception as exc:
|
||||
raise BearerTokenError from exc
|
||||
|
||||
|
||||
@@ -16,7 +16,6 @@ from keycloak.exceptions import (
|
||||
KeycloakError,
|
||||
KeycloakPostError,
|
||||
)
|
||||
from server.auth.auth_error import ExpiredError
|
||||
from server.auth.constants import (
|
||||
BITBUCKET_APP_CLIENT_ID,
|
||||
BITBUCKET_APP_CLIENT_SECRET,
|
||||
@@ -427,8 +426,6 @@ class TokenManager:
|
||||
access_token = data.get('access_token')
|
||||
refresh_token = data.get('refresh_token')
|
||||
if not access_token or not refresh_token:
|
||||
if data.get('error') == 'bad_refresh_token':
|
||||
raise ExpiredError()
|
||||
raise ValueError(
|
||||
'Failed to refresh token: missing access_token or refresh_token in response.'
|
||||
)
|
||||
|
||||
@@ -8,7 +8,7 @@ import socketio
|
||||
from server.logger import logger
|
||||
from server.utils.conversation_callback_utils import invoke_conversation_callbacks
|
||||
from storage.database import session_maker
|
||||
from storage.stored_conversation_metadata_saas import StoredConversationMetadataSaas
|
||||
from storage.stored_conversation_metadata import StoredConversationMetadata
|
||||
|
||||
from openhands.core.config import LLMConfig
|
||||
from openhands.core.config.openhands_config import OpenHandsConfig
|
||||
@@ -524,18 +524,16 @@ class ClusteredConversationManager(StandaloneConversationManager):
|
||||
)
|
||||
# Look up the user_id from the database
|
||||
with session_maker() as session:
|
||||
conversation_metadata_saas = (
|
||||
session.query(StoredConversationMetadataSaas)
|
||||
conversation_metadata = (
|
||||
session.query(StoredConversationMetadata)
|
||||
.filter(
|
||||
StoredConversationMetadataSaas.conversation_id
|
||||
StoredConversationMetadata.conversation_id
|
||||
== conversation_id
|
||||
)
|
||||
.first()
|
||||
)
|
||||
user_id = (
|
||||
str(conversation_metadata_saas.user_id)
|
||||
if conversation_metadata_saas
|
||||
else None
|
||||
conversation_metadata.user_id if conversation_metadata else None
|
||||
)
|
||||
# Handle the stopped conversation asynchronously
|
||||
asyncio.create_task(
|
||||
|
||||
@@ -19,8 +19,8 @@ IS_LOCAL_ENV = bool(HOST == 'localhost')
|
||||
DEFAULT_BILLING_MARGIN = float(os.environ.get('DEFAULT_BILLING_MARGIN', '1.0'))
|
||||
|
||||
# Map of user settings versions to their corresponding default LLM models
|
||||
# This ensures that PERSONAL_WORKSPACE_VERSION_TO_MODEL and LITELLM_DEFAULT_MODEL stay in sync
|
||||
PERSONAL_WORKSPACE_VERSION_TO_MODEL = {
|
||||
# This ensures that CURRENT_USER_SETTINGS_VERSION and LITELLM_DEFAULT_MODEL stay in sync
|
||||
USER_SETTINGS_VERSION_TO_MODEL = {
|
||||
1: 'claude-3-5-sonnet-20241022',
|
||||
2: 'claude-3-7-sonnet-20250219',
|
||||
3: 'claude-sonnet-4-20250514',
|
||||
@@ -31,8 +31,7 @@ PERSONAL_WORKSPACE_VERSION_TO_MODEL = {
|
||||
LITELLM_DEFAULT_MODEL = os.getenv('LITELLM_DEFAULT_MODEL')
|
||||
|
||||
# Current user settings version - this should be the latest key in USER_SETTINGS_VERSION_TO_MODEL
|
||||
ORG_SETTINGS_VERSION = max(PERSONAL_WORKSPACE_VERSION_TO_MODEL.keys())
|
||||
PERSONAL_WORKSPACE_VERSION = max(PERSONAL_WORKSPACE_VERSION_TO_MODEL.keys())
|
||||
CURRENT_USER_SETTINGS_VERSION = max(USER_SETTINGS_VERSION_TO_MODEL.keys())
|
||||
|
||||
LITE_LLM_API_URL = os.environ.get(
|
||||
'LITE_LLM_API_URL', 'https://llm-proxy.app.all-hands.dev'
|
||||
@@ -56,6 +55,7 @@ SUBSCRIPTION_PRICE_DATA = {
|
||||
|
||||
DEFAULT_INITIAL_BUDGET = float(os.environ.get('DEFAULT_INITIAL_BUDGET', '10'))
|
||||
STRIPE_API_KEY = os.environ.get('STRIPE_API_KEY', None)
|
||||
STRIPE_WEBHOOK_SECRET = os.environ.get('STRIPE_WEBHOOK_SECRET', None)
|
||||
REQUIRE_PAYMENT = os.environ.get('REQUIRE_PAYMENT', '0') in ('1', 'true')
|
||||
|
||||
SLACK_CLIENT_ID = os.environ.get('SLACK_CLIENT_ID', None)
|
||||
@@ -91,5 +91,5 @@ def get_default_litellm_model():
|
||||
"""Construct proxy for litellm model based on user settings if not set explicitly."""
|
||||
if LITELLM_DEFAULT_MODEL:
|
||||
return LITELLM_DEFAULT_MODEL
|
||||
model = PERSONAL_WORKSPACE_VERSION_TO_MODEL[PERSONAL_WORKSPACE_VERSION]
|
||||
model = USER_SETTINGS_VERSION_TO_MODEL[CURRENT_USER_SETTINGS_VERSION]
|
||||
return build_litellm_proxy_model_path(model)
|
||||
|
||||
@@ -14,6 +14,7 @@ from storage.conversation_callback import (
|
||||
ConversationCallback,
|
||||
ConversationCallbackProcessor,
|
||||
)
|
||||
from storage.database import session_maker
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.core.schema.agent import AgentState
|
||||
@@ -107,10 +108,13 @@ class GithubCallbackProcessor(ConversationCallbackProcessor):
|
||||
f'[GitHub] Sent summary instruction to conversation {conversation_id} {summary_event}'
|
||||
)
|
||||
|
||||
# Update the processor state - the outer session will commit this
|
||||
# Update the processor state
|
||||
self.send_summary_instruction = False
|
||||
callback.set_processor(self)
|
||||
callback.updated_at = datetime.now()
|
||||
with session_maker() as session:
|
||||
session.merge(callback)
|
||||
session.commit()
|
||||
return
|
||||
|
||||
# Extract the summary from the event store
|
||||
@@ -126,15 +130,14 @@ class GithubCallbackProcessor(ConversationCallbackProcessor):
|
||||
|
||||
logger.info(f'[GitHub] Summary sent for conversation {conversation_id}')
|
||||
|
||||
# Mark callback as completed status - the outer session will commit this
|
||||
# Mark callback as completed status
|
||||
callback.status = CallbackStatus.COMPLETED
|
||||
callback.updated_at = datetime.now()
|
||||
with session_maker() as session:
|
||||
session.merge(callback)
|
||||
session.commit()
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f'[GitHub] Error processing conversation callback: {str(e)}'
|
||||
)
|
||||
# Mark callback as error to prevent infinite re-invocation
|
||||
# The outer session will commit this
|
||||
callback.status = CallbackStatus.ERROR
|
||||
callback.updated_at = datetime.now()
|
||||
|
||||
@@ -1,68 +0,0 @@
|
||||
"""
|
||||
Email domain validation utilities for enterprise endpoints.
|
||||
"""
|
||||
|
||||
from fastapi import Depends, HTTPException, Request, status
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.server.user_auth import get_user_auth, get_user_id
|
||||
|
||||
|
||||
async def get_admin_user_id(
|
||||
request: Request, user_id: str | None = Depends(get_user_id)
|
||||
) -> str:
|
||||
"""
|
||||
Dependency that validates user has @openhands.dev email domain.
|
||||
|
||||
This dependency can be used in place of get_user_id for endpoints that
|
||||
should only be accessible to admin users. Currently, this is implemented
|
||||
by checking for @openhands.dev email domain.
|
||||
|
||||
TODO: In the future, this should be replaced with an explicit is_admin flag
|
||||
in user/org settings instead of relying on email domain validation.
|
||||
|
||||
Args:
|
||||
request: FastAPI request object
|
||||
user_id: User ID from get_user_id dependency
|
||||
|
||||
Returns:
|
||||
str: User ID if email domain is valid
|
||||
|
||||
Raises:
|
||||
HTTPException: 403 if email domain is not @openhands.dev
|
||||
HTTPException: 401 if user is not authenticated
|
||||
|
||||
Example:
|
||||
@router.post('/endpoint')
|
||||
async def create_resource(
|
||||
user_id: str = Depends(get_admin_user_id),
|
||||
):
|
||||
# Only admin users can access this endpoint
|
||||
pass
|
||||
"""
|
||||
if not user_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail='User not authenticated',
|
||||
)
|
||||
|
||||
user_auth = await get_user_auth(request)
|
||||
user_email = await user_auth.get_user_email()
|
||||
|
||||
if not user_email:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail='User email not available',
|
||||
)
|
||||
|
||||
if not user_email.endswith('@openhands.dev'):
|
||||
logger.warning(
|
||||
'Access denied - invalid email domain',
|
||||
extra={'user_id': user_id, 'email_domain': user_email.split('@')[-1]},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail='Access restricted to @openhands.dev users',
|
||||
)
|
||||
|
||||
return user_id
|
||||
@@ -44,13 +44,11 @@ class MyProcessor(MaintenanceTaskProcessor):
|
||||
### UserVersionUpgradeProcessor
|
||||
|
||||
Located in `user_version_upgrade_processor.py`, this processor:
|
||||
|
||||
- Handles up to 100 user IDs per task
|
||||
- Upgrades users with `user_version < ORG_SETTINGS_VERSION`
|
||||
- Upgrades users with `user_version < CURRENT_USER_SETTINGS_VERSION`
|
||||
- Uses `SaasSettingsStore.create_default_settings()` for upgrades
|
||||
|
||||
**Usage:**
|
||||
|
||||
```python
|
||||
from server.maintenance_task_processor.user_version_upgrade_processor import UserVersionUpgradeProcessor
|
||||
|
||||
@@ -146,26 +144,22 @@ task = create_maintenance_task(
|
||||
## Best Practices
|
||||
|
||||
### Processor Design
|
||||
|
||||
- Keep tasks short-running (under 1 minute)
|
||||
- Handle errors gracefully and return meaningful error information
|
||||
- Use batch processing for large datasets
|
||||
- Include progress information in the return dict
|
||||
|
||||
### Error Handling
|
||||
|
||||
- Always wrap your processor logic in try-catch blocks
|
||||
- Return structured error information
|
||||
- Log important events for debugging
|
||||
|
||||
### Performance
|
||||
|
||||
- Limit batch sizes to avoid long-running tasks
|
||||
- Use database sessions efficiently
|
||||
- Consider memory usage for large datasets
|
||||
|
||||
### Testing
|
||||
|
||||
- Create unit tests for your processors
|
||||
- Test error conditions
|
||||
- Verify the processor serialization/deserialization works correctly
|
||||
@@ -173,7 +167,6 @@ task = create_maintenance_task(
|
||||
## Database Patterns
|
||||
|
||||
The maintenance task system follows the repository's established patterns:
|
||||
|
||||
- Uses `session_maker()` for database operations
|
||||
- Wraps sync database operations in `call_sync_from_async` for async routes
|
||||
- Follows proper SQLAlchemy query patterns
|
||||
@@ -181,18 +174,15 @@ The maintenance task system follows the repository's established patterns:
|
||||
## Integration with Existing Systems
|
||||
|
||||
### User Management
|
||||
|
||||
- Integrates with the existing `UserSettings` model
|
||||
- Uses the current user versioning system (`ORG_SETTINGS_VERSION`)
|
||||
- Uses the current user versioning system (`CURRENT_USER_SETTINGS_VERSION`)
|
||||
- Maintains compatibility with existing user management workflows
|
||||
|
||||
### Authentication
|
||||
|
||||
- Admin endpoints use the existing SaaS authentication system
|
||||
- Requires users to have `admin = True` in their UserSettings
|
||||
|
||||
### Monitoring
|
||||
|
||||
- Tasks are logged with structured information
|
||||
- Status updates are tracked in the database
|
||||
- Error information is preserved for debugging
|
||||
@@ -216,7 +206,6 @@ The maintenance task system follows the repository's established patterns:
|
||||
## Future Enhancements
|
||||
|
||||
Potential improvements that could be added:
|
||||
|
||||
- Task dependencies and scheduling
|
||||
- Retry mechanisms for failed tasks
|
||||
- Real-time progress updates
|
||||
|
||||
@@ -0,0 +1,155 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List
|
||||
|
||||
from server.constants import CURRENT_USER_SETTINGS_VERSION
|
||||
from server.logger import logger
|
||||
from storage.database import session_maker
|
||||
from storage.maintenance_task import MaintenanceTask, MaintenanceTaskProcessor
|
||||
from storage.saas_settings_store import SaasSettingsStore
|
||||
from storage.user_settings import UserSettings
|
||||
|
||||
from openhands.core.config import load_openhands_config
|
||||
|
||||
|
||||
class UserVersionUpgradeProcessor(MaintenanceTaskProcessor):
|
||||
"""
|
||||
Processor for upgrading user settings to the current version.
|
||||
|
||||
This processor takes a list of user IDs and upgrades any users
|
||||
whose user_version is less than CURRENT_USER_SETTINGS_VERSION.
|
||||
"""
|
||||
|
||||
user_ids: List[str]
|
||||
|
||||
async def __call__(self, task: MaintenanceTask) -> dict:
|
||||
"""
|
||||
Process user version upgrades for the specified user IDs.
|
||||
|
||||
Args:
|
||||
task: The maintenance task being processed
|
||||
|
||||
Returns:
|
||||
dict: Results containing successful and failed user IDs
|
||||
"""
|
||||
logger.info(
|
||||
'user_version_upgrade_processor:start',
|
||||
extra={
|
||||
'task_id': task.id,
|
||||
'user_count': len(self.user_ids),
|
||||
'current_version': CURRENT_USER_SETTINGS_VERSION,
|
||||
},
|
||||
)
|
||||
|
||||
if len(self.user_ids) > 100:
|
||||
raise ValueError(
|
||||
f'Too many user IDs: {len(self.user_ids)}. Maximum is 100.'
|
||||
)
|
||||
|
||||
config = load_openhands_config()
|
||||
|
||||
# Track results
|
||||
successful_upgrades = []
|
||||
failed_upgrades = []
|
||||
users_already_current = []
|
||||
|
||||
# Find users that need upgrading
|
||||
with session_maker() as session:
|
||||
users_to_upgrade = (
|
||||
session.query(UserSettings)
|
||||
.filter(
|
||||
UserSettings.keycloak_user_id.in_(self.user_ids),
|
||||
UserSettings.user_version < CURRENT_USER_SETTINGS_VERSION,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
# Track users that are already current
|
||||
users_needing_upgrade_ids = {u.keycloak_user_id for u in users_to_upgrade}
|
||||
users_already_current = [
|
||||
uid for uid in self.user_ids if uid not in users_needing_upgrade_ids
|
||||
]
|
||||
|
||||
logger.info(
|
||||
'user_version_upgrade_processor:found_users',
|
||||
extra={
|
||||
'task_id': task.id,
|
||||
'users_to_upgrade': len(users_to_upgrade),
|
||||
'users_already_current': len(users_already_current),
|
||||
'total_requested': len(self.user_ids),
|
||||
},
|
||||
)
|
||||
|
||||
# Process each user that needs upgrading
|
||||
for user_settings in users_to_upgrade:
|
||||
user_id = user_settings.keycloak_user_id
|
||||
old_version = user_settings.user_version
|
||||
|
||||
try:
|
||||
logger.info(
|
||||
'user_version_upgrade_processor:upgrading_user',
|
||||
extra={
|
||||
'task_id': task.id,
|
||||
'user_id': user_id,
|
||||
'old_version': old_version,
|
||||
'new_version': CURRENT_USER_SETTINGS_VERSION,
|
||||
},
|
||||
)
|
||||
|
||||
# Create SaasSettingsStore instance and upgrade
|
||||
settings_store = await SaasSettingsStore.get_instance(config, user_id)
|
||||
await settings_store.create_default_settings(user_settings)
|
||||
|
||||
successful_upgrades.append(
|
||||
{
|
||||
'user_id': user_id,
|
||||
'old_version': old_version,
|
||||
'new_version': CURRENT_USER_SETTINGS_VERSION,
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(
|
||||
'user_version_upgrade_processor:user_upgraded',
|
||||
extra={
|
||||
'task_id': task.id,
|
||||
'user_id': user_id,
|
||||
'old_version': old_version,
|
||||
'new_version': CURRENT_USER_SETTINGS_VERSION,
|
||||
},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
failed_upgrades.append(
|
||||
{'user_id': user_id, 'old_version': old_version, 'error': str(e)}
|
||||
)
|
||||
|
||||
logger.error(
|
||||
'user_version_upgrade_processor:user_upgrade_failed',
|
||||
extra={
|
||||
'task_id': task.id,
|
||||
'user_id': user_id,
|
||||
'old_version': old_version,
|
||||
'error': str(e),
|
||||
},
|
||||
)
|
||||
|
||||
# Create result summary
|
||||
result = {
|
||||
'total_users': len(self.user_ids),
|
||||
'users_already_current': users_already_current,
|
||||
'successful_upgrades': successful_upgrades,
|
||||
'failed_upgrades': failed_upgrades,
|
||||
'summary': (
|
||||
f'Processed {len(self.user_ids)} users: '
|
||||
f'{len(successful_upgrades)} upgraded, '
|
||||
f'{len(users_already_current)} already current, '
|
||||
f'{len(failed_upgrades)} errors'
|
||||
),
|
||||
}
|
||||
|
||||
logger.info(
|
||||
'user_version_upgrade_processor:completed',
|
||||
extra={'task_id': task.id, 'result': result},
|
||||
)
|
||||
|
||||
return result
|
||||
@@ -1,5 +1,7 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from storage.api_key_store import ApiKeyStore
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from openhands.core.config.openhands_config import OpenHandsConfig
|
||||
|
||||
@@ -34,7 +36,6 @@ class SaaSOpenHandsMCPConfig(OpenHandsMCPConfig):
|
||||
Returns:
|
||||
A tuple containing the default SSE server configuration and a list of MCP stdio server configurations
|
||||
"""
|
||||
from storage.api_key_store import ApiKeyStore
|
||||
|
||||
api_key_store = ApiKeyStore.get_instance()
|
||||
if user_id:
|
||||
|
||||
@@ -144,7 +144,7 @@ class SetAuthCookieMiddleware:
|
||||
# "if accepted_tos is not None" as there should not be any users with
|
||||
# accepted_tos equal to "None"
|
||||
if accepted_tos is False and request.url.path != '/api/accept_tos':
|
||||
logger.warning('User has not accepted the terms of service')
|
||||
logger.error('User has not accepted the terms of service')
|
||||
raise TosNotAcceptedError
|
||||
|
||||
def _should_attach(self, request: Request) -> bool:
|
||||
@@ -162,8 +162,6 @@ class SetAuthCookieMiddleware:
|
||||
'/api/email/resend',
|
||||
'/oauth/device/authorize',
|
||||
'/oauth/device/token',
|
||||
'/api/v1/web-client/config',
|
||||
'/api/v1/webhooks/secrets',
|
||||
)
|
||||
if path in ignore_paths:
|
||||
return False
|
||||
|
||||
@@ -1,95 +1,113 @@
|
||||
from datetime import UTC, datetime
|
||||
|
||||
import httpx
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from pydantic import BaseModel, field_validator
|
||||
from server.config import get_config
|
||||
from server.constants import (
|
||||
BYOR_KEY_VERIFICATION_TIMEOUT,
|
||||
LITE_LLM_API_KEY,
|
||||
LITE_LLM_API_URL,
|
||||
)
|
||||
from storage.api_key_store import ApiKeyStore
|
||||
from storage.lite_llm_manager import LiteLlmManager
|
||||
from storage.org_member import OrgMember
|
||||
from storage.org_member_store import OrgMemberStore
|
||||
from storage.user_store import UserStore
|
||||
from storage.database import session_maker
|
||||
from storage.saas_settings_store import SaasSettingsStore
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.server.user_auth import get_user_id
|
||||
from openhands.utils.async_utils import call_sync_from_async
|
||||
from openhands.utils.http_session import httpx_verify_option
|
||||
|
||||
|
||||
# Helper functions for BYOR API key management
|
||||
async def get_byor_key_from_db(user_id: str) -> str | None:
|
||||
"""Get the BYOR key from the database for a user."""
|
||||
config = get_config()
|
||||
settings_store = SaasSettingsStore(
|
||||
user_id=user_id, session_maker=session_maker, config=config
|
||||
)
|
||||
|
||||
def _get_byor_key():
|
||||
user = UserStore.get_user_by_id(user_id)
|
||||
if not user:
|
||||
return None
|
||||
|
||||
current_org_id = user.current_org_id
|
||||
current_org_member: OrgMember = None
|
||||
for org_member in user.org_members:
|
||||
if org_member.org_id == current_org_id:
|
||||
current_org_member = org_member
|
||||
break
|
||||
if not current_org_member:
|
||||
return None
|
||||
if current_org_member.llm_api_key_for_byor:
|
||||
return current_org_member.llm_api_key_for_byor.get_secret_value()
|
||||
return None
|
||||
|
||||
return await call_sync_from_async(_get_byor_key)
|
||||
user_db_settings = await call_sync_from_async(
|
||||
settings_store.get_user_settings_by_keycloak_id, user_id
|
||||
)
|
||||
if user_db_settings and user_db_settings.llm_api_key_for_byor:
|
||||
return user_db_settings.llm_api_key_for_byor
|
||||
return None
|
||||
|
||||
|
||||
async def store_byor_key_in_db(user_id: str, key: str) -> None:
|
||||
"""Store the BYOR key in the database for a user."""
|
||||
config = get_config()
|
||||
settings_store = SaasSettingsStore(
|
||||
user_id=user_id, session_maker=session_maker, config=config
|
||||
)
|
||||
|
||||
def _update_user_settings():
|
||||
user = UserStore.get_user_by_id(user_id)
|
||||
if not user:
|
||||
return None
|
||||
|
||||
current_org_id = user.current_org_id
|
||||
current_org_member: OrgMember = None
|
||||
for org_member in user.org_members:
|
||||
if org_member.org_id == current_org_id:
|
||||
current_org_member = org_member
|
||||
break
|
||||
if not current_org_member:
|
||||
return None
|
||||
current_org_member.llm_api_key_for_byor = key
|
||||
OrgMemberStore.update_org_member(current_org_member)
|
||||
with session_maker() as session:
|
||||
user_db_settings = settings_store.get_user_settings_by_keycloak_id(
|
||||
user_id, session
|
||||
)
|
||||
if user_db_settings:
|
||||
user_db_settings.llm_api_key_for_byor = key
|
||||
session.commit()
|
||||
logger.info(
|
||||
'Successfully stored BYOR key in user settings',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
'User settings not found when trying to store BYOR key',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
|
||||
await call_sync_from_async(_update_user_settings)
|
||||
|
||||
|
||||
async def generate_byor_key(user_id: str) -> str | None:
|
||||
"""Generate a new BYOR key for a user."""
|
||||
if not (LITE_LLM_API_KEY and LITE_LLM_API_URL):
|
||||
logger.warning(
|
||||
'LiteLLM API configuration not found', extra={'user_id': user_id}
|
||||
)
|
||||
return None
|
||||
|
||||
try:
|
||||
user = await UserStore.get_user_by_id_async(user_id)
|
||||
if not user:
|
||||
return None
|
||||
current_org_id = str(user.current_org_id)
|
||||
key = await LiteLlmManager.generate_key(
|
||||
user_id,
|
||||
current_org_id,
|
||||
f'BYOR Key - user {user_id}, org {current_org_id}',
|
||||
{'type': 'byor'},
|
||||
)
|
||||
|
||||
if key:
|
||||
logger.info(
|
||||
'Successfully generated new BYOR key',
|
||||
extra={
|
||||
async with httpx.AsyncClient(
|
||||
verify=httpx_verify_option(),
|
||||
headers={
|
||||
'x-goog-api-key': LITE_LLM_API_KEY,
|
||||
},
|
||||
) as client:
|
||||
response = await client.post(
|
||||
f'{LITE_LLM_API_URL}/key/generate',
|
||||
json={
|
||||
'user_id': user_id,
|
||||
'key_length': len(key) if key else 0,
|
||||
'key_prefix': key[:10] + '...' if key and len(key) > 10 else key,
|
||||
'metadata': {'type': 'byor'},
|
||||
'key_alias': f'BYOR Key - user {user_id}',
|
||||
},
|
||||
)
|
||||
return key
|
||||
else:
|
||||
logger.error(
|
||||
'Failed to generate BYOR LLM API key - no key in response',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
return None
|
||||
response.raise_for_status()
|
||||
response_json = response.json()
|
||||
key = response_json.get('key')
|
||||
|
||||
if key:
|
||||
logger.info(
|
||||
'Successfully generated new BYOR key',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'key_length': len(key) if key else 0,
|
||||
'key_prefix': key[:10] + '...'
|
||||
if key and len(key) > 10
|
||||
else key,
|
||||
},
|
||||
)
|
||||
return key
|
||||
else:
|
||||
logger.error(
|
||||
'Failed to generate BYOR LLM API key - no key in response',
|
||||
extra={'user_id': user_id, 'response_json': response_json},
|
||||
)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
'Error generating BYOR key',
|
||||
@@ -98,25 +116,96 @@ async def generate_byor_key(user_id: str) -> str | None:
|
||||
return None
|
||||
|
||||
|
||||
async def delete_byor_key_from_litellm(user_id: str, byor_key: str) -> bool:
|
||||
"""Delete the BYOR key from LiteLLM using the key directly.
|
||||
async def verify_byor_key_in_litellm(byor_key: str, user_id: str) -> bool:
|
||||
"""Verify that a BYOR key is valid in LiteLLM by making a lightweight API call.
|
||||
|
||||
Also attempts to delete by key alias if the key is not found,
|
||||
to clean up orphaned aliases that could block key regeneration.
|
||||
Args:
|
||||
byor_key: The BYOR key to verify
|
||||
user_id: The user ID for logging purposes
|
||||
|
||||
Returns:
|
||||
True if the key is verified as valid, False if verification fails or key is invalid.
|
||||
Returns False on network errors/timeouts to ensure we don't return potentially invalid keys.
|
||||
"""
|
||||
try:
|
||||
# Get user to construct the key alias
|
||||
user = await UserStore.get_user_by_id_async(user_id)
|
||||
key_alias = None
|
||||
if user and user.current_org_id:
|
||||
key_alias = f'BYOR Key - user {user_id}, org {user.current_org_id}'
|
||||
if not (LITE_LLM_API_URL and byor_key):
|
||||
return False
|
||||
|
||||
await LiteLlmManager.delete_key(byor_key, key_alias=key_alias)
|
||||
logger.info(
|
||||
'Successfully deleted BYOR key from LiteLLM',
|
||||
extra={'user_id': user_id},
|
||||
try:
|
||||
async with httpx.AsyncClient(
|
||||
verify=httpx_verify_option(),
|
||||
timeout=BYOR_KEY_VERIFICATION_TIMEOUT,
|
||||
) as client:
|
||||
# Make a lightweight request to verify the key
|
||||
# Using /v1/models endpoint as it's lightweight and requires authentication
|
||||
response = await client.get(
|
||||
f'{LITE_LLM_API_URL}/v1/models',
|
||||
headers={
|
||||
'Authorization': f'Bearer {byor_key}',
|
||||
},
|
||||
)
|
||||
|
||||
# Only 200 status code indicates valid key
|
||||
if response.status_code == 200:
|
||||
logger.debug(
|
||||
'BYOR key verification successful',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
return True
|
||||
|
||||
# All other status codes (401, 403, 500, etc.) are treated as invalid
|
||||
# This includes authentication errors and server errors
|
||||
logger.warning(
|
||||
'BYOR key verification failed - treating as invalid',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'status_code': response.status_code,
|
||||
'key_prefix': byor_key[:10] + '...'
|
||||
if len(byor_key) > 10
|
||||
else byor_key,
|
||||
},
|
||||
)
|
||||
return False
|
||||
|
||||
except (httpx.TimeoutException, Exception) as e:
|
||||
# Any exception (timeout, network error, etc.) means we can't verify
|
||||
# Return False to trigger regeneration rather than returning potentially invalid key
|
||||
logger.warning(
|
||||
'BYOR key verification error - treating as invalid to ensure key validity',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'error': str(e),
|
||||
'error_type': type(e).__name__,
|
||||
},
|
||||
)
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
async def delete_byor_key_from_litellm(user_id: str, byor_key: str) -> bool:
|
||||
"""Delete the BYOR key from LiteLLM using the key directly."""
|
||||
if not (LITE_LLM_API_KEY and LITE_LLM_API_URL):
|
||||
logger.warning(
|
||||
'LiteLLM API configuration not found', extra={'user_id': user_id}
|
||||
)
|
||||
return False
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(
|
||||
verify=httpx_verify_option(),
|
||||
headers={
|
||||
'x-goog-api-key': LITE_LLM_API_KEY,
|
||||
},
|
||||
) as client:
|
||||
# Delete the key directly using the key value
|
||||
delete_url = f'{LITE_LLM_API_URL}/key/delete'
|
||||
delete_payload = {'keys': [byor_key]}
|
||||
|
||||
delete_response = await client.post(delete_url, json=delete_payload)
|
||||
delete_response.raise_for_status()
|
||||
logger.info(
|
||||
'Successfully deleted BYOR key from LiteLLM',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
'Error deleting BYOR key from LiteLLM',
|
||||
@@ -268,7 +357,7 @@ async def get_llm_api_key_for_byor(user_id: str = Depends(get_user_id)):
|
||||
byor_key = await get_byor_key_from_db(user_id)
|
||||
if byor_key:
|
||||
# Validate that the key is actually registered in LiteLLM
|
||||
is_valid = await LiteLlmManager.verify_key(byor_key, user_id)
|
||||
is_valid = await verify_byor_key_in_litellm(byor_key, user_id)
|
||||
if is_valid:
|
||||
return {'key': byor_key}
|
||||
else:
|
||||
@@ -323,6 +412,15 @@ async def refresh_llm_api_key_for_byor(user_id: str = Depends(get_user_id)):
|
||||
logger.info('Starting BYOR LLM API key refresh', extra={'user_id': user_id})
|
||||
|
||||
try:
|
||||
if not (LITE_LLM_API_KEY and LITE_LLM_API_URL):
|
||||
logger.warning(
|
||||
'LiteLLM API configuration not found', extra={'user_id': user_id}
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='LiteLLM API configuration not found',
|
||||
)
|
||||
|
||||
# Get the existing BYOR key from the database
|
||||
existing_byor_key = await get_byor_key_from_db(user_id)
|
||||
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import base64
|
||||
import json
|
||||
import uuid
|
||||
import warnings
|
||||
from datetime import datetime, timezone
|
||||
from typing import Annotated, Literal, Optional
|
||||
@@ -23,12 +22,12 @@ from server.auth.gitlab_sync import schedule_gitlab_repo_sync
|
||||
from server.auth.recaptcha_service import recaptcha_service
|
||||
from server.auth.saas_user_auth import SaasUserAuth
|
||||
from server.auth.token_manager import TokenManager
|
||||
from server.config import sign_token
|
||||
from server.config import get_config, sign_token
|
||||
from server.constants import IS_FEATURE_ENV
|
||||
from server.routes.event_webhook import _get_session_api_key, _get_user_id
|
||||
from storage.database import session_maker
|
||||
from storage.user import User
|
||||
from storage.user_store import UserStore
|
||||
from storage.saas_settings_store import SaasSettingsStore
|
||||
from storage.user_settings import UserSettings
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.integrations.provider import ProviderHandler
|
||||
@@ -88,8 +87,7 @@ def get_cookie_domain(request: Request) -> str | None:
|
||||
# for now just use the full hostname except for staging stacks.
|
||||
return (
|
||||
None
|
||||
if not request.url.hostname
|
||||
or request.url.hostname.endswith('staging.all-hands.dev')
|
||||
if (request.url.hostname or '').endswith('staging.all-hand.dev')
|
||||
else request.url.hostname
|
||||
)
|
||||
|
||||
@@ -176,20 +174,6 @@ async def keycloak_callback(
|
||||
|
||||
email = user_info.get('email')
|
||||
user_id = user_info['sub']
|
||||
user = await UserStore.get_user_by_id_async(user_id)
|
||||
if not user:
|
||||
user = await UserStore.create_user(user_id, user_info)
|
||||
|
||||
if not user:
|
||||
logger.error(f'Failed to authenticate user {user_info["preferred_username"]}')
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
content={
|
||||
'error': f'Failed to authenticate user {user_info["preferred_username"]}'
|
||||
},
|
||||
)
|
||||
|
||||
logger.info(f'Logging in user {str(user.id)} in org {user.current_org_id}')
|
||||
|
||||
# reCAPTCHA verification with Account Defender
|
||||
if RECAPTCHA_SITE_KEY:
|
||||
@@ -376,7 +360,15 @@ async def keycloak_callback(
|
||||
f'&state={state}'
|
||||
)
|
||||
|
||||
has_accepted_tos = user.accepted_tos is not None
|
||||
config = get_config()
|
||||
settings_store = SaasSettingsStore(
|
||||
user_id=user_id, session_maker=session_maker, config=config
|
||||
)
|
||||
user_settings = settings_store.get_user_settings_by_keycloak_id(user_id)
|
||||
has_accepted_tos = (
|
||||
user_settings is not None and user_settings.accepted_tos is not None
|
||||
)
|
||||
|
||||
# If the user hasn't accepted the TOS, redirect to the TOS page
|
||||
if not has_accepted_tos:
|
||||
encoded_redirect_url = quote(redirect_url, safe='')
|
||||
@@ -494,20 +486,28 @@ async def accept_tos(request: Request):
|
||||
redirect_url = body.get('redirect_url', str(request.base_url))
|
||||
|
||||
# Update user settings with TOS acceptance
|
||||
accepted_tos: datetime = datetime.now(timezone.utc)
|
||||
with session_maker() as session:
|
||||
user = session.query(User).filter(User.id == uuid.UUID(user_id)).first()
|
||||
if not user:
|
||||
session.rollback()
|
||||
logger.error('User for {user_id} not found.')
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
content={'error': 'User does not exist'},
|
||||
user_settings = (
|
||||
session.query(UserSettings)
|
||||
.filter(UserSettings.keycloak_user_id == user_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if user_settings:
|
||||
user_settings.accepted_tos = datetime.now(timezone.utc)
|
||||
session.merge(user_settings)
|
||||
else:
|
||||
# Create user settings if they don't exist
|
||||
user_settings = UserSettings(
|
||||
keycloak_user_id=user_id,
|
||||
accepted_tos=datetime.now(timezone.utc),
|
||||
user_version=0, # This will trigger a migration to the latest version on next load
|
||||
)
|
||||
user.accepted_tos = accepted_tos
|
||||
session.add(user_settings)
|
||||
|
||||
session.commit()
|
||||
|
||||
logger.info(f'User {user_id} accepted TOS')
|
||||
logger.info(f'User {user_id} accepted TOS')
|
||||
|
||||
response = JSONResponse(
|
||||
status_code=status.HTTP_200_OK, content={'redirect_url': redirect_url}
|
||||
|
||||
@@ -4,42 +4,63 @@ from datetime import UTC, datetime
|
||||
from decimal import Decimal
|
||||
from enum import Enum
|
||||
|
||||
import httpx
|
||||
import stripe
|
||||
from dateutil.relativedelta import relativedelta # type: ignore
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from fastapi.responses import RedirectResponse
|
||||
from fastapi.responses import JSONResponse, RedirectResponse
|
||||
from integrations import stripe_service
|
||||
from pydantic import BaseModel
|
||||
from server.config import get_config
|
||||
from server.constants import (
|
||||
LITE_LLM_API_KEY,
|
||||
LITE_LLM_API_URL,
|
||||
STRIPE_API_KEY,
|
||||
STRIPE_WEBHOOK_SECRET,
|
||||
SUBSCRIPTION_PRICE_DATA,
|
||||
get_default_litellm_model,
|
||||
)
|
||||
from server.logger import logger
|
||||
from starlette.datastructures import URL
|
||||
from storage.billing_session import BillingSession
|
||||
from storage.database import session_maker
|
||||
from storage.lite_llm_manager import LiteLlmManager
|
||||
from storage.saas_settings_store import SaasSettingsStore
|
||||
from storage.subscription_access import SubscriptionAccess
|
||||
from storage.user_store import UserStore
|
||||
|
||||
from openhands.app_server.config import get_global_config
|
||||
from openhands.server.user_auth import get_user_id
|
||||
from openhands.utils.http_session import httpx_verify_option
|
||||
|
||||
stripe.api_key = STRIPE_API_KEY
|
||||
billing_router = APIRouter(prefix='/api/billing')
|
||||
|
||||
|
||||
async def validate_billing_enabled() -> None:
|
||||
# TODO: Add a new app_mode named "ON_PREM" to support self-hosted customers instead of doing this
|
||||
# and members should comment out the "validate_saas_environment" function if they are developing and testing locally.
|
||||
def is_all_hands_saas_environment(request: Request) -> bool:
|
||||
"""Check if the current domain is an All Hands SaaS environment.
|
||||
|
||||
Args:
|
||||
request: FastAPI Request object
|
||||
|
||||
Returns:
|
||||
True if the current domain contains "all-hands.dev" or "openhands.dev" postfix
|
||||
"""
|
||||
Validate that the billing feature flag is enabled
|
||||
hostname = request.url.hostname or ''
|
||||
return hostname.endswith('all-hands.dev') or hostname.endswith('openhands.dev')
|
||||
|
||||
|
||||
def validate_saas_environment(request: Request) -> None:
|
||||
"""Validate that the request is coming from an All Hands SaaS environment.
|
||||
|
||||
Args:
|
||||
request: FastAPI Request object
|
||||
|
||||
Raises:
|
||||
HTTPException: If the request is not from an All Hands SaaS environment
|
||||
"""
|
||||
config = get_global_config()
|
||||
web_client_config = await config.web_client.get_web_client_config()
|
||||
if not web_client_config.feature_flags.enable_billing:
|
||||
if not is_all_hands_saas_environment(request):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=(
|
||||
'Billing is disabled in this environment. '
|
||||
'Please set OH_WEB_CLIENT_FEATURE_FLAGS_ENABLE_BILLING to enable billing.'
|
||||
),
|
||||
detail='Checkout sessions are only available for All Hands SaaS environments',
|
||||
)
|
||||
|
||||
|
||||
@@ -85,20 +106,31 @@ def calculate_credits(user_info: LiteLlmUserInfo) -> float:
|
||||
return max(max_budget - spend, 0.0)
|
||||
|
||||
|
||||
# Endpoint to retrieve the current organization's credit balance
|
||||
# Endpoint to retrieve user's current credit balance
|
||||
@billing_router.get('/credits')
|
||||
async def get_credits(user_id: str = Depends(get_user_id)) -> GetCreditsResponse:
|
||||
if not stripe_service.STRIPE_API_KEY:
|
||||
return GetCreditsResponse()
|
||||
user = await UserStore.get_user_by_id_async(user_id)
|
||||
user_team_info = await LiteLlmManager.get_user_team_info(
|
||||
user_id, str(user.current_org_id)
|
||||
)
|
||||
# Update to use calculate_credits
|
||||
spend = user_team_info.get('spend', 0)
|
||||
max_budget = (user_team_info.get('litellm_budget_table') or {}).get('max_budget', 0)
|
||||
credits = max(max_budget - spend, 0)
|
||||
return GetCreditsResponse(credits=Decimal('{:.2f}'.format(credits)))
|
||||
try:
|
||||
async with httpx.AsyncClient(
|
||||
verify=httpx_verify_option(), timeout=15.0
|
||||
) as client:
|
||||
user_json = await _get_litellm_user(client, user_id)
|
||||
credits = calculate_credits(user_json['user_info'])
|
||||
return GetCreditsResponse(credits=Decimal('{:.2f}'.format(credits)))
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(
|
||||
f'litellm_get_user_failed: {type(e).__name__}: {e}',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'status_code': e.response.status_code,
|
||||
},
|
||||
exc_info=True,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='Failed to retrieve credit balance from billing service',
|
||||
)
|
||||
|
||||
|
||||
# Endpoint to retrieve user's current subscription access
|
||||
@@ -133,7 +165,79 @@ async def get_subscription_access(
|
||||
async def has_payment_method(user_id: str = Depends(get_user_id)) -> bool:
|
||||
if not user_id:
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED)
|
||||
return await stripe_service.has_payment_method_by_user_id(user_id)
|
||||
return await stripe_service.has_payment_method(user_id)
|
||||
|
||||
|
||||
# Endpoint to cancel user's subscription
|
||||
@billing_router.post('/cancel-subscription')
|
||||
async def cancel_subscription(user_id: str = Depends(get_user_id)) -> JSONResponse:
|
||||
"""Cancel user's active subscription at the end of the current billing period."""
|
||||
if not user_id:
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED)
|
||||
|
||||
with session_maker() as session:
|
||||
# Find the user's active subscription
|
||||
now = datetime.now(UTC)
|
||||
subscription_access = (
|
||||
session.query(SubscriptionAccess)
|
||||
.filter(SubscriptionAccess.status == 'ACTIVE')
|
||||
.filter(SubscriptionAccess.user_id == user_id)
|
||||
.filter(SubscriptionAccess.start_at <= now)
|
||||
.filter(SubscriptionAccess.end_at >= now)
|
||||
.filter(SubscriptionAccess.cancelled_at.is_(None)) # Not already cancelled
|
||||
.first()
|
||||
)
|
||||
|
||||
if not subscription_access:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail='No active subscription found',
|
||||
)
|
||||
|
||||
if not subscription_access.stripe_subscription_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail='Cannot cancel subscription: missing Stripe subscription ID',
|
||||
)
|
||||
|
||||
try:
|
||||
# Cancel the subscription in Stripe at period end
|
||||
await stripe.Subscription.modify_async(
|
||||
subscription_access.stripe_subscription_id, cancel_at_period_end=True
|
||||
)
|
||||
|
||||
# Update local database
|
||||
subscription_access.cancelled_at = datetime.now(UTC)
|
||||
session.merge(subscription_access)
|
||||
session.commit()
|
||||
|
||||
logger.info(
|
||||
'subscription_cancelled',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'stripe_subscription_id': subscription_access.stripe_subscription_id,
|
||||
'subscription_access_id': subscription_access.id,
|
||||
'end_at': subscription_access.end_at,
|
||||
},
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
{'status': 'success', 'message': 'Subscription cancelled successfully'}
|
||||
)
|
||||
|
||||
except stripe.StripeError as e:
|
||||
logger.error(
|
||||
'stripe_cancellation_failed',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'stripe_subscription_id': subscription_access.stripe_subscription_id,
|
||||
'error': str(e),
|
||||
},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f'Failed to cancel subscription: {str(e)}',
|
||||
)
|
||||
|
||||
|
||||
# Endpoint to create a new setup intent in stripe
|
||||
@@ -141,17 +245,17 @@ async def has_payment_method(user_id: str = Depends(get_user_id)) -> bool:
|
||||
async def create_customer_setup_session(
|
||||
request: Request, user_id: str = Depends(get_user_id)
|
||||
) -> CreateBillingSessionResponse:
|
||||
await validate_billing_enabled()
|
||||
customer_info = await stripe_service.find_or_create_customer_by_user_id(user_id)
|
||||
base_url = _get_base_url(request)
|
||||
validate_saas_environment(request)
|
||||
|
||||
customer_id = await stripe_service.find_or_create_customer(user_id)
|
||||
checkout_session = await stripe.checkout.Session.create_async(
|
||||
customer=customer_info['customer_id'],
|
||||
customer=customer_id,
|
||||
mode='setup',
|
||||
payment_method_types=['card'],
|
||||
success_url=f'{base_url}?free_credits=success',
|
||||
cancel_url=f'{base_url}',
|
||||
success_url=f'{request.base_url}?free_credits=success',
|
||||
cancel_url=f'{request.base_url}',
|
||||
)
|
||||
return CreateBillingSessionResponse(redirect_url=checkout_session.url)
|
||||
return CreateBillingSessionResponse(redirect_url=checkout_session.url) # type: ignore[arg-type]
|
||||
|
||||
|
||||
# Endpoint to create a new Stripe checkout session for credit purchase
|
||||
@@ -161,11 +265,11 @@ async def create_checkout_session(
|
||||
request: Request,
|
||||
user_id: str = Depends(get_user_id),
|
||||
) -> CreateBillingSessionResponse:
|
||||
await validate_billing_enabled()
|
||||
base_url = _get_base_url(request)
|
||||
customer_info = await stripe_service.find_or_create_customer_by_user_id(user_id)
|
||||
validate_saas_environment(request)
|
||||
|
||||
customer_id = await stripe_service.find_or_create_customer(user_id)
|
||||
checkout_session = await stripe.checkout.Session.create_async(
|
||||
customer=customer_info['customer_id'],
|
||||
customer=customer_id,
|
||||
line_items=[
|
||||
{
|
||||
'price_data': {
|
||||
@@ -178,22 +282,21 @@ async def create_checkout_session(
|
||||
'tax_behavior': 'exclusive',
|
||||
},
|
||||
'quantity': 1,
|
||||
},
|
||||
}
|
||||
],
|
||||
mode='payment',
|
||||
payment_method_types=['card'],
|
||||
saved_payment_method_options={
|
||||
'payment_method_save': 'enabled',
|
||||
},
|
||||
success_url=f'{base_url}api/billing/success?session_id={{CHECKOUT_SESSION_ID}}',
|
||||
cancel_url=f'{base_url}api/billing/cancel?session_id={{CHECKOUT_SESSION_ID}}',
|
||||
success_url=f'{request.base_url}api/billing/success?session_id={{CHECKOUT_SESSION_ID}}',
|
||||
cancel_url=f'{request.base_url}api/billing/cancel?session_id={{CHECKOUT_SESSION_ID}}',
|
||||
)
|
||||
logger.info(
|
||||
'created_stripe_checkout_session',
|
||||
extra={
|
||||
'stripe_customer_id': customer_info['customer_id'],
|
||||
'stripe_customer_id': customer_id,
|
||||
'user_id': user_id,
|
||||
'org_id': customer_info['org_id'],
|
||||
'amount': body.amount,
|
||||
'checkout_session_id': checkout_session.id,
|
||||
},
|
||||
@@ -202,14 +305,105 @@ async def create_checkout_session(
|
||||
billing_session = BillingSession(
|
||||
id=checkout_session.id,
|
||||
user_id=user_id,
|
||||
org_id=customer_info['org_id'],
|
||||
price=body.amount,
|
||||
price_code='NA',
|
||||
billing_session_type=BillingSessionType.DIRECT_PAYMENT.value,
|
||||
)
|
||||
session.add(billing_session)
|
||||
session.commit()
|
||||
|
||||
return CreateBillingSessionResponse(redirect_url=checkout_session.url)
|
||||
return CreateBillingSessionResponse(redirect_url=checkout_session.url) # type: ignore[arg-type]
|
||||
|
||||
|
||||
@billing_router.post('/subscription-checkout-session')
|
||||
async def create_subscription_checkout_session(
|
||||
request: Request,
|
||||
billing_session_type: BillingSessionType = BillingSessionType.MONTHLY_SUBSCRIPTION,
|
||||
user_id: str = Depends(get_user_id),
|
||||
) -> CreateBillingSessionResponse:
|
||||
validate_saas_environment(request)
|
||||
|
||||
# Prevent duplicate subscriptions for the same user
|
||||
with session_maker() as session:
|
||||
now = datetime.now(UTC)
|
||||
existing_active_subscription = (
|
||||
session.query(SubscriptionAccess)
|
||||
.filter(SubscriptionAccess.status == 'ACTIVE')
|
||||
.filter(SubscriptionAccess.user_id == user_id)
|
||||
.filter(SubscriptionAccess.start_at <= now)
|
||||
.filter(SubscriptionAccess.end_at >= now)
|
||||
.filter(SubscriptionAccess.cancelled_at.is_(None)) # Not cancelled
|
||||
.first()
|
||||
)
|
||||
|
||||
if existing_active_subscription:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail='Cannot create subscription: User already has an active subscription that has not been cancelled',
|
||||
)
|
||||
|
||||
customer_id = await stripe_service.find_or_create_customer(user_id)
|
||||
subscription_price_data = SUBSCRIPTION_PRICE_DATA[billing_session_type.value]
|
||||
checkout_session = await stripe.checkout.Session.create_async(
|
||||
customer=customer_id,
|
||||
line_items=[
|
||||
{
|
||||
'price_data': subscription_price_data,
|
||||
'quantity': 1,
|
||||
}
|
||||
],
|
||||
mode='subscription',
|
||||
payment_method_types=['card'],
|
||||
saved_payment_method_options={
|
||||
'payment_method_save': 'enabled',
|
||||
},
|
||||
success_url=f'{request.base_url}api/billing/success?session_id={{CHECKOUT_SESSION_ID}}',
|
||||
cancel_url=f'{request.base_url}api/billing/cancel?session_id={{CHECKOUT_SESSION_ID}}',
|
||||
subscription_data={
|
||||
'metadata': {
|
||||
'user_id': user_id,
|
||||
'billing_session_type': billing_session_type.value,
|
||||
}
|
||||
},
|
||||
)
|
||||
logger.info(
|
||||
'created_stripe_subscription_checkout_session',
|
||||
extra={
|
||||
'stripe_customer_id': customer_id,
|
||||
'user_id': user_id,
|
||||
'checkout_session_id': checkout_session.id,
|
||||
'billing_session_type': billing_session_type.value,
|
||||
},
|
||||
)
|
||||
with session_maker() as session:
|
||||
billing_session = BillingSession(
|
||||
id=checkout_session.id,
|
||||
user_id=user_id,
|
||||
price=subscription_price_data['unit_amount'],
|
||||
price_code='NA',
|
||||
billing_session_type=billing_session_type.value,
|
||||
)
|
||||
session.add(billing_session)
|
||||
session.commit()
|
||||
|
||||
return CreateBillingSessionResponse(
|
||||
redirect_url=typing.cast(str, checkout_session.url)
|
||||
)
|
||||
|
||||
|
||||
@billing_router.get('/create-subscription-checkout-session')
|
||||
async def create_subscription_checkout_session_via_get(
|
||||
request: Request,
|
||||
billing_session_type: BillingSessionType = BillingSessionType.MONTHLY_SUBSCRIPTION,
|
||||
user_id: str = Depends(get_user_id),
|
||||
) -> RedirectResponse:
|
||||
"""Create a subscription checkout session using a GET request (For easier copy / paste to URL bar)."""
|
||||
validate_saas_environment(request)
|
||||
|
||||
response = await create_subscription_checkout_session(
|
||||
request, billing_session_type, user_id
|
||||
)
|
||||
return RedirectResponse(response.redirect_url)
|
||||
|
||||
|
||||
# Callback endpoint for successful Stripe payments - updates user credits and billing session status
|
||||
@@ -231,6 +425,15 @@ async def success_callback(session_id: str, request: Request):
|
||||
)
|
||||
raise HTTPException(status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
# Any non direct payment (Subscription) is processed in the invoice_payment.paid by the webhook
|
||||
if (
|
||||
billing_session.billing_session_type
|
||||
!= BillingSessionType.DIRECT_PAYMENT.value
|
||||
):
|
||||
return RedirectResponse(
|
||||
f'{request.base_url}settings?checkout=success', status_code=302
|
||||
)
|
||||
|
||||
stripe_session = stripe.checkout.Session.retrieve(session_id)
|
||||
if stripe_session.status != 'complete':
|
||||
# Hopefully this never happens - we get a redirect from stripe where the payment is not yet complete
|
||||
@@ -244,40 +447,34 @@ async def success_callback(session_id: str, request: Request):
|
||||
)
|
||||
raise HTTPException(status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
user = await UserStore.get_user_by_id_async(billing_session.user_id)
|
||||
user_team_info = await LiteLlmManager.get_user_team_info(
|
||||
billing_session.user_id, str(user.current_org_id)
|
||||
)
|
||||
amount_subtotal = stripe_session.amount_subtotal or 0
|
||||
add_credits = amount_subtotal / 100
|
||||
max_budget = (user_team_info.get('litellm_budget_table') or {}).get(
|
||||
'max_budget', 0
|
||||
)
|
||||
new_max_budget = max_budget + add_credits
|
||||
async with httpx.AsyncClient(verify=httpx_verify_option()) as client:
|
||||
# Update max budget in litellm
|
||||
user_json = await _get_litellm_user(client, billing_session.user_id)
|
||||
amount_subtotal = stripe_session.amount_subtotal or 0
|
||||
add_credits = amount_subtotal / 100
|
||||
new_max_budget = (
|
||||
(user_json.get('user_info') or {}).get('max_budget') or 0
|
||||
) + add_credits
|
||||
await _upsert_litellm_user(client, billing_session.user_id, new_max_budget)
|
||||
|
||||
await LiteLlmManager.update_team_and_users_budget(
|
||||
str(user.current_org_id), new_max_budget
|
||||
)
|
||||
|
||||
# Store transaction status
|
||||
billing_session.status = 'completed'
|
||||
billing_session.price = add_credits
|
||||
billing_session.updated_at = datetime.now(UTC)
|
||||
session.merge(billing_session)
|
||||
logger.info(
|
||||
'stripe_checkout_success',
|
||||
extra={
|
||||
'amount_subtotal': stripe_session.amount_subtotal,
|
||||
'user_id': billing_session.user_id,
|
||||
'org_id': str(user.current_org_id),
|
||||
'checkout_session_id': billing_session.id,
|
||||
'stripe_customer_id': stripe_session.customer,
|
||||
},
|
||||
)
|
||||
session.commit()
|
||||
# Store transaction status
|
||||
billing_session.status = 'completed'
|
||||
billing_session.price = amount_subtotal
|
||||
billing_session.updated_at = datetime.now(UTC)
|
||||
session.merge(billing_session)
|
||||
logger.info(
|
||||
'stripe_checkout_success',
|
||||
extra={
|
||||
'amount_subtotal': stripe_session.amount_subtotal,
|
||||
'user_id': billing_session.user_id,
|
||||
'checkout_session_id': billing_session.id,
|
||||
'stripe_customer_id': stripe_session.customer,
|
||||
},
|
||||
)
|
||||
session.commit()
|
||||
|
||||
return RedirectResponse(
|
||||
f'{_get_base_url(request)}settings/billing?checkout=success', status_code=302
|
||||
f'{request.base_url}settings/billing?checkout=success', status_code=302
|
||||
)
|
||||
|
||||
|
||||
@@ -304,14 +501,206 @@ async def cancel_callback(session_id: str, request: Request):
|
||||
session.merge(billing_session)
|
||||
session.commit()
|
||||
|
||||
# Redirect credit purchases to billing screen, subscriptions to LLM settings
|
||||
if (
|
||||
billing_session.billing_session_type
|
||||
== BillingSessionType.DIRECT_PAYMENT.value
|
||||
):
|
||||
return RedirectResponse(
|
||||
f'{request.base_url}settings/billing?checkout=cancel',
|
||||
status_code=302,
|
||||
)
|
||||
else:
|
||||
return RedirectResponse(
|
||||
f'{request.base_url}settings?checkout=cancel', status_code=302
|
||||
)
|
||||
|
||||
# If no billing session found, default to LLM settings (subscription flow)
|
||||
return RedirectResponse(
|
||||
f'{_get_base_url(request)}settings/billing?checkout=cancel', status_code=302
|
||||
f'{request.base_url}settings?checkout=cancel', status_code=302
|
||||
)
|
||||
|
||||
|
||||
def _get_base_url(request: Request) -> URL:
|
||||
# Never send any part of the credit card process over a non secure connection
|
||||
base_url = request.base_url
|
||||
if base_url.hostname != 'localhost':
|
||||
base_url = base_url.replace(scheme='https')
|
||||
return base_url
|
||||
@billing_router.post('/stripe-webhook')
|
||||
async def stripe_webhook(request: Request) -> JSONResponse:
|
||||
"""Endpoint for stripe webhooks."""
|
||||
payload = await request.body()
|
||||
sig_header = request.headers.get('stripe-signature')
|
||||
|
||||
try:
|
||||
event = stripe.Webhook.construct_event(
|
||||
payload, sig_header, STRIPE_WEBHOOK_SECRET
|
||||
)
|
||||
except ValueError as e:
|
||||
# Invalid payload
|
||||
raise HTTPException(status_code=400, detail=f'Invalid payload: {e}')
|
||||
except stripe.SignatureVerificationError as e:
|
||||
# Invalid signature
|
||||
raise HTTPException(status_code=400, detail=f'Invalid signature: {e}')
|
||||
|
||||
# Handle the event
|
||||
logger.info('stripe_webhook_event', extra={'event': event})
|
||||
event_type = event['type']
|
||||
if event_type == 'invoice.paid':
|
||||
invoice = event['data']['object']
|
||||
amount_paid = invoice.amount_paid
|
||||
metadata = invoice.parent.subscription_details.metadata # type: ignore
|
||||
billing_session_type = metadata.billing_session_type
|
||||
assert (
|
||||
amount_paid == SUBSCRIPTION_PRICE_DATA[billing_session_type]['unit_amount']
|
||||
)
|
||||
user_id = metadata.user_id
|
||||
|
||||
start_at = datetime.now(UTC)
|
||||
if billing_session_type == BillingSessionType.MONTHLY_SUBSCRIPTION.value:
|
||||
end_at = start_at + relativedelta(months=1)
|
||||
else:
|
||||
raise ValueError(f'unknown_billing_session_type:{billing_session_type}')
|
||||
|
||||
with session_maker() as session:
|
||||
subscription_access = SubscriptionAccess(
|
||||
status='ACTIVE',
|
||||
user_id=user_id,
|
||||
start_at=start_at,
|
||||
end_at=end_at,
|
||||
amount_paid=amount_paid,
|
||||
stripe_invoice_payment_id=invoice.payment_intent,
|
||||
stripe_subscription_id=invoice.subscription, # Store Stripe subscription ID
|
||||
)
|
||||
session.add(subscription_access)
|
||||
session.commit()
|
||||
elif event_type == 'customer.subscription.updated':
|
||||
subscription = event['data']['object']
|
||||
subscription_id = subscription['id']
|
||||
|
||||
# Handle subscription cancellation
|
||||
if subscription.get('cancel_at_period_end') is True:
|
||||
with session_maker() as session:
|
||||
subscription_access = (
|
||||
session.query(SubscriptionAccess)
|
||||
.filter(
|
||||
SubscriptionAccess.stripe_subscription_id == subscription_id
|
||||
)
|
||||
.filter(SubscriptionAccess.status == 'ACTIVE')
|
||||
.first()
|
||||
)
|
||||
|
||||
if subscription_access and not subscription_access.cancelled_at:
|
||||
subscription_access.cancelled_at = datetime.now(UTC)
|
||||
session.merge(subscription_access)
|
||||
session.commit()
|
||||
|
||||
logger.info(
|
||||
'subscription_cancelled_via_webhook',
|
||||
extra={
|
||||
'stripe_subscription_id': subscription_id,
|
||||
'user_id': subscription_access.user_id,
|
||||
'subscription_access_id': subscription_access.id,
|
||||
},
|
||||
)
|
||||
elif event_type == 'customer.subscription.deleted':
|
||||
subscription = event['data']['object']
|
||||
subscription_id = subscription['id']
|
||||
|
||||
with session_maker() as session:
|
||||
subscription_access = (
|
||||
session.query(SubscriptionAccess)
|
||||
.filter(SubscriptionAccess.stripe_subscription_id == subscription_id)
|
||||
.filter(SubscriptionAccess.status == 'ACTIVE')
|
||||
.first()
|
||||
)
|
||||
|
||||
if subscription_access:
|
||||
subscription_access.status = 'DISABLED'
|
||||
subscription_access.updated_at = datetime.now(UTC)
|
||||
session.merge(subscription_access)
|
||||
session.commit()
|
||||
|
||||
# Reset user settings to free tier defaults
|
||||
reset_user_to_free_tier_settings(subscription_access.user_id)
|
||||
|
||||
logger.info(
|
||||
'subscription_expired_reset_to_free_tier',
|
||||
extra={
|
||||
'stripe_subscription_id': subscription_id,
|
||||
'user_id': subscription_access.user_id,
|
||||
'subscription_access_id': subscription_access.id,
|
||||
},
|
||||
)
|
||||
else:
|
||||
logger.info('stripe_webhook_unhandled_event_type', extra={'type': event_type})
|
||||
|
||||
return JSONResponse({'status': 'success'})
|
||||
|
||||
|
||||
def reset_user_to_free_tier_settings(user_id: str) -> None:
|
||||
"""Reset user settings to free tier defaults when subscription ends."""
|
||||
config = get_config()
|
||||
settings_store = SaasSettingsStore(
|
||||
user_id=user_id, session_maker=session_maker, config=config
|
||||
)
|
||||
|
||||
with session_maker() as session:
|
||||
user_settings = settings_store.get_user_settings_by_keycloak_id(
|
||||
user_id, session
|
||||
)
|
||||
|
||||
if user_settings:
|
||||
user_settings.llm_model = get_default_litellm_model()
|
||||
user_settings.llm_api_key = None
|
||||
user_settings.llm_api_key_for_byor = None
|
||||
user_settings.llm_base_url = LITE_LLM_API_URL
|
||||
user_settings.max_budget_per_task = None
|
||||
user_settings.confirmation_mode = False
|
||||
user_settings.enable_solvability_analysis = False
|
||||
user_settings.security_analyzer = 'llm'
|
||||
user_settings.agent = 'CodeActAgent'
|
||||
user_settings.language = 'en'
|
||||
user_settings.enable_default_condenser = True
|
||||
user_settings.enable_sound_notifications = False
|
||||
user_settings.enable_proactive_conversation_starters = True
|
||||
user_settings.user_consents_to_analytics = False
|
||||
|
||||
session.merge(user_settings)
|
||||
session.commit()
|
||||
|
||||
logger.info(
|
||||
'user_settings_reset_to_free_tier',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'reset_timestamp': datetime.now(UTC).isoformat(),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def _get_litellm_user(client: httpx.AsyncClient, user_id: str) -> dict:
|
||||
"""Get a user from litellm with the id matching that given.
|
||||
|
||||
If no such user exists, returns a dummy user in the format:
|
||||
`{'user_id': '<USER_ID>', 'user_info': {'spend': 0}, 'keys': [], 'teams': []}`
|
||||
"""
|
||||
response = await client.get(
|
||||
f'{LITE_LLM_API_URL}/user/info?user_id={user_id}',
|
||||
headers={
|
||||
'x-goog-api-key': LITE_LLM_API_KEY,
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
|
||||
async def _upsert_litellm_user(
|
||||
client: httpx.AsyncClient, user_id: str, max_budget: float
|
||||
):
|
||||
"""Insert / Update a user in litellm."""
|
||||
response = await client.post(
|
||||
f'{LITE_LLM_API_URL}/user/update',
|
||||
headers={
|
||||
'x-goog-api-key': LITE_LLM_API_KEY,
|
||||
},
|
||||
json={
|
||||
'user_id': user_id,
|
||||
'max_budget': max_budget,
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@@ -5,8 +5,8 @@ from threading import Thread
|
||||
|
||||
from fastapi import APIRouter, FastAPI
|
||||
from sqlalchemy import func, select
|
||||
from storage.database import a_session_maker, get_engine, session_maker
|
||||
from storage.user import User
|
||||
from storage.database import a_session_maker, engine, session_maker
|
||||
from storage.user_settings import UserSettings
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.utils.async_utils import wait_all
|
||||
@@ -47,7 +47,6 @@ def add_debugging_routes(api: FastAPI):
|
||||
- checked_out: Number of connections currently in use
|
||||
- overflow: Number of overflow connections created beyond pool_size
|
||||
"""
|
||||
engine = get_engine()
|
||||
return {
|
||||
'checked_in': engine.pool.checkedin(),
|
||||
'checked_out': engine.pool.checkedout(),
|
||||
@@ -128,9 +127,8 @@ def _db_check(delay: int):
|
||||
delay: Number of seconds to hold the database connection
|
||||
"""
|
||||
with session_maker() as session:
|
||||
num_users = session.query(User).count()
|
||||
num_users = session.query(UserSettings).count()
|
||||
time.sleep(delay)
|
||||
engine = get_engine()
|
||||
logger.info(
|
||||
'check',
|
||||
extra={
|
||||
@@ -157,7 +155,7 @@ async def _a_db_check(delay: int):
|
||||
delay: Number of seconds to hold the database connection
|
||||
"""
|
||||
async with a_session_maker() as a_session:
|
||||
stmt = select(func.count(User.id))
|
||||
stmt = select(func.count(UserSettings.id))
|
||||
num_users = await a_session.execute(stmt)
|
||||
await asyncio.sleep(delay)
|
||||
logger.info(f'a_num_users:{num_users.scalar_one()}')
|
||||
|
||||
@@ -21,7 +21,7 @@ from server.utils.conversation_callback_utils import (
|
||||
update_conversation_stats,
|
||||
)
|
||||
from storage.database import session_maker
|
||||
from storage.stored_conversation_metadata_saas import StoredConversationMetadataSaas
|
||||
from storage.stored_conversation_metadata import StoredConversationMetadata
|
||||
|
||||
from openhands.server.shared import conversation_manager
|
||||
|
||||
@@ -226,12 +226,12 @@ def _parse_conversation_id_and_subpath(path: str) -> Tuple[str, str]:
|
||||
|
||||
def _get_user_id(conversation_id: str) -> str:
|
||||
with session_maker() as session:
|
||||
conversation_metadata_saas = (
|
||||
session.query(StoredConversationMetadataSaas)
|
||||
.filter(StoredConversationMetadataSaas.conversation_id == conversation_id)
|
||||
conversation_metadata = (
|
||||
session.query(StoredConversationMetadata)
|
||||
.filter(StoredConversationMetadata.conversation_id == conversation_id)
|
||||
.first()
|
||||
)
|
||||
return str(conversation_metadata_saas.user_id)
|
||||
return conversation_metadata.user_id
|
||||
|
||||
|
||||
async def _get_session_api_key(user_id: str, conversation_id: str) -> str | None:
|
||||
|
||||
@@ -5,7 +5,7 @@ from pydantic import BaseModel, Field
|
||||
from sqlalchemy.future import select
|
||||
from storage.database import session_maker
|
||||
from storage.feedback import ConversationFeedback
|
||||
from storage.stored_conversation_metadata_saas import StoredConversationMetadataSaas
|
||||
from storage.stored_conversation_metadata import StoredConversationMetadata
|
||||
|
||||
from openhands.events.event_store import EventStore
|
||||
from openhands.server.shared import file_store
|
||||
@@ -33,10 +33,10 @@ async def get_event_ids(conversation_id: str, user_id: str) -> List[int]:
|
||||
def _verify_conversation():
|
||||
with session_maker() as session:
|
||||
metadata = (
|
||||
session.query(StoredConversationMetadataSaas)
|
||||
session.query(StoredConversationMetadata)
|
||||
.filter(
|
||||
StoredConversationMetadataSaas.conversation_id == conversation_id,
|
||||
StoredConversationMetadataSaas.user_id == user_id,
|
||||
StoredConversationMetadata.conversation_id == conversation_id,
|
||||
StoredConversationMetadata.user_id == user_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import zlib
|
||||
from base64 import b64decode, b64encode
|
||||
from urllib.parse import parse_qs, urlencode, urlparse
|
||||
|
||||
@@ -52,11 +51,7 @@ def add_github_proxy_routes(app: FastAPI):
|
||||
state_payload = json.dumps(
|
||||
[query_params['state'][0], query_params['redirect_uri'][0]]
|
||||
)
|
||||
# Compress before encrypting to reduce URL length
|
||||
# This is critical for feature deployments where reCAPTCHA tokens in state
|
||||
# can cause "URL too long" errors from GitHub
|
||||
compressed_payload = zlib.compress(state_payload.encode())
|
||||
state = b64encode(_fernet().encrypt(compressed_payload)).decode()
|
||||
state = b64encode(_fernet().encrypt(state_payload.encode())).decode()
|
||||
query_params['state'] = [state]
|
||||
query_params['redirect_uri'] = [
|
||||
f'https://{request.url.netloc}/github-proxy/callback'
|
||||
@@ -72,9 +67,7 @@ def add_github_proxy_routes(app: FastAPI):
|
||||
parsed_url = urlparse(str(request.url))
|
||||
query_params = parse_qs(parsed_url.query)
|
||||
state = query_params['state'][0]
|
||||
# Decrypt and decompress (reverse of github_proxy_start)
|
||||
decrypted_payload = _fernet().decrypt(b64decode(state.encode()))
|
||||
decrypted_state = zlib.decompress(decrypted_payload).decode()
|
||||
decrypted_state = _fernet().decrypt(b64decode(state.encode())).decode()
|
||||
|
||||
# Build query Params
|
||||
state, redirect_uri = json.loads(decrypted_state)
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
@@ -7,16 +5,15 @@ import uuid
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import requests
|
||||
from fastapi import APIRouter, BackgroundTasks, Header, HTTPException, Request, status
|
||||
from fastapi import APIRouter, BackgroundTasks, HTTPException, Request, status
|
||||
from fastapi.responses import JSONResponse, RedirectResponse
|
||||
from integrations.jira.jira_manager import JiraManager
|
||||
from integrations.models import Message, SourceType
|
||||
from integrations.utils import HOST_URL
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from server.auth.constants import JIRA_CLIENT_ID, JIRA_CLIENT_SECRET
|
||||
from server.auth.saas_user_auth import SaasUserAuth
|
||||
from server.auth.token_manager import TokenManager
|
||||
from storage.jira_workspace import JiraWorkspace
|
||||
from server.constants import WEB_HOST
|
||||
from storage.redis import create_redis_client
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
@@ -27,7 +24,7 @@ JIRA_WEBHOOKS_ENABLED = os.environ.get('JIRA_WEBHOOKS_ENABLED', '0') in (
|
||||
'1',
|
||||
'true',
|
||||
)
|
||||
JIRA_REDIRECT_URI = f'{HOST_URL}/integration/jira/callback'
|
||||
JIRA_REDIRECT_URI = f'https://{WEB_HOST}/integration/jira/callback'
|
||||
JIRA_SCOPES = 'read:me read:jira-user read:jira-work'
|
||||
JIRA_AUTH_URL = 'https://auth.atlassian.com/authorize'
|
||||
JIRA_TOKEN_URL = 'https://auth.atlassian.com/oauth/token'
|
||||
@@ -125,63 +122,6 @@ jira_manager = JiraManager(token_manager)
|
||||
redis_client = create_redis_client()
|
||||
|
||||
|
||||
async def verify_jira_signature(body: bytes, signature: str, payload: dict):
|
||||
"""
|
||||
Verify Jira webhook signature.
|
||||
|
||||
Args:
|
||||
body: Raw request body bytes
|
||||
signature: Signature from x-hub-signature header (format: "sha256=<hash>")
|
||||
payload: Parsed JSON payload from webhook
|
||||
|
||||
Raises:
|
||||
HTTPException: 403 if signature verification fails or workspace is invalid
|
||||
|
||||
Returns:
|
||||
None (raises exception on failure)
|
||||
"""
|
||||
|
||||
if not signature:
|
||||
raise HTTPException(
|
||||
status_code=403, detail='x-hub-signature header is missing!'
|
||||
)
|
||||
|
||||
workspace_name = jira_manager.get_workspace_name_from_payload(payload)
|
||||
if workspace_name is None:
|
||||
logger.warning('[Jira] No workspace name found in webhook payload')
|
||||
raise HTTPException(
|
||||
status_code=403, detail='Workspace name not found in payload'
|
||||
)
|
||||
|
||||
workspace: (
|
||||
JiraWorkspace | None
|
||||
) = await jira_manager.integration_store.get_workspace_by_name(workspace_name)
|
||||
|
||||
if workspace is None:
|
||||
logger.warning(f'[Jira] Could not identify workspace {workspace_name}')
|
||||
raise HTTPException(status_code=403, detail='Unidentified workspace')
|
||||
|
||||
if workspace.status != 'active':
|
||||
logger.warning(
|
||||
'[Jira] Workspace is inactive',
|
||||
extra={
|
||||
'jira_workspace_id': workspace.id,
|
||||
'parsed_workspace_name': workspace.name,
|
||||
'status': workspace.status,
|
||||
},
|
||||
)
|
||||
|
||||
raise HTTPException(status_code=403, detail='Workspace is inactive')
|
||||
|
||||
webhook_secret = token_manager.decrypt_text(workspace.webhook_secret)
|
||||
expected_signature = hmac.new(
|
||||
webhook_secret.encode(), body, hashlib.sha256
|
||||
).hexdigest()
|
||||
|
||||
if not hmac.compare_digest(expected_signature, signature):
|
||||
raise HTTPException(status_code=403, detail="Request signatures didn't match!")
|
||||
|
||||
|
||||
async def _handle_workspace_link_creation(
|
||||
user_id: str, jira_user_id: str, target_workspace: str
|
||||
):
|
||||
@@ -276,7 +216,6 @@ async def _validate_workspace_update_permissions(user_id: str, target_workspace:
|
||||
async def jira_events(
|
||||
request: Request,
|
||||
background_tasks: BackgroundTasks,
|
||||
x_hub_signature: str = Header(None),
|
||||
):
|
||||
"""Handle Jira webhook events."""
|
||||
# Check if Jira webhooks are enabled
|
||||
@@ -288,15 +227,13 @@ async def jira_events(
|
||||
)
|
||||
|
||||
try:
|
||||
parts = x_hub_signature.split('=', 1)
|
||||
if not (len(parts) == 2 and parts[1]):
|
||||
raise HTTPException(status_code=403, detail='Malformed x-hub-signature!')
|
||||
signature_valid, signature, payload = await jira_manager.validate_request(
|
||||
request
|
||||
)
|
||||
|
||||
signature = parts[1]
|
||||
body = await request.body()
|
||||
payload = await request.json()
|
||||
|
||||
await verify_jira_signature(body, signature, payload)
|
||||
if not signature_valid:
|
||||
logger.warning('[Jira] Invalid webhook signature')
|
||||
raise HTTPException(status_code=403, detail='Invalid webhook signature!')
|
||||
|
||||
# Check for duplicate requests using Redis
|
||||
key = f'jira:{signature}'
|
||||
|
||||
@@ -15,6 +15,7 @@ from integrations.slack.slack_manager import SlackManager
|
||||
from integrations.utils import (
|
||||
HOST_URL,
|
||||
)
|
||||
from pydantic import SecretStr
|
||||
from server.auth.constants import (
|
||||
KEYCLOAK_CLIENT_ID,
|
||||
KEYCLOAK_REALM_NAME,
|
||||
@@ -34,7 +35,6 @@ from slack_sdk.web.async_client import AsyncWebClient
|
||||
from storage.database import session_maker
|
||||
from storage.slack_team_store import SlackTeamStore
|
||||
from storage.slack_user import SlackUser
|
||||
from storage.user_store import UserStore
|
||||
|
||||
from openhands.integrations.service_types import ProviderType
|
||||
from openhands.server.shared import config, sio
|
||||
@@ -79,14 +79,6 @@ async def install_callback(
|
||||
status_code=400,
|
||||
)
|
||||
|
||||
if not config.jwt_secret:
|
||||
logger.error('slack_install_callback_error JWT not configured.')
|
||||
return _html_response(
|
||||
title='Error',
|
||||
description=html.escape('JWT not configured'),
|
||||
status_code=500,
|
||||
)
|
||||
|
||||
try:
|
||||
client = AsyncWebClient() # no prepared token needed for this
|
||||
# Complete the installation by calling oauth.v2.access API method
|
||||
@@ -102,17 +94,16 @@ async def install_callback(
|
||||
|
||||
# Create a state variable for keycloak oauth
|
||||
payload = {}
|
||||
jwt_secret: SecretStr = config.jwt_secret # type: ignore[assignment]
|
||||
if state:
|
||||
payload = jwt.decode(
|
||||
state, config.jwt_secret.get_secret_value(), algorithms=['HS256']
|
||||
state, jwt_secret.get_secret_value(), algorithms=['HS256']
|
||||
)
|
||||
payload['slack_user_id'] = authed_user.get('id')
|
||||
payload['bot_access_token'] = bot_access_token
|
||||
payload['team_id'] = team_id
|
||||
|
||||
state = jwt.encode(
|
||||
payload, config.jwt_secret.get_secret_value(), algorithm='HS256'
|
||||
)
|
||||
state = jwt.encode(payload, jwt_secret.get_secret_value(), algorithm='HS256')
|
||||
|
||||
# Redirect into keycloak
|
||||
scope = quote('openid email profile offline_access')
|
||||
@@ -158,16 +149,9 @@ async def keycloak_callback(
|
||||
status_code=400,
|
||||
)
|
||||
|
||||
if not config.jwt_secret:
|
||||
logger.error('problem_retrieving_keycloak_tokens JWT not configured.')
|
||||
return _html_response(
|
||||
title='Error',
|
||||
description=html.escape('JWT not configured'),
|
||||
status_code=500,
|
||||
)
|
||||
|
||||
jwt_secret: SecretStr = config.jwt_secret # type: ignore[assignment]
|
||||
payload: dict[str, str] = jwt.decode(
|
||||
state, config.jwt_secret.get_secret_value(), algorithms=['HS256']
|
||||
state, jwt_secret.get_secret_value(), algorithms=['HS256']
|
||||
)
|
||||
slack_user_id = payload['slack_user_id']
|
||||
bot_access_token = payload['bot_access_token']
|
||||
@@ -196,13 +180,6 @@ async def keycloak_callback(
|
||||
|
||||
user_info = await token_manager.get_user_info(keycloak_access_token)
|
||||
keycloak_user_id = user_info['sub']
|
||||
user = await UserStore.get_user_by_id_async(keycloak_user_id)
|
||||
if not user:
|
||||
return _html_response(
|
||||
title='Failed to authenticate.',
|
||||
description=f'Please re-login into <a href="{HOST_URL}" style="color:#ecedee;text-decoration:underline;">OpenHands Cloud</a>. Then try <a href="https://docs.all-hands.dev/usage/cloud/slack-installation" style="color:#ecedee;text-decoration:underline;">installing the OpenHands Slack App</a> again',
|
||||
status_code=400,
|
||||
)
|
||||
|
||||
# These tokens are offline access tokens - store them!
|
||||
await token_manager.store_offline_token(keycloak_user_id, keycloak_refresh_token)
|
||||
@@ -234,7 +211,6 @@ async def keycloak_callback(
|
||||
slack_display_name = slack_user_info.data['user']['profile']['display_name']
|
||||
slack_user = SlackUser(
|
||||
keycloak_user_id=keycloak_user_id,
|
||||
org_id=user.current_org_id,
|
||||
slack_user_id=slack_user_id,
|
||||
slack_display_name=slack_display_name,
|
||||
)
|
||||
@@ -329,7 +305,7 @@ async def on_form_interaction(request: Request, background_tasks: BackgroundTask
|
||||
|
||||
body = await request.body()
|
||||
form = await request.form()
|
||||
payload = json.loads(form.get('payload'))
|
||||
payload = json.loads(form.get('payload')) # type: ignore[arg-type]
|
||||
|
||||
logger.info('slack_on_form_interaction', extra={'payload': payload})
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ DEVICE_CODE_EXPIRES_IN = 600 # 10 minutes
|
||||
DEVICE_TOKEN_POLL_INTERVAL = 5 # seconds
|
||||
|
||||
API_KEY_NAME = 'Device Link Access Key'
|
||||
KEY_EXPIRATION_TIME = timedelta(days=7) # Key expires in a week
|
||||
KEY_EXPIRATION_TIME = timedelta(days=1) # Key expires in 24 hours
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Models
|
||||
|
||||
@@ -1,171 +0,0 @@
|
||||
from pydantic import BaseModel, EmailStr, Field
|
||||
from storage.org import Org
|
||||
|
||||
|
||||
class OrgCreationError(Exception):
|
||||
"""Base exception for organization creation errors."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class OrgNameExistsError(OrgCreationError):
|
||||
"""Raised when an organization name already exists."""
|
||||
|
||||
def __init__(self, name: str):
|
||||
self.name = name
|
||||
super().__init__(f'Organization with name "{name}" already exists')
|
||||
|
||||
|
||||
class LiteLLMIntegrationError(OrgCreationError):
|
||||
"""Raised when LiteLLM integration fails."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class OrgDatabaseError(OrgCreationError):
|
||||
"""Raised when database operations fail."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class OrgDeletionError(Exception):
|
||||
"""Base exception for organization deletion errors."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class OrgAuthorizationError(OrgDeletionError):
|
||||
"""Raised when user is not authorized to delete organization."""
|
||||
|
||||
def __init__(self, message: str = 'Not authorized to delete organization'):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class OrgNotFoundError(Exception):
|
||||
"""Raised when organization is not found or user doesn't have access."""
|
||||
|
||||
def __init__(self, org_id: str):
|
||||
self.org_id = org_id
|
||||
super().__init__(f'Organization with id "{org_id}" not found')
|
||||
|
||||
|
||||
class OrgCreate(BaseModel):
|
||||
"""Request model for creating a new organization."""
|
||||
|
||||
# Required fields
|
||||
name: str = Field(min_length=1, max_length=255, strip_whitespace=True)
|
||||
contact_name: str
|
||||
contact_email: EmailStr = Field(strip_whitespace=True)
|
||||
|
||||
|
||||
class OrgResponse(BaseModel):
|
||||
"""Response model for organization."""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
contact_name: str
|
||||
contact_email: str
|
||||
conversation_expiration: int | None = None
|
||||
agent: str | None = None
|
||||
default_max_iterations: int | None = None
|
||||
security_analyzer: str | None = None
|
||||
confirmation_mode: bool | None = None
|
||||
default_llm_model: str | None = None
|
||||
default_llm_api_key_for_byor: str | None = None
|
||||
default_llm_base_url: str | None = None
|
||||
remote_runtime_resource_factor: int | None = None
|
||||
enable_default_condenser: bool = True
|
||||
billing_margin: float | None = None
|
||||
enable_proactive_conversation_starters: bool = True
|
||||
sandbox_base_container_image: str | None = None
|
||||
sandbox_runtime_container_image: str | None = None
|
||||
org_version: int = 0
|
||||
mcp_config: dict | None = None
|
||||
search_api_key: str | None = None
|
||||
sandbox_api_key: str | None = None
|
||||
max_budget_per_task: float | None = None
|
||||
enable_solvability_analysis: bool | None = None
|
||||
v1_enabled: bool | None = None
|
||||
credits: float | None = None
|
||||
|
||||
@classmethod
|
||||
def from_org(cls, org: Org, credits: float | None = None) -> 'OrgResponse':
|
||||
"""Create an OrgResponse from an Org entity.
|
||||
|
||||
Args:
|
||||
org: The organization entity to convert
|
||||
credits: Optional credits value (defaults to None)
|
||||
|
||||
Returns:
|
||||
OrgResponse: The response model instance
|
||||
"""
|
||||
return cls(
|
||||
id=str(org.id),
|
||||
name=org.name,
|
||||
contact_name=org.contact_name,
|
||||
contact_email=org.contact_email,
|
||||
conversation_expiration=org.conversation_expiration,
|
||||
agent=org.agent,
|
||||
default_max_iterations=org.default_max_iterations,
|
||||
security_analyzer=org.security_analyzer,
|
||||
confirmation_mode=org.confirmation_mode,
|
||||
default_llm_model=org.default_llm_model,
|
||||
default_llm_api_key_for_byor=None,
|
||||
default_llm_base_url=org.default_llm_base_url,
|
||||
remote_runtime_resource_factor=org.remote_runtime_resource_factor,
|
||||
enable_default_condenser=org.enable_default_condenser
|
||||
if org.enable_default_condenser is not None
|
||||
else True,
|
||||
billing_margin=org.billing_margin,
|
||||
enable_proactive_conversation_starters=org.enable_proactive_conversation_starters
|
||||
if org.enable_proactive_conversation_starters is not None
|
||||
else True,
|
||||
sandbox_base_container_image=org.sandbox_base_container_image,
|
||||
sandbox_runtime_container_image=org.sandbox_runtime_container_image,
|
||||
org_version=org.org_version if org.org_version is not None else 0,
|
||||
mcp_config=org.mcp_config,
|
||||
search_api_key=None,
|
||||
sandbox_api_key=None,
|
||||
max_budget_per_task=org.max_budget_per_task,
|
||||
enable_solvability_analysis=org.enable_solvability_analysis,
|
||||
v1_enabled=org.v1_enabled,
|
||||
credits=credits,
|
||||
)
|
||||
|
||||
|
||||
class OrgPage(BaseModel):
|
||||
"""Paginated response model for organization list."""
|
||||
|
||||
items: list[OrgResponse]
|
||||
next_page_id: str | None = None
|
||||
|
||||
|
||||
class OrgUpdate(BaseModel):
|
||||
"""Request model for updating an organization."""
|
||||
|
||||
# Basic organization information (any authenticated user can update)
|
||||
contact_name: str | None = None
|
||||
contact_email: EmailStr | None = Field(default=None, strip_whitespace=True)
|
||||
conversation_expiration: int | None = None
|
||||
default_max_iterations: int | None = Field(default=None, gt=0)
|
||||
remote_runtime_resource_factor: int | None = Field(default=None, gt=0)
|
||||
billing_margin: float | None = Field(default=None, ge=0, le=1)
|
||||
enable_proactive_conversation_starters: bool | None = None
|
||||
sandbox_base_container_image: str | None = None
|
||||
sandbox_runtime_container_image: str | None = None
|
||||
mcp_config: dict | None = None
|
||||
sandbox_api_key: str | None = None
|
||||
max_budget_per_task: float | None = Field(default=None, gt=0)
|
||||
enable_solvability_analysis: bool | None = None
|
||||
v1_enabled: bool | None = None
|
||||
|
||||
# LLM settings (require admin/owner role)
|
||||
default_llm_model: str | None = None
|
||||
default_llm_api_key_for_byor: str | None = None
|
||||
default_llm_base_url: str | None = None
|
||||
search_api_key: str | None = None
|
||||
security_analyzer: str | None = None
|
||||
agent: str | None = None
|
||||
confirmation_mode: bool | None = None
|
||||
enable_default_condenser: bool | None = None
|
||||
condenser_max_size: int | None = Field(default=None, ge=20)
|
||||
@@ -1,402 +0,0 @@
|
||||
from typing import Annotated
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
from server.email_validation import get_admin_user_id
|
||||
from server.routes.org_models import (
|
||||
LiteLLMIntegrationError,
|
||||
OrgAuthorizationError,
|
||||
OrgCreate,
|
||||
OrgDatabaseError,
|
||||
OrgNameExistsError,
|
||||
OrgNotFoundError,
|
||||
OrgPage,
|
||||
OrgResponse,
|
||||
OrgUpdate,
|
||||
)
|
||||
from storage.org_service import OrgService
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.server.user_auth import get_user_id
|
||||
|
||||
# Initialize API router
|
||||
org_router = APIRouter(prefix='/api/organizations')
|
||||
|
||||
|
||||
@org_router.get('', response_model=OrgPage)
|
||||
async def list_user_orgs(
|
||||
page_id: Annotated[
|
||||
str | None,
|
||||
Query(title='Optional next_page_id from the previously returned page'),
|
||||
] = None,
|
||||
limit: Annotated[
|
||||
int,
|
||||
Query(title='The max number of results in the page', gt=0, lte=100),
|
||||
] = 100,
|
||||
user_id: str = Depends(get_user_id),
|
||||
) -> OrgPage:
|
||||
"""List organizations for the authenticated user.
|
||||
|
||||
This endpoint returns a paginated list of all organizations that the
|
||||
authenticated user is a member of.
|
||||
|
||||
Args:
|
||||
page_id: Optional page ID (offset) for pagination
|
||||
limit: Maximum number of organizations to return (1-100, default 100)
|
||||
user_id: Authenticated user ID (injected by dependency)
|
||||
|
||||
Returns:
|
||||
OrgPage: Paginated list of organizations
|
||||
|
||||
Raises:
|
||||
HTTPException: 500 if retrieval fails
|
||||
"""
|
||||
logger.info(
|
||||
'Listing organizations for user',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'page_id': page_id,
|
||||
'limit': limit,
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
# Fetch organizations from service layer
|
||||
orgs, next_page_id = OrgService.get_user_orgs_paginated(
|
||||
user_id=user_id,
|
||||
page_id=page_id,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
# Convert Org entities to OrgResponse objects
|
||||
org_responses = [OrgResponse.from_org(org, credits=None) for org in orgs]
|
||||
|
||||
logger.info(
|
||||
'Successfully retrieved organizations',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'org_count': len(org_responses),
|
||||
'has_more': next_page_id is not None,
|
||||
},
|
||||
)
|
||||
|
||||
return OrgPage(items=org_responses, next_page_id=next_page_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
'Unexpected error listing organizations',
|
||||
extra={'user_id': user_id, 'error': str(e)},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='Failed to retrieve organizations',
|
||||
)
|
||||
|
||||
|
||||
@org_router.post('', response_model=OrgResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def create_org(
|
||||
org_data: OrgCreate,
|
||||
user_id: str = Depends(get_admin_user_id),
|
||||
) -> OrgResponse:
|
||||
"""Create a new organization.
|
||||
|
||||
This endpoint allows authenticated users with @openhands.dev email to create
|
||||
a new organization. The user who creates the organization automatically becomes
|
||||
its owner.
|
||||
|
||||
Args:
|
||||
org_data: Organization creation data
|
||||
user_id: Authenticated user ID (injected by dependency)
|
||||
|
||||
Returns:
|
||||
OrgResponse: The created organization details
|
||||
|
||||
Raises:
|
||||
HTTPException: 403 if user email domain is not @openhands.dev
|
||||
HTTPException: 409 if organization name already exists
|
||||
HTTPException: 500 if creation fails
|
||||
"""
|
||||
logger.info(
|
||||
'Creating new organization',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'org_name': org_data.name,
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
# Use service layer to create organization
|
||||
org = await OrgService.create_org_with_owner(
|
||||
name=org_data.name,
|
||||
contact_name=org_data.contact_name,
|
||||
contact_email=org_data.contact_email,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
# Retrieve credits from LiteLLM
|
||||
credits = await OrgService.get_org_credits(user_id, org.id)
|
||||
|
||||
return OrgResponse.from_org(org, credits=credits)
|
||||
except OrgNameExistsError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail=str(e),
|
||||
)
|
||||
except LiteLLMIntegrationError as e:
|
||||
logger.error(
|
||||
'LiteLLM integration failed',
|
||||
extra={'user_id': user_id, 'error': str(e)},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='Failed to create LiteLLM integration',
|
||||
)
|
||||
except OrgDatabaseError as e:
|
||||
logger.error(
|
||||
'Database operation failed',
|
||||
extra={'user_id': user_id, 'error': str(e)},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='Failed to create organization',
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
'Unexpected error creating organization',
|
||||
extra={'user_id': user_id, 'error': str(e)},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='An unexpected error occurred',
|
||||
)
|
||||
|
||||
|
||||
@org_router.get('/{org_id}', response_model=OrgResponse, status_code=status.HTTP_200_OK)
|
||||
async def get_org(
|
||||
org_id: UUID,
|
||||
user_id: str = Depends(get_user_id),
|
||||
) -> OrgResponse:
|
||||
"""Get organization details by ID.
|
||||
|
||||
This endpoint allows authenticated users who are members of an organization
|
||||
to retrieve its details. Only members of the organization can access this endpoint.
|
||||
|
||||
Args:
|
||||
org_id: Organization ID (UUID)
|
||||
user_id: Authenticated user ID (injected by dependency)
|
||||
|
||||
Returns:
|
||||
OrgResponse: The organization details
|
||||
|
||||
Raises:
|
||||
HTTPException: 422 if org_id is not a valid UUID (handled by FastAPI)
|
||||
HTTPException: 404 if organization not found or user is not a member
|
||||
HTTPException: 500 if retrieval fails
|
||||
"""
|
||||
logger.info(
|
||||
'Retrieving organization details',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'org_id': str(org_id),
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
# Use service layer to get organization with membership validation
|
||||
org = await OrgService.get_org_by_id(
|
||||
org_id=org_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
# Retrieve credits from LiteLLM
|
||||
credits = await OrgService.get_org_credits(user_id, org.id)
|
||||
|
||||
return OrgResponse.from_org(org, credits=credits)
|
||||
except OrgNotFoundError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=str(e),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
'Unexpected error retrieving organization',
|
||||
extra={'user_id': user_id, 'org_id': str(org_id), 'error': str(e)},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='An unexpected error occurred',
|
||||
)
|
||||
|
||||
|
||||
@org_router.delete('/{org_id}', status_code=status.HTTP_200_OK)
|
||||
async def delete_org(
|
||||
org_id: UUID,
|
||||
user_id: str = Depends(get_admin_user_id),
|
||||
) -> dict:
|
||||
"""Delete an organization.
|
||||
|
||||
This endpoint allows authenticated organization owners to delete their organization.
|
||||
All associated data including organization members, conversations, billing data,
|
||||
and external LiteLLM team resources will be permanently removed.
|
||||
|
||||
Args:
|
||||
org_id: Organization ID to delete
|
||||
user_id: Authenticated user ID (injected by dependency)
|
||||
|
||||
Returns:
|
||||
dict: Confirmation message with deleted organization details
|
||||
|
||||
Raises:
|
||||
HTTPException: 403 if user is not the organization owner
|
||||
HTTPException: 404 if organization not found
|
||||
HTTPException: 500 if deletion fails
|
||||
"""
|
||||
logger.info(
|
||||
'Organization deletion requested',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'org_id': str(org_id),
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
# Use service layer to delete organization with cleanup
|
||||
deleted_org = await OrgService.delete_org_with_cleanup(
|
||||
user_id=user_id,
|
||||
org_id=org_id,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
'Organization deletion completed successfully',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'org_id': str(org_id),
|
||||
'org_name': deleted_org.name,
|
||||
},
|
||||
)
|
||||
|
||||
return {
|
||||
'message': 'Organization deleted successfully',
|
||||
'organization': {
|
||||
'id': str(deleted_org.id),
|
||||
'name': deleted_org.name,
|
||||
'contact_name': deleted_org.contact_name,
|
||||
'contact_email': deleted_org.contact_email,
|
||||
},
|
||||
}
|
||||
|
||||
except OrgNotFoundError as e:
|
||||
logger.warning(
|
||||
'Organization not found for deletion',
|
||||
extra={'user_id': user_id, 'org_id': str(org_id)},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=str(e),
|
||||
)
|
||||
except OrgAuthorizationError as e:
|
||||
logger.warning(
|
||||
'User not authorized to delete organization',
|
||||
extra={'user_id': user_id, 'org_id': str(org_id)},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=str(e),
|
||||
)
|
||||
except OrgDatabaseError as e:
|
||||
logger.error(
|
||||
'Database error during organization deletion',
|
||||
extra={'user_id': user_id, 'org_id': str(org_id), 'error': str(e)},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='Failed to delete organization',
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
'Unexpected error during organization deletion',
|
||||
extra={'user_id': user_id, 'org_id': str(org_id), 'error': str(e)},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='An unexpected error occurred',
|
||||
)
|
||||
|
||||
|
||||
@org_router.patch('/{org_id}', response_model=OrgResponse)
|
||||
async def update_org(
|
||||
org_id: UUID,
|
||||
update_data: OrgUpdate,
|
||||
user_id: str = Depends(get_user_id),
|
||||
) -> OrgResponse:
|
||||
"""Update an existing organization.
|
||||
|
||||
This endpoint allows authenticated users to update organization settings.
|
||||
LLM-related settings require admin or owner role in the organization.
|
||||
|
||||
Args:
|
||||
org_id: Organization ID to update (UUID validated by FastAPI)
|
||||
update_data: Organization update data
|
||||
user_id: Authenticated user ID (injected by dependency)
|
||||
|
||||
Returns:
|
||||
OrgResponse: The updated organization details
|
||||
|
||||
Raises:
|
||||
HTTPException: 400 if org_id is invalid UUID format (handled by FastAPI)
|
||||
HTTPException: 403 if user lacks permission for LLM settings
|
||||
HTTPException: 404 if organization not found
|
||||
HTTPException: 422 if validation errors occur (handled by FastAPI)
|
||||
HTTPException: 500 if update fails
|
||||
"""
|
||||
logger.info(
|
||||
'Updating organization',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'org_id': str(org_id),
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
# Use service layer to update organization with permission checks
|
||||
updated_org = await OrgService.update_org_with_permissions(
|
||||
org_id=org_id,
|
||||
update_data=update_data,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
# Retrieve credits from LiteLLM (following same pattern as create endpoint)
|
||||
credits = await OrgService.get_org_credits(user_id, updated_org.id)
|
||||
|
||||
return OrgResponse.from_org(updated_org, credits=credits)
|
||||
|
||||
except ValueError as e:
|
||||
# Organization not found
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=str(e),
|
||||
)
|
||||
except PermissionError as e:
|
||||
# User lacks permission for LLM settings
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=str(e),
|
||||
)
|
||||
except OrgDatabaseError as e:
|
||||
logger.error(
|
||||
'Database operation failed',
|
||||
extra={'user_id': user_id, 'org_id': str(org_id), 'error': str(e)},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='Failed to update organization',
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
'Unexpected error updating organization',
|
||||
extra={'user_id': user_id, 'org_id': str(org_id), 'error': str(e)},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='An unexpected error occurred',
|
||||
)
|
||||
@@ -23,7 +23,6 @@ from sqlalchemy import orm
|
||||
from storage.api_key_store import ApiKeyStore
|
||||
from storage.database import session_maker
|
||||
from storage.stored_conversation_metadata import StoredConversationMetadata
|
||||
from storage.stored_conversation_metadata_saas import StoredConversationMetadataSaas
|
||||
|
||||
from openhands.controller.agent import Agent
|
||||
from openhands.core.config import LLMConfig, OpenHandsConfig
|
||||
@@ -647,18 +646,16 @@ class SaasNestedConversationManager(ConversationManager):
|
||||
"""
|
||||
|
||||
with session_maker() as session:
|
||||
conversation_metadata_saas = (
|
||||
session.query(StoredConversationMetadataSaas)
|
||||
.filter(
|
||||
StoredConversationMetadataSaas.conversation_id == conversation_id
|
||||
)
|
||||
conversation_metadata = (
|
||||
session.query(StoredConversationMetadata)
|
||||
.filter(StoredConversationMetadata.conversation_id == conversation_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not conversation_metadata_saas:
|
||||
if not conversation_metadata:
|
||||
raise ValueError(f'No conversation found {conversation_id}')
|
||||
|
||||
return str(conversation_metadata_saas.user_id)
|
||||
return conversation_metadata.user_id
|
||||
|
||||
async def _get_runtime_status_from_nested_runtime(
|
||||
self, session_api_key: Any | None, nested_url: str, conversation_id: str
|
||||
@@ -997,17 +994,9 @@ class SaasNestedConversationManager(ConversationManager):
|
||||
with session_maker() as session:
|
||||
# Only include conversations updated in the past week
|
||||
one_week_ago = datetime.now(UTC) - timedelta(days=7)
|
||||
query = (
|
||||
session.query(StoredConversationMetadata.conversation_id)
|
||||
.join(
|
||||
StoredConversationMetadataSaas,
|
||||
StoredConversationMetadata.conversation_id
|
||||
== StoredConversationMetadataSaas.conversation_id,
|
||||
)
|
||||
.filter(
|
||||
StoredConversationMetadataSaas.user_id == user_id,
|
||||
StoredConversationMetadata.last_updated_at >= one_week_ago,
|
||||
)
|
||||
query = session.query(StoredConversationMetadata.conversation_id).filter(
|
||||
StoredConversationMetadata.user_id == user_id,
|
||||
StoredConversationMetadata.last_updated_at >= one_week_ago,
|
||||
)
|
||||
user_conversation_ids = set(query)
|
||||
return user_conversation_ids
|
||||
@@ -1081,16 +1070,11 @@ class SaasNestedConversationManager(ConversationManager):
|
||||
.filter(StoredConversationMetadata.conversation_id == conversation_id)
|
||||
.first()
|
||||
)
|
||||
conversation_metadata_saas = (
|
||||
session.query(StoredConversationMetadataSaas)
|
||||
.filter(StoredConversationMetadataSaas.conversation_id == conversation_id)
|
||||
.first()
|
||||
)
|
||||
if conversation_metadata is None or conversation_metadata_saas is None:
|
||||
if conversation_metadata is None:
|
||||
# Conversation is running in different server
|
||||
return
|
||||
|
||||
user_id = conversation_metadata_saas.user_id
|
||||
user_id = conversation_metadata.user_id
|
||||
|
||||
# Get the id of the next event which is not present
|
||||
events_dir = get_conversation_events_dir(
|
||||
|
||||
@@ -26,7 +26,6 @@ from server.sharing.shared_conversation_models import (
|
||||
)
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from storage.stored_conversation_metadata_saas import StoredConversationMetadataSaas
|
||||
|
||||
from openhands.app_server.app_conversation.sql_app_conversation_info_service import (
|
||||
StoredConversationMetadata,
|
||||
@@ -58,7 +57,7 @@ class SQLSharedConversationInfoService(SharedConversationInfoService):
|
||||
include_sub_conversations: bool = False,
|
||||
) -> SharedConversationPage:
|
||||
"""Search for shared conversations."""
|
||||
query = self._public_select_with_saas_metadata()
|
||||
query = self._public_select()
|
||||
|
||||
# Conditionally exclude sub-conversations based on the parameter
|
||||
if not include_sub_conversations:
|
||||
@@ -105,17 +104,14 @@ class SQLSharedConversationInfoService(SharedConversationInfoService):
|
||||
query = query.limit(limit + 1)
|
||||
|
||||
result = await self.db_session.execute(query)
|
||||
rows = result.all()
|
||||
rows = result.scalars().all()
|
||||
|
||||
# Check if there are more results
|
||||
has_more = len(rows) > limit
|
||||
if has_more:
|
||||
rows = rows[:limit]
|
||||
|
||||
items = [
|
||||
self._to_shared_conversation(stored, saas_metadata=saas_metadata)
|
||||
for stored, saas_metadata in rows
|
||||
]
|
||||
items = [self._to_shared_conversation(row) for row in rows]
|
||||
|
||||
# Calculate next page ID
|
||||
next_page_id = None
|
||||
@@ -156,18 +152,17 @@ class SQLSharedConversationInfoService(SharedConversationInfoService):
|
||||
self, conversation_id: UUID
|
||||
) -> SharedConversation | None:
|
||||
"""Get a single public conversation info, returning None if missing or not shared."""
|
||||
query = self._public_select_with_saas_metadata().where(
|
||||
query = self._public_select().where(
|
||||
StoredConversationMetadata.conversation_id == str(conversation_id)
|
||||
)
|
||||
|
||||
result = await self.db_session.execute(query)
|
||||
row = result.first()
|
||||
stored = result.scalar_one_or_none()
|
||||
|
||||
if row is None:
|
||||
if stored is None:
|
||||
return None
|
||||
|
||||
stored, saas_metadata = row
|
||||
return self._to_shared_conversation(stored, saas_metadata=saas_metadata)
|
||||
return self._to_shared_conversation(stored)
|
||||
|
||||
def _public_select(self):
|
||||
"""Create a select query that only returns public conversations."""
|
||||
@@ -178,25 +173,6 @@ class SQLSharedConversationInfoService(SharedConversationInfoService):
|
||||
query = query.where(StoredConversationMetadata.public == True) # noqa: E712
|
||||
return query
|
||||
|
||||
def _public_select_with_saas_metadata(self):
|
||||
"""Create a select query that returns public conversations with SAAS metadata.
|
||||
|
||||
This joins with conversation_metadata_saas to retrieve the user_id needed
|
||||
for constructing the correct event storage path. Uses LEFT OUTER JOIN to
|
||||
support conversations that may not have SAAS metadata (e.g., in tests).
|
||||
"""
|
||||
query = (
|
||||
select(StoredConversationMetadata, StoredConversationMetadataSaas)
|
||||
.outerjoin(
|
||||
StoredConversationMetadataSaas,
|
||||
StoredConversationMetadata.conversation_id
|
||||
== StoredConversationMetadataSaas.conversation_id,
|
||||
)
|
||||
.where(StoredConversationMetadata.conversation_version == 'V1')
|
||||
.where(StoredConversationMetadata.public == True) # noqa: E712
|
||||
)
|
||||
return query
|
||||
|
||||
def _apply_filters(
|
||||
self,
|
||||
query,
|
||||
@@ -235,16 +211,9 @@ class SQLSharedConversationInfoService(SharedConversationInfoService):
|
||||
def _to_shared_conversation(
|
||||
self,
|
||||
stored: StoredConversationMetadata,
|
||||
saas_metadata: StoredConversationMetadataSaas | None = None,
|
||||
sub_conversation_ids: list[UUID] | None = None,
|
||||
) -> SharedConversation:
|
||||
"""Convert StoredConversationMetadata to SharedConversation.
|
||||
|
||||
Args:
|
||||
stored: The base conversation metadata from conversation_metadata table.
|
||||
saas_metadata: Optional SAAS metadata containing user_id and org_id.
|
||||
sub_conversation_ids: Optional list of sub-conversation IDs.
|
||||
"""
|
||||
"""Convert StoredConversationMetadata to SharedConversation."""
|
||||
# V1 conversations should always have a sandbox_id
|
||||
sandbox_id = stored.sandbox_id
|
||||
assert sandbox_id is not None
|
||||
@@ -270,16 +239,9 @@ class SQLSharedConversationInfoService(SharedConversationInfoService):
|
||||
created_at = self._fix_timezone(stored.created_at)
|
||||
updated_at = self._fix_timezone(stored.last_updated_at)
|
||||
|
||||
# Get user_id from SAAS metadata if available
|
||||
created_by_user_id = (
|
||||
str(saas_metadata.user_id)
|
||||
if saas_metadata and saas_metadata.user_id
|
||||
else None
|
||||
)
|
||||
|
||||
return SharedConversation(
|
||||
id=UUID(stored.conversation_id),
|
||||
created_by_user_id=created_by_user_id,
|
||||
created_by_user_id=stored.user_id if stored.user_id else None,
|
||||
sandbox_id=stored.sandbox_id,
|
||||
selected_repository=stored.selected_repository,
|
||||
selected_branch=stored.selected_branch,
|
||||
|
||||
@@ -1,350 +0,0 @@
|
||||
"""Enterprise injector for SQLAppConversationInfoService with SAAS filtering."""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import AsyncGenerator
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import Request
|
||||
from sqlalchemy import func, select
|
||||
from storage.stored_conversation_metadata import StoredConversationMetadata
|
||||
from storage.stored_conversation_metadata_saas import StoredConversationMetadataSaas
|
||||
from storage.user import User
|
||||
|
||||
from openhands.app_server.app_conversation.app_conversation_info_service import (
|
||||
AppConversationInfoService,
|
||||
AppConversationInfoServiceInjector,
|
||||
)
|
||||
from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
AppConversationInfo,
|
||||
AppConversationInfoPage,
|
||||
AppConversationSortOrder,
|
||||
)
|
||||
from openhands.app_server.app_conversation.sql_app_conversation_info_service import (
|
||||
SQLAppConversationInfoService,
|
||||
)
|
||||
from openhands.app_server.services.injector import InjectorState
|
||||
|
||||
|
||||
class SaasSQLAppConversationInfoService(SQLAppConversationInfoService):
|
||||
"""Extended SQLAppConversationInfoService with user-based filtering and SAAS metadata handling."""
|
||||
|
||||
async def _secure_select(self):
|
||||
query = (
|
||||
select(StoredConversationMetadata)
|
||||
.join(
|
||||
StoredConversationMetadataSaas,
|
||||
StoredConversationMetadata.conversation_id
|
||||
== StoredConversationMetadataSaas.conversation_id,
|
||||
)
|
||||
.where(StoredConversationMetadata.conversation_version == 'V1')
|
||||
)
|
||||
|
||||
user_id_str = await self.user_context.get_user_id()
|
||||
if user_id_str:
|
||||
user_id_uuid = UUID(user_id_str)
|
||||
query = query.where(StoredConversationMetadataSaas.user_id == user_id_uuid)
|
||||
|
||||
return query
|
||||
|
||||
async def _secure_select_with_saas_metadata(self):
|
||||
"""Select query that includes SAAS metadata for retrieving user_id."""
|
||||
query = (
|
||||
select(StoredConversationMetadata, StoredConversationMetadataSaas)
|
||||
.join(
|
||||
StoredConversationMetadataSaas,
|
||||
StoredConversationMetadata.conversation_id
|
||||
== StoredConversationMetadataSaas.conversation_id,
|
||||
)
|
||||
.where(StoredConversationMetadata.conversation_version == 'V1')
|
||||
)
|
||||
|
||||
user_id_str = await self.user_context.get_user_id()
|
||||
if user_id_str:
|
||||
user_id_uuid = UUID(user_id_str)
|
||||
query = query.where(StoredConversationMetadataSaas.user_id == user_id_uuid)
|
||||
|
||||
return query
|
||||
|
||||
async def search_app_conversation_info(
|
||||
self,
|
||||
title__contains: str | None = None,
|
||||
created_at__gte: datetime | None = None,
|
||||
created_at__lt: datetime | None = None,
|
||||
updated_at__gte: datetime | None = None,
|
||||
updated_at__lt: datetime | None = None,
|
||||
sort_order: AppConversationSortOrder = AppConversationSortOrder.CREATED_AT_DESC,
|
||||
page_id: str | None = None,
|
||||
limit: int = 100,
|
||||
include_sub_conversations: bool = False,
|
||||
) -> AppConversationInfoPage:
|
||||
"""Search for conversations with user_id from SAAS metadata."""
|
||||
query = await self._secure_select_with_saas_metadata()
|
||||
|
||||
# Conditionally exclude sub-conversations based on the parameter
|
||||
if not include_sub_conversations:
|
||||
# Exclude sub-conversations (only include top-level conversations)
|
||||
query = query.where(
|
||||
StoredConversationMetadata.parent_conversation_id.is_(None)
|
||||
)
|
||||
|
||||
query = self._apply_filters_with_saas_metadata(
|
||||
query=query,
|
||||
title__contains=title__contains,
|
||||
created_at__gte=created_at__gte,
|
||||
created_at__lt=created_at__lt,
|
||||
updated_at__gte=updated_at__gte,
|
||||
updated_at__lt=updated_at__lt,
|
||||
)
|
||||
|
||||
# Add sort order
|
||||
if sort_order == AppConversationSortOrder.CREATED_AT:
|
||||
query = query.order_by(StoredConversationMetadata.created_at)
|
||||
elif sort_order == AppConversationSortOrder.CREATED_AT_DESC:
|
||||
query = query.order_by(StoredConversationMetadata.created_at.desc())
|
||||
elif sort_order == AppConversationSortOrder.UPDATED_AT:
|
||||
query = query.order_by(StoredConversationMetadata.last_updated_at)
|
||||
elif sort_order == AppConversationSortOrder.UPDATED_AT_DESC:
|
||||
query = query.order_by(StoredConversationMetadata.last_updated_at.desc())
|
||||
elif sort_order == AppConversationSortOrder.TITLE:
|
||||
query = query.order_by(StoredConversationMetadata.title)
|
||||
elif sort_order == AppConversationSortOrder.TITLE_DESC:
|
||||
query = query.order_by(StoredConversationMetadata.title.desc())
|
||||
|
||||
# Apply pagination
|
||||
if page_id is not None:
|
||||
try:
|
||||
offset = int(page_id)
|
||||
query = query.offset(offset)
|
||||
except ValueError:
|
||||
# If page_id is not a valid integer, start from beginning
|
||||
offset = 0
|
||||
else:
|
||||
offset = 0
|
||||
|
||||
# Apply limit and get one extra to check if there are more results
|
||||
query = query.limit(limit + 1)
|
||||
|
||||
result = await self.db_session.execute(query)
|
||||
rows = result.all()
|
||||
|
||||
# Check if there are more results
|
||||
has_more = len(rows) > limit
|
||||
if has_more:
|
||||
rows = rows[:limit]
|
||||
|
||||
items = [
|
||||
self._to_info_with_user_id(stored_metadata, saas_metadata)
|
||||
for stored_metadata, saas_metadata in rows
|
||||
]
|
||||
|
||||
# Calculate next page ID
|
||||
next_page_id = None
|
||||
if has_more:
|
||||
next_page_id = str(offset + limit)
|
||||
|
||||
return AppConversationInfoPage(items=items, next_page_id=next_page_id)
|
||||
|
||||
async def count_app_conversation_info(
|
||||
self,
|
||||
title__contains: str | None = None,
|
||||
created_at__gte: datetime | None = None,
|
||||
created_at__lt: datetime | None = None,
|
||||
updated_at__gte: datetime | None = None,
|
||||
updated_at__lt: datetime | None = None,
|
||||
) -> int:
|
||||
"""Count conversations matching the given filters with SAAS metadata."""
|
||||
query = (
|
||||
select(func.count(StoredConversationMetadata.conversation_id))
|
||||
.select_from(
|
||||
StoredConversationMetadata.join(
|
||||
StoredConversationMetadataSaas,
|
||||
StoredConversationMetadata.conversation_id
|
||||
== StoredConversationMetadataSaas.conversation_id,
|
||||
)
|
||||
)
|
||||
.where(StoredConversationMetadata.conversation_version == 'V1')
|
||||
)
|
||||
|
||||
# Apply user filtering
|
||||
user_id_str = await self.user_context.get_user_id()
|
||||
if user_id_str:
|
||||
user_id_uuid = UUID(user_id_str)
|
||||
query = query.where(StoredConversationMetadataSaas.user_id == user_id_uuid)
|
||||
|
||||
query = self._apply_filters_with_saas_metadata(
|
||||
query=query,
|
||||
title__contains=title__contains,
|
||||
created_at__gte=created_at__gte,
|
||||
created_at__lt=created_at__lt,
|
||||
updated_at__gte=updated_at__gte,
|
||||
updated_at__lt=updated_at__lt,
|
||||
)
|
||||
|
||||
result = await self.db_session.execute(query)
|
||||
count = result.scalar()
|
||||
return count or 0
|
||||
|
||||
def _apply_filters_with_saas_metadata(
|
||||
self,
|
||||
query,
|
||||
title__contains: str | None = None,
|
||||
created_at__gte: datetime | None = None,
|
||||
created_at__lt: datetime | None = None,
|
||||
updated_at__gte: datetime | None = None,
|
||||
updated_at__lt: datetime | None = None,
|
||||
):
|
||||
"""Apply filters to query that includes SAAS metadata."""
|
||||
# Apply the same filters as the base class
|
||||
conditions = []
|
||||
if title__contains is not None:
|
||||
conditions.append(
|
||||
StoredConversationMetadata.title.like(f'%{title__contains}%')
|
||||
)
|
||||
|
||||
if created_at__gte is not None:
|
||||
conditions.append(StoredConversationMetadata.created_at >= created_at__gte)
|
||||
|
||||
if created_at__lt is not None:
|
||||
conditions.append(StoredConversationMetadata.created_at < created_at__lt)
|
||||
|
||||
if updated_at__gte is not None:
|
||||
conditions.append(
|
||||
StoredConversationMetadata.last_updated_at >= updated_at__gte
|
||||
)
|
||||
|
||||
if updated_at__lt is not None:
|
||||
conditions.append(
|
||||
StoredConversationMetadata.last_updated_at < updated_at__lt
|
||||
)
|
||||
|
||||
if conditions:
|
||||
query = query.where(*conditions)
|
||||
return query
|
||||
|
||||
async def get_app_conversation_info(
|
||||
self, conversation_id: UUID
|
||||
) -> AppConversationInfo | None:
|
||||
"""Get conversation info with user_id from SAAS metadata."""
|
||||
query = await self._secure_select_with_saas_metadata()
|
||||
query = query.where(
|
||||
StoredConversationMetadata.conversation_id == str(conversation_id)
|
||||
)
|
||||
result_set = await self.db_session.execute(query)
|
||||
result = result_set.first()
|
||||
if result:
|
||||
stored_metadata, saas_metadata = result
|
||||
return self._to_info_with_user_id(stored_metadata, saas_metadata)
|
||||
return None
|
||||
|
||||
async def batch_get_app_conversation_info(
|
||||
self, conversation_ids: list[UUID]
|
||||
) -> list[AppConversationInfo | None]:
|
||||
"""Batch get conversation info with user_id from SAAS metadata."""
|
||||
conversation_id_strs = [
|
||||
str(conversation_id) for conversation_id in conversation_ids
|
||||
]
|
||||
query = await self._secure_select_with_saas_metadata()
|
||||
query = query.where(
|
||||
StoredConversationMetadata.conversation_id.in_(conversation_id_strs)
|
||||
)
|
||||
result = await self.db_session.execute(query)
|
||||
rows = result.all()
|
||||
|
||||
# Create a mapping of conversation_id to (metadata, saas_metadata)
|
||||
info_by_id = {}
|
||||
for stored_metadata, saas_metadata in rows:
|
||||
info_by_id[stored_metadata.conversation_id] = (
|
||||
stored_metadata,
|
||||
saas_metadata,
|
||||
)
|
||||
|
||||
results: list[AppConversationInfo | None] = []
|
||||
for conversation_id in conversation_id_strs:
|
||||
if conversation_id in info_by_id:
|
||||
stored_metadata, saas_metadata = info_by_id[conversation_id]
|
||||
results.append(
|
||||
self._to_info_with_user_id(stored_metadata, saas_metadata)
|
||||
)
|
||||
else:
|
||||
results.append(None)
|
||||
|
||||
return results
|
||||
|
||||
async def save_app_conversation_info(
|
||||
self, info: AppConversationInfo
|
||||
) -> AppConversationInfo:
|
||||
"""Save conversation info and create/update SAAS metadata with user_id and org_id."""
|
||||
# Save the base conversation metadata
|
||||
await super().save_app_conversation_info(info)
|
||||
|
||||
# Get current user_id for SAAS metadata
|
||||
user_id_str = await self.user_context.get_user_id()
|
||||
if user_id_str:
|
||||
# Convert string user_id to UUID
|
||||
user_id_uuid = UUID(user_id_str)
|
||||
user_query = select(User).where(User.id == user_id_uuid)
|
||||
result = await self.db_session.execute(user_query)
|
||||
user = result.scalar_one_or_none()
|
||||
assert user
|
||||
|
||||
# Check if SAAS metadata already exists
|
||||
saas_query = select(StoredConversationMetadataSaas).where(
|
||||
StoredConversationMetadataSaas.conversation_id == str(info.id)
|
||||
)
|
||||
result = await self.db_session.execute(saas_query)
|
||||
existing_saas_metadata = result.scalar_one_or_none()
|
||||
assert existing_saas_metadata is None or (
|
||||
existing_saas_metadata.user_id == user_id_uuid
|
||||
and existing_saas_metadata.org_id == user.current_org_id
|
||||
)
|
||||
|
||||
if not existing_saas_metadata:
|
||||
# Create new SAAS metadata
|
||||
# Set org_id to user_id as specified in requirements
|
||||
saas_metadata = StoredConversationMetadataSaas(
|
||||
conversation_id=str(info.id),
|
||||
user_id=user_id_uuid,
|
||||
org_id=user.current_org_id,
|
||||
)
|
||||
self.db_session.add(saas_metadata)
|
||||
|
||||
await self.db_session.commit()
|
||||
|
||||
return info
|
||||
|
||||
def _to_info_with_user_id(
|
||||
self,
|
||||
stored: StoredConversationMetadata,
|
||||
saas_metadata: StoredConversationMetadataSaas,
|
||||
) -> AppConversationInfo:
|
||||
"""Convert stored metadata to AppConversationInfo with user_id from SAAS metadata."""
|
||||
# Use the base _to_info method to get the basic info
|
||||
info = self._to_info(stored)
|
||||
|
||||
# Override the created_by_user_id with the user_id from SAAS metadata
|
||||
info.created_by_user_id = (
|
||||
str(saas_metadata.user_id) if saas_metadata.user_id else None
|
||||
)
|
||||
|
||||
return info
|
||||
|
||||
|
||||
class SaasAppConversationInfoServiceInjector(AppConversationInfoServiceInjector):
|
||||
"""Enterprise injector for SQLAppConversationInfoService with SAAS filtering."""
|
||||
|
||||
async def inject(
|
||||
self, state: InjectorState, request: Request | None = None
|
||||
) -> AsyncGenerator[AppConversationInfoService, None]:
|
||||
from openhands.app_server.config import (
|
||||
get_db_session,
|
||||
get_user_context,
|
||||
)
|
||||
|
||||
async with (
|
||||
get_user_context(state, request) as user_context,
|
||||
get_db_session(state, request) as db_session,
|
||||
):
|
||||
service = SaasSQLAppConversationInfoService(
|
||||
db_session=db_session, user_context=user_context
|
||||
)
|
||||
yield service
|
||||
@@ -1,85 +0,0 @@
|
||||
from storage.api_key import ApiKey
|
||||
from storage.auth_tokens import AuthTokens
|
||||
from storage.billing_session import BillingSession
|
||||
from storage.billing_session_type import BillingSessionType
|
||||
from storage.conversation_callback import CallbackStatus, ConversationCallback
|
||||
from storage.conversation_work import ConversationWork
|
||||
from storage.experiment_assignment import ExperimentAssignment
|
||||
from storage.feedback import ConversationFeedback, Feedback
|
||||
from storage.github_app_installation import GithubAppInstallation
|
||||
from storage.gitlab_webhook import GitlabWebhook, WebhookStatus
|
||||
from storage.jira_conversation import JiraConversation
|
||||
from storage.jira_dc_conversation import JiraDcConversation
|
||||
from storage.jira_dc_user import JiraDcUser
|
||||
from storage.jira_dc_workspace import JiraDcWorkspace
|
||||
from storage.jira_user import JiraUser
|
||||
from storage.jira_workspace import JiraWorkspace
|
||||
from storage.linear_conversation import LinearConversation
|
||||
from storage.linear_user import LinearUser
|
||||
from storage.linear_workspace import LinearWorkspace
|
||||
from storage.maintenance_task import MaintenanceTask, MaintenanceTaskStatus
|
||||
from storage.openhands_pr import OpenhandsPR
|
||||
from storage.org import Org
|
||||
from storage.org_member import OrgMember
|
||||
from storage.proactive_convos import ProactiveConversation
|
||||
from storage.role import Role
|
||||
from storage.slack_conversation import SlackConversation
|
||||
from storage.slack_team import SlackTeam
|
||||
from storage.slack_user import SlackUser
|
||||
from storage.stored_conversation_metadata import StoredConversationMetadata
|
||||
from storage.stored_conversation_metadata_saas import StoredConversationMetadataSaas
|
||||
from storage.stored_custom_secrets import StoredCustomSecrets
|
||||
from storage.stored_offline_token import StoredOfflineToken
|
||||
from storage.stored_repository import StoredRepository
|
||||
from storage.stripe_customer import StripeCustomer
|
||||
from storage.subscription_access import SubscriptionAccess
|
||||
from storage.subscription_access_status import SubscriptionAccessStatus
|
||||
from storage.user import User
|
||||
from storage.user_repo_map import UserRepositoryMap
|
||||
from storage.user_settings import UserSettings
|
||||
|
||||
__all__ = [
|
||||
'ApiKey',
|
||||
'AuthTokens',
|
||||
'BillingSession',
|
||||
'BillingSessionType',
|
||||
'CallbackStatus',
|
||||
'ConversationCallback',
|
||||
'ConversationFeedback',
|
||||
'StoredConversationMetadataSaas',
|
||||
'ConversationWork',
|
||||
'ExperimentAssignment',
|
||||
'Feedback',
|
||||
'GithubAppInstallation',
|
||||
'GitlabWebhook',
|
||||
'JiraConversation',
|
||||
'JiraDcConversation',
|
||||
'JiraDcUser',
|
||||
'JiraDcWorkspace',
|
||||
'JiraUser',
|
||||
'JiraWorkspace',
|
||||
'LinearConversation',
|
||||
'LinearUser',
|
||||
'LinearWorkspace',
|
||||
'MaintenanceTask',
|
||||
'MaintenanceTaskStatus',
|
||||
'OpenhandsPR',
|
||||
'Org',
|
||||
'OrgMember',
|
||||
'ProactiveConversation',
|
||||
'Role',
|
||||
'SlackConversation',
|
||||
'SlackTeam',
|
||||
'SlackUser',
|
||||
'StoredConversationMetadata',
|
||||
'StoredOfflineToken',
|
||||
'StoredRepository',
|
||||
'StoredCustomSecrets',
|
||||
'StripeCustomer',
|
||||
'SubscriptionAccess',
|
||||
'SubscriptionAccessStatus',
|
||||
'User',
|
||||
'UserRepositoryMap',
|
||||
'UserSettings',
|
||||
'WebhookStatus',
|
||||
]
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
from sqlalchemy import Column, DateTime, ForeignKey, Integer, String, text
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy import Column, DateTime, Integer, String, text
|
||||
from storage.base import Base
|
||||
|
||||
|
||||
@@ -13,13 +11,9 @@ class ApiKey(Base):
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
key = Column(String(255), nullable=False, unique=True, index=True)
|
||||
user_id = Column(String(255), nullable=False, index=True)
|
||||
org_id = Column(UUID(as_uuid=True), ForeignKey('org.id'), nullable=True)
|
||||
name = Column(String(255), nullable=True)
|
||||
created_at = Column(
|
||||
DateTime, server_default=text('CURRENT_TIMESTAMP'), nullable=False
|
||||
)
|
||||
last_used_at = Column(DateTime, nullable=True)
|
||||
expires_at = Column(DateTime, nullable=True)
|
||||
|
||||
# Relationships
|
||||
org = relationship('Org', back_populates='api_keys')
|
||||
|
||||
@@ -9,7 +9,6 @@ from sqlalchemy import update
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from storage.api_key import ApiKey
|
||||
from storage.database import session_maker
|
||||
from storage.user_store import UserStore
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
|
||||
@@ -40,15 +39,10 @@ class ApiKeyStore:
|
||||
The generated API key
|
||||
"""
|
||||
api_key = self.generate_api_key()
|
||||
user = UserStore.get_user_by_id(user_id)
|
||||
org_id = user.current_org_id
|
||||
|
||||
with self.session_maker() as session:
|
||||
key_record = ApiKey(
|
||||
key=api_key,
|
||||
user_id=user_id,
|
||||
org_id=org_id,
|
||||
name=name,
|
||||
expires_at=expires_at,
|
||||
key=api_key, user_id=user_id, name=name, expires_at=expires_at
|
||||
)
|
||||
session.add(key_record)
|
||||
session.commit()
|
||||
@@ -114,15 +108,8 @@ class ApiKeyStore:
|
||||
|
||||
def list_api_keys(self, user_id: str) -> list[dict]:
|
||||
"""List all API keys for a user."""
|
||||
user = UserStore.get_user_by_id(user_id)
|
||||
org_id = user.current_org_id
|
||||
with self.session_maker() as session:
|
||||
keys = (
|
||||
session.query(ApiKey)
|
||||
.filter(ApiKey.user_id == user_id)
|
||||
.filter(ApiKey.org_id == org_id)
|
||||
.all()
|
||||
)
|
||||
keys = session.query(ApiKey).filter(ApiKey.user_id == user_id).all()
|
||||
|
||||
return [
|
||||
{
|
||||
@@ -137,14 +124,9 @@ class ApiKeyStore:
|
||||
]
|
||||
|
||||
def retrieve_mcp_api_key(self, user_id: str) -> str | None:
|
||||
user = UserStore.get_user_by_id(user_id)
|
||||
org_id = user.current_org_id
|
||||
with self.session_maker() as session:
|
||||
keys: list[ApiKey] = (
|
||||
session.query(ApiKey)
|
||||
.filter(ApiKey.user_id == user_id)
|
||||
.filter(ApiKey.org_id == org_id)
|
||||
.all()
|
||||
session.query(ApiKey).filter(ApiKey.user_id == user_id).all()
|
||||
)
|
||||
for key in keys:
|
||||
if key.name == 'MCP_API_KEY':
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from sqlalchemy import DECIMAL, Column, DateTime, Enum, ForeignKey, String
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy import DECIMAL, Column, DateTime, Enum, String
|
||||
from storage.base import Base
|
||||
|
||||
|
||||
@@ -13,9 +11,9 @@ class BillingSession(Base): # type: ignore
|
||||
"""
|
||||
|
||||
__tablename__ = 'billing_sessions'
|
||||
|
||||
id = Column(String, primary_key=True)
|
||||
user_id = Column(String, nullable=False)
|
||||
org_id = Column(UUID(as_uuid=True), ForeignKey('org.id'), nullable=True)
|
||||
status = Column(
|
||||
Enum(
|
||||
'in_progress',
|
||||
@@ -26,6 +24,15 @@ class BillingSession(Base): # type: ignore
|
||||
),
|
||||
default='in_progress',
|
||||
)
|
||||
billing_session_type = Column(
|
||||
Enum(
|
||||
'DIRECT_PAYMENT',
|
||||
'MONTHLY_SUBSCRIPTION',
|
||||
name='billing_session_type_enum',
|
||||
),
|
||||
nullable=False,
|
||||
default='DIRECT_PAYMENT',
|
||||
)
|
||||
price = Column(DECIMAL(19, 4), nullable=False)
|
||||
price_code = Column(String, nullable=False)
|
||||
created_at = Column(
|
||||
@@ -36,6 +43,3 @@ class BillingSession(Base): # type: ignore
|
||||
DateTime(timezone=True),
|
||||
default=lambda: datetime.now(UTC), # type: ignore[attr-defined]
|
||||
)
|
||||
|
||||
# Relationships
|
||||
org = relationship('Org', back_populates='billing_sessions')
|
||||
|
||||
@@ -1,10 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from openhands.events.observation.agent import AgentStateChangedObservation
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
@@ -15,6 +10,7 @@ from sqlalchemy import Column, DateTime, ForeignKey, Integer, String, Text, text
|
||||
from sqlalchemy import Enum as SQLEnum
|
||||
from storage.base import Base
|
||||
|
||||
from openhands.events.observation.agent import AgentStateChangedObservation
|
||||
from openhands.utils.import_utils import get_impl
|
||||
|
||||
|
||||
@@ -37,7 +33,7 @@ class ConversationCallbackProcessor(BaseModel, ABC):
|
||||
async def __call__(
|
||||
self,
|
||||
callback: ConversationCallback,
|
||||
observation: 'AgentStateChangedObservation',
|
||||
observation: AgentStateChangedObservation,
|
||||
) -> None:
|
||||
"""
|
||||
Process a conversation event.
|
||||
|
||||
@@ -1,38 +1,122 @@
|
||||
"""
|
||||
Database connection module for enterprise storage.
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
This is for backwards compatibility with V0.
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.pool import NullPool
|
||||
from sqlalchemy.util import await_only
|
||||
|
||||
This module provides database engines and session makers by delegating to the
|
||||
centralized DbSessionInjector from app_server/config.py. This ensures a single
|
||||
source of truth for database connection configuration.
|
||||
"""
|
||||
DB_HOST = os.environ.get('DB_HOST', 'localhost') # for non-GCP environments
|
||||
DB_PORT = os.environ.get('DB_PORT', '5432') # for non-GCP environments
|
||||
DB_USER = os.environ.get('DB_USER', 'postgres')
|
||||
DB_PASS = os.environ.get('DB_PASS', 'postgres').strip()
|
||||
DB_NAME = os.environ.get('DB_NAME', 'openhands')
|
||||
|
||||
import contextlib
|
||||
GCP_DB_INSTANCE = os.environ.get('GCP_DB_INSTANCE') # for GCP environments
|
||||
GCP_PROJECT = os.environ.get('GCP_PROJECT')
|
||||
GCP_REGION = os.environ.get('GCP_REGION')
|
||||
|
||||
POOL_SIZE = int(os.environ.get('DB_POOL_SIZE', '25'))
|
||||
MAX_OVERFLOW = int(os.environ.get('DB_MAX_OVERFLOW', '10'))
|
||||
POOL_RECYCLE = int(os.environ.get('DB_POOL_RECYCLE', '1800'))
|
||||
|
||||
# Initialize Cloud SQL Connector once at module level for GCP environments.
|
||||
_connector = None
|
||||
|
||||
|
||||
def _get_db_session_injector():
|
||||
from openhands.app_server.config import get_global_config
|
||||
def _get_db_engine():
|
||||
if GCP_DB_INSTANCE: # GCP environments
|
||||
|
||||
_config = get_global_config()
|
||||
return _config.db_session
|
||||
def get_db_connection():
|
||||
global _connector
|
||||
from google.cloud.sql.connector import Connector
|
||||
|
||||
if not _connector:
|
||||
_connector = Connector()
|
||||
instance_string = f'{GCP_PROJECT}:{GCP_REGION}:{GCP_DB_INSTANCE}'
|
||||
return _connector.connect(
|
||||
instance_string, 'pg8000', user=DB_USER, password=DB_PASS, db=DB_NAME
|
||||
)
|
||||
|
||||
return create_engine(
|
||||
'postgresql+pg8000://',
|
||||
creator=get_db_connection,
|
||||
pool_size=POOL_SIZE,
|
||||
max_overflow=MAX_OVERFLOW,
|
||||
pool_recycle=POOL_RECYCLE,
|
||||
pool_pre_ping=True,
|
||||
)
|
||||
else:
|
||||
host_string = (
|
||||
f'postgresql+pg8000://{DB_USER}:{DB_PASS}@{DB_HOST}:{DB_PORT}/{DB_NAME}'
|
||||
)
|
||||
return create_engine(
|
||||
host_string,
|
||||
pool_size=POOL_SIZE,
|
||||
max_overflow=MAX_OVERFLOW,
|
||||
pool_recycle=POOL_RECYCLE,
|
||||
pool_pre_ping=True,
|
||||
)
|
||||
|
||||
|
||||
def session_maker():
|
||||
db_session_injector = _get_db_session_injector()
|
||||
session_maker = db_session_injector.get_session_maker()
|
||||
return session_maker()
|
||||
async def async_creator():
|
||||
from google.cloud.sql.connector import Connector
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
async with Connector(loop=loop) as connector:
|
||||
conn = await connector.connect_async(
|
||||
f'{GCP_PROJECT}:{GCP_REGION}:{GCP_DB_INSTANCE}', # Cloud SQL instance connection name"
|
||||
'asyncpg',
|
||||
user=DB_USER,
|
||||
password=DB_PASS,
|
||||
db=DB_NAME,
|
||||
)
|
||||
return conn
|
||||
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def a_session_maker():
|
||||
db_session_injector = _get_db_session_injector()
|
||||
a_session_maker = await db_session_injector.get_async_session_maker()
|
||||
async with a_session_maker() as session:
|
||||
yield session
|
||||
def _get_async_db_engine():
|
||||
if GCP_DB_INSTANCE: # GCP environments
|
||||
|
||||
def adapted_creator():
|
||||
dbapi = engine.dialect.dbapi
|
||||
from sqlalchemy.dialects.postgresql.asyncpg import (
|
||||
AsyncAdapt_asyncpg_connection,
|
||||
)
|
||||
|
||||
return AsyncAdapt_asyncpg_connection(
|
||||
dbapi,
|
||||
await_only(async_creator()),
|
||||
prepared_statement_cache_size=100,
|
||||
)
|
||||
|
||||
# create async connection pool with wrapped creator
|
||||
return create_async_engine(
|
||||
'postgresql+asyncpg://',
|
||||
creator=adapted_creator,
|
||||
# Use NullPool to disable connection pooling and avoid event loop issues
|
||||
poolclass=NullPool,
|
||||
)
|
||||
else:
|
||||
host_string = (
|
||||
f'postgresql+asyncpg://{DB_USER}:{DB_PASS}@{DB_HOST}:{DB_PORT}/{DB_NAME}'
|
||||
)
|
||||
return create_async_engine(
|
||||
host_string,
|
||||
# Use NullPool to disable connection pooling and avoid event loop issues
|
||||
poolclass=NullPool,
|
||||
)
|
||||
|
||||
|
||||
def get_engine():
|
||||
db_session_injector = _get_db_session_injector()
|
||||
engine = db_session_injector.get_db_engine()
|
||||
return engine
|
||||
engine = _get_db_engine()
|
||||
session_maker = sessionmaker(bind=engine)
|
||||
|
||||
a_engine = _get_async_db_engine()
|
||||
a_session_maker = sessionmaker(
|
||||
bind=a_engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
# Configure the session to use the same connection for all operations in a transaction
|
||||
# This helps prevent the "Task got Future attached to a different loop" error
|
||||
future=True,
|
||||
)
|
||||
|
||||
@@ -1,137 +0,0 @@
|
||||
import binascii
|
||||
import hashlib
|
||||
from base64 import b64decode, b64encode
|
||||
|
||||
from cryptography.fernet import Fernet, InvalidToken
|
||||
from pydantic import SecretStr
|
||||
from server.config import get_config
|
||||
|
||||
_jwt_service = None
|
||||
_fernet = None
|
||||
|
||||
|
||||
def encrypt_model(encrypt_keys: list, model_instance) -> dict:
|
||||
return encrypt_kwargs(encrypt_keys, model_to_kwargs(model_instance))
|
||||
|
||||
|
||||
def decrypt_model(decrypt_keys: list, model_instance) -> dict:
|
||||
return decrypt_kwargs(decrypt_keys, model_to_kwargs(model_instance))
|
||||
|
||||
|
||||
def encrypt_kwargs(encrypt_keys: list, kwargs: dict) -> dict:
|
||||
for key, value in kwargs.items():
|
||||
if value is None:
|
||||
continue
|
||||
|
||||
if isinstance(value, dict):
|
||||
encrypt_kwargs(encrypt_keys, value)
|
||||
continue
|
||||
|
||||
if key in encrypt_keys:
|
||||
value = encrypt_value(value)
|
||||
kwargs[key] = value
|
||||
return kwargs
|
||||
|
||||
|
||||
def decrypt_kwargs(encrypt_keys: list, kwargs: dict) -> dict:
|
||||
for key, value in kwargs.items():
|
||||
try:
|
||||
if value is None:
|
||||
continue
|
||||
if key in encrypt_keys:
|
||||
value = decrypt_value(value)
|
||||
kwargs[key] = value
|
||||
except binascii.Error:
|
||||
pass # Key is in legacy format...
|
||||
return kwargs
|
||||
|
||||
|
||||
def encrypt_value(value: str | SecretStr) -> str:
|
||||
return get_jwt_service().create_jwe_token(
|
||||
{'v': value.get_secret_value() if isinstance(value, SecretStr) else value}
|
||||
)
|
||||
|
||||
|
||||
def decrypt_value(value: str | SecretStr) -> str:
|
||||
token = get_jwt_service().decrypt_jwe_token(
|
||||
value.get_secret_value() if isinstance(value, SecretStr) else value
|
||||
)
|
||||
return token['v']
|
||||
|
||||
|
||||
def get_jwt_service():
|
||||
from openhands.app_server.config import get_global_config
|
||||
|
||||
global _jwt_service
|
||||
if _jwt_service is None:
|
||||
jwt_service_injector = get_global_config().jwt
|
||||
assert jwt_service_injector is not None
|
||||
_jwt_service = jwt_service_injector.get_jwt_service()
|
||||
return _jwt_service
|
||||
|
||||
|
||||
def decrypt_legacy_model(decrypt_keys: list, model_instance) -> dict:
|
||||
return decrypt_legacy_kwargs(decrypt_keys, model_to_kwargs(model_instance))
|
||||
|
||||
|
||||
def decrypt_legacy_kwargs(encrypt_keys: list, kwargs: dict) -> dict:
|
||||
for key, value in kwargs.items():
|
||||
try:
|
||||
if value is None:
|
||||
continue
|
||||
if key in encrypt_keys:
|
||||
value = decrypt_legacy_value(value)
|
||||
kwargs[key] = value
|
||||
except binascii.Error:
|
||||
pass # Key is in legacy format...
|
||||
except InvalidToken:
|
||||
pass # Key not encrypted...
|
||||
return kwargs
|
||||
|
||||
|
||||
def decrypt_legacy_value(value: str | SecretStr) -> str:
|
||||
if isinstance(value, SecretStr):
|
||||
return (
|
||||
get_fernet().decrypt(b64decode(value.get_secret_value().encode())).decode()
|
||||
)
|
||||
else:
|
||||
return get_fernet().decrypt(b64decode(value.encode())).decode()
|
||||
|
||||
|
||||
def encrypt_legacy_model(encrypt_keys: list, model_instance) -> dict:
|
||||
return encrypt_legacy_kwargs(encrypt_keys, model_to_kwargs(model_instance))
|
||||
|
||||
|
||||
def encrypt_legacy_kwargs(encrypt_keys: list, kwargs: dict) -> dict:
|
||||
for key, value in kwargs.items():
|
||||
if value is None:
|
||||
continue
|
||||
if key in encrypt_keys:
|
||||
value = encrypt_legacy_value(value)
|
||||
kwargs[key] = value
|
||||
return kwargs
|
||||
|
||||
|
||||
def encrypt_legacy_value(value: str | SecretStr) -> str:
|
||||
if isinstance(value, SecretStr):
|
||||
return b64encode(
|
||||
get_fernet().encrypt(value.get_secret_value().encode())
|
||||
).decode()
|
||||
else:
|
||||
return b64encode(get_fernet().encrypt(value.encode())).decode()
|
||||
|
||||
|
||||
def get_fernet():
|
||||
global _fernet
|
||||
if _fernet is None:
|
||||
jwt_secret = get_config().jwt_secret.get_secret_value()
|
||||
fernet_key = b64encode(hashlib.sha256(jwt_secret.encode()).digest())
|
||||
_fernet = Fernet(fernet_key)
|
||||
return _fernet
|
||||
|
||||
|
||||
def model_to_kwargs(model_instance):
|
||||
return {
|
||||
column.name: getattr(model_instance, column.name)
|
||||
for column in model_instance.__table__.columns
|
||||
}
|
||||
@@ -1,16 +1,7 @@
|
||||
import sys
|
||||
from enum import IntEnum
|
||||
|
||||
from sqlalchemy import (
|
||||
ARRAY,
|
||||
Boolean,
|
||||
Column,
|
||||
DateTime,
|
||||
Integer,
|
||||
String,
|
||||
Text,
|
||||
text,
|
||||
)
|
||||
from sqlalchemy import ARRAY, Boolean, Column, DateTime, Integer, String, Text, text
|
||||
from storage.base import Base
|
||||
|
||||
|
||||
|
||||
@@ -118,7 +118,9 @@ class JiraIntegrationStore:
|
||||
.first()
|
||||
)
|
||||
|
||||
async def get_workspace_by_name(self, workspace_name: str) -> JiraWorkspace | None:
|
||||
async def get_workspace_by_name(
|
||||
self, workspace_name: str
|
||||
) -> Optional[JiraWorkspace]:
|
||||
"""Retrieve workspace by name."""
|
||||
with session_maker() as session:
|
||||
return (
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,100 +0,0 @@
|
||||
"""
|
||||
SQLAlchemy model for Organization.
|
||||
"""
|
||||
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import SecretStr
|
||||
from server.constants import DEFAULT_BILLING_MARGIN
|
||||
from sqlalchemy import JSON, UUID, Boolean, Column, Float, Integer, String
|
||||
from sqlalchemy.orm import relationship
|
||||
from storage.base import Base
|
||||
from storage.encrypt_utils import decrypt_value, encrypt_value
|
||||
|
||||
|
||||
class Org(Base): # type: ignore
|
||||
"""Organization model."""
|
||||
|
||||
__tablename__ = 'org'
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid4)
|
||||
name = Column(String, nullable=False, unique=True)
|
||||
contact_name = Column(String, nullable=True)
|
||||
contact_email = Column(String, nullable=True)
|
||||
agent = Column(String, nullable=True)
|
||||
default_max_iterations = Column(Integer, nullable=True)
|
||||
security_analyzer = Column(String, nullable=True)
|
||||
confirmation_mode = Column(Boolean, nullable=True, default=False)
|
||||
default_llm_model = Column(String, nullable=True)
|
||||
default_llm_base_url = Column(String, nullable=True)
|
||||
remote_runtime_resource_factor = Column(Integer, nullable=True)
|
||||
enable_default_condenser = Column(Boolean, nullable=False, default=True)
|
||||
billing_margin = Column(Float, nullable=True, default=DEFAULT_BILLING_MARGIN)
|
||||
enable_proactive_conversation_starters = Column(
|
||||
Boolean, nullable=False, default=True
|
||||
)
|
||||
sandbox_base_container_image = Column(String, nullable=True)
|
||||
sandbox_runtime_container_image = Column(String, nullable=True)
|
||||
org_version = Column(Integer, nullable=False, default=0)
|
||||
mcp_config = Column(JSON, nullable=True)
|
||||
# encrypted column, don't set directly, set without the underscore
|
||||
_search_api_key = Column(String, nullable=True)
|
||||
# encrypted column, don't set directly, set without the underscore
|
||||
_sandbox_api_key = Column(String, nullable=True)
|
||||
max_budget_per_task = Column(Float, nullable=True)
|
||||
enable_solvability_analysis = Column(Boolean, nullable=True, default=False)
|
||||
v1_enabled = Column(Boolean, nullable=True)
|
||||
conversation_expiration = Column(Integer, nullable=True)
|
||||
condenser_max_size = Column(Integer, nullable=True)
|
||||
|
||||
# Relationships
|
||||
org_members = relationship('OrgMember', back_populates='org')
|
||||
current_users = relationship('User', back_populates='current_org')
|
||||
billing_sessions = relationship('BillingSession', back_populates='org')
|
||||
stored_conversation_metadata_saas = relationship(
|
||||
'StoredConversationMetadataSaas', back_populates='org'
|
||||
)
|
||||
user_secrets = relationship('StoredCustomSecrets', back_populates='org')
|
||||
api_keys = relationship('ApiKey', back_populates='org')
|
||||
slack_conversations = relationship('SlackConversation', back_populates='org')
|
||||
slack_users = relationship('SlackUser', back_populates='org')
|
||||
stripe_customers = relationship('StripeCustomer', back_populates='org')
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
# Handle known SQLAlchemy columns directly
|
||||
for key in list(kwargs):
|
||||
if hasattr(self.__class__, key):
|
||||
setattr(self, key, kwargs.pop(key))
|
||||
|
||||
# Handle custom property-style fields
|
||||
if 'search_api_key' in kwargs:
|
||||
self.search_api_key = kwargs.pop('search_api_key')
|
||||
if 'sandbox_api_key' in kwargs:
|
||||
self.sandbox_api_key = kwargs.pop('sandbox_api_key')
|
||||
|
||||
if kwargs:
|
||||
raise TypeError(f'Unexpected keyword arguments: {list(kwargs.keys())}')
|
||||
|
||||
@property
|
||||
def search_api_key(self) -> SecretStr | None:
|
||||
if self._search_api_key:
|
||||
decrypted = decrypt_value(self._search_api_key)
|
||||
return SecretStr(decrypted)
|
||||
return None
|
||||
|
||||
@search_api_key.setter
|
||||
def search_api_key(self, value: str | SecretStr | None):
|
||||
raw = value.get_secret_value() if isinstance(value, SecretStr) else value
|
||||
self._search_api_key = encrypt_value(raw) if raw else None
|
||||
|
||||
@property
|
||||
def sandbox_api_key(self) -> SecretStr | None:
|
||||
if self._sandbox_api_key:
|
||||
decrypted = decrypt_value(self._sandbox_api_key)
|
||||
return SecretStr(decrypted)
|
||||
return None
|
||||
|
||||
@sandbox_api_key.setter
|
||||
def sandbox_api_key(self, value: str | SecretStr | None):
|
||||
raw = value.get_secret_value() if isinstance(value, SecretStr) else value
|
||||
self._sandbox_api_key = encrypt_value(raw) if raw else None
|
||||
@@ -1,67 +0,0 @@
|
||||
"""
|
||||
SQLAlchemy model for Organization-Member relationship.
|
||||
"""
|
||||
|
||||
from pydantic import SecretStr
|
||||
from sqlalchemy import UUID, Column, ForeignKey, Integer, String
|
||||
from sqlalchemy.orm import relationship
|
||||
from storage.base import Base
|
||||
from storage.encrypt_utils import decrypt_value, encrypt_value
|
||||
|
||||
|
||||
class OrgMember(Base): # type: ignore
|
||||
"""Junction table for organization-member relationships with roles."""
|
||||
|
||||
__tablename__ = 'org_member'
|
||||
|
||||
org_id = Column(UUID(as_uuid=True), ForeignKey('org.id'), primary_key=True)
|
||||
user_id = Column(UUID(as_uuid=True), ForeignKey('user.id'), primary_key=True)
|
||||
role_id = Column(Integer, ForeignKey('role.id'), nullable=False)
|
||||
_llm_api_key = Column(String, nullable=False)
|
||||
max_iterations = Column(Integer, nullable=True)
|
||||
llm_model = Column(String, nullable=True)
|
||||
_llm_api_key_for_byor = Column(String, nullable=True)
|
||||
llm_base_url = Column(String, nullable=True)
|
||||
status = Column(String, nullable=True)
|
||||
|
||||
# Relationships
|
||||
org = relationship('Org', back_populates='org_members')
|
||||
user = relationship('User', back_populates='org_members')
|
||||
role = relationship('Role', back_populates='org_members')
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
# Handle known SQLAlchemy columns directly
|
||||
for key in list(kwargs):
|
||||
if hasattr(self.__class__, key):
|
||||
setattr(self, key, kwargs.pop(key))
|
||||
|
||||
# Handle custom property-style fields
|
||||
if 'llm_api_key' in kwargs:
|
||||
self.llm_api_key = kwargs.pop('llm_api_key')
|
||||
if 'llm_api_key_for_byor' in kwargs:
|
||||
self.llm_api_key_for_byor = kwargs.pop('llm_api_key_for_byor')
|
||||
|
||||
if kwargs:
|
||||
raise TypeError(f'Unexpected keyword arguments: {list(kwargs.keys())}')
|
||||
|
||||
@property
|
||||
def llm_api_key(self) -> SecretStr:
|
||||
decrypted = decrypt_value(self._llm_api_key)
|
||||
return SecretStr(decrypted)
|
||||
|
||||
@llm_api_key.setter
|
||||
def llm_api_key(self, value: str | SecretStr):
|
||||
raw = value.get_secret_value() if isinstance(value, SecretStr) else value
|
||||
self._llm_api_key = encrypt_value(raw)
|
||||
|
||||
@property
|
||||
def llm_api_key_for_byor(self) -> SecretStr | None:
|
||||
if self._llm_api_key_for_byor:
|
||||
decrypted = decrypt_value(self._llm_api_key_for_byor)
|
||||
return SecretStr(decrypted)
|
||||
return None
|
||||
|
||||
@llm_api_key_for_byor.setter
|
||||
def llm_api_key_for_byor(self, value: str | SecretStr | None):
|
||||
raw = value.get_secret_value() if isinstance(value, SecretStr) else value
|
||||
self._llm_api_key_for_byor = encrypt_value(raw) if raw else None
|
||||
@@ -1,125 +0,0 @@
|
||||
"""
|
||||
Store class for managing organization-member relationships.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from storage.database import session_maker
|
||||
from storage.org_member import OrgMember
|
||||
from storage.user_settings import UserSettings
|
||||
|
||||
from openhands.storage.data_models.settings import Settings
|
||||
|
||||
|
||||
class OrgMemberStore:
|
||||
"""Store for managing organization-member relationships."""
|
||||
|
||||
@staticmethod
|
||||
def add_user_to_org(
|
||||
org_id: UUID,
|
||||
user_id: UUID,
|
||||
role_id: int,
|
||||
llm_api_key: str,
|
||||
status: Optional[str] = None,
|
||||
) -> OrgMember:
|
||||
"""Add a user to an organization with a specific role."""
|
||||
with session_maker() as session:
|
||||
org_member = OrgMember(
|
||||
org_id=org_id,
|
||||
user_id=user_id,
|
||||
role_id=role_id,
|
||||
llm_api_key=llm_api_key,
|
||||
status=status,
|
||||
)
|
||||
session.add(org_member)
|
||||
session.commit()
|
||||
session.refresh(org_member)
|
||||
return org_member
|
||||
|
||||
@staticmethod
|
||||
def get_org_member(org_id: UUID, user_id: int) -> Optional[OrgMember]:
|
||||
"""Get organization-user relationship."""
|
||||
with session_maker() as session:
|
||||
return (
|
||||
session.query(OrgMember)
|
||||
.filter(OrgMember.org_id == org_id, OrgMember.user_id == user_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_user_orgs(user_id: int) -> list[OrgMember]:
|
||||
"""Get all organizations for a user."""
|
||||
with session_maker() as session:
|
||||
return session.query(OrgMember).filter(OrgMember.user_id == user_id).all()
|
||||
|
||||
@staticmethod
|
||||
def get_org_members(org_id: UUID) -> list[OrgMember]:
|
||||
"""Get all users in an organization."""
|
||||
with session_maker() as session:
|
||||
return session.query(OrgMember).filter(OrgMember.org_id == org_id).all()
|
||||
|
||||
@staticmethod
|
||||
def update_org_member(org_member: OrgMember) -> None:
|
||||
"""Update an organization-member relationship."""
|
||||
with session_maker() as session:
|
||||
session.merge(org_member)
|
||||
session.commit()
|
||||
|
||||
@staticmethod
|
||||
def update_user_role_in_org(
|
||||
org_id: UUID, user_id: int, role_id: int, status: Optional[str] = None
|
||||
) -> Optional[OrgMember]:
|
||||
"""Update user's role in an organization."""
|
||||
with session_maker() as session:
|
||||
org_member = (
|
||||
session.query(OrgMember)
|
||||
.filter(OrgMember.org_id == org_id, OrgMember.user_id == user_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not org_member:
|
||||
return None
|
||||
|
||||
org_member.role_id = role_id
|
||||
if status is not None:
|
||||
org_member.status = status
|
||||
|
||||
session.commit()
|
||||
session.refresh(org_member)
|
||||
return org_member
|
||||
|
||||
@staticmethod
|
||||
def remove_user_from_org(org_id: UUID, user_id: int) -> bool:
|
||||
"""Remove a user from an organization."""
|
||||
with session_maker() as session:
|
||||
org_member = (
|
||||
session.query(OrgMember)
|
||||
.filter(OrgMember.org_id == org_id, OrgMember.user_id == user_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not org_member:
|
||||
return False
|
||||
|
||||
session.delete(org_member)
|
||||
session.commit()
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def get_kwargs_from_settings(settings: Settings):
|
||||
kwargs = {
|
||||
normalized: getattr(settings, normalized)
|
||||
for c in OrgMember.__table__.columns
|
||||
if (normalized := c.name.lstrip('_')) and hasattr(settings, normalized)
|
||||
}
|
||||
return kwargs
|
||||
|
||||
@staticmethod
|
||||
def get_kwargs_from_user_settings(user_settings: UserSettings):
|
||||
kwargs = {
|
||||
normalized: getattr(user_settings, normalized)
|
||||
for c in OrgMember.__table__.columns
|
||||
if (normalized := c.name.lstrip('_')) and hasattr(user_settings, normalized)
|
||||
}
|
||||
return kwargs
|
||||
@@ -1,844 +0,0 @@
|
||||
"""
|
||||
Service class for managing organization operations.
|
||||
Separates business logic from route handlers.
|
||||
"""
|
||||
|
||||
from uuid import UUID, uuid4
|
||||
from uuid import UUID as parse_uuid
|
||||
|
||||
from server.constants import ORG_SETTINGS_VERSION, get_default_litellm_model
|
||||
from server.routes.org_models import (
|
||||
LiteLLMIntegrationError,
|
||||
OrgAuthorizationError,
|
||||
OrgDatabaseError,
|
||||
OrgNameExistsError,
|
||||
OrgNotFoundError,
|
||||
OrgUpdate,
|
||||
)
|
||||
from storage.lite_llm_manager import LiteLlmManager
|
||||
from storage.org import Org
|
||||
from storage.org_member import OrgMember
|
||||
from storage.org_member_store import OrgMemberStore
|
||||
from storage.org_store import OrgStore
|
||||
from storage.role_store import RoleStore
|
||||
from storage.user_store import UserStore
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
|
||||
|
||||
class OrgService:
|
||||
"""Service for handling organization-related operations."""
|
||||
|
||||
@staticmethod
|
||||
def validate_name_uniqueness(name: str) -> None:
|
||||
"""
|
||||
Validate that organization name is unique.
|
||||
|
||||
Args:
|
||||
name: Organization name to validate
|
||||
|
||||
Raises:
|
||||
OrgNameExistsError: If organization name already exists
|
||||
"""
|
||||
existing_org = OrgStore.get_org_by_name(name)
|
||||
if existing_org is not None:
|
||||
raise OrgNameExistsError(name)
|
||||
|
||||
@staticmethod
|
||||
async def create_litellm_integration(org_id: UUID, user_id: str) -> dict:
|
||||
"""
|
||||
Create LiteLLM team integration for the organization.
|
||||
|
||||
Args:
|
||||
org_id: Organization ID
|
||||
user_id: User ID who will own the organization
|
||||
|
||||
Returns:
|
||||
dict: LiteLLM settings object
|
||||
|
||||
Raises:
|
||||
LiteLLMIntegrationError: If LiteLLM integration fails
|
||||
"""
|
||||
try:
|
||||
settings = await UserStore.create_default_settings(
|
||||
org_id=str(org_id), user_id=user_id, create_user=False
|
||||
)
|
||||
|
||||
if not settings:
|
||||
logger.error(
|
||||
'Failed to create LiteLLM settings',
|
||||
extra={'org_id': str(org_id), 'user_id': user_id},
|
||||
)
|
||||
raise LiteLLMIntegrationError('Failed to create LiteLLM settings')
|
||||
|
||||
logger.debug(
|
||||
'LiteLLM integration created',
|
||||
extra={'org_id': str(org_id), 'user_id': user_id},
|
||||
)
|
||||
return settings
|
||||
|
||||
except LiteLLMIntegrationError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
'Error creating LiteLLM integration',
|
||||
extra={'org_id': str(org_id), 'user_id': user_id, 'error': str(e)},
|
||||
)
|
||||
raise LiteLLMIntegrationError(f'LiteLLM integration failed: {str(e)}')
|
||||
|
||||
@staticmethod
|
||||
def create_org_entity(
|
||||
org_id: UUID,
|
||||
name: str,
|
||||
contact_name: str,
|
||||
contact_email: str,
|
||||
) -> Org:
|
||||
"""
|
||||
Create an organization entity with basic information.
|
||||
|
||||
Args:
|
||||
org_id: Organization UUID
|
||||
name: Organization name
|
||||
contact_name: Contact person name
|
||||
contact_email: Contact email address
|
||||
|
||||
Returns:
|
||||
Org: New organization entity (not yet persisted)
|
||||
"""
|
||||
return Org(
|
||||
id=org_id,
|
||||
name=name,
|
||||
contact_name=contact_name,
|
||||
contact_email=contact_email,
|
||||
org_version=ORG_SETTINGS_VERSION,
|
||||
default_llm_model=get_default_litellm_model(),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def apply_litellm_settings_to_org(org: Org, settings: dict) -> None:
|
||||
"""
|
||||
Apply LiteLLM settings to organization entity.
|
||||
|
||||
Args:
|
||||
org: Organization entity to update
|
||||
settings: LiteLLM settings object
|
||||
"""
|
||||
org_kwargs = OrgStore.get_kwargs_from_settings(settings)
|
||||
for key, value in org_kwargs.items():
|
||||
if hasattr(org, key):
|
||||
setattr(org, key, value)
|
||||
|
||||
@staticmethod
|
||||
def get_owner_role():
|
||||
"""
|
||||
Get the owner role from the database.
|
||||
|
||||
Returns:
|
||||
Role: The owner role object
|
||||
|
||||
Raises:
|
||||
Exception: If owner role not found
|
||||
"""
|
||||
owner_role = RoleStore.get_role_by_name('owner')
|
||||
if not owner_role:
|
||||
raise Exception('Owner role not found in database')
|
||||
return owner_role
|
||||
|
||||
@staticmethod
|
||||
def create_org_member_entity(
|
||||
org_id: UUID,
|
||||
user_id: str,
|
||||
role_id: int,
|
||||
settings: dict,
|
||||
) -> OrgMember:
|
||||
"""
|
||||
Create an organization member entity.
|
||||
|
||||
Args:
|
||||
org_id: Organization UUID
|
||||
user_id: User ID (string that will be converted to UUID)
|
||||
role_id: Role ID
|
||||
settings: LiteLLM settings object
|
||||
|
||||
Returns:
|
||||
OrgMember: New organization member entity (not yet persisted)
|
||||
"""
|
||||
org_member_kwargs = OrgMemberStore.get_kwargs_from_settings(settings)
|
||||
return OrgMember(
|
||||
org_id=org_id,
|
||||
user_id=parse_uuid(user_id),
|
||||
role_id=role_id,
|
||||
status='active',
|
||||
**org_member_kwargs,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def create_org_with_owner(
|
||||
name: str,
|
||||
contact_name: str,
|
||||
contact_email: str,
|
||||
user_id: str,
|
||||
) -> Org:
|
||||
"""
|
||||
Create a new organization with the specified user as owner.
|
||||
|
||||
This method orchestrates the complete organization creation workflow:
|
||||
1. Validates that the organization name doesn't already exist
|
||||
2. Generates a unique organization ID
|
||||
3. Creates LiteLLM team integration
|
||||
4. Creates the organization entity
|
||||
5. Applies LiteLLM settings
|
||||
6. Creates owner membership
|
||||
7. Persists everything in a transaction
|
||||
|
||||
If database persistence fails, LiteLLM resources are cleaned up (compensation).
|
||||
|
||||
Args:
|
||||
name: Organization name (must be unique)
|
||||
contact_name: Contact person name
|
||||
contact_email: Contact email address
|
||||
user_id: ID of the user who will be the owner
|
||||
|
||||
Returns:
|
||||
Org: The created organization object
|
||||
|
||||
Raises:
|
||||
OrgNameExistsError: If organization name already exists
|
||||
LiteLLMIntegrationError: If LiteLLM integration fails
|
||||
OrgDatabaseError: If database operations fail
|
||||
"""
|
||||
logger.info(
|
||||
'Starting organization creation',
|
||||
extra={'user_id': user_id, 'org_name': name},
|
||||
)
|
||||
|
||||
# Step 1: Validate name uniqueness (fails early, no cleanup needed)
|
||||
OrgService.validate_name_uniqueness(name)
|
||||
|
||||
# Step 2: Generate organization ID
|
||||
org_id = uuid4()
|
||||
|
||||
# Step 3: Create LiteLLM integration (external state created)
|
||||
settings = await OrgService.create_litellm_integration(org_id, user_id)
|
||||
|
||||
# Steps 4-7: Create entities and persist with compensation
|
||||
# If any of these fail, we need to clean up LiteLLM resources
|
||||
try:
|
||||
# Step 4: Create organization entity
|
||||
org = OrgService.create_org_entity(
|
||||
org_id=org_id,
|
||||
name=name,
|
||||
contact_name=contact_name,
|
||||
contact_email=contact_email,
|
||||
)
|
||||
|
||||
# Step 5: Apply LiteLLM settings
|
||||
OrgService.apply_litellm_settings_to_org(org, settings)
|
||||
|
||||
# Step 6: Get owner role and create member entity
|
||||
owner_role = OrgService.get_owner_role()
|
||||
org_member = OrgService.create_org_member_entity(
|
||||
org_id=org_id,
|
||||
user_id=user_id,
|
||||
role_id=owner_role.id,
|
||||
settings=settings,
|
||||
)
|
||||
|
||||
# Step 7: Persist in transaction (critical section)
|
||||
persisted_org = await OrgService._persist_with_compensation(
|
||||
org, org_member, org_id, user_id
|
||||
)
|
||||
|
||||
logger.info(
|
||||
'Successfully created organization',
|
||||
extra={
|
||||
'org_id': str(persisted_org.id),
|
||||
'org_name': persisted_org.name,
|
||||
'user_id': user_id,
|
||||
'role': 'owner',
|
||||
},
|
||||
)
|
||||
|
||||
return persisted_org
|
||||
|
||||
except OrgDatabaseError:
|
||||
# Already handled by _persist_with_compensation, just re-raise
|
||||
raise
|
||||
except Exception as e:
|
||||
# Unexpected error in steps 4-6, need to clean up LiteLLM
|
||||
logger.error(
|
||||
'Unexpected error during organization creation, initiating cleanup',
|
||||
extra={
|
||||
'org_id': str(org_id),
|
||||
'user_id': user_id,
|
||||
'error': str(e),
|
||||
},
|
||||
)
|
||||
await OrgService._handle_failure_with_cleanup(
|
||||
org_id, user_id, e, 'Failed to create organization'
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def _persist_with_compensation(
|
||||
org: Org,
|
||||
org_member: OrgMember,
|
||||
org_id: UUID,
|
||||
user_id: str,
|
||||
) -> Org:
|
||||
"""
|
||||
Persist organization with compensation on failure.
|
||||
|
||||
If database persistence fails, cleans up LiteLLM resources.
|
||||
|
||||
Args:
|
||||
org: Organization entity to persist
|
||||
org_member: Organization member entity to persist
|
||||
org_id: Organization ID (for cleanup)
|
||||
user_id: User ID (for cleanup)
|
||||
|
||||
Returns:
|
||||
Org: The persisted organization object
|
||||
|
||||
Raises:
|
||||
OrgDatabaseError: If database operations fail
|
||||
"""
|
||||
try:
|
||||
persisted_org = OrgStore.persist_org_with_owner(org, org_member)
|
||||
return persisted_org
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
'Database persistence failed, initiating LiteLLM cleanup',
|
||||
extra={
|
||||
'org_id': str(org_id),
|
||||
'user_id': user_id,
|
||||
'error': str(e),
|
||||
},
|
||||
)
|
||||
await OrgService._handle_failure_with_cleanup(
|
||||
org_id, user_id, e, 'Failed to create organization'
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def _handle_failure_with_cleanup(
|
||||
org_id: UUID,
|
||||
user_id: str,
|
||||
original_error: Exception,
|
||||
error_message: str,
|
||||
) -> None:
|
||||
"""
|
||||
Handle failure by cleaning up LiteLLM resources and raising appropriate error.
|
||||
|
||||
This method performs compensating transaction and raises OrgDatabaseError.
|
||||
|
||||
Args:
|
||||
org_id: Organization ID
|
||||
user_id: User ID
|
||||
original_error: The original exception that caused the failure
|
||||
error_message: Base error message for the exception
|
||||
|
||||
Raises:
|
||||
OrgDatabaseError: Always raises with details about the failure
|
||||
"""
|
||||
cleanup_error = await OrgService._cleanup_litellm_resources(org_id, user_id)
|
||||
|
||||
if cleanup_error:
|
||||
logger.error(
|
||||
'Both operation and cleanup failed',
|
||||
extra={
|
||||
'org_id': str(org_id),
|
||||
'user_id': user_id,
|
||||
'original_error': str(original_error),
|
||||
'cleanup_error': str(cleanup_error),
|
||||
},
|
||||
)
|
||||
raise OrgDatabaseError(
|
||||
f'{error_message}: {str(original_error)}. '
|
||||
f'Cleanup also failed: {str(cleanup_error)}'
|
||||
)
|
||||
|
||||
raise OrgDatabaseError(f'{error_message}: {str(original_error)}')
|
||||
|
||||
@staticmethod
|
||||
async def _cleanup_litellm_resources(
|
||||
org_id: UUID, user_id: str
|
||||
) -> Exception | None:
|
||||
"""
|
||||
Compensating transaction: Clean up LiteLLM resources.
|
||||
|
||||
Deletes the team which should cascade to remove keys and memberships.
|
||||
This is a best-effort operation - errors are logged but not raised.
|
||||
|
||||
Args:
|
||||
org_id: Organization ID
|
||||
user_id: User ID
|
||||
|
||||
Returns:
|
||||
Exception | None: Exception if cleanup failed, None if successful
|
||||
"""
|
||||
try:
|
||||
await LiteLlmManager.delete_team(str(org_id))
|
||||
|
||||
logger.info(
|
||||
'Successfully cleaned up LiteLLM team',
|
||||
extra={'org_id': str(org_id), 'user_id': user_id},
|
||||
)
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
'Failed to cleanup LiteLLM team (resources may be orphaned)',
|
||||
extra={
|
||||
'org_id': str(org_id),
|
||||
'user_id': user_id,
|
||||
'error': str(e),
|
||||
},
|
||||
)
|
||||
return e
|
||||
|
||||
@staticmethod
|
||||
def has_admin_or_owner_role(user_id: str, org_id: UUID) -> bool:
|
||||
"""
|
||||
Check if user has admin or owner role in the specified organization.
|
||||
|
||||
Args:
|
||||
user_id: User ID to check
|
||||
org_id: Organization ID to check membership in
|
||||
|
||||
Returns:
|
||||
bool: True if user has admin or owner role, False otherwise
|
||||
"""
|
||||
try:
|
||||
# Parse user_id as UUID for database query
|
||||
user_uuid = parse_uuid(user_id)
|
||||
|
||||
# Get the user's membership in this organization
|
||||
# Note: The type annotation says int but the actual column is UUID
|
||||
org_member = OrgMemberStore.get_org_member(org_id, user_uuid)
|
||||
if not org_member:
|
||||
return False
|
||||
|
||||
# Get the role details
|
||||
role = RoleStore.get_role_by_id(org_member.role_id)
|
||||
if not role:
|
||||
return False
|
||||
|
||||
# Admin and owner roles have elevated permissions
|
||||
# Based on test files, both admin and owner have rank 1
|
||||
return role.name in ['admin', 'owner']
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
'Error checking user role in organization',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'org_id': str(org_id),
|
||||
'error': str(e),
|
||||
},
|
||||
)
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def is_org_member(user_id: str, org_id: UUID) -> bool:
|
||||
"""
|
||||
Check if user is a member of the specified organization.
|
||||
|
||||
Args:
|
||||
user_id: User ID to check
|
||||
org_id: Organization ID to check membership in
|
||||
|
||||
Returns:
|
||||
bool: True if user is a member, False otherwise
|
||||
"""
|
||||
try:
|
||||
user_uuid = parse_uuid(user_id)
|
||||
org_member = OrgMemberStore.get_org_member(org_id, user_uuid)
|
||||
return org_member is not None
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
'Error checking user membership in organization',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'org_id': str(org_id),
|
||||
'error': str(e),
|
||||
},
|
||||
)
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _get_llm_settings_fields() -> set[str]:
|
||||
"""
|
||||
Get the set of organization fields that are considered LLM settings
|
||||
and require admin/owner role to update.
|
||||
|
||||
Returns:
|
||||
set[str]: Set of field names that require elevated permissions
|
||||
"""
|
||||
return {
|
||||
'default_llm_model',
|
||||
'default_llm_api_key_for_byor',
|
||||
'default_llm_base_url',
|
||||
'search_api_key',
|
||||
'security_analyzer',
|
||||
'agent',
|
||||
'confirmation_mode',
|
||||
'enable_default_condenser',
|
||||
'condenser_max_size',
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _has_llm_settings_updates(update_data: OrgUpdate) -> set[str]:
|
||||
"""
|
||||
Check if the update contains any LLM settings fields.
|
||||
|
||||
Args:
|
||||
update_data: The organization update data
|
||||
|
||||
Returns:
|
||||
set[str]: Set of LLM fields being updated (empty if none)
|
||||
"""
|
||||
llm_fields = OrgService._get_llm_settings_fields()
|
||||
update_dict = update_data.model_dump(exclude_none=True)
|
||||
return llm_fields.intersection(update_dict.keys())
|
||||
|
||||
@staticmethod
|
||||
async def update_org_with_permissions(
|
||||
org_id: UUID,
|
||||
update_data: OrgUpdate,
|
||||
user_id: str,
|
||||
) -> Org:
|
||||
"""
|
||||
Update organization with permission checks for LLM settings.
|
||||
|
||||
Args:
|
||||
org_id: Organization UUID to update
|
||||
update_data: Organization update data from request
|
||||
user_id: ID of the user requesting the update
|
||||
|
||||
Returns:
|
||||
Org: The updated organization object
|
||||
|
||||
Raises:
|
||||
ValueError: If organization not found
|
||||
PermissionError: If user is not a member, or lacks admin/owner role for LLM settings
|
||||
OrgDatabaseError: If database update fails
|
||||
"""
|
||||
logger.info(
|
||||
'Updating organization with permission checks',
|
||||
extra={
|
||||
'org_id': str(org_id),
|
||||
'user_id': user_id,
|
||||
'has_update_data': update_data is not None,
|
||||
},
|
||||
)
|
||||
|
||||
# Validate organization exists
|
||||
existing_org = OrgStore.get_org_by_id(org_id)
|
||||
if not existing_org:
|
||||
raise ValueError(f'Organization with ID {org_id} not found')
|
||||
|
||||
# Check if user is a member of this organization
|
||||
if not OrgService.is_org_member(user_id, org_id):
|
||||
logger.warning(
|
||||
'Non-member attempted to update organization',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'org_id': str(org_id),
|
||||
},
|
||||
)
|
||||
raise PermissionError(
|
||||
'User must be a member of the organization to update it'
|
||||
)
|
||||
|
||||
# Check if update contains any LLM settings
|
||||
llm_fields_being_updated = OrgService._has_llm_settings_updates(update_data)
|
||||
if llm_fields_being_updated:
|
||||
# Verify user has admin or owner role
|
||||
has_permission = OrgService.has_admin_or_owner_role(user_id, org_id)
|
||||
if not has_permission:
|
||||
logger.warning(
|
||||
'User attempted to update LLM settings without permission',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'org_id': str(org_id),
|
||||
'attempted_fields': list(llm_fields_being_updated),
|
||||
},
|
||||
)
|
||||
raise PermissionError(
|
||||
'Admin or owner role required to update LLM settings'
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
'User has permission to update LLM settings',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'org_id': str(org_id),
|
||||
'llm_fields': list(llm_fields_being_updated),
|
||||
},
|
||||
)
|
||||
|
||||
# Convert to dict for OrgStore (excluding None values)
|
||||
update_dict = update_data.model_dump(exclude_none=True)
|
||||
if not update_dict:
|
||||
logger.info(
|
||||
'No fields to update',
|
||||
extra={'org_id': str(org_id), 'user_id': user_id},
|
||||
)
|
||||
return existing_org
|
||||
|
||||
# Perform the update
|
||||
try:
|
||||
updated_org = OrgStore.update_org(org_id, update_dict)
|
||||
if not updated_org:
|
||||
raise OrgDatabaseError('Failed to update organization in database')
|
||||
|
||||
logger.info(
|
||||
'Organization updated successfully',
|
||||
extra={
|
||||
'org_id': str(org_id),
|
||||
'user_id': user_id,
|
||||
'updated_fields': list(update_dict.keys()),
|
||||
},
|
||||
)
|
||||
|
||||
return updated_org
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
'Failed to update organization',
|
||||
extra={
|
||||
'org_id': str(org_id),
|
||||
'user_id': user_id,
|
||||
'error': str(e),
|
||||
},
|
||||
)
|
||||
raise OrgDatabaseError(f'Failed to update organization: {str(e)}')
|
||||
|
||||
@staticmethod
|
||||
async def get_org_credits(user_id: str, org_id: UUID) -> float | None:
|
||||
"""
|
||||
Get organization credits from LiteLLM team.
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
org_id: Organization ID
|
||||
|
||||
Returns:
|
||||
float | None: Credits (max_budget - spend) or None if LiteLLM not configured
|
||||
"""
|
||||
try:
|
||||
user_team_info = await LiteLlmManager.get_user_team_info(
|
||||
user_id, str(org_id)
|
||||
)
|
||||
if not user_team_info:
|
||||
logger.warning(
|
||||
'No team info available from LiteLLM',
|
||||
extra={'user_id': user_id, 'org_id': str(org_id)},
|
||||
)
|
||||
return None
|
||||
|
||||
max_budget = (user_team_info.get('litellm_budget_table') or {}).get(
|
||||
'max_budget', 0
|
||||
)
|
||||
spend = user_team_info.get('spend', 0)
|
||||
credits = max(max_budget - spend, 0)
|
||||
|
||||
logger.debug(
|
||||
'Retrieved organization credits',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'org_id': str(org_id),
|
||||
'credits': credits,
|
||||
'max_budget': max_budget,
|
||||
'spend': spend,
|
||||
},
|
||||
)
|
||||
|
||||
return credits
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
'Failed to retrieve organization credits',
|
||||
extra={'user_id': user_id, 'org_id': str(org_id), 'error': str(e)},
|
||||
)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def get_user_orgs_paginated(
|
||||
user_id: str, page_id: str | None = None, limit: int = 100
|
||||
):
|
||||
"""
|
||||
Get paginated list of organizations for a user.
|
||||
|
||||
Args:
|
||||
user_id: User ID (string that will be converted to UUID)
|
||||
page_id: Optional page ID (offset as string) for pagination
|
||||
limit: Maximum number of organizations to return
|
||||
|
||||
Returns:
|
||||
Tuple of (list of Org objects, next_page_id or None)
|
||||
"""
|
||||
logger.debug(
|
||||
'Fetching paginated organizations for user',
|
||||
extra={'user_id': user_id, 'page_id': page_id, 'limit': limit},
|
||||
)
|
||||
|
||||
# Convert user_id string to UUID
|
||||
user_uuid = parse_uuid(user_id)
|
||||
|
||||
# Fetch organizations from store
|
||||
orgs, next_page_id = OrgStore.get_user_orgs_paginated(
|
||||
user_id=user_uuid, page_id=page_id, limit=limit
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
'Retrieved organizations for user',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'org_count': len(orgs),
|
||||
'has_more': next_page_id is not None,
|
||||
},
|
||||
)
|
||||
|
||||
return orgs, next_page_id
|
||||
|
||||
@staticmethod
|
||||
async def get_org_by_id(org_id: UUID, user_id: str) -> Org:
|
||||
"""
|
||||
Get organization by ID with membership validation.
|
||||
|
||||
This method verifies that the user is a member of the organization
|
||||
before returning the organization details.
|
||||
|
||||
Args:
|
||||
org_id: Organization ID
|
||||
user_id: User ID (string that will be converted to UUID)
|
||||
|
||||
Returns:
|
||||
Org: The organization object
|
||||
|
||||
Raises:
|
||||
OrgNotFoundError: If organization not found or user is not a member
|
||||
"""
|
||||
logger.info(
|
||||
'Retrieving organization',
|
||||
extra={'user_id': user_id, 'org_id': str(org_id)},
|
||||
)
|
||||
|
||||
# Verify user is a member of the organization
|
||||
org_member = OrgMemberStore.get_org_member(org_id, parse_uuid(user_id))
|
||||
if not org_member:
|
||||
logger.warning(
|
||||
'User is not a member of organization or organization does not exist',
|
||||
extra={'user_id': user_id, 'org_id': str(org_id)},
|
||||
)
|
||||
raise OrgNotFoundError(str(org_id))
|
||||
|
||||
# Retrieve organization
|
||||
org = OrgStore.get_org_by_id(org_id)
|
||||
if not org:
|
||||
logger.error(
|
||||
'Organization not found despite valid membership',
|
||||
extra={'user_id': user_id, 'org_id': str(org_id)},
|
||||
)
|
||||
raise OrgNotFoundError(str(org_id))
|
||||
|
||||
logger.info(
|
||||
'Successfully retrieved organization',
|
||||
extra={
|
||||
'org_id': str(org.id),
|
||||
'org_name': org.name,
|
||||
'user_id': user_id,
|
||||
},
|
||||
)
|
||||
|
||||
return org
|
||||
|
||||
@staticmethod
|
||||
def verify_owner_authorization(user_id: str, org_id: UUID) -> None:
|
||||
"""
|
||||
Verify that the user is the owner of the organization.
|
||||
|
||||
Args:
|
||||
user_id: User ID to check
|
||||
org_id: Organization ID
|
||||
|
||||
Raises:
|
||||
OrgNotFoundError: If organization doesn't exist
|
||||
OrgAuthorizationError: If user is not authorized to delete
|
||||
"""
|
||||
# Check if organization exists
|
||||
org = OrgStore.get_org_by_id(org_id)
|
||||
if not org:
|
||||
raise OrgNotFoundError(str(org_id))
|
||||
|
||||
# Check if user is a member of the organization
|
||||
org_member = OrgMemberStore.get_org_member(org_id, parse_uuid(user_id))
|
||||
if not org_member:
|
||||
raise OrgAuthorizationError('User is not a member of this organization')
|
||||
|
||||
# Check if user has owner role
|
||||
role = RoleStore.get_role_by_id(org_member.role_id)
|
||||
if not role or role.name != 'owner':
|
||||
raise OrgAuthorizationError(
|
||||
'Only organization owners can delete organizations'
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
'User authorization verified for organization deletion',
|
||||
extra={'user_id': user_id, 'org_id': str(org_id), 'role': role.name},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def delete_org_with_cleanup(user_id: str, org_id: UUID) -> Org:
|
||||
"""
|
||||
Delete organization with complete cleanup of all associated data.
|
||||
|
||||
This method performs the complete organization deletion workflow:
|
||||
1. Verifies user authorization (owner only)
|
||||
2. Performs database cascade deletion and LiteLLM cleanup in single transaction
|
||||
|
||||
Args:
|
||||
user_id: User ID requesting deletion (must be owner)
|
||||
org_id: Organization ID to delete
|
||||
|
||||
Returns:
|
||||
Org: The deleted organization details
|
||||
|
||||
Raises:
|
||||
OrgNotFoundError: If organization doesn't exist
|
||||
OrgAuthorizationError: If user is not authorized to delete
|
||||
OrgDatabaseError: If database operations or LiteLLM cleanup fail
|
||||
"""
|
||||
logger.info(
|
||||
'Starting organization deletion',
|
||||
extra={'user_id': user_id, 'org_id': str(org_id)},
|
||||
)
|
||||
|
||||
# Step 1: Verify user authorization
|
||||
OrgService.verify_owner_authorization(user_id, org_id)
|
||||
|
||||
# Step 2: Perform database cascade deletion with LiteLLM cleanup in transaction
|
||||
try:
|
||||
deleted_org = await OrgStore.delete_org_cascade(org_id)
|
||||
if not deleted_org:
|
||||
# This shouldn't happen since we verified existence above
|
||||
raise OrgDatabaseError('Organization not found during deletion')
|
||||
|
||||
logger.info(
|
||||
'Organization deletion completed successfully',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'org_id': str(org_id),
|
||||
'org_name': deleted_org.name,
|
||||
},
|
||||
)
|
||||
|
||||
return deleted_org
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
'Organization deletion failed',
|
||||
extra={'user_id': user_id, 'org_id': str(org_id), 'error': str(e)},
|
||||
)
|
||||
raise OrgDatabaseError(f'Failed to delete organization: {str(e)}')
|
||||
@@ -1,363 +0,0 @@
|
||||
"""
|
||||
Store class for managing organizations.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from server.constants import (
|
||||
LITE_LLM_API_URL,
|
||||
ORG_SETTINGS_VERSION,
|
||||
get_default_litellm_model,
|
||||
)
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.orm import joinedload
|
||||
from storage.database import session_maker
|
||||
from storage.lite_llm_manager import LiteLlmManager
|
||||
from storage.org import Org
|
||||
from storage.org_member import OrgMember
|
||||
from storage.user import User
|
||||
from storage.user_settings import UserSettings
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.storage.data_models.settings import Settings
|
||||
|
||||
|
||||
class OrgStore:
|
||||
"""Store for managing organizations."""
|
||||
|
||||
@staticmethod
|
||||
def create_org(
|
||||
kwargs: dict,
|
||||
) -> Org:
|
||||
"""Create a new organization."""
|
||||
with session_maker() as session:
|
||||
org = Org(**kwargs)
|
||||
org.org_version = ORG_SETTINGS_VERSION
|
||||
org.default_llm_model = get_default_litellm_model()
|
||||
session.add(org)
|
||||
session.commit()
|
||||
session.refresh(org)
|
||||
return org
|
||||
|
||||
@staticmethod
|
||||
def get_org_by_id(org_id: UUID) -> Org | None:
|
||||
"""Get organization by ID."""
|
||||
org = None
|
||||
with session_maker() as session:
|
||||
org = session.query(Org).filter(Org.id == org_id).first()
|
||||
return OrgStore._validate_org_version(org)
|
||||
|
||||
@staticmethod
|
||||
def get_current_org_from_keycloak_user_id(keycloak_user_id: str) -> Org | None:
|
||||
with session_maker() as session:
|
||||
user = (
|
||||
session.query(User)
|
||||
.options(joinedload(User.org_members))
|
||||
.filter(User.id == UUID(keycloak_user_id))
|
||||
.first()
|
||||
)
|
||||
if not user:
|
||||
logger.warning(f'User not found for ID {keycloak_user_id}')
|
||||
return None
|
||||
org_id = user.current_org_id
|
||||
org = session.query(Org).filter(Org.id == org_id).first()
|
||||
if not org:
|
||||
logger.warning(
|
||||
f'Org not found for ID {org_id} as the current org for user {keycloak_user_id}'
|
||||
)
|
||||
return None
|
||||
return OrgStore._validate_org_version(org)
|
||||
|
||||
@staticmethod
|
||||
def get_org_by_name(name: str) -> Org | None:
|
||||
"""Get organization by name."""
|
||||
org = None
|
||||
with session_maker() as session:
|
||||
org = session.query(Org).filter(Org.name == name).first()
|
||||
return OrgStore._validate_org_version(org)
|
||||
|
||||
@staticmethod
|
||||
def _validate_org_version(org: Org) -> Org | None:
|
||||
"""Check if we need to update org version."""
|
||||
if org and org.org_version < ORG_SETTINGS_VERSION:
|
||||
org = OrgStore.update_org(
|
||||
org.id,
|
||||
{
|
||||
'org_version': ORG_SETTINGS_VERSION,
|
||||
'default_llm_model': get_default_litellm_model(),
|
||||
'llm_base_url': LITE_LLM_API_URL,
|
||||
},
|
||||
)
|
||||
return org
|
||||
|
||||
@staticmethod
|
||||
def list_orgs() -> list[Org]:
|
||||
"""List all organizations."""
|
||||
with session_maker() as session:
|
||||
orgs = session.query(Org).all()
|
||||
return orgs
|
||||
|
||||
@staticmethod
|
||||
def get_user_orgs_paginated(
|
||||
user_id: UUID, page_id: str | None = None, limit: int = 100
|
||||
) -> tuple[list[Org], str | None]:
|
||||
"""
|
||||
Get paginated list of organizations for a user.
|
||||
|
||||
Args:
|
||||
user_id: User UUID
|
||||
page_id: Optional page ID (offset as string) for pagination
|
||||
limit: Maximum number of organizations to return
|
||||
|
||||
Returns:
|
||||
Tuple of (list of Org objects, next_page_id or None)
|
||||
"""
|
||||
with session_maker() as session:
|
||||
# Build query joining OrgMember with Org
|
||||
query = (
|
||||
session.query(Org)
|
||||
.join(OrgMember, Org.id == OrgMember.org_id)
|
||||
.filter(OrgMember.user_id == user_id)
|
||||
.order_by(Org.name)
|
||||
)
|
||||
|
||||
# Apply pagination offset
|
||||
if page_id is not None:
|
||||
try:
|
||||
offset = int(page_id)
|
||||
query = query.offset(offset)
|
||||
except ValueError:
|
||||
# If page_id is not a valid integer, start from beginning
|
||||
offset = 0
|
||||
else:
|
||||
offset = 0
|
||||
|
||||
# Fetch limit + 1 to check if there are more results
|
||||
query = query.limit(limit + 1)
|
||||
orgs = query.all()
|
||||
|
||||
# Check if there are more results
|
||||
has_more = len(orgs) > limit
|
||||
if has_more:
|
||||
orgs = orgs[:limit]
|
||||
|
||||
# Calculate next page ID
|
||||
next_page_id = None
|
||||
if has_more:
|
||||
next_page_id = str(offset + limit)
|
||||
|
||||
# Validate org versions
|
||||
validated_orgs = [
|
||||
OrgStore._validate_org_version(org) for org in orgs if org
|
||||
]
|
||||
validated_orgs = [org for org in validated_orgs if org is not None]
|
||||
|
||||
return validated_orgs, next_page_id
|
||||
|
||||
@staticmethod
|
||||
def update_org(
|
||||
org_id: UUID,
|
||||
kwargs: dict,
|
||||
) -> Optional[Org]:
|
||||
"""Update organization details."""
|
||||
with session_maker() as session:
|
||||
org = session.query(Org).filter(Org.id == org_id).first()
|
||||
if not org:
|
||||
return None
|
||||
|
||||
if 'id' in kwargs:
|
||||
kwargs.pop('id')
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(org, key):
|
||||
setattr(org, key, value)
|
||||
|
||||
session.commit()
|
||||
session.refresh(org)
|
||||
return org
|
||||
|
||||
@staticmethod
|
||||
def get_kwargs_from_settings(settings: Settings):
|
||||
kwargs = {}
|
||||
|
||||
for c in Org.__table__.columns:
|
||||
# Normalize for lookup
|
||||
normalized = (
|
||||
c.name.removeprefix('_default_').removeprefix('default_').lstrip('_')
|
||||
)
|
||||
|
||||
if not hasattr(settings, normalized):
|
||||
continue
|
||||
|
||||
# ---- FIX: Output key should drop *only* leading "_" but preserve "default" ----
|
||||
key = c.name
|
||||
if key.startswith('_'):
|
||||
key = key[1:] # remove only the very first leading underscore
|
||||
|
||||
kwargs[key] = getattr(settings, normalized)
|
||||
|
||||
return kwargs
|
||||
|
||||
@staticmethod
|
||||
def get_kwargs_from_user_settings(user_settings: UserSettings):
|
||||
kwargs = {}
|
||||
|
||||
for c in Org.__table__.columns:
|
||||
# Normalize for lookup
|
||||
normalized = (
|
||||
c.name.removeprefix('_default_').removeprefix('default_').lstrip('_')
|
||||
)
|
||||
|
||||
if not hasattr(user_settings, normalized):
|
||||
continue
|
||||
|
||||
# ---- FIX: Output key should drop *only* leading "_" but preserve "default" ----
|
||||
key = c.name
|
||||
if key.startswith('_'):
|
||||
key = key[1:] # remove only the very first leading underscore
|
||||
|
||||
kwargs[key] = getattr(user_settings, normalized)
|
||||
|
||||
kwargs['org_version'] = user_settings.user_version
|
||||
return kwargs
|
||||
|
||||
@staticmethod
|
||||
def persist_org_with_owner(
|
||||
org: Org,
|
||||
org_member: OrgMember,
|
||||
) -> Org:
|
||||
"""
|
||||
Persist organization and owner membership in a single transaction.
|
||||
|
||||
Args:
|
||||
org: Organization entity to persist
|
||||
org_member: Organization member entity to persist
|
||||
|
||||
Returns:
|
||||
Org: The persisted organization object
|
||||
|
||||
Raises:
|
||||
Exception: If database operations fail
|
||||
"""
|
||||
with session_maker() as session:
|
||||
session.add(org)
|
||||
session.add(org_member)
|
||||
session.commit()
|
||||
session.refresh(org)
|
||||
return org
|
||||
|
||||
@staticmethod
|
||||
async def delete_org_cascade(org_id: UUID) -> Org | None:
|
||||
"""
|
||||
Delete organization and all associated data in cascade, including external LiteLLM cleanup.
|
||||
|
||||
Args:
|
||||
org_id: UUID of the organization to delete
|
||||
|
||||
Returns:
|
||||
Org: The deleted organization object, or None if not found
|
||||
|
||||
Raises:
|
||||
Exception: If database operations or LiteLLM cleanup fail
|
||||
"""
|
||||
with session_maker() as session:
|
||||
# First get the organization to return it
|
||||
org = session.query(Org).filter(Org.id == org_id).first()
|
||||
if not org:
|
||||
return None
|
||||
|
||||
try:
|
||||
# 1. Delete conversation data for organization conversations
|
||||
session.execute(
|
||||
text("""
|
||||
DELETE FROM conversation_metadata
|
||||
WHERE conversation_id IN (
|
||||
SELECT conversation_id FROM conversation_metadata_saas WHERE org_id = :org_id
|
||||
)
|
||||
"""),
|
||||
{'org_id': str(org_id)},
|
||||
)
|
||||
|
||||
session.execute(
|
||||
text("""
|
||||
DELETE FROM app_conversation_start_task
|
||||
WHERE app_conversation_id::text IN (
|
||||
SELECT conversation_id FROM conversation_metadata_saas WHERE org_id = :org_id
|
||||
)
|
||||
"""),
|
||||
{'org_id': str(org_id)},
|
||||
)
|
||||
|
||||
# 2. Delete organization-owned data tables (direct org_id foreign keys)
|
||||
session.execute(
|
||||
text('DELETE FROM billing_sessions WHERE org_id = :org_id'),
|
||||
{'org_id': str(org_id)},
|
||||
)
|
||||
session.execute(
|
||||
text(
|
||||
'DELETE FROM conversation_metadata_saas WHERE org_id = :org_id'
|
||||
),
|
||||
{'org_id': str(org_id)},
|
||||
)
|
||||
session.execute(
|
||||
text('DELETE FROM custom_secrets WHERE org_id = :org_id'),
|
||||
{'org_id': str(org_id)},
|
||||
)
|
||||
session.execute(
|
||||
text('DELETE FROM api_keys WHERE org_id = :org_id'),
|
||||
{'org_id': str(org_id)},
|
||||
)
|
||||
session.execute(
|
||||
text('DELETE FROM slack_conversation WHERE org_id = :org_id'),
|
||||
{'org_id': str(org_id)},
|
||||
)
|
||||
session.execute(
|
||||
text('DELETE FROM slack_users WHERE org_id = :org_id'),
|
||||
{'org_id': str(org_id)},
|
||||
)
|
||||
session.execute(
|
||||
text('DELETE FROM stripe_customers WHERE org_id = :org_id'),
|
||||
{'org_id': str(org_id)},
|
||||
)
|
||||
|
||||
# 3. Delete organization memberships
|
||||
session.execute(
|
||||
text('DELETE FROM org_member WHERE org_id = :org_id'),
|
||||
{'org_id': str(org_id)},
|
||||
)
|
||||
|
||||
# 4. Handle users with this as current_org_id
|
||||
session.execute(
|
||||
text(
|
||||
'UPDATE "user" SET current_org_id = NULL WHERE current_org_id = :org_id'
|
||||
),
|
||||
{'org_id': str(org_id)},
|
||||
)
|
||||
|
||||
# 5. Finally delete the organization
|
||||
session.delete(org)
|
||||
|
||||
# 6. Clean up LiteLLM team before committing transaction
|
||||
logger.info(
|
||||
'Deleting LiteLLM team within database transaction',
|
||||
extra={'org_id': str(org_id)},
|
||||
)
|
||||
await LiteLlmManager.delete_team(str(org_id))
|
||||
|
||||
# 7. Commit all changes only if everything succeeded
|
||||
session.commit()
|
||||
|
||||
logger.info(
|
||||
'Successfully deleted organization and all associated data including LiteLLM team',
|
||||
extra={'org_id': str(org_id), 'org_name': org.name},
|
||||
)
|
||||
|
||||
return org
|
||||
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
logger.error(
|
||||
'Failed to delete organization - transaction rolled back',
|
||||
extra={'org_id': str(org_id), 'error': str(e)},
|
||||
)
|
||||
raise
|
||||
@@ -1,21 +0,0 @@
|
||||
"""
|
||||
SQLAlchemy model for Role.
|
||||
"""
|
||||
|
||||
from sqlalchemy import Column, Identity, Integer, String
|
||||
from sqlalchemy.orm import relationship
|
||||
from storage.base import Base
|
||||
|
||||
|
||||
class Role(Base): # type: ignore
|
||||
"""Role model for user permissions."""
|
||||
|
||||
__tablename__ = 'role'
|
||||
|
||||
id = Column(Integer, Identity(), primary_key=True)
|
||||
name = Column(String, nullable=False, unique=True)
|
||||
rank = Column(Integer, nullable=False)
|
||||
|
||||
# Relationships
|
||||
users = relationship('User', back_populates='role')
|
||||
org_members = relationship('OrgMember', back_populates='role')
|
||||
@@ -1,56 +0,0 @@
|
||||
"""
|
||||
Store class for managing roles.
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from storage.database import a_session_maker, session_maker
|
||||
from storage.role import Role
|
||||
|
||||
|
||||
class RoleStore:
|
||||
"""Store for managing roles."""
|
||||
|
||||
@staticmethod
|
||||
def create_role(name: str, rank: int) -> Role:
|
||||
"""Create a new role."""
|
||||
with session_maker() as session:
|
||||
role = Role(name=name, rank=rank)
|
||||
session.add(role)
|
||||
session.commit()
|
||||
session.refresh(role)
|
||||
return role
|
||||
|
||||
@staticmethod
|
||||
def get_role_by_id(role_id: int) -> Optional[Role]:
|
||||
"""Get role by ID."""
|
||||
with session_maker() as session:
|
||||
return session.query(Role).filter(Role.id == role_id).first()
|
||||
|
||||
@staticmethod
|
||||
def get_role_by_name(name: str) -> Optional[Role]:
|
||||
"""Get role by name."""
|
||||
with session_maker() as session:
|
||||
return session.query(Role).filter(Role.name == name).first()
|
||||
|
||||
@staticmethod
|
||||
async def get_role_by_name_async(
|
||||
name: str,
|
||||
session: Optional[AsyncSession] = None,
|
||||
) -> Optional[Role]:
|
||||
"""Get role by name."""
|
||||
if session is not None:
|
||||
result = await session.execute(select(Role).where(Role.name == name))
|
||||
return result.scalars().first()
|
||||
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(select(Role).where(Role.name == name))
|
||||
return result.scalars().first()
|
||||
|
||||
@staticmethod
|
||||
def list_roles() -> List[Role]:
|
||||
"""List all roles."""
|
||||
with session_maker() as session:
|
||||
return session.query(Role).order_by(Role.rank).all()
|
||||
@@ -4,13 +4,10 @@ import dataclasses
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from storage.database import session_maker
|
||||
from storage.stored_conversation_metadata import StoredConversationMetadata
|
||||
from storage.stored_conversation_metadata_saas import StoredConversationMetadataSaas
|
||||
from storage.user_store import UserStore
|
||||
|
||||
from openhands.core.config.openhands_config import OpenHandsConfig
|
||||
from openhands.integrations.provider import ProviderType
|
||||
@@ -32,37 +29,20 @@ logger = logging.getLogger(__name__)
|
||||
class SaasConversationStore(ConversationStore):
|
||||
user_id: str
|
||||
session_maker: sessionmaker
|
||||
org_id: UUID | None = None # will be fetched automatically
|
||||
|
||||
def __init__(self, user_id: str, session_maker: sessionmaker):
|
||||
self.user_id = user_id
|
||||
self.session_maker = session_maker
|
||||
user = UserStore.get_user_by_id(user_id)
|
||||
self.org_id = user.current_org_id if user else None
|
||||
|
||||
def _select_by_id(self, session, conversation_id: str):
|
||||
# Join StoredConversationMetadata with ConversationMetadataSaas to filter by user/org
|
||||
query = (
|
||||
return (
|
||||
session.query(StoredConversationMetadata)
|
||||
.join(
|
||||
StoredConversationMetadataSaas,
|
||||
StoredConversationMetadata.conversation_id
|
||||
== StoredConversationMetadataSaas.conversation_id,
|
||||
)
|
||||
.filter(StoredConversationMetadataSaas.user_id == UUID(self.user_id))
|
||||
.filter(StoredConversationMetadata.user_id == self.user_id)
|
||||
.filter(StoredConversationMetadata.conversation_id == conversation_id)
|
||||
.filter(StoredConversationMetadata.conversation_version == 'V0')
|
||||
)
|
||||
|
||||
if self.org_id is not None:
|
||||
query = query.filter(StoredConversationMetadataSaas.org_id == self.org_id)
|
||||
|
||||
return query
|
||||
|
||||
def _to_external_model(self, conversation_metadata: StoredConversationMetadata):
|
||||
kwargs = {
|
||||
c.name: getattr(conversation_metadata, c.name)
|
||||
for c in StoredConversationMetadata.__table__.columns
|
||||
if c.name != 'github_user_id' # Skip github_user_id field
|
||||
}
|
||||
# TODO: I'm not sure why the timezone is not set on the dates coming back out of the db
|
||||
kwargs['created_at'] = kwargs['created_at'].replace(tzinfo=UTC)
|
||||
@@ -73,8 +53,6 @@ class SaasConversationStore(ConversationStore):
|
||||
# Convert string to ProviderType enum
|
||||
kwargs['git_provider'] = ProviderType(kwargs['git_provider'])
|
||||
|
||||
kwargs['user_id'] = self.user_id
|
||||
|
||||
# Remove V1 attributes
|
||||
kwargs.pop('max_budget_per_task', None)
|
||||
kwargs.pop('cache_read_tokens', None)
|
||||
@@ -89,10 +67,7 @@ class SaasConversationStore(ConversationStore):
|
||||
|
||||
async def save_metadata(self, metadata: ConversationMetadata):
|
||||
kwargs = dataclasses.asdict(metadata)
|
||||
|
||||
# Remove user_id and org_id from kwargs since they're no longer in StoredConversationMetadata
|
||||
kwargs.pop('user_id', None)
|
||||
kwargs.pop('org_id', None)
|
||||
kwargs['user_id'] = self.user_id
|
||||
|
||||
# Convert ProviderType enum to string for storage
|
||||
if kwargs.get('git_provider') is not None:
|
||||
@@ -106,41 +81,7 @@ class SaasConversationStore(ConversationStore):
|
||||
|
||||
def _save_metadata():
|
||||
with self.session_maker() as session:
|
||||
# Save the main conversation metadata
|
||||
session.merge(stored_metadata)
|
||||
|
||||
# Create or update the SaaS metadata record
|
||||
saas_metadata = (
|
||||
session.query(StoredConversationMetadataSaas)
|
||||
.filter(
|
||||
StoredConversationMetadataSaas.conversation_id
|
||||
== stored_metadata.conversation_id
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not saas_metadata:
|
||||
saas_metadata = StoredConversationMetadataSaas(
|
||||
conversation_id=stored_metadata.conversation_id,
|
||||
user_id=UUID(self.user_id),
|
||||
org_id=self.org_id,
|
||||
)
|
||||
session.add(saas_metadata)
|
||||
else:
|
||||
# Validate
|
||||
expected_user_id = UUID(self.user_id)
|
||||
expected_org_id = self.org_id
|
||||
|
||||
if saas_metadata.user_id != expected_user_id:
|
||||
raise ValueError(
|
||||
f'Existing user_id ({saas_metadata.user_id}) does not match expected value ({expected_user_id}).'
|
||||
)
|
||||
|
||||
if expected_org_id and saas_metadata.org_id != expected_org_id:
|
||||
raise ValueError(
|
||||
f'Existing org_id ({saas_metadata.org_id}) does not match expected value ({expected_org_id}).'
|
||||
)
|
||||
|
||||
session.commit()
|
||||
|
||||
await call_sync_from_async(_save_metadata)
|
||||
@@ -160,29 +101,8 @@ class SaasConversationStore(ConversationStore):
|
||||
async def delete_metadata(self, conversation_id: str) -> None:
|
||||
def _delete_metadata():
|
||||
with self.session_maker() as session:
|
||||
saas_record = (
|
||||
session.query(StoredConversationMetadataSaas)
|
||||
.filter(
|
||||
StoredConversationMetadataSaas.conversation_id
|
||||
== conversation_id,
|
||||
StoredConversationMetadataSaas.user_id == UUID(self.user_id),
|
||||
StoredConversationMetadataSaas.org_id == self.org_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if saas_record:
|
||||
# Delete both records, but only if the SaaS one exists
|
||||
session.query(StoredConversationMetadata).filter(
|
||||
StoredConversationMetadata.conversation_id == conversation_id,
|
||||
).delete()
|
||||
|
||||
session.delete(saas_record)
|
||||
|
||||
session.commit()
|
||||
else:
|
||||
# No SaaS record found → skip deleting main metadata
|
||||
session.rollback()
|
||||
self._select_by_id(session, conversation_id).delete()
|
||||
session.commit()
|
||||
|
||||
await call_sync_from_async(_delete_metadata)
|
||||
|
||||
@@ -205,15 +125,7 @@ class SaasConversationStore(ConversationStore):
|
||||
with self.session_maker() as session:
|
||||
conversations = (
|
||||
session.query(StoredConversationMetadata)
|
||||
.join(
|
||||
StoredConversationMetadataSaas,
|
||||
StoredConversationMetadata.conversation_id
|
||||
== StoredConversationMetadataSaas.conversation_id,
|
||||
)
|
||||
.filter(
|
||||
StoredConversationMetadataSaas.user_id == UUID(self.user_id)
|
||||
)
|
||||
.filter(StoredConversationMetadataSaas.org_id == self.org_id)
|
||||
.filter(StoredConversationMetadata.user_id == self.user_id)
|
||||
.filter(StoredConversationMetadata.conversation_version == 'V0')
|
||||
.order_by(StoredConversationMetadata.created_at.desc())
|
||||
.offset(offset)
|
||||
|
||||
@@ -8,7 +8,6 @@ from cryptography.fernet import Fernet
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from storage.database import session_maker
|
||||
from storage.stored_custom_secrets import StoredCustomSecrets
|
||||
from storage.user_store import UserStore
|
||||
|
||||
from openhands.core.config.openhands_config import OpenHandsConfig
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
@@ -25,17 +24,14 @@ class SaasSecretsStore(SecretsStore):
|
||||
async def load(self) -> Secrets | None:
|
||||
if not self.user_id:
|
||||
return None
|
||||
user = await UserStore.get_user_by_id_async(self.user_id)
|
||||
org_id = user.current_org_id if user else None
|
||||
|
||||
with self.session_maker() as session:
|
||||
# Fetch all secrets for the given user ID
|
||||
query = session.query(StoredCustomSecrets).filter(
|
||||
StoredCustomSecrets.keycloak_user_id == self.user_id
|
||||
settings = (
|
||||
session.query(StoredCustomSecrets)
|
||||
.filter(StoredCustomSecrets.keycloak_user_id == self.user_id)
|
||||
.all()
|
||||
)
|
||||
if org_id is not None:
|
||||
query = query.filter(StoredCustomSecrets.org_id == org_id)
|
||||
settings = query.all()
|
||||
|
||||
if not settings:
|
||||
return Secrets()
|
||||
@@ -52,8 +48,6 @@ class SaasSecretsStore(SecretsStore):
|
||||
return Secrets(custom_secrets=kwargs) # type: ignore[arg-type]
|
||||
|
||||
async def store(self, item: Secrets):
|
||||
user = await UserStore.get_user_by_id_async(self.user_id)
|
||||
org_id = user.current_org_id
|
||||
with self.session_maker() as session:
|
||||
# Incoming secrets are always the most updated ones
|
||||
# Delete all existing records and override with incoming ones
|
||||
@@ -82,7 +76,6 @@ class SaasSecretsStore(SecretsStore):
|
||||
for secret_name, secret_value, description in secret_tuples:
|
||||
new_secret = StoredCustomSecrets(
|
||||
keycloak_user_id=self.user_id,
|
||||
org_id=org_id,
|
||||
secret_name=secret_name,
|
||||
secret_value=secret_value,
|
||||
description=description,
|
||||
|
||||
@@ -1,28 +1,46 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import binascii
|
||||
import hashlib
|
||||
import uuid
|
||||
import json
|
||||
import os
|
||||
from base64 import b64decode, b64encode
|
||||
from dataclasses import dataclass
|
||||
|
||||
import httpx
|
||||
from cryptography.fernet import Fernet
|
||||
from integrations import stripe_service
|
||||
from pydantic import SecretStr
|
||||
from server.auth.token_manager import TokenManager
|
||||
from server.constants import (
|
||||
CURRENT_USER_SETTINGS_VERSION,
|
||||
DEFAULT_INITIAL_BUDGET,
|
||||
LITE_LLM_API_KEY,
|
||||
LITE_LLM_API_URL,
|
||||
LITE_LLM_TEAM_ID,
|
||||
REQUIRE_PAYMENT,
|
||||
USER_SETTINGS_VERSION_TO_MODEL,
|
||||
get_default_litellm_model,
|
||||
)
|
||||
from server.logger import logger
|
||||
from sqlalchemy.orm import joinedload, sessionmaker
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from storage.database import session_maker
|
||||
from storage.lite_llm_manager import LiteLlmManager
|
||||
from storage.org import Org
|
||||
from storage.org_member import OrgMember
|
||||
from storage.org_store import OrgStore
|
||||
from storage.user import User
|
||||
from storage.user_settings import UserSettings
|
||||
from storage.user_store import UserStore
|
||||
|
||||
from openhands.core.config.openhands_config import OpenHandsConfig
|
||||
from openhands.server.settings import Settings
|
||||
from openhands.storage import get_file_store
|
||||
from openhands.storage.settings.settings_store import SettingsStore
|
||||
from openhands.utils.async_utils import call_sync_from_async
|
||||
from openhands.utils.http_session import httpx_verify_option
|
||||
|
||||
# The max possible time to wait for another process to finish creating a user before retrying
|
||||
_REDIS_CREATE_TIMEOUT_SECONDS = 30
|
||||
# The delay to wait for another process to finish creating a user before trying to load again
|
||||
_RETRY_LOAD_DELAY_SECONDS = 2
|
||||
# Redis key prefix for user creation locks
|
||||
_REDIS_USER_CREATION_KEY_PREFIX = 'create_user:'
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -30,9 +48,8 @@ class SaasSettingsStore(SettingsStore):
|
||||
user_id: str
|
||||
session_maker: sessionmaker
|
||||
config: OpenHandsConfig
|
||||
ENCRYPT_VALUES = ['llm_api_key', 'llm_api_key_for_byor', 'search_api_key']
|
||||
|
||||
def _get_user_settings_by_keycloak_id(
|
||||
def get_user_settings_by_keycloak_id(
|
||||
self, keycloak_user_id: str, session=None
|
||||
) -> UserSettings | None:
|
||||
"""
|
||||
@@ -68,105 +85,354 @@ class SaasSettingsStore(SettingsStore):
|
||||
return _get_settings()
|
||||
|
||||
async def load(self) -> Settings | None:
|
||||
user = await call_sync_from_async(UserStore.get_user_by_id, self.user_id)
|
||||
if not user:
|
||||
logger.error(f'User not found for ID {self.user_id}')
|
||||
if not self.user_id:
|
||||
return None
|
||||
with self.session_maker() as session:
|
||||
settings = self.get_user_settings_by_keycloak_id(self.user_id, session)
|
||||
|
||||
org_id = user.current_org_id
|
||||
org_member: OrgMember = None
|
||||
for om in user.org_members:
|
||||
if om.org_id == org_id:
|
||||
org_member = om
|
||||
break
|
||||
if not org_member or not org_member.llm_api_key:
|
||||
return None
|
||||
org = OrgStore.get_org_by_id(org_id)
|
||||
if not org:
|
||||
logger.error(
|
||||
f'Org not found for ID {org_id} as the current org for user {self.user_id}'
|
||||
)
|
||||
return None
|
||||
kwargs = {
|
||||
**{
|
||||
normalized: getattr(org, c.name)
|
||||
for c in Org.__table__.columns
|
||||
if (
|
||||
normalized := c.name.removeprefix('_default_')
|
||||
.removeprefix('default_')
|
||||
.lstrip('_')
|
||||
if not settings or settings.user_version != CURRENT_USER_SETTINGS_VERSION:
|
||||
logger.info(
|
||||
'saas_settings_store:load:triggering_migration',
|
||||
extra={'user_id': self.user_id},
|
||||
)
|
||||
in Settings.model_fields
|
||||
},
|
||||
**{
|
||||
normalized: getattr(user, c.name)
|
||||
for c in User.__table__.columns
|
||||
if (normalized := c.name.lstrip('_')) in Settings.model_fields
|
||||
},
|
||||
}
|
||||
kwargs['llm_api_key'] = org_member.llm_api_key
|
||||
if org_member.max_iterations:
|
||||
kwargs['max_iterations'] = org_member.max_iterations
|
||||
if org_member.llm_model:
|
||||
kwargs['llm_model'] = org_member.llm_model
|
||||
if org_member.llm_api_key_for_byor:
|
||||
kwargs['llm_api_key_for_byor'] = org_member.llm_api_key_for_byor
|
||||
if org_member.llm_base_url:
|
||||
kwargs['llm_base_url'] = org_member.llm_base_url
|
||||
if org.v1_enabled is None:
|
||||
kwargs['v1_enabled'] = True
|
||||
return await self.create_default_settings(settings)
|
||||
kwargs = {
|
||||
c.name: getattr(settings, c.name)
|
||||
for c in UserSettings.__table__.columns
|
||||
if c.name in Settings.model_fields
|
||||
}
|
||||
self._decrypt_kwargs(kwargs)
|
||||
settings = Settings(**kwargs)
|
||||
|
||||
settings = Settings(**kwargs)
|
||||
return settings
|
||||
return settings
|
||||
|
||||
async def store(self, item: Settings):
|
||||
with self.session_maker() as session:
|
||||
if not item:
|
||||
return None
|
||||
kwargs = item.model_dump(context={'expose_secrets': True})
|
||||
user = (
|
||||
session.query(User)
|
||||
.options(joinedload(User.org_members))
|
||||
.filter(User.id == uuid.UUID(self.user_id))
|
||||
).first()
|
||||
# Check if provider is OpenHands and generate API key if needed
|
||||
if item and self._is_openhands_provider(item):
|
||||
await self._ensure_openhands_api_key(item)
|
||||
|
||||
with self.session_maker() as session:
|
||||
existing = None
|
||||
kwargs = {}
|
||||
if item:
|
||||
kwargs = item.model_dump(context={'expose_secrets': True})
|
||||
self._encrypt_kwargs(kwargs)
|
||||
# First check if we have an existing entry in the new table
|
||||
existing = self.get_user_settings_by_keycloak_id(self.user_id, session)
|
||||
|
||||
kwargs = {
|
||||
key: value
|
||||
for key, value in kwargs.items()
|
||||
if key in UserSettings.__table__.columns
|
||||
}
|
||||
if existing:
|
||||
# Update existing entry
|
||||
for key, value in kwargs.items():
|
||||
setattr(existing, key, value)
|
||||
existing.user_version = CURRENT_USER_SETTINGS_VERSION
|
||||
session.merge(existing)
|
||||
else:
|
||||
kwargs['keycloak_user_id'] = self.user_id
|
||||
kwargs['user_version'] = CURRENT_USER_SETTINGS_VERSION
|
||||
kwargs.pop('secrets_store', None) # Don't save secrets_store to db
|
||||
settings = UserSettings(**kwargs)
|
||||
session.add(settings)
|
||||
session.commit()
|
||||
|
||||
def _get_redis_client(self):
|
||||
"""Get the Redis client from the Socket.IO manager."""
|
||||
from openhands.server.shared import sio
|
||||
|
||||
return getattr(sio.manager, 'redis', None)
|
||||
|
||||
async def _acquire_user_creation_lock(self) -> bool:
|
||||
"""Attempt to acquire a distributed lock for user creation.
|
||||
|
||||
Returns True if the lock was acquired or if Redis is unavailable (fallback to no locking).
|
||||
Returns False if another process holds the lock.
|
||||
"""
|
||||
redis_client = self._get_redis_client()
|
||||
if redis_client is None:
|
||||
logger.warning(
|
||||
'saas_settings_store:_acquire_user_creation_lock:no_redis_client',
|
||||
extra={'user_id': self.user_id},
|
||||
)
|
||||
return True # Proceed without locking if Redis is unavailable
|
||||
|
||||
user_key = f'{_REDIS_USER_CREATION_KEY_PREFIX}{self.user_id}'
|
||||
lock_acquired = await redis_client.set(
|
||||
user_key, 1, nx=True, ex=_REDIS_CREATE_TIMEOUT_SECONDS
|
||||
)
|
||||
return bool(lock_acquired)
|
||||
|
||||
async def create_default_settings(self, user_settings: UserSettings | None):
|
||||
logger.info(
|
||||
'saas_settings_store:create_default_settings:start',
|
||||
extra={'user_id': self.user_id},
|
||||
)
|
||||
# You must log in before you get default settings
|
||||
if not self.user_id:
|
||||
return None
|
||||
|
||||
# Prevent duplicate settings creation using distributed lock
|
||||
if not await self._acquire_user_creation_lock():
|
||||
# The user is already being created in another thread / process
|
||||
logger.info(
|
||||
'saas_settings_store:create_default_settings:waiting_for_lock',
|
||||
extra={'user_id': self.user_id},
|
||||
)
|
||||
await asyncio.sleep(_RETRY_LOAD_DELAY_SECONDS)
|
||||
return await self.load()
|
||||
|
||||
# Only users that have specified a payment method get default settings
|
||||
if REQUIRE_PAYMENT and not await stripe_service.has_payment_method(
|
||||
self.user_id
|
||||
):
|
||||
logger.info(
|
||||
'saas_settings_store:create_default_settings:no_payment',
|
||||
extra={'user_id': self.user_id},
|
||||
)
|
||||
return None
|
||||
settings: Settings | None = None
|
||||
if user_settings is None:
|
||||
settings = Settings(
|
||||
language='en',
|
||||
enable_proactive_conversation_starters=True,
|
||||
)
|
||||
elif isinstance(user_settings, UserSettings):
|
||||
# Convert UserSettings (SQLAlchemy model) to Settings (Pydantic model)
|
||||
kwargs = {
|
||||
c.name: getattr(user_settings, c.name)
|
||||
for c in UserSettings.__table__.columns
|
||||
if c.name in Settings.model_fields
|
||||
}
|
||||
self._decrypt_kwargs(kwargs)
|
||||
settings = Settings(**kwargs)
|
||||
|
||||
if settings:
|
||||
settings = await self.update_settings_with_litellm_default(settings)
|
||||
if settings is None:
|
||||
logger.info(
|
||||
'saas_settings_store:create_default_settings:litellm_update_failed',
|
||||
extra={'user_id': self.user_id},
|
||||
)
|
||||
return None
|
||||
|
||||
await self.store(settings)
|
||||
return settings
|
||||
|
||||
async def load_legacy_file_store_settings(self, github_user_id: str):
|
||||
if not github_user_id:
|
||||
return None
|
||||
|
||||
file_store = get_file_store(self.config.file_store, self.config.file_store_path)
|
||||
path = f'users/github/{github_user_id}/settings.json'
|
||||
|
||||
try:
|
||||
json_str = await call_sync_from_async(file_store.read, path)
|
||||
logger.info(
|
||||
'saas_settings_store:load_legacy_file_store_settings:found',
|
||||
extra={'github_user_id': github_user_id},
|
||||
)
|
||||
kwargs = json.loads(json_str)
|
||||
self._decrypt_kwargs(kwargs)
|
||||
settings = Settings(**kwargs)
|
||||
return settings
|
||||
except FileNotFoundError:
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
'saas_settings_store:load_legacy_file_store_settings:error',
|
||||
extra={'github_user_id': github_user_id, 'error': str(e)},
|
||||
)
|
||||
return None
|
||||
|
||||
def _has_custom_settings(
|
||||
self, settings: Settings, old_user_version: int | None
|
||||
) -> bool:
|
||||
"""
|
||||
Check if user has custom LLM settings that should be preserved.
|
||||
Returns True if user customized either model or base_url.
|
||||
|
||||
Args:
|
||||
settings: The user's current settings
|
||||
old_user_version: The user's old settings version, if any
|
||||
|
||||
Returns:
|
||||
True if user has custom settings, False if using old defaults
|
||||
"""
|
||||
# Normalize values
|
||||
user_model = (
|
||||
settings.llm_model.strip()
|
||||
if settings.llm_model and settings.llm_model.strip()
|
||||
else None
|
||||
)
|
||||
user_base_url = (
|
||||
settings.llm_base_url.strip()
|
||||
if settings.llm_base_url and settings.llm_base_url.strip()
|
||||
else None
|
||||
)
|
||||
|
||||
# Custom base_url = definitely custom settings (BYOK)
|
||||
if user_base_url and user_base_url != LITE_LLM_API_URL:
|
||||
return True
|
||||
|
||||
# No model set = using defaults
|
||||
if not user_model:
|
||||
return False
|
||||
|
||||
# Check if model matches old version's default
|
||||
if (
|
||||
old_user_version
|
||||
and old_user_version < CURRENT_USER_SETTINGS_VERSION
|
||||
and old_user_version in USER_SETTINGS_VERSION_TO_MODEL
|
||||
):
|
||||
old_default_base = USER_SETTINGS_VERSION_TO_MODEL[old_user_version]
|
||||
user_model_base = user_model.split('/')[-1]
|
||||
if user_model_base == old_default_base:
|
||||
return False # Matches old default
|
||||
|
||||
return True # Custom model
|
||||
|
||||
async def update_settings_with_litellm_default(
|
||||
self, settings: Settings
|
||||
) -> Settings | None:
|
||||
logger.info(
|
||||
'saas_settings_store:update_settings_with_litellm_default:start',
|
||||
extra={'user_id': self.user_id},
|
||||
)
|
||||
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
|
||||
return None
|
||||
local_deploy = os.environ.get('LOCAL_DEPLOYMENT', None)
|
||||
key = LITE_LLM_API_KEY
|
||||
|
||||
# Check if user has custom settings
|
||||
has_custom = self._has_custom_settings(settings, settings.user_version)
|
||||
|
||||
# Determine model to use (needed before LiteLLM user creation)
|
||||
llm_model_to_use = (
|
||||
settings.llm_model
|
||||
if has_custom and settings.llm_model
|
||||
else get_default_litellm_model()
|
||||
)
|
||||
|
||||
if not local_deploy:
|
||||
# Get user info to add to litellm
|
||||
token_manager = TokenManager()
|
||||
keycloak_user_info = (
|
||||
await token_manager.get_user_info_from_user_id(self.user_id) or {}
|
||||
)
|
||||
|
||||
async with httpx.AsyncClient(
|
||||
verify=httpx_verify_option(),
|
||||
headers={
|
||||
'x-goog-api-key': LITE_LLM_API_KEY,
|
||||
},
|
||||
) as client:
|
||||
# Get the previous max budget to prevent accidental loss.
|
||||
#
|
||||
# LiteLLM v1.80+ returns 404 for non-existent users (previously returned empty user_info)
|
||||
response = await client.get(
|
||||
f'{LITE_LLM_API_URL}/user/info?user_id={self.user_id}'
|
||||
)
|
||||
user_info: dict
|
||||
if response.status_code == 404:
|
||||
# New user - doesn't exist in LiteLLM yet (v1.80+ behavior)
|
||||
user_info = {}
|
||||
else:
|
||||
# For any other status, use standard error handling
|
||||
response.raise_for_status()
|
||||
response_json = response.json()
|
||||
user_info = response_json.get('user_info') or {}
|
||||
logger.info(
|
||||
f'creating_litellm_user: {self.user_id}; prev_max_budget: {user_info.get("max_budget")}; prev_metadata: {user_info.get("metadata")}'
|
||||
)
|
||||
max_budget = user_info.get('max_budget') or DEFAULT_INITIAL_BUDGET
|
||||
spend = user_info.get('spend') or 0
|
||||
|
||||
if not user:
|
||||
# Check if we need to migrate from user_settings
|
||||
user_settings = None
|
||||
with session_maker() as session:
|
||||
user_settings = self._get_user_settings_by_keycloak_id(
|
||||
user_settings = self.get_user_settings_by_keycloak_id(
|
||||
self.user_id, session
|
||||
)
|
||||
if user_settings:
|
||||
user = await UserStore.migrate_user(self.user_id, user_settings)
|
||||
else:
|
||||
logger.error(f'User not found for ID {self.user_id}')
|
||||
# In upgrade to V4, we no longer use billing margin, but instead apply this directly
|
||||
# in litellm. The default billing marign was 2 before this (hence the magic numbers below)
|
||||
if (
|
||||
user_settings
|
||||
and user_settings.user_version < 4
|
||||
and user_settings.billing_margin
|
||||
and user_settings.billing_margin != 1.0
|
||||
):
|
||||
billing_margin = user_settings.billing_margin
|
||||
logger.info(
|
||||
'user_settings_v4_budget_upgrade',
|
||||
extra={
|
||||
'max_budget': max_budget,
|
||||
'billing_margin': billing_margin,
|
||||
'spend': spend,
|
||||
},
|
||||
)
|
||||
max_budget *= billing_margin
|
||||
spend *= billing_margin
|
||||
user_settings.billing_margin = 1.0
|
||||
session.commit()
|
||||
|
||||
email = keycloak_user_info.get('email')
|
||||
|
||||
# We explicitly delete here to guard against odd inherited settings on upgrade.
|
||||
# We don't care if this fails with a 404
|
||||
await client.post(
|
||||
f'{LITE_LLM_API_URL}/user/delete', json={'user_ids': [self.user_id]}
|
||||
)
|
||||
|
||||
# Create the new litellm user
|
||||
response = await self._create_user_in_lite_llm(
|
||||
client, email, max_budget, spend, llm_model_to_use
|
||||
)
|
||||
if not response.is_success:
|
||||
logger.warning(
|
||||
'duplicate_user_email',
|
||||
extra={'user_id': self.user_id, 'email': email},
|
||||
)
|
||||
# Litellm insists on unique email addresses - it is possible the email address was registered with a different user.
|
||||
response = await self._create_user_in_lite_llm(
|
||||
client, None, max_budget, spend, llm_model_to_use
|
||||
)
|
||||
|
||||
# User failed to create in litellm - this is an unforseen error state...
|
||||
if not response.is_success:
|
||||
logger.error(
|
||||
'error_creating_litellm_user',
|
||||
extra={
|
||||
'status_code': response.status_code,
|
||||
'text': response.text,
|
||||
'user_id': [self.user_id],
|
||||
'email': email,
|
||||
'max_budget': max_budget,
|
||||
'spend': spend,
|
||||
},
|
||||
)
|
||||
return None
|
||||
|
||||
org_id = user.current_org_id
|
||||
# Check if provider is OpenHands and generate API key if needed
|
||||
if self._is_openhands_provider(item):
|
||||
await self._ensure_openhands_api_key(item, str(org_id))
|
||||
org_member = None
|
||||
for om in user.org_members:
|
||||
if om.org_id == org_id:
|
||||
org_member = om
|
||||
break
|
||||
if not org_member or not org_member.llm_api_key:
|
||||
return None
|
||||
org = session.query(Org).filter(Org.id == org_id).first()
|
||||
if not org:
|
||||
logger.error(
|
||||
f'Org not found for ID {org_id} as the current org for user {self.user_id}'
|
||||
response_json = response.json()
|
||||
key = response_json['key']
|
||||
|
||||
logger.info(
|
||||
'saas_settings_store:update_settings_with_litellm_default:user_created',
|
||||
extra={'user_id': self.user_id},
|
||||
)
|
||||
return None
|
||||
|
||||
for model in (user, org, org_member):
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(model, key):
|
||||
setattr(model, key, value)
|
||||
if has_custom:
|
||||
settings.llm_model = settings.llm_model or get_default_litellm_model()
|
||||
settings.llm_base_url = settings.llm_base_url or LITE_LLM_API_URL
|
||||
settings.llm_api_key = settings.llm_api_key or SecretStr(key)
|
||||
else:
|
||||
settings.llm_model = get_default_litellm_model()
|
||||
settings.llm_base_url = LITE_LLM_API_URL
|
||||
settings.llm_api_key = SecretStr(key)
|
||||
|
||||
session.commit()
|
||||
settings.agent = 'CodeActAgent'
|
||||
|
||||
return settings
|
||||
|
||||
@classmethod
|
||||
async def get_instance(
|
||||
@@ -177,9 +443,6 @@ class SaasSettingsStore(SettingsStore):
|
||||
logger.debug(f'saas_settings_store.get_instance::{user_id}')
|
||||
return SaasSettingsStore(user_id, session_maker, config)
|
||||
|
||||
def _should_encrypt(self, key):
|
||||
return key in self.ENCRYPT_VALUES
|
||||
|
||||
def _decrypt_kwargs(self, kwargs: dict):
|
||||
fernet = self._fernet()
|
||||
for key, value in kwargs.items():
|
||||
@@ -223,24 +486,21 @@ class SaasSettingsStore(SettingsStore):
|
||||
fernet_key = b64encode(hashlib.sha256(jwt_secret.encode()).digest())
|
||||
return Fernet(fernet_key)
|
||||
|
||||
def _should_encrypt(self, key: str) -> bool:
|
||||
return key in ('llm_api_key', 'llm_api_key_for_byor', 'search_api_key')
|
||||
|
||||
def _is_openhands_provider(self, item: Settings) -> bool:
|
||||
"""Check if the settings use the OpenHands provider."""
|
||||
return bool(item.llm_model and item.llm_model.startswith('openhands/'))
|
||||
|
||||
async def _ensure_openhands_api_key(self, item: Settings, org_id: str) -> None:
|
||||
async def _ensure_openhands_api_key(self, item: Settings) -> None:
|
||||
"""Generate and set the OpenHands API key for the given settings.
|
||||
|
||||
First checks if an existing key with the OpenHands alias exists,
|
||||
and reuses it if found. Otherwise, generates a new key.
|
||||
"""
|
||||
# Generate new key if none exists
|
||||
generated_key = await LiteLlmManager.generate_key(
|
||||
self.user_id,
|
||||
org_id,
|
||||
None,
|
||||
{'type': 'openhands'},
|
||||
)
|
||||
|
||||
generated_key = await self._generate_openhands_key()
|
||||
if generated_key:
|
||||
item.llm_api_key = SecretStr(generated_key)
|
||||
logger.info(
|
||||
@@ -252,3 +512,83 @@ class SaasSettingsStore(SettingsStore):
|
||||
'saas_settings_store:store:failed_to_generate_openhands_key',
|
||||
extra={'user_id': self.user_id},
|
||||
)
|
||||
|
||||
async def _create_user_in_lite_llm(
|
||||
self,
|
||||
client: httpx.AsyncClient,
|
||||
email: str | None,
|
||||
max_budget: int,
|
||||
spend: int,
|
||||
llm_model: str,
|
||||
):
|
||||
response = await client.post(
|
||||
f'{LITE_LLM_API_URL}/user/new',
|
||||
json={
|
||||
'user_email': email,
|
||||
'models': [],
|
||||
'max_budget': max_budget,
|
||||
'spend': spend,
|
||||
'user_id': str(self.user_id),
|
||||
'teams': [LITE_LLM_TEAM_ID],
|
||||
'auto_create_key': True,
|
||||
'send_invite_email': False,
|
||||
'metadata': {
|
||||
'version': CURRENT_USER_SETTINGS_VERSION,
|
||||
'model': llm_model,
|
||||
},
|
||||
'key_alias': f'OpenHands Cloud - user {self.user_id}',
|
||||
},
|
||||
)
|
||||
return response
|
||||
|
||||
async def _generate_openhands_key(self) -> str | None:
|
||||
"""Generate a new OpenHands provider key for a user."""
|
||||
if not (LITE_LLM_API_KEY and LITE_LLM_API_URL):
|
||||
logger.warning(
|
||||
'saas_settings_store:_generate_openhands_key:litellm_config_not_found',
|
||||
extra={'user_id': self.user_id},
|
||||
)
|
||||
return None
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(
|
||||
verify=httpx_verify_option(),
|
||||
headers={
|
||||
'x-goog-api-key': LITE_LLM_API_KEY,
|
||||
},
|
||||
) as client:
|
||||
response = await client.post(
|
||||
f'{LITE_LLM_API_URL}/key/generate',
|
||||
json={
|
||||
'user_id': self.user_id,
|
||||
'metadata': {'type': 'openhands'},
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
response_json = response.json()
|
||||
key = response_json.get('key')
|
||||
|
||||
if key:
|
||||
logger.info(
|
||||
'saas_settings_store:_generate_openhands_key:success',
|
||||
extra={
|
||||
'user_id': self.user_id,
|
||||
'key_length': len(key) if key else 0,
|
||||
'key_prefix': (
|
||||
key[:10] + '...' if key and len(key) > 10 else key
|
||||
),
|
||||
},
|
||||
)
|
||||
return key
|
||||
else:
|
||||
logger.error(
|
||||
'saas_settings_store:_generate_openhands_key:no_key_in_response',
|
||||
extra={'user_id': self.user_id, 'response_json': response_json},
|
||||
)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
'saas_settings_store:_generate_openhands_key:error',
|
||||
extra={'user_id': self.user_id, 'error': str(e)},
|
||||
)
|
||||
return None
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
from sqlalchemy import Boolean, Column, ForeignKey, Identity, Integer, String
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy import Boolean, Column, Identity, Integer, String
|
||||
from storage.base import Base
|
||||
|
||||
|
||||
@@ -10,9 +8,5 @@ class SlackConversation(Base): # type: ignore
|
||||
conversation_id = Column(String, nullable=False, index=True)
|
||||
channel_id = Column(String, nullable=False)
|
||||
keycloak_user_id = Column(String, nullable=False)
|
||||
org_id = Column(UUID(as_uuid=True), ForeignKey('org.id'), nullable=True)
|
||||
parent_id = Column(String, nullable=True, index=True)
|
||||
v1_enabled = Column(Boolean, nullable=True)
|
||||
|
||||
# Relationships
|
||||
org = relationship('Org', back_populates='slack_conversations')
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
from sqlalchemy import Column, DateTime, ForeignKey, Identity, Integer, String, text
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy import Column, DateTime, Identity, Integer, String, text
|
||||
from storage.base import Base
|
||||
|
||||
|
||||
@@ -8,7 +6,6 @@ class SlackUser(Base): # type: ignore
|
||||
__tablename__ = 'slack_users'
|
||||
id = Column(Integer, Identity(), primary_key=True)
|
||||
keycloak_user_id = Column(String, nullable=False, index=True)
|
||||
org_id = Column(UUID(as_uuid=True), ForeignKey('org.id'), nullable=True)
|
||||
slack_user_id = Column(String, nullable=False, index=True)
|
||||
slack_display_name = Column(String, nullable=False)
|
||||
created_at = Column(
|
||||
@@ -16,6 +13,3 @@ class SlackUser(Base): # type: ignore
|
||||
server_default=text('CURRENT_TIMESTAMP'),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# Relationships
|
||||
org = relationship('Org', back_populates='slack_users')
|
||||
|
||||
@@ -4,4 +4,5 @@ from openhands.app_server.app_conversation.sql_app_conversation_info_service imp
|
||||
|
||||
StoredConversationMetadata = _StoredConversationMetadata
|
||||
|
||||
|
||||
__all__ = ['StoredConversationMetadata']
|
||||
|
||||
@@ -1,28 +0,0 @@
|
||||
"""
|
||||
SQLAlchemy model for ConversationMetadataSaas.
|
||||
|
||||
This model stores the SaaS-specific metadata for conversations,
|
||||
containing only the conversation_id, user_id, and org_id.
|
||||
"""
|
||||
|
||||
from sqlalchemy import UUID as SQL_UUID
|
||||
from sqlalchemy import Column, ForeignKey, String
|
||||
from sqlalchemy.orm import relationship
|
||||
from storage.base import Base
|
||||
|
||||
|
||||
class StoredConversationMetadataSaas(Base): # type: ignore
|
||||
"""SaaS conversation metadata model containing user and org associations."""
|
||||
|
||||
__tablename__ = 'conversation_metadata_saas'
|
||||
|
||||
conversation_id = Column(String, primary_key=True)
|
||||
user_id = Column(SQL_UUID(as_uuid=True), ForeignKey('user.id'), nullable=False)
|
||||
org_id = Column(SQL_UUID(as_uuid=True), ForeignKey('org.id'), nullable=False)
|
||||
|
||||
# Relationships
|
||||
user = relationship('User', back_populates='stored_conversation_metadata_saas')
|
||||
org = relationship('Org', back_populates='stored_conversation_metadata_saas')
|
||||
|
||||
|
||||
__all__ = ['StoredConversationMetadataSaas']
|
||||
@@ -1,6 +1,4 @@
|
||||
from sqlalchemy import Column, ForeignKey, Identity, Integer, String
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy import Column, Identity, Integer, String
|
||||
from storage.base import Base
|
||||
|
||||
|
||||
@@ -8,10 +6,6 @@ class StoredCustomSecrets(Base): # type: ignore
|
||||
__tablename__ = 'custom_secrets'
|
||||
id = Column(Integer, Identity(), primary_key=True)
|
||||
keycloak_user_id = Column(String, nullable=True, index=True)
|
||||
org_id = Column(UUID(as_uuid=True), ForeignKey('org.id'), nullable=True)
|
||||
secret_name = Column(String, nullable=False)
|
||||
secret_value = Column(String, nullable=False)
|
||||
description = Column(String, nullable=True)
|
||||
|
||||
# Relationships
|
||||
org = relationship('Org', back_populates='user_secrets')
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
from sqlalchemy import Column, DateTime, ForeignKey, Integer, String, text
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy import Column, DateTime, Integer, String, text
|
||||
from storage.base import Base
|
||||
|
||||
|
||||
@@ -15,7 +13,6 @@ class StripeCustomer(Base): # type: ignore
|
||||
__tablename__ = 'stripe_customers'
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
keycloak_user_id = Column(String, nullable=False)
|
||||
org_id = Column(UUID(as_uuid=True), ForeignKey('org.id'), nullable=True)
|
||||
stripe_customer_id = Column(String, nullable=False)
|
||||
created_at = Column(
|
||||
DateTime, server_default=text('CURRENT_TIMESTAMP'), nullable=False
|
||||
@@ -26,6 +23,3 @@ class StripeCustomer(Base): # type: ignore
|
||||
onupdate=text('CURRENT_TIMESTAMP'),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# Relationships
|
||||
org = relationship('Org', back_populates='stripe_customers')
|
||||
|
||||
@@ -1,43 +0,0 @@
|
||||
"""
|
||||
SQLAlchemy model for User.
|
||||
"""
|
||||
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy import (
|
||||
UUID,
|
||||
Boolean,
|
||||
Column,
|
||||
DateTime,
|
||||
ForeignKey,
|
||||
Integer,
|
||||
String,
|
||||
)
|
||||
from sqlalchemy.orm import relationship
|
||||
from storage.base import Base
|
||||
|
||||
|
||||
class User(Base): # type: ignore
|
||||
"""User model with organizational relationships."""
|
||||
|
||||
__tablename__ = 'user'
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid4)
|
||||
current_org_id = Column(UUID(as_uuid=True), ForeignKey('org.id'), nullable=False)
|
||||
role_id = Column(Integer, ForeignKey('role.id'), nullable=True)
|
||||
accepted_tos = Column(DateTime, nullable=True)
|
||||
enable_sound_notifications = Column(Boolean, nullable=True)
|
||||
language = Column(String, nullable=True)
|
||||
user_consents_to_analytics = Column(Boolean, nullable=True)
|
||||
email = Column(String, nullable=True)
|
||||
email_verified = Column(Boolean, nullable=True)
|
||||
git_user_name = Column(String, nullable=True)
|
||||
git_user_email = Column(String, nullable=True)
|
||||
|
||||
# Relationships
|
||||
role = relationship('Role', back_populates='users')
|
||||
org_members = relationship('OrgMember', back_populates='user')
|
||||
current_org = relationship('Org', back_populates='current_users')
|
||||
stored_conversation_metadata_saas = relationship(
|
||||
'StoredConversationMetadataSaas', back_populates='user'
|
||||
)
|
||||
@@ -39,6 +39,3 @@ class UserSettings(Base): # type: ignore
|
||||
git_user_name = Column(String, nullable=True)
|
||||
git_user_email = Column(String, nullable=True)
|
||||
v1_enabled = Column(Boolean, nullable=True)
|
||||
already_migrated = Column(
|
||||
Boolean, nullable=True, default=False
|
||||
) # False = not migrated, True = migrated
|
||||
|
||||
@@ -1,909 +0,0 @@
|
||||
"""
|
||||
Store class for managing users.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from server.auth.token_manager import TokenManager
|
||||
from server.constants import (
|
||||
LITE_LLM_API_URL,
|
||||
ORG_SETTINGS_VERSION,
|
||||
PERSONAL_WORKSPACE_VERSION_TO_MODEL,
|
||||
get_default_litellm_model,
|
||||
)
|
||||
from server.logger import logger
|
||||
from sqlalchemy import select, text
|
||||
from sqlalchemy.orm import joinedload
|
||||
from storage.database import a_session_maker, session_maker
|
||||
from storage.encrypt_utils import (
|
||||
decrypt_legacy_model,
|
||||
encrypt_legacy_value,
|
||||
)
|
||||
from storage.org import Org
|
||||
from storage.org_member import OrgMember
|
||||
from storage.role_store import RoleStore
|
||||
from storage.user import User
|
||||
from storage.user_settings import UserSettings
|
||||
|
||||
from openhands.utils.async_utils import GENERAL_TIMEOUT, call_async_from_sync
|
||||
|
||||
# The max possible time to wait for another process to finish creating a user before retrying
|
||||
_REDIS_CREATE_TIMEOUT_SECONDS = 30
|
||||
# The delay to wait for another process to finish creating a user before trying to load again
|
||||
_RETRY_LOAD_DELAY_SECONDS = 2
|
||||
# Redis key prefix for user creation locks
|
||||
_REDIS_USER_CREATION_KEY_PREFIX = 'create_user:'
|
||||
|
||||
|
||||
class UserStore:
|
||||
"""Store for managing users."""
|
||||
|
||||
@staticmethod
|
||||
async def create_user(
|
||||
user_id: str,
|
||||
user_info: dict,
|
||||
role_id: Optional[int] = None,
|
||||
) -> User | None:
|
||||
"""Create a new user."""
|
||||
with session_maker() as session:
|
||||
# create personal org
|
||||
org = Org(
|
||||
id=uuid.UUID(user_id),
|
||||
name=f'user_{user_id}_org',
|
||||
contact_name=user_info['preferred_username'],
|
||||
contact_email=user_info['email'],
|
||||
v1_enabled=True,
|
||||
)
|
||||
session.add(org)
|
||||
|
||||
settings = await UserStore.create_default_settings(
|
||||
org_id=str(org.id), user_id=user_id
|
||||
)
|
||||
|
||||
if not settings:
|
||||
return None
|
||||
|
||||
from storage.org_store import OrgStore
|
||||
|
||||
org_kwargs = OrgStore.get_kwargs_from_settings(settings)
|
||||
for key, value in org_kwargs.items():
|
||||
if hasattr(org, key):
|
||||
setattr(org, key, value)
|
||||
|
||||
user_kwargs = UserStore.get_kwargs_from_settings(settings)
|
||||
user = User(
|
||||
id=uuid.UUID(user_id),
|
||||
current_org_id=org.id,
|
||||
role_id=role_id,
|
||||
**user_kwargs,
|
||||
)
|
||||
session.add(user)
|
||||
|
||||
role = RoleStore.get_role_by_name('owner')
|
||||
|
||||
from storage.org_member_store import OrgMemberStore
|
||||
|
||||
org_member_kwargs = OrgMemberStore.get_kwargs_from_settings(settings)
|
||||
# avoid setting org member llm fields to use org defaults on user creation
|
||||
del org_member_kwargs['llm_model']
|
||||
del org_member_kwargs['llm_base_url']
|
||||
org_member = OrgMember(
|
||||
org_id=org.id,
|
||||
user_id=user.id,
|
||||
role_id=role.id, # owner of your own org.
|
||||
status='active',
|
||||
**org_member_kwargs,
|
||||
)
|
||||
session.add(org_member)
|
||||
session.commit()
|
||||
session.refresh(user)
|
||||
user.org_members # load org_members
|
||||
return user
|
||||
|
||||
@staticmethod
|
||||
def _get_redis_client():
|
||||
"""Get the Redis client from the Socket.IO manager."""
|
||||
from openhands.server.shared import sio
|
||||
|
||||
return getattr(sio.manager, 'redis', None)
|
||||
|
||||
@staticmethod
|
||||
async def _acquire_user_creation_lock(user_id: str) -> bool:
|
||||
"""Attempt to acquire a distributed lock for user creation.
|
||||
|
||||
Returns True if the lock was acquired or if Redis is unavailable (fallback to no locking).
|
||||
Returns False if another process holds the lock.
|
||||
"""
|
||||
redis_client = UserStore._get_redis_client()
|
||||
if redis_client is None:
|
||||
logger.warning(
|
||||
'user_store:_acquire_user_creation_lock:no_redis_client',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
return True # Proceed without locking if Redis is unavailable
|
||||
|
||||
user_key = f'{_REDIS_USER_CREATION_KEY_PREFIX}{user_id}'
|
||||
lock_acquired = await redis_client.set(
|
||||
user_key, 1, nx=True, ex=_REDIS_CREATE_TIMEOUT_SECONDS
|
||||
)
|
||||
return bool(lock_acquired)
|
||||
|
||||
@staticmethod
|
||||
async def migrate_user(
|
||||
user_id: str,
|
||||
user_settings: UserSettings,
|
||||
user_info: dict,
|
||||
) -> User:
|
||||
if not user_id or not user_settings:
|
||||
return None
|
||||
|
||||
kwargs = decrypt_legacy_model(
|
||||
[
|
||||
'llm_api_key',
|
||||
'llm_api_key_for_byor',
|
||||
'search_api_key',
|
||||
'sandbox_api_key',
|
||||
],
|
||||
user_settings,
|
||||
)
|
||||
decrypted_user_settings = UserSettings(**kwargs)
|
||||
with session_maker() as session:
|
||||
# create personal org
|
||||
org = Org(
|
||||
id=uuid.UUID(user_id),
|
||||
name=f'user_{user_id}_org',
|
||||
org_version=user_settings.user_version,
|
||||
contact_name=user_info['username'],
|
||||
contact_email=user_info['email'],
|
||||
)
|
||||
session.add(org)
|
||||
|
||||
from storage.lite_llm_manager import LiteLlmManager
|
||||
|
||||
logger.debug(
|
||||
'user_store:migrate_user:calling_litellm_migrate_entries',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
await LiteLlmManager.migrate_entries(
|
||||
str(org.id),
|
||||
user_id,
|
||||
decrypted_user_settings,
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
'user_store:migrate_user:done_litellm_migrate_entries',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
custom_settings = UserStore._has_custom_settings(
|
||||
decrypted_user_settings, user_settings.user_version
|
||||
)
|
||||
|
||||
# avoids circular reference. This migrate method is temprorary until all users are migrated.
|
||||
from integrations.stripe_service import migrate_customer
|
||||
|
||||
logger.debug(
|
||||
'user_store:migrate_user:calling_stripe_migrate_customer',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
await migrate_customer(session, user_id, org)
|
||||
logger.debug(
|
||||
'user_store:migrate_user:done_stripe_migrate_customer',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
|
||||
from storage.org_store import OrgStore
|
||||
|
||||
org_kwargs = OrgStore.get_kwargs_from_user_settings(decrypted_user_settings)
|
||||
org_kwargs.pop('id', None)
|
||||
|
||||
# if user has custom settings, set org defaults to current version
|
||||
if custom_settings:
|
||||
org_kwargs['default_llm_model'] = get_default_litellm_model()
|
||||
org_kwargs['llm_base_url'] = LITE_LLM_API_URL
|
||||
org_kwargs['org_version'] = ORG_SETTINGS_VERSION
|
||||
|
||||
for key, value in org_kwargs.items():
|
||||
if hasattr(org, key):
|
||||
setattr(org, key, value)
|
||||
|
||||
user_kwargs = UserStore.get_kwargs_from_user_settings(
|
||||
decrypted_user_settings
|
||||
)
|
||||
user_kwargs.pop('id', None)
|
||||
user = User(
|
||||
id=uuid.UUID(user_id),
|
||||
current_org_id=org.id,
|
||||
role_id=None,
|
||||
**user_kwargs,
|
||||
)
|
||||
session.add(user)
|
||||
|
||||
logger.debug(
|
||||
'user_store:migrate_user:calling_get_role_by_name',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
role = await RoleStore.get_role_by_name_async('owner')
|
||||
logger.debug(
|
||||
'user_store:migrate_user:done_get_role_by_name',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
|
||||
from storage.org_member_store import OrgMemberStore
|
||||
|
||||
org_member_kwargs = OrgMemberStore.get_kwargs_from_user_settings(
|
||||
decrypted_user_settings
|
||||
)
|
||||
|
||||
# if the user did not have custom settings in the old model,
|
||||
# then use the org defaults by not setting org_member fields
|
||||
if not custom_settings:
|
||||
del org_member_kwargs['llm_model']
|
||||
del org_member_kwargs['llm_base_url']
|
||||
|
||||
org_member = OrgMember(
|
||||
org_id=org.id,
|
||||
user_id=user.id,
|
||||
role_id=role.id, # owner of your own org.
|
||||
status='active',
|
||||
**org_member_kwargs,
|
||||
)
|
||||
session.add(org_member)
|
||||
|
||||
# Mark the old user_settings as migrated instead of deleting
|
||||
user_settings.already_migrated = True
|
||||
session.merge(user_settings)
|
||||
session.flush()
|
||||
logger.debug(
|
||||
'user_store:migrate_user:session_flush_complete',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
|
||||
# need to migrate conversation metadata
|
||||
session.execute(
|
||||
text("""
|
||||
INSERT INTO conversation_metadata_saas (conversation_id, user_id, org_id)
|
||||
SELECT
|
||||
conversation_id,
|
||||
:user_id,
|
||||
:user_id
|
||||
FROM conversation_metadata
|
||||
WHERE user_id = :user_id
|
||||
"""),
|
||||
{'user_id': user_id},
|
||||
)
|
||||
|
||||
# Update org_id for tables that had org_id added
|
||||
user_uuid = uuid.UUID(user_id)
|
||||
|
||||
# Update stripe_customers
|
||||
session.execute(
|
||||
text(
|
||||
'UPDATE stripe_customers SET org_id = :org_id WHERE keycloak_user_id = :user_id'
|
||||
),
|
||||
{'org_id': user_uuid, 'user_id': user_uuid},
|
||||
)
|
||||
|
||||
# Update slack_users
|
||||
session.execute(
|
||||
text(
|
||||
'UPDATE slack_users SET org_id = :org_id WHERE keycloak_user_id = :user_id'
|
||||
),
|
||||
{'org_id': user_uuid, 'user_id': user_uuid},
|
||||
)
|
||||
|
||||
# Update slack_conversation
|
||||
session.execute(
|
||||
text(
|
||||
'UPDATE slack_conversation SET org_id = :org_id WHERE keycloak_user_id = :user_id'
|
||||
),
|
||||
{'org_id': user_uuid, 'user_id': user_uuid},
|
||||
)
|
||||
|
||||
# Update api_keys
|
||||
session.execute(
|
||||
text('UPDATE api_keys SET org_id = :org_id WHERE user_id = :user_id'),
|
||||
{'org_id': user_uuid, 'user_id': user_uuid},
|
||||
)
|
||||
|
||||
# Update custom_secrets
|
||||
session.execute(
|
||||
text(
|
||||
'UPDATE custom_secrets SET org_id = :org_id WHERE keycloak_user_id = :user_id'
|
||||
),
|
||||
{'org_id': user_uuid, 'user_id': user_uuid},
|
||||
)
|
||||
|
||||
# Update billing_sessions
|
||||
session.execute(
|
||||
text(
|
||||
'UPDATE billing_sessions SET org_id = :org_id WHERE user_id = :user_id'
|
||||
),
|
||||
{'org_id': user_uuid, 'user_id': user_uuid},
|
||||
)
|
||||
|
||||
session.commit()
|
||||
session.refresh(user)
|
||||
user.org_members # load org_members
|
||||
logger.debug(
|
||||
'user_store:migrate_user:session_committed',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
return user
|
||||
|
||||
@staticmethod
|
||||
async def downgrade_user(user_id: str) -> UserSettings | None:
|
||||
"""Downgrade a migrated user back to the pre-migration state.
|
||||
|
||||
This reverses the migrate_user operation:
|
||||
1. Get the user's settings from user_settings table (migrated users) or
|
||||
create new user_settings from org_members table (new sign-ups)
|
||||
2. Call LiteLlmManager.downgrade_entries to revert LiteLLM state
|
||||
3. Copy user_id from conversation_metadata_saas to conversation_metadata
|
||||
4. Delete conversation_metadata_saas entries
|
||||
5. Reset org_id columns in related tables (stripe_customers, slack_users, etc.)
|
||||
6. Delete the org_member and org entries
|
||||
7. Delete the user entry
|
||||
8. Set already_migrated=False on user_settings
|
||||
|
||||
For new sign-ups (users who registered after migration was deployed),
|
||||
there won't be an existing user_settings entry. In this case, we fall back
|
||||
to the org_members table to get the user's API keys and settings, and create
|
||||
a new user_settings entry for them.
|
||||
|
||||
Args:
|
||||
user_id: The Keycloak user ID to downgrade
|
||||
|
||||
Returns:
|
||||
The user_settings if downgrade was successful, None otherwise.
|
||||
Returns None if the org has multiple members (not a personal org).
|
||||
"""
|
||||
logger.info(
|
||||
'user_store:downgrade_user:start',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
|
||||
with session_maker() as session:
|
||||
# Get the user and their org_member
|
||||
user = (
|
||||
session.query(User)
|
||||
.options(joinedload(User.org_members))
|
||||
.filter(User.id == uuid.UUID(user_id))
|
||||
.first()
|
||||
)
|
||||
if not user:
|
||||
logger.warning(
|
||||
'user_store:downgrade_user:user_not_found',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
return None
|
||||
|
||||
# Get the user's personal org (org_id == user_id)
|
||||
org = session.query(Org).filter(Org.id == uuid.UUID(user_id)).first()
|
||||
if not org:
|
||||
logger.warning(
|
||||
'user_store:downgrade_user:org_not_found',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
return None
|
||||
|
||||
# Get the user_settings (for migrated users)
|
||||
user_settings = (
|
||||
session.query(UserSettings)
|
||||
.filter(
|
||||
UserSettings.keycloak_user_id == user_id,
|
||||
UserSettings.already_migrated.is_(True),
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
# For new sign-ups after migration, user_settings won't exist
|
||||
# Fall back to getting data from org_members
|
||||
is_new_signup = False
|
||||
if not user_settings:
|
||||
logger.info(
|
||||
'user_store:downgrade_user:user_settings_not_found_checking_org_members',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
# Get org_members for this org - should only be one for personal orgs
|
||||
org_members = (
|
||||
session.query(OrgMember).filter(OrgMember.org_id == org.id).all()
|
||||
)
|
||||
|
||||
if len(org_members) != 1:
|
||||
logger.error(
|
||||
'user_store:downgrade_user:unexpected_org_members_count',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'org_id': str(org.id),
|
||||
'org_members_count': len(org_members),
|
||||
},
|
||||
)
|
||||
return None
|
||||
|
||||
org_member = org_members[0]
|
||||
is_new_signup = True
|
||||
|
||||
# Create a new user_settings entry from OrgMember, User, and Org data
|
||||
# This is needed for new sign-ups who don't have user_settings
|
||||
user_settings = UserStore._create_user_settings_from_entities(
|
||||
user_id, org_member, user, org
|
||||
)
|
||||
session.add(user_settings)
|
||||
session.flush()
|
||||
logger.info(
|
||||
'user_store:downgrade_user:created_user_settings_from_org_member',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
|
||||
# Call LiteLLM downgrade
|
||||
from storage.lite_llm_manager import LiteLlmManager
|
||||
|
||||
logger.debug(
|
||||
'user_store:downgrade_user:calling_litellm_downgrade_entries',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
|
||||
# Get the API keys for LiteLLM downgrade
|
||||
if is_new_signup:
|
||||
# For new signups, we already have decrypted values in user_settings
|
||||
decrypted_user_settings = user_settings
|
||||
else:
|
||||
# For migrated users, decrypt the legacy model
|
||||
kwargs = decrypt_legacy_model(
|
||||
[
|
||||
'llm_api_key',
|
||||
'llm_api_key_for_byor',
|
||||
'search_api_key',
|
||||
'sandbox_api_key',
|
||||
],
|
||||
user_settings,
|
||||
)
|
||||
decrypted_user_settings = UserSettings(**kwargs)
|
||||
|
||||
await LiteLlmManager.downgrade_entries(
|
||||
str(org.id),
|
||||
user_id,
|
||||
decrypted_user_settings,
|
||||
)
|
||||
logger.debug(
|
||||
'user_store:downgrade_user:done_litellm_downgrade_entries',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
|
||||
user_uuid = uuid.UUID(user_id)
|
||||
|
||||
# Step 3: Copy user_id from conversation_metadata_saas to conversation_metadata
|
||||
# This ensures any conversations created after migration have their user_id
|
||||
# preserved in the original table before we delete the saas entries
|
||||
session.execute(
|
||||
text("""
|
||||
UPDATE conversation_metadata
|
||||
SET user_id = :user_id
|
||||
WHERE conversation_id IN (
|
||||
SELECT conversation_id
|
||||
FROM conversation_metadata_saas
|
||||
WHERE user_id = :user_uuid
|
||||
)
|
||||
"""),
|
||||
{'user_id': user_id, 'user_uuid': user_uuid},
|
||||
)
|
||||
|
||||
# Step 4: Delete conversation_metadata_saas entries
|
||||
session.execute(
|
||||
text('DELETE FROM conversation_metadata_saas WHERE user_id = :user_id'),
|
||||
{'user_id': user_uuid},
|
||||
)
|
||||
|
||||
# Step 5: Reset org_id columns in related tables
|
||||
# Reset stripe_customers
|
||||
session.execute(
|
||||
text(
|
||||
'UPDATE stripe_customers SET org_id = NULL WHERE org_id = :org_id'
|
||||
),
|
||||
{'org_id': user_uuid},
|
||||
)
|
||||
|
||||
# Reset slack_users
|
||||
session.execute(
|
||||
text('UPDATE slack_users SET org_id = NULL WHERE org_id = :org_id'),
|
||||
{'org_id': user_uuid},
|
||||
)
|
||||
|
||||
# Reset slack_conversation
|
||||
session.execute(
|
||||
text(
|
||||
'UPDATE slack_conversation SET org_id = NULL WHERE org_id = :org_id'
|
||||
),
|
||||
{'org_id': user_uuid},
|
||||
)
|
||||
|
||||
# Reset api_keys
|
||||
session.execute(
|
||||
text('UPDATE api_keys SET org_id = NULL WHERE org_id = :org_id'),
|
||||
{'org_id': user_uuid},
|
||||
)
|
||||
|
||||
# Reset custom_secrets
|
||||
session.execute(
|
||||
text('UPDATE custom_secrets SET org_id = NULL WHERE org_id = :org_id'),
|
||||
{'org_id': user_uuid},
|
||||
)
|
||||
|
||||
# Reset billing_sessions
|
||||
session.execute(
|
||||
text(
|
||||
'UPDATE billing_sessions SET org_id = NULL WHERE org_id = :org_id'
|
||||
),
|
||||
{'org_id': user_uuid},
|
||||
)
|
||||
|
||||
# Step 6: Delete org_member entries for this org
|
||||
session.execute(
|
||||
text('DELETE FROM org_member WHERE org_id = :org_id'),
|
||||
{'org_id': user_uuid},
|
||||
)
|
||||
|
||||
# Step 7: Delete the user entry
|
||||
session.execute(
|
||||
text('DELETE FROM "user" WHERE id = :user_id'),
|
||||
{'user_id': user_uuid},
|
||||
)
|
||||
|
||||
# Delete the org entry
|
||||
session.execute(
|
||||
text('DELETE FROM org WHERE id = :org_id'),
|
||||
{'org_id': user_uuid},
|
||||
)
|
||||
|
||||
# Step 8: Set already_migrated=False on user_settings and encrypt fields
|
||||
user_settings.already_migrated = False
|
||||
|
||||
# Re-encrypt the sensitive fields before storing in the DB
|
||||
encrypt_keys = [
|
||||
'llm_api_key',
|
||||
'llm_api_key_for_byor',
|
||||
'search_api_key',
|
||||
'sandbox_api_key',
|
||||
]
|
||||
for key in encrypt_keys:
|
||||
value = getattr(user_settings, key, None)
|
||||
if value is not None:
|
||||
setattr(user_settings, key, encrypt_legacy_value(value))
|
||||
|
||||
session.merge(user_settings)
|
||||
|
||||
session.commit()
|
||||
|
||||
logger.info(
|
||||
'user_store:downgrade_user:complete',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
return user_settings
|
||||
|
||||
@staticmethod
|
||||
def get_user_by_id(user_id: str) -> Optional[User]:
|
||||
"""Get user by Keycloak user ID (sync version).
|
||||
|
||||
Note: This method uses call_async_from_sync internally which creates a new
|
||||
event loop. If you're already in an async context, use get_user_by_id_async
|
||||
instead to avoid event loop conflicts.
|
||||
"""
|
||||
with session_maker() as session:
|
||||
user = (
|
||||
session.query(User)
|
||||
.options(joinedload(User.org_members))
|
||||
.filter(User.id == uuid.UUID(user_id))
|
||||
.first()
|
||||
)
|
||||
if user:
|
||||
return user
|
||||
|
||||
# Check if we need to migrate from user_settings
|
||||
while not call_async_from_sync(
|
||||
UserStore._acquire_user_creation_lock, GENERAL_TIMEOUT, user_id
|
||||
):
|
||||
# The user is already being created in another thread / process
|
||||
logger.info(
|
||||
'user_store:create_default_settings:waiting_for_lock',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
call_async_from_sync(
|
||||
asyncio.sleep, GENERAL_TIMEOUT, _RETRY_LOAD_DELAY_SECONDS
|
||||
)
|
||||
|
||||
# Check for user again as migration could have happened while trying to get the lock.
|
||||
user = (
|
||||
session.query(User)
|
||||
.options(joinedload(User.org_members))
|
||||
.filter(User.id == uuid.UUID(user_id))
|
||||
.first()
|
||||
)
|
||||
if user:
|
||||
return user
|
||||
|
||||
user_settings = (
|
||||
session.query(UserSettings)
|
||||
.filter(
|
||||
UserSettings.keycloak_user_id == user_id,
|
||||
UserSettings.already_migrated.is_(False),
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if user_settings:
|
||||
token_manager = TokenManager()
|
||||
user_info = call_async_from_sync(
|
||||
token_manager.get_user_info_from_user_id,
|
||||
GENERAL_TIMEOUT,
|
||||
user_id,
|
||||
)
|
||||
user = call_async_from_sync(
|
||||
UserStore.migrate_user,
|
||||
GENERAL_TIMEOUT,
|
||||
user_id,
|
||||
user_settings,
|
||||
user_info,
|
||||
)
|
||||
return user
|
||||
else:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
async def get_user_by_id_async(user_id: str) -> Optional[User]:
|
||||
"""Get user by Keycloak user ID (async version).
|
||||
|
||||
This is the preferred method when calling from an async context as it
|
||||
avoids event loop conflicts that can occur with the sync version.
|
||||
"""
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(User)
|
||||
.options(joinedload(User.org_members))
|
||||
.filter(User.id == uuid.UUID(user_id))
|
||||
)
|
||||
user = result.scalars().first()
|
||||
if user:
|
||||
return user
|
||||
|
||||
# Check if we need to migrate from user_settings
|
||||
while not await UserStore._acquire_user_creation_lock(user_id):
|
||||
# The user is already being created in another thread / process
|
||||
logger.info(
|
||||
'user_store:get_user_by_id_async:waiting_for_lock',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
await asyncio.sleep(_RETRY_LOAD_DELAY_SECONDS)
|
||||
|
||||
# Check for user again as migration could have happened while trying to get the lock.
|
||||
result = await session.execute(
|
||||
select(User)
|
||||
.options(joinedload(User.org_members))
|
||||
.filter(User.id == uuid.UUID(user_id))
|
||||
)
|
||||
user = result.scalars().first()
|
||||
if user:
|
||||
return user
|
||||
|
||||
logger.info(
|
||||
'user_store:get_user_by_id_async:start_migration',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
result = await session.execute(
|
||||
select(UserSettings).filter(
|
||||
UserSettings.keycloak_user_id == user_id,
|
||||
UserSettings.already_migrated.is_(False),
|
||||
)
|
||||
)
|
||||
user_settings = result.scalars().first()
|
||||
if user_settings:
|
||||
token_manager = TokenManager()
|
||||
user_info = await token_manager.get_user_info_from_user_id(user_id)
|
||||
logger.info(
|
||||
'user_store:get_user_by_id_async:calling_migrate_user',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
user = await UserStore.migrate_user(
|
||||
user_id,
|
||||
user_settings,
|
||||
user_info,
|
||||
)
|
||||
return user
|
||||
else:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def list_users() -> list[User]:
|
||||
"""List all users."""
|
||||
with session_maker() as session:
|
||||
return session.query(User).all()
|
||||
|
||||
# Prevent circular imports
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from openhands.storage.data_models.settings import Settings
|
||||
|
||||
@staticmethod
|
||||
async def create_default_settings(
|
||||
org_id: str, user_id: str, create_user: bool = True
|
||||
) -> Optional['Settings']:
|
||||
logger.info(
|
||||
'UserStore:create_default_settings:start',
|
||||
extra={'org_id': org_id, 'user_id': user_id},
|
||||
)
|
||||
# You must log in before you get default settings
|
||||
if not org_id:
|
||||
return None
|
||||
|
||||
from openhands.storage.data_models.settings import Settings
|
||||
|
||||
settings = Settings(language='en', enable_proactive_conversation_starters=True)
|
||||
|
||||
from storage.lite_llm_manager import LiteLlmManager
|
||||
|
||||
settings = await LiteLlmManager.create_entries(
|
||||
org_id, user_id, settings, create_user
|
||||
)
|
||||
if not settings:
|
||||
logger.info(
|
||||
'UserStore:create_default_settings:litellm_create_failed',
|
||||
extra={'org_id': org_id},
|
||||
)
|
||||
return None
|
||||
|
||||
return settings
|
||||
|
||||
@staticmethod
|
||||
def get_kwargs_from_settings(settings: 'Settings'):
|
||||
kwargs = {
|
||||
normalized: getattr(settings, normalized)
|
||||
for c in User.__table__.columns
|
||||
if (normalized := c.name.lstrip('_')) and hasattr(settings, normalized)
|
||||
}
|
||||
return kwargs
|
||||
|
||||
@staticmethod
|
||||
def get_kwargs_from_user_settings(user_settings: UserSettings):
|
||||
kwargs = {
|
||||
normalized: getattr(user_settings, normalized)
|
||||
for c in User.__table__.columns
|
||||
if (normalized := c.name.lstrip('_')) and hasattr(user_settings, normalized)
|
||||
}
|
||||
return kwargs
|
||||
|
||||
@staticmethod
|
||||
def _create_user_settings_from_entities(
|
||||
user_id: str, org_member: OrgMember, user: User, org: Org
|
||||
) -> UserSettings:
|
||||
"""Create UserSettings from OrgMember, User, and Org data.
|
||||
|
||||
Uses OrgMember values first. If an OrgMember field is None and there's
|
||||
a corresponding "default_" field in Org, use the Org value.
|
||||
Also pulls relevant fields from User.
|
||||
|
||||
Args:
|
||||
user_id: The Keycloak user ID
|
||||
org_member: The OrgMember entity
|
||||
user: The User entity
|
||||
org: The Org entity
|
||||
|
||||
Returns:
|
||||
A new UserSettings object populated from the entities
|
||||
"""
|
||||
# Mapping from OrgMember fields to corresponding Org "default_" fields
|
||||
org_member_to_org_default = {
|
||||
'llm_model': 'default_llm_model',
|
||||
'llm_base_url': 'default_llm_base_url',
|
||||
'max_iterations': 'default_max_iterations',
|
||||
}
|
||||
|
||||
def get_value_with_org_fallback(field_name: str, org_member_value):
|
||||
"""Get value from OrgMember, falling back to Org default if None."""
|
||||
if org_member_value is not None:
|
||||
return org_member_value
|
||||
org_default_field = org_member_to_org_default.get(field_name)
|
||||
if org_default_field and hasattr(org, org_default_field):
|
||||
return getattr(org, org_default_field)
|
||||
return None
|
||||
|
||||
# Get values from OrgMember with Org fallback for fields with default_ prefix
|
||||
llm_model = get_value_with_org_fallback('llm_model', org_member.llm_model)
|
||||
llm_base_url = get_value_with_org_fallback(
|
||||
'llm_base_url', org_member.llm_base_url
|
||||
)
|
||||
max_iterations = get_value_with_org_fallback(
|
||||
'max_iterations', org_member.max_iterations
|
||||
)
|
||||
|
||||
return UserSettings(
|
||||
keycloak_user_id=user_id,
|
||||
# OrgMember fields
|
||||
llm_api_key=org_member.llm_api_key.get_secret_value()
|
||||
if org_member.llm_api_key
|
||||
else None,
|
||||
llm_api_key_for_byor=org_member.llm_api_key_for_byor.get_secret_value()
|
||||
if org_member.llm_api_key_for_byor
|
||||
else None,
|
||||
llm_model=llm_model,
|
||||
llm_base_url=llm_base_url,
|
||||
max_iterations=max_iterations,
|
||||
# User fields
|
||||
accepted_tos=user.accepted_tos,
|
||||
enable_sound_notifications=user.enable_sound_notifications,
|
||||
language=user.language,
|
||||
user_consents_to_analytics=user.user_consents_to_analytics,
|
||||
email=user.email,
|
||||
email_verified=user.email_verified,
|
||||
git_user_name=user.git_user_name,
|
||||
git_user_email=user.git_user_email,
|
||||
# Org fields
|
||||
agent=org.agent,
|
||||
security_analyzer=org.security_analyzer,
|
||||
confirmation_mode=org.confirmation_mode,
|
||||
remote_runtime_resource_factor=org.remote_runtime_resource_factor,
|
||||
enable_default_condenser=org.enable_default_condenser,
|
||||
billing_margin=org.billing_margin,
|
||||
enable_proactive_conversation_starters=org.enable_proactive_conversation_starters,
|
||||
sandbox_base_container_image=org.sandbox_base_container_image,
|
||||
sandbox_runtime_container_image=org.sandbox_runtime_container_image,
|
||||
user_version=org.org_version,
|
||||
mcp_config=org.mcp_config,
|
||||
search_api_key=org.search_api_key.get_secret_value()
|
||||
if org.search_api_key
|
||||
else None,
|
||||
sandbox_api_key=org.sandbox_api_key.get_secret_value()
|
||||
if org.sandbox_api_key
|
||||
else None,
|
||||
max_budget_per_task=org.max_budget_per_task,
|
||||
enable_solvability_analysis=org.enable_solvability_analysis,
|
||||
v1_enabled=org.v1_enabled,
|
||||
condenser_max_size=org.condenser_max_size,
|
||||
already_migrated=False,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _has_custom_settings(
|
||||
user_settings: UserSettings, old_user_version: int | None
|
||||
) -> bool:
|
||||
"""
|
||||
Check if user has custom LLM settings that should be preserved.
|
||||
Returns True if user customized either model or base_url.
|
||||
|
||||
Args:
|
||||
settings: The user's current settings
|
||||
old_user_version: The user's old settings version, if any
|
||||
|
||||
Returns:
|
||||
True if user has custom settings, False if using old defaults
|
||||
"""
|
||||
# Normalize values
|
||||
user_model = (
|
||||
user_settings.llm_model.strip() or None if user_settings.llm_model else None
|
||||
)
|
||||
user_base_url = (
|
||||
user_settings.llm_base_url.strip() or None
|
||||
if user_settings.llm_base_url
|
||||
else None
|
||||
)
|
||||
|
||||
# Custom base_url = definitely custom settings (BYOK)
|
||||
if user_base_url and user_base_url != LITE_LLM_API_URL:
|
||||
return True
|
||||
|
||||
# No model set = using defaults
|
||||
if not user_model:
|
||||
return False
|
||||
|
||||
# Check if model matches old version's default
|
||||
if (
|
||||
old_user_version
|
||||
and old_user_version <= ORG_SETTINGS_VERSION
|
||||
and old_user_version in PERSONAL_WORKSPACE_VERSION_TO_MODEL
|
||||
):
|
||||
old_default_base = PERSONAL_WORKSPACE_VERSION_TO_MODEL[old_user_version]
|
||||
user_model_base = user_model.split('/')[-1]
|
||||
if user_model_base == old_default_base:
|
||||
return False # Matches old default
|
||||
|
||||
return True # Custom model
|
||||
@@ -21,7 +21,7 @@ from sqlalchemy import text
|
||||
# Add the parent directory to the path so we can import from storage
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
from server.auth.token_manager import get_keycloak_admin
|
||||
from storage.database import get_engine
|
||||
from storage.database import engine
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
@@ -85,7 +85,7 @@ def get_recent_conversations(minutes: int = 60) -> List[Dict[str, Any]]:
|
||||
created_at DESC
|
||||
""")
|
||||
|
||||
with get_engine().connect() as connection:
|
||||
with engine.connect() as connection:
|
||||
result = connection.execute(query, {'minutes': minutes})
|
||||
conversations = [
|
||||
{
|
||||
|
||||
@@ -13,7 +13,7 @@ from sqlalchemy import text
|
||||
# Add the parent directory to the path so we can import from storage
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from storage.database import get_engine
|
||||
from storage.database import engine
|
||||
|
||||
|
||||
def test_conversation_count_query():
|
||||
@@ -29,8 +29,6 @@ def test_conversation_count_query():
|
||||
user_id
|
||||
""")
|
||||
|
||||
engine = get_engine()
|
||||
|
||||
with engine.connect() as connection:
|
||||
count_result = connection.execute(count_query)
|
||||
user_counts = [
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
import pytest
|
||||
from server.constants import ORG_SETTINGS_VERSION
|
||||
from server.constants import CURRENT_USER_SETTINGS_VERSION
|
||||
from server.maintenance_task_processor.user_version_upgrade_processor import (
|
||||
UserVersionUpgradeProcessor,
|
||||
)
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from storage.base import Base
|
||||
@@ -14,16 +15,11 @@ from storage.conversation_work import ConversationWork
|
||||
from storage.device_code import DeviceCode # noqa: F401
|
||||
from storage.feedback import Feedback
|
||||
from storage.github_app_installation import GithubAppInstallation
|
||||
from storage.org import Org
|
||||
from storage.org_member import OrgMember
|
||||
from storage.role import Role
|
||||
from storage.maintenance_task import MaintenanceTask, MaintenanceTaskStatus
|
||||
from storage.stored_conversation_metadata import StoredConversationMetadata
|
||||
from storage.stored_conversation_metadata_saas import (
|
||||
StoredConversationMetadataSaas,
|
||||
)
|
||||
from storage.stored_offline_token import StoredOfflineToken
|
||||
from storage.stripe_customer import StripeCustomer
|
||||
from storage.user import User
|
||||
from storage.user_settings import UserSettings
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -72,6 +68,7 @@ def add_minimal_fixtures(session_maker):
|
||||
session.add(
|
||||
StoredConversationMetadata(
|
||||
conversation_id='mock-conversation-id',
|
||||
user_id='mock-user-id',
|
||||
created_at=datetime.fromisoformat('2025-03-07'),
|
||||
last_updated_at=datetime.fromisoformat('2025-03-08'),
|
||||
accumulated_cost=5.25,
|
||||
@@ -80,13 +77,6 @@ def add_minimal_fixtures(session_maker):
|
||||
total_tokens=750,
|
||||
)
|
||||
)
|
||||
session.add(
|
||||
StoredConversationMetadataSaas(
|
||||
conversation_id='mock-conversation-id',
|
||||
user_id=UUID('5594c7b6-f959-4b81-92e9-b09c206f5081'),
|
||||
org_id=UUID('5594c7b6-f959-4b81-92e9-b09c206f5081'),
|
||||
)
|
||||
)
|
||||
session.add(
|
||||
StoredOfflineToken(
|
||||
user_id='mock-user-id',
|
||||
@@ -95,38 +85,7 @@ def add_minimal_fixtures(session_maker):
|
||||
updated_at=datetime.fromisoformat('2025-03-08'),
|
||||
)
|
||||
)
|
||||
session.add(
|
||||
Org(
|
||||
id=uuid.UUID('5594c7b6-f959-4b81-92e9-b09c206f5081'),
|
||||
name='mock-org',
|
||||
org_version=ORG_SETTINGS_VERSION,
|
||||
enable_default_condenser=True,
|
||||
enable_proactive_conversation_starters=True,
|
||||
)
|
||||
)
|
||||
session.add(
|
||||
Role(
|
||||
id=1,
|
||||
name='admin',
|
||||
rank=1,
|
||||
)
|
||||
)
|
||||
session.add(
|
||||
User(
|
||||
id=uuid.UUID('5594c7b6-f959-4b81-92e9-b09c206f5081'),
|
||||
current_org_id=uuid.UUID('5594c7b6-f959-4b81-92e9-b09c206f5081'),
|
||||
user_consents_to_analytics=True,
|
||||
)
|
||||
)
|
||||
session.add(
|
||||
OrgMember(
|
||||
org_id=uuid.UUID('5594c7b6-f959-4b81-92e9-b09c206f5081'),
|
||||
user_id=uuid.UUID('5594c7b6-f959-4b81-92e9-b09c206f5081'),
|
||||
role_id=1,
|
||||
llm_api_key='mock-api-key',
|
||||
status='active',
|
||||
)
|
||||
)
|
||||
|
||||
session.add(
|
||||
StripeCustomer(
|
||||
keycloak_user_id='mock-user-id',
|
||||
@@ -135,6 +94,13 @@ def add_minimal_fixtures(session_maker):
|
||||
updated_at=datetime.fromisoformat('2025-03-10'),
|
||||
)
|
||||
)
|
||||
session.add(
|
||||
UserSettings(
|
||||
keycloak_user_id='mock-user-id',
|
||||
user_consents_to_analytics=True,
|
||||
user_version=CURRENT_USER_SETTINGS_VERSION,
|
||||
)
|
||||
)
|
||||
session.add(
|
||||
ConversationWork(
|
||||
conversation_id='mock-conversation-id',
|
||||
@@ -143,6 +109,17 @@ def add_minimal_fixtures(session_maker):
|
||||
updated_at=datetime.fromisoformat('2025-03-08'),
|
||||
)
|
||||
)
|
||||
maintenance_task = MaintenanceTask(
|
||||
status=MaintenanceTaskStatus.PENDING,
|
||||
)
|
||||
maintenance_task.set_processor(
|
||||
UserVersionUpgradeProcessor(
|
||||
user_ids=['mock-user-id'],
|
||||
created_at=datetime.fromisoformat('2025-03-07'),
|
||||
updated_at=datetime.fromisoformat('2025-03-08'),
|
||||
)
|
||||
)
|
||||
session.add(maintenance_task)
|
||||
session.commit()
|
||||
|
||||
|
||||
|
||||
@@ -6,13 +6,11 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from integrations.jira.jira_manager import JiraManager
|
||||
from integrations.jira.jira_payload import (
|
||||
JiraEventType,
|
||||
JiraWebhookPayload,
|
||||
)
|
||||
from integrations.jira.jira_view import (
|
||||
JiraExistingConversationView,
|
||||
JiraNewConversationView,
|
||||
)
|
||||
from integrations.models import JobContext
|
||||
from jinja2 import DictLoader, Environment
|
||||
from storage.jira_conversation import JiraConversation
|
||||
from storage.jira_user import JiraUser
|
||||
@@ -27,7 +25,7 @@ def mock_token_manager():
|
||||
"""Create a mock TokenManager for testing."""
|
||||
token_manager = MagicMock()
|
||||
token_manager.get_user_id_from_user_email = AsyncMock()
|
||||
token_manager.decrypt_text = MagicMock(return_value='decrypted_key')
|
||||
token_manager.decrypt_text = MagicMock()
|
||||
return token_manager
|
||||
|
||||
|
||||
@@ -62,7 +60,6 @@ def sample_jira_workspace():
|
||||
workspace = MagicMock(spec=JiraWorkspace)
|
||||
workspace.id = 1
|
||||
workspace.name = 'test.atlassian.net'
|
||||
workspace.jira_cloud_id = 'cloud-123'
|
||||
workspace.admin_user_id = 'admin_id'
|
||||
workspace.webhook_secret = 'encrypted_secret'
|
||||
workspace.svc_acc_email = 'service@example.com'
|
||||
@@ -78,41 +75,22 @@ def sample_user_auth():
|
||||
user_auth.get_provider_tokens = AsyncMock(return_value={})
|
||||
user_auth.get_access_token = AsyncMock(return_value='test_token')
|
||||
user_auth.get_user_id = AsyncMock(return_value='test_user_id')
|
||||
user_auth.get_secrets = AsyncMock(return_value=None)
|
||||
return user_auth
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_webhook_payload():
|
||||
"""Create a sample JiraWebhookPayload for testing."""
|
||||
return JiraWebhookPayload(
|
||||
event_type=JiraEventType.COMMENT_MENTION,
|
||||
raw_event='comment_created',
|
||||
def sample_job_context():
|
||||
"""Create a sample JobContext for testing."""
|
||||
return JobContext(
|
||||
issue_id='12345',
|
||||
issue_key='TEST-123',
|
||||
user_msg='Fix this bug @openhands',
|
||||
user_email='user@test.com',
|
||||
display_name='Test User',
|
||||
account_id='user123',
|
||||
workspace_name='test.atlassian.net',
|
||||
base_api_url='https://test.atlassian.net',
|
||||
comment_body='Fix this bug @openhands',
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_label_webhook_payload():
|
||||
"""Create a sample labeled ticket JiraWebhookPayload for testing."""
|
||||
return JiraWebhookPayload(
|
||||
event_type=JiraEventType.LABELED_TICKET,
|
||||
raw_event='jira:issue_updated',
|
||||
issue_id='12345',
|
||||
issue_key='PROJ-123',
|
||||
user_email='user@company.com',
|
||||
display_name='Test User',
|
||||
account_id='user456',
|
||||
workspace_name='jira.company.com',
|
||||
base_api_url='https://jira.company.com',
|
||||
comment_body='',
|
||||
issue_title='Test Issue',
|
||||
issue_description='This is a test issue',
|
||||
)
|
||||
|
||||
|
||||
@@ -203,17 +181,31 @@ def jira_conversation():
|
||||
|
||||
@pytest.fixture
|
||||
def new_conversation_view(
|
||||
sample_webhook_payload, sample_user_auth, sample_jira_user, sample_jira_workspace
|
||||
sample_job_context, sample_user_auth, sample_jira_user, sample_jira_workspace
|
||||
):
|
||||
"""JiraNewConversationView instance for testing"""
|
||||
return JiraNewConversationView(
|
||||
payload=sample_webhook_payload,
|
||||
job_context=sample_job_context,
|
||||
saas_user_auth=sample_user_auth,
|
||||
jira_user=sample_jira_user,
|
||||
jira_workspace=sample_jira_workspace,
|
||||
selected_repo='test/repo1',
|
||||
conversation_id='conv-123',
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def existing_conversation_view(
|
||||
sample_job_context, sample_user_auth, sample_jira_user, sample_jira_workspace
|
||||
):
|
||||
"""JiraExistingConversationView instance for testing"""
|
||||
return JiraExistingConversationView(
|
||||
job_context=sample_job_context,
|
||||
saas_user_auth=sample_user_auth,
|
||||
jira_user=sample_jira_user,
|
||||
jira_workspace=sample_jira_workspace,
|
||||
selected_repo='test/repo1',
|
||||
conversation_id='conv-123',
|
||||
_decrypted_api_key='decrypted_key',
|
||||
)
|
||||
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -5,87 +5,28 @@ Tests for Jira view classes and factory.
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from integrations.jira.jira_payload import (
|
||||
JiraEventType,
|
||||
JiraPayloadError,
|
||||
JiraPayloadParser,
|
||||
JiraPayloadSkipped,
|
||||
JiraPayloadSuccess,
|
||||
)
|
||||
from integrations.jira.jira_types import RepositoryNotFoundError, StartingConvoException
|
||||
from integrations.jira.jira_types import StartingConvoException
|
||||
from integrations.jira.jira_view import (
|
||||
JiraExistingConversationView,
|
||||
JiraFactory,
|
||||
JiraNewConversationView,
|
||||
)
|
||||
|
||||
from openhands.core.schema.agent import AgentState
|
||||
|
||||
|
||||
class TestJiraNewConversationView:
|
||||
"""Tests for JiraNewConversationView"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_issue_details_success(
|
||||
self, new_conversation_view, sample_jira_workspace
|
||||
):
|
||||
"""Test successful issue details retrieval."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
'fields': {'summary': 'Test Issue', 'description': 'Test description'}
|
||||
}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
with patch('httpx.AsyncClient') as mock_client:
|
||||
mock_client.return_value.__aenter__.return_value.get = AsyncMock(
|
||||
return_value=mock_response
|
||||
)
|
||||
|
||||
title, description = await new_conversation_view.get_issue_details()
|
||||
|
||||
assert title == 'Test Issue'
|
||||
assert description == 'Test description'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_issue_details_cached(self, new_conversation_view):
|
||||
"""Test issue details are cached after first call."""
|
||||
new_conversation_view._issue_title = 'Cached Title'
|
||||
new_conversation_view._issue_description = 'Cached Description'
|
||||
|
||||
title, description = await new_conversation_view.get_issue_details()
|
||||
|
||||
assert title == 'Cached Title'
|
||||
assert description == 'Cached Description'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_issue_details_no_title(self, new_conversation_view):
|
||||
"""Test issue details with no title raises error."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
'fields': {'summary': '', 'description': 'Test description'}
|
||||
}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
with patch('httpx.AsyncClient') as mock_client:
|
||||
mock_client.return_value.__aenter__.return_value.get = AsyncMock(
|
||||
return_value=mock_response
|
||||
)
|
||||
|
||||
with pytest.raises(StartingConvoException, match='does not have a title'):
|
||||
await new_conversation_view.get_issue_details()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_instructions(self, new_conversation_view, mock_jinja_env):
|
||||
"""Test _get_instructions method fetches issue details."""
|
||||
new_conversation_view._issue_title = 'Test Issue'
|
||||
new_conversation_view._issue_description = 'This is a test issue'
|
||||
|
||||
instructions, user_msg = await new_conversation_view._get_instructions(
|
||||
mock_jinja_env
|
||||
)
|
||||
def test_get_instructions(self, new_conversation_view, mock_jinja_env):
|
||||
"""Test _get_instructions method"""
|
||||
instructions, user_msg = new_conversation_view._get_instructions(mock_jinja_env)
|
||||
|
||||
assert instructions == 'Test Jira instructions template'
|
||||
assert 'TEST-123' in user_msg
|
||||
assert 'Test Issue' in user_msg
|
||||
assert 'Fix this bug @openhands' in user_msg
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('integrations.jira.jira_view.create_new_conversation')
|
||||
@patch('integrations.jira.jira_view.integration_store')
|
||||
async def test_create_or_update_conversation_success(
|
||||
@@ -97,8 +38,6 @@ class TestJiraNewConversationView:
|
||||
mock_agent_loop_info,
|
||||
):
|
||||
"""Test successful conversation creation"""
|
||||
new_conversation_view._issue_title = 'Test Issue'
|
||||
new_conversation_view._issue_description = 'Test description'
|
||||
mock_create_conversation.return_value = mock_agent_loop_info
|
||||
mock_store.create_conversation = AsyncMock()
|
||||
|
||||
@@ -110,7 +49,6 @@ class TestJiraNewConversationView:
|
||||
mock_create_conversation.assert_called_once()
|
||||
mock_store.create_conversation.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_or_update_conversation_no_repo(
|
||||
self, new_conversation_view, mock_jinja_env
|
||||
):
|
||||
@@ -120,6 +58,18 @@ class TestJiraNewConversationView:
|
||||
with pytest.raises(StartingConvoException, match='No repository selected'):
|
||||
await new_conversation_view.create_or_update_conversation(mock_jinja_env)
|
||||
|
||||
@patch('integrations.jira.jira_view.create_new_conversation')
|
||||
async def test_create_or_update_conversation_failure(
|
||||
self, mock_create_conversation, new_conversation_view, mock_jinja_env
|
||||
):
|
||||
"""Test conversation creation failure"""
|
||||
mock_create_conversation.side_effect = Exception('Creation failed')
|
||||
|
||||
with pytest.raises(
|
||||
StartingConvoException, match='Failed to create conversation'
|
||||
):
|
||||
await new_conversation_view.create_or_update_conversation(mock_jinja_env)
|
||||
|
||||
def test_get_response_msg(self, new_conversation_view):
|
||||
"""Test get_response_msg method"""
|
||||
response = new_conversation_view.get_response_msg()
|
||||
@@ -130,336 +80,344 @@ class TestJiraNewConversationView:
|
||||
assert 'conv-123' in response
|
||||
|
||||
|
||||
class TestJiraExistingConversationView:
|
||||
"""Tests for JiraExistingConversationView"""
|
||||
|
||||
def test_get_instructions(self, existing_conversation_view, mock_jinja_env):
|
||||
"""Test _get_instructions method"""
|
||||
instructions, user_msg = existing_conversation_view._get_instructions(
|
||||
mock_jinja_env
|
||||
)
|
||||
|
||||
assert instructions == ''
|
||||
assert 'TEST-123' in user_msg
|
||||
assert 'Test Issue' in user_msg
|
||||
assert 'Fix this bug @openhands' in user_msg
|
||||
|
||||
@patch('integrations.jira.jira_view.ConversationStoreImpl.get_instance')
|
||||
@patch('integrations.jira.jira_view.setup_init_conversation_settings')
|
||||
@patch('integrations.jira.jira_view.conversation_manager')
|
||||
@patch('integrations.jira.jira_view.get_final_agent_observation')
|
||||
async def test_create_or_update_conversation_success(
|
||||
self,
|
||||
mock_get_observation,
|
||||
mock_conversation_manager,
|
||||
mock_setup_init,
|
||||
mock_store_impl,
|
||||
existing_conversation_view,
|
||||
mock_jinja_env,
|
||||
mock_conversation_store,
|
||||
mock_conversation_init_data,
|
||||
mock_agent_loop_info,
|
||||
):
|
||||
"""Test successful existing conversation update"""
|
||||
# Setup mocks
|
||||
mock_store_impl.return_value = mock_conversation_store
|
||||
mock_setup_init.return_value = mock_conversation_init_data
|
||||
mock_conversation_manager.maybe_start_agent_loop = AsyncMock(
|
||||
return_value=mock_agent_loop_info
|
||||
)
|
||||
mock_conversation_manager.send_event_to_conversation = AsyncMock()
|
||||
|
||||
# Mock agent observation with RUNNING state
|
||||
mock_observation = MagicMock()
|
||||
mock_observation.agent_state = AgentState.RUNNING
|
||||
mock_get_observation.return_value = [mock_observation]
|
||||
|
||||
result = await existing_conversation_view.create_or_update_conversation(
|
||||
mock_jinja_env
|
||||
)
|
||||
|
||||
assert result == 'conv-123'
|
||||
mock_conversation_manager.send_event_to_conversation.assert_called_once()
|
||||
|
||||
@patch('integrations.jira.jira_view.ConversationStoreImpl.get_instance')
|
||||
async def test_create_or_update_conversation_no_metadata(
|
||||
self, mock_store_impl, existing_conversation_view, mock_jinja_env
|
||||
):
|
||||
"""Test conversation update with no metadata"""
|
||||
mock_store = AsyncMock()
|
||||
mock_store.get_metadata.side_effect = FileNotFoundError(
|
||||
'No such file or directory'
|
||||
)
|
||||
mock_store_impl.return_value = mock_store
|
||||
|
||||
with pytest.raises(
|
||||
StartingConvoException, match='Conversation no longer exists'
|
||||
):
|
||||
await existing_conversation_view.create_or_update_conversation(
|
||||
mock_jinja_env
|
||||
)
|
||||
|
||||
@patch('integrations.jira.jira_view.ConversationStoreImpl.get_instance')
|
||||
@patch('integrations.jira.jira_view.setup_init_conversation_settings')
|
||||
@patch('integrations.jira.jira_view.conversation_manager')
|
||||
@patch('integrations.jira.jira_view.get_final_agent_observation')
|
||||
async def test_create_or_update_conversation_loading_state(
|
||||
self,
|
||||
mock_get_observation,
|
||||
mock_conversation_manager,
|
||||
mock_setup_init,
|
||||
mock_store_impl,
|
||||
existing_conversation_view,
|
||||
mock_jinja_env,
|
||||
mock_conversation_store,
|
||||
mock_conversation_init_data,
|
||||
mock_agent_loop_info,
|
||||
):
|
||||
"""Test conversation update with loading state"""
|
||||
mock_store_impl.return_value = mock_conversation_store
|
||||
mock_setup_init.return_value = mock_conversation_init_data
|
||||
mock_conversation_manager.maybe_start_agent_loop = AsyncMock(
|
||||
return_value=mock_agent_loop_info
|
||||
)
|
||||
|
||||
# Mock agent observation with LOADING state
|
||||
mock_observation = MagicMock()
|
||||
mock_observation.agent_state = AgentState.LOADING
|
||||
mock_get_observation.return_value = [mock_observation]
|
||||
|
||||
with pytest.raises(
|
||||
StartingConvoException, match='Conversation is still starting'
|
||||
):
|
||||
await existing_conversation_view.create_or_update_conversation(
|
||||
mock_jinja_env
|
||||
)
|
||||
|
||||
@patch('integrations.jira.jira_view.ConversationStoreImpl.get_instance')
|
||||
async def test_create_or_update_conversation_failure(
|
||||
self, mock_store_impl, existing_conversation_view, mock_jinja_env
|
||||
):
|
||||
"""Test conversation update failure"""
|
||||
mock_store_impl.side_effect = Exception('Store error')
|
||||
|
||||
with pytest.raises(
|
||||
StartingConvoException, match='Failed to create conversation'
|
||||
):
|
||||
await existing_conversation_view.create_or_update_conversation(
|
||||
mock_jinja_env
|
||||
)
|
||||
|
||||
def test_get_response_msg(self, existing_conversation_view):
|
||||
"""Test get_response_msg method"""
|
||||
response = existing_conversation_view.get_response_msg()
|
||||
|
||||
assert "I'm on it!" in response
|
||||
assert 'Test User' in response
|
||||
assert 'continue tracking my progress here' in response
|
||||
assert 'conv-123' in response
|
||||
|
||||
|
||||
class TestJiraFactory:
|
||||
"""Tests for JiraFactory"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('integrations.jira.jira_view.JiraFactory._create_provider_handler')
|
||||
@patch('integrations.jira.jira_view.infer_repo_from_message')
|
||||
async def test_create_view_success(
|
||||
@patch('integrations.jira.jira_view.integration_store')
|
||||
async def test_create_jira_view_from_payload_existing_conversation(
|
||||
self,
|
||||
mock_infer_repos,
|
||||
mock_create_handler,
|
||||
sample_webhook_payload,
|
||||
mock_store,
|
||||
sample_job_context,
|
||||
sample_user_auth,
|
||||
sample_jira_user,
|
||||
sample_jira_workspace,
|
||||
sample_repositories,
|
||||
jira_conversation,
|
||||
):
|
||||
"""Test factory creating view with repo selection."""
|
||||
# Setup mock provider handler
|
||||
mock_handler = MagicMock()
|
||||
mock_handler.verify_repo_provider = AsyncMock(
|
||||
return_value=sample_repositories[0]
|
||||
)
|
||||
mock_create_handler.return_value = mock_handler
|
||||
|
||||
# Mock repo inference to return a repo name
|
||||
mock_infer_repos.return_value = ['test/repo1']
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
'fields': {'summary': 'Test Issue', 'description': 'Test description'}
|
||||
}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
with patch('httpx.AsyncClient') as mock_client:
|
||||
mock_client.return_value.__aenter__.return_value.get = AsyncMock(
|
||||
return_value=mock_response
|
||||
)
|
||||
|
||||
view = await JiraFactory.create_view(
|
||||
payload=sample_webhook_payload,
|
||||
workspace=sample_jira_workspace,
|
||||
user=sample_jira_user,
|
||||
user_auth=sample_user_auth,
|
||||
decrypted_api_key='test_api_key',
|
||||
)
|
||||
|
||||
assert isinstance(view, JiraNewConversationView)
|
||||
assert view.selected_repo == 'test/repo1'
|
||||
mock_handler.verify_repo_provider.assert_called_once_with('test/repo1')
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('integrations.jira.jira_view.JiraFactory._create_provider_handler')
|
||||
@patch('integrations.jira.jira_view.infer_repo_from_message')
|
||||
async def test_create_view_no_repo_in_text(
|
||||
self,
|
||||
mock_infer_repos,
|
||||
mock_create_handler,
|
||||
sample_webhook_payload,
|
||||
sample_user_auth,
|
||||
sample_jira_user,
|
||||
sample_jira_workspace,
|
||||
):
|
||||
"""Test factory raises error when no repo mentioned in text."""
|
||||
mock_handler = MagicMock()
|
||||
mock_create_handler.return_value = mock_handler
|
||||
|
||||
# No repos found in text
|
||||
mock_infer_repos.return_value = []
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
'fields': {'summary': 'Test Issue', 'description': 'Test description'}
|
||||
}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
with patch('httpx.AsyncClient') as mock_client:
|
||||
mock_client.return_value.__aenter__.return_value.get = AsyncMock(
|
||||
return_value=mock_response
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
RepositoryNotFoundError, match='Could not determine which repository'
|
||||
):
|
||||
await JiraFactory.create_view(
|
||||
payload=sample_webhook_payload,
|
||||
workspace=sample_jira_workspace,
|
||||
user=sample_jira_user,
|
||||
user_auth=sample_user_auth,
|
||||
decrypted_api_key='test_api_key',
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('integrations.jira.jira_view.JiraFactory._create_provider_handler')
|
||||
@patch('integrations.jira.jira_view.infer_repo_from_message')
|
||||
async def test_create_view_repo_verification_fails(
|
||||
self,
|
||||
mock_infer_repos,
|
||||
mock_create_handler,
|
||||
sample_webhook_payload,
|
||||
sample_user_auth,
|
||||
sample_jira_user,
|
||||
sample_jira_workspace,
|
||||
):
|
||||
"""Test factory raises error when repo verification fails."""
|
||||
mock_handler = MagicMock()
|
||||
mock_handler.verify_repo_provider = AsyncMock(
|
||||
side_effect=Exception('Repository not found')
|
||||
)
|
||||
mock_create_handler.return_value = mock_handler
|
||||
|
||||
# Repos found in text but verification fails
|
||||
mock_infer_repos.return_value = ['test/repo1', 'test/repo2']
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
'fields': {'summary': 'Test Issue', 'description': 'Test description'}
|
||||
}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
with patch('httpx.AsyncClient') as mock_client:
|
||||
mock_client.return_value.__aenter__.return_value.get = AsyncMock(
|
||||
return_value=mock_response
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
RepositoryNotFoundError,
|
||||
match='Could not access any of the mentioned repositories',
|
||||
):
|
||||
await JiraFactory.create_view(
|
||||
payload=sample_webhook_payload,
|
||||
workspace=sample_jira_workspace,
|
||||
user=sample_jira_user,
|
||||
user_auth=sample_user_auth,
|
||||
decrypted_api_key='test_api_key',
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('integrations.jira.jira_view.JiraFactory._create_provider_handler')
|
||||
@patch('integrations.jira.jira_view.infer_repo_from_message')
|
||||
async def test_create_view_multiple_repos_verified(
|
||||
self,
|
||||
mock_infer_repos,
|
||||
mock_create_handler,
|
||||
sample_webhook_payload,
|
||||
sample_user_auth,
|
||||
sample_jira_user,
|
||||
sample_jira_workspace,
|
||||
sample_repositories,
|
||||
):
|
||||
"""Test factory raises error when multiple repos are verified."""
|
||||
mock_handler = MagicMock()
|
||||
# Both repos verify successfully
|
||||
mock_handler.verify_repo_provider = AsyncMock(
|
||||
side_effect=[sample_repositories[0], sample_repositories[1]]
|
||||
)
|
||||
mock_create_handler.return_value = mock_handler
|
||||
|
||||
# Multiple repos found in text
|
||||
mock_infer_repos.return_value = ['test/repo1', 'test/repo2']
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
'fields': {'summary': 'Test Issue', 'description': 'Test description'}
|
||||
}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
with patch('httpx.AsyncClient') as mock_client:
|
||||
mock_client.return_value.__aenter__.return_value.get = AsyncMock(
|
||||
return_value=mock_response
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
RepositoryNotFoundError, match='Multiple repositories found'
|
||||
):
|
||||
await JiraFactory.create_view(
|
||||
payload=sample_webhook_payload,
|
||||
workspace=sample_jira_workspace,
|
||||
user=sample_jira_user,
|
||||
user_auth=sample_user_auth,
|
||||
decrypted_api_key='test_api_key',
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('integrations.jira.jira_view.JiraFactory._create_provider_handler')
|
||||
async def test_create_view_no_provider(
|
||||
self,
|
||||
mock_create_handler,
|
||||
sample_webhook_payload,
|
||||
sample_user_auth,
|
||||
sample_jira_user,
|
||||
sample_jira_workspace,
|
||||
):
|
||||
"""Test factory raises error when no provider is connected."""
|
||||
mock_create_handler.return_value = None
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
'fields': {'summary': 'Test Issue', 'description': 'Test description'}
|
||||
}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
with patch('httpx.AsyncClient') as mock_client:
|
||||
mock_client.return_value.__aenter__.return_value.get = AsyncMock(
|
||||
return_value=mock_response
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
RepositoryNotFoundError, match='No Git provider connected'
|
||||
):
|
||||
await JiraFactory.create_view(
|
||||
payload=sample_webhook_payload,
|
||||
workspace=sample_jira_workspace,
|
||||
user=sample_jira_user,
|
||||
user_auth=sample_user_auth,
|
||||
decrypted_api_key='test_api_key',
|
||||
)
|
||||
|
||||
|
||||
class TestJiraPayloadParser:
|
||||
"""Tests for JiraPayloadParser"""
|
||||
|
||||
@pytest.fixture
|
||||
def parser(self):
|
||||
"""Create a parser for testing."""
|
||||
return JiraPayloadParser(oh_label='openhands', inline_oh_label='@openhands')
|
||||
|
||||
def test_parse_label_event_success(
|
||||
self, parser, sample_issue_update_webhook_payload
|
||||
):
|
||||
"""Test parsing label event."""
|
||||
result = parser.parse(sample_issue_update_webhook_payload)
|
||||
|
||||
assert isinstance(result, JiraPayloadSuccess)
|
||||
assert result.payload.event_type == JiraEventType.LABELED_TICKET
|
||||
assert result.payload.issue_key == 'PROJ-123'
|
||||
|
||||
def test_parse_comment_event_success(self, parser, sample_comment_webhook_payload):
|
||||
"""Test parsing comment event."""
|
||||
result = parser.parse(sample_comment_webhook_payload)
|
||||
|
||||
assert isinstance(result, JiraPayloadSuccess)
|
||||
assert result.payload.event_type == JiraEventType.COMMENT_MENTION
|
||||
assert result.payload.issue_key == 'TEST-123'
|
||||
assert '@openhands' in result.payload.comment_body
|
||||
|
||||
def test_parse_unknown_event_skipped(self, parser):
|
||||
"""Test unknown event is skipped."""
|
||||
payload = {'webhookEvent': 'unknown_event'}
|
||||
result = parser.parse(payload)
|
||||
|
||||
assert isinstance(result, JiraPayloadSkipped)
|
||||
assert 'Unhandled webhook event type' in result.skip_reason
|
||||
|
||||
def test_parse_label_event_wrong_label_skipped(self, parser):
|
||||
"""Test label event without OH label is skipped."""
|
||||
payload = {
|
||||
'webhookEvent': 'jira:issue_updated',
|
||||
'changelog': {'items': [{'field': 'labels', 'toString': 'other-label'}]},
|
||||
}
|
||||
result = parser.parse(payload)
|
||||
|
||||
assert isinstance(result, JiraPayloadSkipped)
|
||||
assert 'does not contain' in result.skip_reason
|
||||
|
||||
def test_parse_comment_event_no_mention_skipped(self, parser):
|
||||
"""Test comment without mention is skipped."""
|
||||
payload = {
|
||||
'webhookEvent': 'comment_created',
|
||||
'comment': {
|
||||
'body': 'Regular comment',
|
||||
'author': {'emailAddress': 'test@test.com'},
|
||||
},
|
||||
}
|
||||
result = parser.parse(payload)
|
||||
|
||||
assert isinstance(result, JiraPayloadSkipped)
|
||||
assert 'does not mention' in result.skip_reason
|
||||
|
||||
def test_parse_missing_fields_error(self, parser):
|
||||
"""Test missing required fields returns error."""
|
||||
payload = {
|
||||
'webhookEvent': 'jira:issue_updated',
|
||||
'changelog': {'items': [{'field': 'labels', 'toString': 'openhands'}]},
|
||||
'issue': {'id': '123'}, # Missing key
|
||||
'user': {'emailAddress': 'test@test.com'}, # Missing other fields
|
||||
}
|
||||
result = parser.parse(payload)
|
||||
|
||||
assert isinstance(result, JiraPayloadError)
|
||||
assert 'Missing required fields' in result.error
|
||||
|
||||
|
||||
class TestJiraPayloadParserStagingLabels:
|
||||
"""Tests for JiraPayloadParser with staging labels."""
|
||||
|
||||
@pytest.fixture
|
||||
def staging_parser(self):
|
||||
"""Create a parser with staging labels."""
|
||||
return JiraPayloadParser(
|
||||
oh_label='openhands-exp', inline_oh_label='@openhands-exp'
|
||||
"""Test factory creating existing conversation view"""
|
||||
mock_store.get_user_conversations_by_issue_id = AsyncMock(
|
||||
return_value=jira_conversation
|
||||
)
|
||||
|
||||
def test_parse_staging_label(self, staging_parser):
|
||||
"""Test parsing with staging label."""
|
||||
payload = {
|
||||
'webhookEvent': 'jira:issue_updated',
|
||||
'changelog': {'items': [{'field': 'labels', 'toString': 'openhands-exp'}]},
|
||||
'issue': {
|
||||
'id': '123',
|
||||
'key': 'TEST-1',
|
||||
'self': 'https://test.atlassian.net/rest/api/2/issue/123',
|
||||
},
|
||||
'user': {
|
||||
'emailAddress': 'test@test.com',
|
||||
'displayName': 'Test',
|
||||
'accountId': 'acc123',
|
||||
'self': 'https://test.atlassian.net/rest/api/2/user',
|
||||
},
|
||||
}
|
||||
result = staging_parser.parse(payload)
|
||||
view = await JiraFactory.create_jira_view_from_payload(
|
||||
sample_job_context,
|
||||
sample_user_auth,
|
||||
sample_jira_user,
|
||||
sample_jira_workspace,
|
||||
)
|
||||
|
||||
assert isinstance(result, JiraPayloadSuccess)
|
||||
assert result.payload.event_type == JiraEventType.LABELED_TICKET
|
||||
assert isinstance(view, JiraExistingConversationView)
|
||||
assert view.conversation_id == 'conv-123'
|
||||
|
||||
def test_parse_prod_label_in_staging_skipped(self, staging_parser):
|
||||
"""Test prod label is skipped in staging environment."""
|
||||
payload = {
|
||||
'webhookEvent': 'jira:issue_updated',
|
||||
'changelog': {'items': [{'field': 'labels', 'toString': 'openhands'}]},
|
||||
}
|
||||
result = staging_parser.parse(payload)
|
||||
@patch('integrations.jira.jira_view.integration_store')
|
||||
async def test_create_jira_view_from_payload_new_conversation(
|
||||
self,
|
||||
mock_store,
|
||||
sample_job_context,
|
||||
sample_user_auth,
|
||||
sample_jira_user,
|
||||
sample_jira_workspace,
|
||||
):
|
||||
"""Test factory creating new conversation view"""
|
||||
mock_store.get_user_conversations_by_issue_id = AsyncMock(return_value=None)
|
||||
|
||||
assert isinstance(result, JiraPayloadSkipped)
|
||||
view = await JiraFactory.create_jira_view_from_payload(
|
||||
sample_job_context,
|
||||
sample_user_auth,
|
||||
sample_jira_user,
|
||||
sample_jira_workspace,
|
||||
)
|
||||
|
||||
assert isinstance(view, JiraNewConversationView)
|
||||
assert view.conversation_id == ''
|
||||
|
||||
async def test_create_jira_view_from_payload_no_user(
|
||||
self, sample_job_context, sample_user_auth, sample_jira_workspace
|
||||
):
|
||||
"""Test factory with no Jira user"""
|
||||
with pytest.raises(StartingConvoException, match='User not authenticated'):
|
||||
await JiraFactory.create_jira_view_from_payload(
|
||||
sample_job_context,
|
||||
sample_user_auth,
|
||||
None,
|
||||
sample_jira_workspace, # type: ignore
|
||||
)
|
||||
|
||||
async def test_create_jira_view_from_payload_no_auth(
|
||||
self, sample_job_context, sample_jira_user, sample_jira_workspace
|
||||
):
|
||||
"""Test factory with no SaaS auth"""
|
||||
with pytest.raises(StartingConvoException, match='User not authenticated'):
|
||||
await JiraFactory.create_jira_view_from_payload(
|
||||
sample_job_context,
|
||||
None,
|
||||
sample_jira_user,
|
||||
sample_jira_workspace, # type: ignore
|
||||
)
|
||||
|
||||
async def test_create_jira_view_from_payload_no_workspace(
|
||||
self, sample_job_context, sample_user_auth, sample_jira_user
|
||||
):
|
||||
"""Test factory with no workspace"""
|
||||
with pytest.raises(StartingConvoException, match='User not authenticated'):
|
||||
await JiraFactory.create_jira_view_from_payload(
|
||||
sample_job_context,
|
||||
sample_user_auth,
|
||||
sample_jira_user,
|
||||
None, # type: ignore
|
||||
)
|
||||
|
||||
|
||||
class TestJiraViewEdgeCases:
|
||||
"""Tests for edge cases and error scenarios"""
|
||||
|
||||
@patch('integrations.jira.jira_view.create_new_conversation')
|
||||
@patch('integrations.jira.jira_view.integration_store')
|
||||
async def test_conversation_creation_with_no_user_secrets(
|
||||
self,
|
||||
mock_store,
|
||||
mock_create_conversation,
|
||||
new_conversation_view,
|
||||
mock_jinja_env,
|
||||
mock_agent_loop_info,
|
||||
):
|
||||
"""Test conversation creation when user has no secrets"""
|
||||
new_conversation_view.saas_user_auth.get_secrets.return_value = None
|
||||
mock_create_conversation.return_value = mock_agent_loop_info
|
||||
mock_store.create_conversation = AsyncMock()
|
||||
|
||||
result = await new_conversation_view.create_or_update_conversation(
|
||||
mock_jinja_env
|
||||
)
|
||||
|
||||
assert result == 'conv-123'
|
||||
# Verify create_new_conversation was called with custom_secrets=None
|
||||
call_kwargs = mock_create_conversation.call_args[1]
|
||||
assert call_kwargs['custom_secrets'] is None
|
||||
|
||||
@patch('integrations.jira.jira_view.create_new_conversation')
|
||||
@patch('integrations.jira.jira_view.integration_store')
|
||||
async def test_conversation_creation_store_failure(
|
||||
self,
|
||||
mock_store,
|
||||
mock_create_conversation,
|
||||
new_conversation_view,
|
||||
mock_jinja_env,
|
||||
mock_agent_loop_info,
|
||||
):
|
||||
"""Test conversation creation when store creation fails"""
|
||||
mock_create_conversation.return_value = mock_agent_loop_info
|
||||
mock_store.create_conversation = AsyncMock(side_effect=Exception('Store error'))
|
||||
|
||||
with pytest.raises(
|
||||
StartingConvoException, match='Failed to create conversation'
|
||||
):
|
||||
await new_conversation_view.create_or_update_conversation(mock_jinja_env)
|
||||
|
||||
@patch('integrations.jira.jira_view.ConversationStoreImpl.get_instance')
|
||||
@patch('integrations.jira.jira_view.setup_init_conversation_settings')
|
||||
@patch('integrations.jira.jira_view.conversation_manager')
|
||||
@patch('integrations.jira.jira_view.get_final_agent_observation')
|
||||
async def test_existing_conversation_empty_observations(
|
||||
self,
|
||||
mock_get_observation,
|
||||
mock_conversation_manager,
|
||||
mock_setup_init,
|
||||
mock_store_impl,
|
||||
existing_conversation_view,
|
||||
mock_jinja_env,
|
||||
mock_conversation_store,
|
||||
mock_conversation_init_data,
|
||||
mock_agent_loop_info,
|
||||
):
|
||||
"""Test existing conversation with empty observations"""
|
||||
mock_store_impl.return_value = mock_conversation_store
|
||||
mock_setup_init.return_value = mock_conversation_init_data
|
||||
mock_conversation_manager.maybe_start_agent_loop = AsyncMock(
|
||||
return_value=mock_agent_loop_info
|
||||
)
|
||||
mock_get_observation.return_value = [] # Empty observations
|
||||
|
||||
with pytest.raises(
|
||||
StartingConvoException, match='Conversation is still starting'
|
||||
):
|
||||
await existing_conversation_view.create_or_update_conversation(
|
||||
mock_jinja_env
|
||||
)
|
||||
|
||||
def test_new_conversation_view_attributes(self, new_conversation_view):
|
||||
"""Test new conversation view attribute access"""
|
||||
assert new_conversation_view.job_context.issue_key == 'TEST-123'
|
||||
assert new_conversation_view.selected_repo == 'test/repo1'
|
||||
assert new_conversation_view.conversation_id == 'conv-123'
|
||||
|
||||
def test_existing_conversation_view_attributes(self, existing_conversation_view):
|
||||
"""Test existing conversation view attribute access"""
|
||||
assert existing_conversation_view.job_context.issue_key == 'TEST-123'
|
||||
assert existing_conversation_view.selected_repo == 'test/repo1'
|
||||
assert existing_conversation_view.conversation_id == 'conv-123'
|
||||
|
||||
@patch('integrations.jira.jira_view.ConversationStoreImpl.get_instance')
|
||||
@patch('integrations.jira.jira_view.setup_init_conversation_settings')
|
||||
@patch('integrations.jira.jira_view.conversation_manager')
|
||||
@patch('integrations.jira.jira_view.get_final_agent_observation')
|
||||
async def test_existing_conversation_message_send_failure(
|
||||
self,
|
||||
mock_get_observation,
|
||||
mock_conversation_manager,
|
||||
mock_setup_init,
|
||||
mock_store_impl,
|
||||
existing_conversation_view,
|
||||
mock_jinja_env,
|
||||
mock_conversation_store,
|
||||
mock_conversation_init_data,
|
||||
mock_agent_loop_info,
|
||||
):
|
||||
"""Test existing conversation when message sending fails"""
|
||||
mock_store_impl.return_value = mock_conversation_store
|
||||
mock_setup_init.return_value = mock_conversation_init_data
|
||||
mock_conversation_manager.maybe_start_agent_loop = AsyncMock(
|
||||
return_value=mock_agent_loop_info
|
||||
)
|
||||
mock_conversation_manager.send_event_to_conversation = AsyncMock(
|
||||
side_effect=Exception('Send error')
|
||||
)
|
||||
|
||||
# Mock agent observation with RUNNING state
|
||||
mock_observation = MagicMock()
|
||||
mock_observation.agent_state = AgentState.RUNNING
|
||||
mock_get_observation.return_value = [mock_observation]
|
||||
|
||||
with pytest.raises(
|
||||
StartingConvoException, match='Failed to create conversation'
|
||||
):
|
||||
await existing_conversation_view.create_or_update_conversation(
|
||||
mock_jinja_env
|
||||
)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""Test for ResolverUserContext get_secrets and get_latest_token logic.
|
||||
"""Test for ResolverUserContext get_secrets conversion logic.
|
||||
|
||||
This test focuses on testing the actual ResolverUserContext implementation.
|
||||
"""
|
||||
@@ -12,8 +12,7 @@ from pydantic import SecretStr
|
||||
from enterprise.integrations.resolver_context import ResolverUserContext
|
||||
|
||||
# Import the real classes we want to test
|
||||
from openhands.integrations.provider import CustomSecret, ProviderToken
|
||||
from openhands.integrations.service_types import ProviderType
|
||||
from openhands.integrations.provider import CustomSecret
|
||||
|
||||
# Import the SDK types we need for testing
|
||||
from openhands.sdk.secret import SecretSource, StaticSecret
|
||||
@@ -132,135 +131,3 @@ def test_custom_to_static_conversion():
|
||||
assert isinstance(static_secret, StaticSecret)
|
||||
assert isinstance(static_secret, SecretSource)
|
||||
assert static_secret.value.get_secret_value() == secret_value
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests for get_latest_token - ensuring string values are returned
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def create_provider_tokens(
|
||||
tokens_dict: dict[ProviderType, str],
|
||||
) -> dict[ProviderType, ProviderToken]:
|
||||
"""Helper to create provider tokens dictionary."""
|
||||
return {
|
||||
provider_type: ProviderToken(token=SecretStr(token_value))
|
||||
for provider_type, token_value in tokens_dict.items()
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_latest_token_returns_string(resolver_context, mock_saas_user_auth):
|
||||
"""Test that get_latest_token returns a string, not a ProviderToken object."""
|
||||
# Arrange
|
||||
token_value = 'ghp_test_github_token_123'
|
||||
provider_tokens = create_provider_tokens({ProviderType.GITHUB: token_value})
|
||||
mock_saas_user_auth.get_provider_tokens = AsyncMock(return_value=provider_tokens)
|
||||
|
||||
# Act
|
||||
result = await resolver_context.get_latest_token(ProviderType.GITHUB)
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert isinstance(result, str), (
|
||||
f'Expected str, got {type(result).__name__}. '
|
||||
'get_latest_token must return a string for StaticSecret compatibility.'
|
||||
)
|
||||
assert result == token_value
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_latest_token_returns_string_for_multiple_providers(
|
||||
resolver_context, mock_saas_user_auth
|
||||
):
|
||||
"""Test that get_latest_token returns strings for all provider types."""
|
||||
# Arrange
|
||||
provider_tokens = create_provider_tokens(
|
||||
{
|
||||
ProviderType.GITHUB: 'ghp_github_token',
|
||||
ProviderType.GITLAB: 'glpat_gitlab_token',
|
||||
ProviderType.BITBUCKET: 'bitbucket_token',
|
||||
}
|
||||
)
|
||||
mock_saas_user_auth.get_provider_tokens = AsyncMock(return_value=provider_tokens)
|
||||
|
||||
# Act & Assert - verify each provider returns a string
|
||||
for provider_type, expected_token in [
|
||||
(ProviderType.GITHUB, 'ghp_github_token'),
|
||||
(ProviderType.GITLAB, 'glpat_gitlab_token'),
|
||||
(ProviderType.BITBUCKET, 'bitbucket_token'),
|
||||
]:
|
||||
result = await resolver_context.get_latest_token(provider_type)
|
||||
assert isinstance(
|
||||
result, str
|
||||
), f'Expected str for {provider_type.name}, got {type(result).__name__}'
|
||||
assert result == expected_token
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_latest_token_returns_none_for_missing_provider(
|
||||
resolver_context, mock_saas_user_auth
|
||||
):
|
||||
"""Test that get_latest_token returns None when provider is not in tokens."""
|
||||
# Arrange - only GitHub token available
|
||||
provider_tokens = create_provider_tokens({ProviderType.GITHUB: 'ghp_token'})
|
||||
mock_saas_user_auth.get_provider_tokens = AsyncMock(return_value=provider_tokens)
|
||||
|
||||
# Act - request GitLab token which doesn't exist
|
||||
result = await resolver_context.get_latest_token(ProviderType.GITLAB)
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_latest_token_returns_none_when_no_provider_tokens(
|
||||
resolver_context, mock_saas_user_auth
|
||||
):
|
||||
"""Test that get_latest_token returns None when no provider tokens exist."""
|
||||
# Arrange
|
||||
mock_saas_user_auth.get_provider_tokens = AsyncMock(return_value=None)
|
||||
|
||||
# Act
|
||||
result = await resolver_context.get_latest_token(ProviderType.GITHUB)
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_latest_token_returns_none_for_empty_token(
|
||||
resolver_context, mock_saas_user_auth
|
||||
):
|
||||
"""Test that get_latest_token returns None when provider token has no value."""
|
||||
# Arrange - provider exists but token is None
|
||||
provider_tokens = {ProviderType.GITHUB: ProviderToken(token=None)}
|
||||
mock_saas_user_auth.get_provider_tokens = AsyncMock(return_value=provider_tokens)
|
||||
|
||||
# Act
|
||||
result = await resolver_context.get_latest_token(ProviderType.GITHUB)
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_latest_token_can_be_used_with_static_secret(
|
||||
resolver_context, mock_saas_user_auth
|
||||
):
|
||||
"""Test that get_latest_token result can be used directly with StaticSecret.
|
||||
|
||||
This is a critical integration test to ensure the return value is compatible
|
||||
with how it's used in _setup_secrets_for_git_providers.
|
||||
"""
|
||||
# Arrange
|
||||
token_value = 'ghp_integration_test_token'
|
||||
provider_tokens = create_provider_tokens({ProviderType.GITHUB: token_value})
|
||||
mock_saas_user_auth.get_provider_tokens = AsyncMock(return_value=provider_tokens)
|
||||
|
||||
# Act
|
||||
token = await resolver_context.get_latest_token(ProviderType.GITHUB)
|
||||
|
||||
# Assert - this should NOT raise a ValidationError
|
||||
static_secret = StaticSecret(value=token, description='GITHUB authentication token')
|
||||
assert static_secret.get_value() == token_value
|
||||
|
||||
@@ -6,18 +6,17 @@ import httpx
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
from server.routes.api_keys import (
|
||||
delete_byor_key_from_litellm,
|
||||
get_llm_api_key_for_byor,
|
||||
verify_byor_key_in_litellm,
|
||||
)
|
||||
from storage.lite_llm_manager import LiteLlmManager
|
||||
|
||||
|
||||
class TestVerifyByorKeyInLitellm:
|
||||
"""Test the verify_byor_key_in_litellm function."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'https://litellm.example.com')
|
||||
@patch('storage.lite_llm_manager.httpx.AsyncClient')
|
||||
@patch('server.routes.api_keys.LITE_LLM_API_URL', 'https://litellm.example.com')
|
||||
@patch('server.routes.api_keys.httpx.AsyncClient')
|
||||
async def test_verify_valid_key_returns_true(self, mock_client_class):
|
||||
"""Test that a valid key (200 response) returns True."""
|
||||
# Arrange
|
||||
@@ -33,7 +32,7 @@ class TestVerifyByorKeyInLitellm:
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
# Act
|
||||
result = await LiteLlmManager.verify_key(byor_key, user_id)
|
||||
result = await verify_byor_key_in_litellm(byor_key, user_id)
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
@@ -43,8 +42,8 @@ class TestVerifyByorKeyInLitellm:
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'https://litellm.example.com')
|
||||
@patch('storage.lite_llm_manager.httpx.AsyncClient')
|
||||
@patch('server.routes.api_keys.LITE_LLM_API_URL', 'https://litellm.example.com')
|
||||
@patch('server.routes.api_keys.httpx.AsyncClient')
|
||||
async def test_verify_invalid_key_401_returns_false(self, mock_client_class):
|
||||
"""Test that an invalid key (401 response) returns False."""
|
||||
# Arrange
|
||||
@@ -59,14 +58,14 @@ class TestVerifyByorKeyInLitellm:
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
# Act
|
||||
result = await LiteLlmManager.verify_key(byor_key, user_id)
|
||||
result = await verify_byor_key_in_litellm(byor_key, user_id)
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'https://litellm.example.com')
|
||||
@patch('storage.lite_llm_manager.httpx.AsyncClient')
|
||||
@patch('server.routes.api_keys.LITE_LLM_API_URL', 'https://litellm.example.com')
|
||||
@patch('server.routes.api_keys.httpx.AsyncClient')
|
||||
async def test_verify_invalid_key_403_returns_false(self, mock_client_class):
|
||||
"""Test that an invalid key (403 response) returns False."""
|
||||
# Arrange
|
||||
@@ -81,14 +80,14 @@ class TestVerifyByorKeyInLitellm:
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
# Act
|
||||
result = await LiteLlmManager.verify_key(byor_key, user_id)
|
||||
result = await verify_byor_key_in_litellm(byor_key, user_id)
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'https://litellm.example.com')
|
||||
@patch('storage.lite_llm_manager.httpx.AsyncClient')
|
||||
@patch('server.routes.api_keys.LITE_LLM_API_URL', 'https://litellm.example.com')
|
||||
@patch('server.routes.api_keys.httpx.AsyncClient')
|
||||
async def test_verify_server_error_returns_false(self, mock_client_class):
|
||||
"""Test that a server error (500) returns False to ensure key validity."""
|
||||
# Arrange
|
||||
@@ -104,14 +103,14 @@ class TestVerifyByorKeyInLitellm:
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
# Act
|
||||
result = await LiteLlmManager.verify_key(byor_key, user_id)
|
||||
result = await verify_byor_key_in_litellm(byor_key, user_id)
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'https://litellm.example.com')
|
||||
@patch('storage.lite_llm_manager.httpx.AsyncClient')
|
||||
@patch('server.routes.api_keys.LITE_LLM_API_URL', 'https://litellm.example.com')
|
||||
@patch('server.routes.api_keys.httpx.AsyncClient')
|
||||
async def test_verify_timeout_returns_false(self, mock_client_class):
|
||||
"""Test that a timeout returns False to ensure key validity."""
|
||||
# Arrange
|
||||
@@ -124,14 +123,14 @@ class TestVerifyByorKeyInLitellm:
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
# Act
|
||||
result = await LiteLlmManager.verify_key(byor_key, user_id)
|
||||
result = await verify_byor_key_in_litellm(byor_key, user_id)
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'https://litellm.example.com')
|
||||
@patch('storage.lite_llm_manager.httpx.AsyncClient')
|
||||
@patch('server.routes.api_keys.LITE_LLM_API_URL', 'https://litellm.example.com')
|
||||
@patch('server.routes.api_keys.httpx.AsyncClient')
|
||||
async def test_verify_network_error_returns_false(self, mock_client_class):
|
||||
"""Test that a network error returns False to ensure key validity."""
|
||||
# Arrange
|
||||
@@ -144,13 +143,13 @@ class TestVerifyByorKeyInLitellm:
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
# Act
|
||||
result = await LiteLlmManager.verify_key(byor_key, user_id)
|
||||
result = await verify_byor_key_in_litellm(byor_key, user_id)
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('storage.lite_llm_manager.LITE_LLM_API_URL', None)
|
||||
@patch('server.routes.api_keys.LITE_LLM_API_URL', None)
|
||||
async def test_verify_missing_api_url_returns_false(self):
|
||||
"""Test that missing LITE_LLM_API_URL returns False."""
|
||||
# Arrange
|
||||
@@ -158,13 +157,13 @@ class TestVerifyByorKeyInLitellm:
|
||||
user_id = 'user-123'
|
||||
|
||||
# Act
|
||||
result = await LiteLlmManager.verify_key(byor_key, user_id)
|
||||
result = await verify_byor_key_in_litellm(byor_key, user_id)
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'https://litellm.example.com')
|
||||
@patch('server.routes.api_keys.LITE_LLM_API_URL', 'https://litellm.example.com')
|
||||
async def test_verify_empty_key_returns_false(self):
|
||||
"""Test that empty key returns False."""
|
||||
# Arrange
|
||||
@@ -172,7 +171,7 @@ class TestVerifyByorKeyInLitellm:
|
||||
user_id = 'user-123'
|
||||
|
||||
# Act
|
||||
result = await LiteLlmManager.verify_key(byor_key, user_id)
|
||||
result = await verify_byor_key_in_litellm(byor_key, user_id)
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
@@ -206,7 +205,7 @@ class TestGetLlmApiKeyForByor:
|
||||
mock_store_key.assert_called_once_with(user_id, new_key)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('storage.lite_llm_manager.LiteLlmManager.verify_key')
|
||||
@patch('server.routes.api_keys.verify_byor_key_in_litellm')
|
||||
@patch('server.routes.api_keys.get_byor_key_from_db')
|
||||
async def test_valid_key_in_database_returns_key(
|
||||
self, mock_get_key, mock_verify_key
|
||||
@@ -230,7 +229,7 @@ class TestGetLlmApiKeyForByor:
|
||||
@patch('server.routes.api_keys.store_byor_key_in_db')
|
||||
@patch('server.routes.api_keys.generate_byor_key')
|
||||
@patch('server.routes.api_keys.delete_byor_key_from_litellm')
|
||||
@patch('storage.lite_llm_manager.LiteLlmManager.verify_key')
|
||||
@patch('server.routes.api_keys.verify_byor_key_in_litellm')
|
||||
@patch('server.routes.api_keys.get_byor_key_from_db')
|
||||
async def test_invalid_key_in_database_regenerates(
|
||||
self,
|
||||
@@ -266,7 +265,7 @@ class TestGetLlmApiKeyForByor:
|
||||
@patch('server.routes.api_keys.store_byor_key_in_db')
|
||||
@patch('server.routes.api_keys.generate_byor_key')
|
||||
@patch('server.routes.api_keys.delete_byor_key_from_litellm')
|
||||
@patch('storage.lite_llm_manager.LiteLlmManager.verify_key')
|
||||
@patch('server.routes.api_keys.verify_byor_key_in_litellm')
|
||||
@patch('server.routes.api_keys.get_byor_key_from_db')
|
||||
async def test_invalid_key_deletion_failure_still_regenerates(
|
||||
self,
|
||||
@@ -329,99 +328,3 @@ class TestGetLlmApiKeyForByor:
|
||||
|
||||
assert exc_info.value.status_code == 500
|
||||
assert 'Failed to retrieve BYOR LLM API key' in exc_info.value.detail
|
||||
|
||||
|
||||
class TestDeleteByorKeyFromLitellm:
|
||||
"""Test the delete_byor_key_from_litellm function with alias cleanup."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('storage.lite_llm_manager.LiteLlmManager.delete_key')
|
||||
@patch('storage.user_store.UserStore.get_user_by_id_async')
|
||||
async def test_delete_constructs_alias_from_user(
|
||||
self, mock_get_user, mock_delete_key
|
||||
):
|
||||
"""Test that delete_byor_key_from_litellm constructs key alias from user."""
|
||||
# Arrange
|
||||
user_id = 'user-123'
|
||||
org_id = 'org-456'
|
||||
byor_key = 'sk-byor-key-to-delete'
|
||||
expected_alias = f'BYOR Key - user {user_id}, org {org_id}'
|
||||
|
||||
mock_user = MagicMock()
|
||||
mock_user.current_org_id = org_id
|
||||
mock_get_user.return_value = mock_user
|
||||
mock_delete_key.return_value = None
|
||||
|
||||
# Act
|
||||
result = await delete_byor_key_from_litellm(user_id, byor_key)
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
mock_get_user.assert_called_once_with(user_id)
|
||||
mock_delete_key.assert_called_once_with(byor_key, key_alias=expected_alias)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('storage.lite_llm_manager.LiteLlmManager.delete_key')
|
||||
@patch('storage.user_store.UserStore.get_user_by_id_async')
|
||||
async def test_delete_without_user_passes_no_alias(
|
||||
self, mock_get_user, mock_delete_key
|
||||
):
|
||||
"""Test that when user is not found, no alias is passed."""
|
||||
# Arrange
|
||||
user_id = 'user-123'
|
||||
byor_key = 'sk-byor-key-to-delete'
|
||||
|
||||
mock_get_user.return_value = None
|
||||
mock_delete_key.return_value = None
|
||||
|
||||
# Act
|
||||
result = await delete_byor_key_from_litellm(user_id, byor_key)
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
mock_delete_key.assert_called_once_with(byor_key, key_alias=None)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('storage.lite_llm_manager.LiteLlmManager.delete_key')
|
||||
@patch('storage.user_store.UserStore.get_user_by_id_async')
|
||||
async def test_delete_without_org_id_passes_no_alias(
|
||||
self, mock_get_user, mock_delete_key
|
||||
):
|
||||
"""Test that when user has no current_org_id, no alias is passed."""
|
||||
# Arrange
|
||||
user_id = 'user-123'
|
||||
byor_key = 'sk-byor-key-to-delete'
|
||||
|
||||
mock_user = MagicMock()
|
||||
mock_user.current_org_id = None
|
||||
mock_get_user.return_value = mock_user
|
||||
mock_delete_key.return_value = None
|
||||
|
||||
# Act
|
||||
result = await delete_byor_key_from_litellm(user_id, byor_key)
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
mock_delete_key.assert_called_once_with(byor_key, key_alias=None)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('storage.lite_llm_manager.LiteLlmManager.delete_key')
|
||||
@patch('storage.user_store.UserStore.get_user_by_id_async')
|
||||
async def test_delete_returns_false_on_exception(
|
||||
self, mock_get_user, mock_delete_key
|
||||
):
|
||||
"""Test that exceptions during deletion return False."""
|
||||
# Arrange
|
||||
user_id = 'user-123'
|
||||
byor_key = 'sk-byor-key-to-delete'
|
||||
|
||||
mock_user = MagicMock()
|
||||
mock_user.current_org_id = 'org-456'
|
||||
mock_get_user.return_value = mock_user
|
||||
mock_delete_key.side_effect = Exception('LiteLLM API error')
|
||||
|
||||
# Act
|
||||
result = await delete_byor_key_from_litellm(user_id, byor_key)
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
@@ -1,96 +0,0 @@
|
||||
from unittest.mock import patch
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
from pydantic import SecretStr
|
||||
from server.routes.github_proxy import add_github_proxy_routes
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app_with_github_proxy(monkeypatch):
|
||||
"""Create a FastAPI app with github proxy routes enabled."""
|
||||
# Enable the github proxy endpoints
|
||||
monkeypatch.setenv('GITHUB_PROXY_ENDPOINTS', '1')
|
||||
|
||||
# Mock the config to have a jwt_secret
|
||||
mock_config = type(
|
||||
'MockConfig', (), {'jwt_secret': SecretStr('test-secret-key-for-testing')}
|
||||
)()
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
with patch('server.routes.github_proxy.GITHUB_PROXY_ENDPOINTS', True):
|
||||
with patch('server.routes.github_proxy.config', mock_config):
|
||||
add_github_proxy_routes(app)
|
||||
|
||||
# Return app and mock_config so we can use the same config in tests
|
||||
return app, mock_config
|
||||
|
||||
|
||||
def test_state_compress_encrypt_and_decrypt_decompress_roundtrip(
|
||||
app_with_github_proxy, monkeypatch
|
||||
):
|
||||
"""
|
||||
Verify the code path used by github_proxy_start -> github_proxy_callback:
|
||||
- compress payload, encrypt, base64-encode (what the start code does)
|
||||
- base64-decode, decrypt, decompress (what the callback code does)
|
||||
|
||||
This test exercises the actual endpoints to verify the roundtrip works correctly.
|
||||
"""
|
||||
app, mock_config = app_with_github_proxy
|
||||
client = TestClient(app)
|
||||
|
||||
original_state = 'some-state-value'
|
||||
original_redirect_uri = 'https://example.com/redirect'
|
||||
|
||||
# Call github_proxy_start endpoint - it should redirect to GitHub with encrypted state
|
||||
with patch('server.routes.github_proxy.config', mock_config):
|
||||
response = client.get(
|
||||
'/github-proxy/test-subdomain/login/oauth/authorize',
|
||||
params={
|
||||
'state': original_state,
|
||||
'redirect_uri': original_redirect_uri,
|
||||
'client_id': 'test-client-id',
|
||||
},
|
||||
follow_redirects=False,
|
||||
)
|
||||
|
||||
assert response.status_code == 307
|
||||
redirect_url = response.headers['location']
|
||||
|
||||
# Verify it redirects to GitHub
|
||||
assert redirect_url.startswith('https://github.com/login/oauth/authorize')
|
||||
|
||||
# Parse the redirect URL to get the encrypted state
|
||||
parsed = urlparse(redirect_url)
|
||||
query_params = parse_qs(parsed.query)
|
||||
encrypted_state = query_params['state'][0]
|
||||
|
||||
# The redirect_uri should now point to our callback
|
||||
assert 'github-proxy/callback' in query_params['redirect_uri'][0]
|
||||
|
||||
# Now simulate GitHub calling back with this encrypted state
|
||||
with patch('server.routes.github_proxy.config', mock_config):
|
||||
callback_response = client.get(
|
||||
'/github-proxy/callback',
|
||||
params={
|
||||
'state': encrypted_state,
|
||||
'code': 'test-auth-code',
|
||||
},
|
||||
follow_redirects=False,
|
||||
)
|
||||
|
||||
assert callback_response.status_code == 307
|
||||
final_redirect = callback_response.headers['location']
|
||||
|
||||
# Verify the callback redirects to the original redirect_uri
|
||||
assert final_redirect.startswith(original_redirect_uri)
|
||||
|
||||
# Parse the final redirect to verify the state was decrypted correctly
|
||||
final_parsed = urlparse(final_redirect)
|
||||
final_params = parse_qs(final_parsed.query)
|
||||
|
||||
assert final_params['state'][0] == original_state
|
||||
assert final_params['code'][0] == 'test-auth-code'
|
||||
@@ -1,5 +1,3 @@
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
from datetime import datetime
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
@@ -21,7 +19,6 @@ from server.routes.integration.jira import (
|
||||
jira_events,
|
||||
unlink_workspace,
|
||||
validate_workspace_integration,
|
||||
verify_jira_signature,
|
||||
)
|
||||
|
||||
|
||||
@@ -64,35 +61,25 @@ def mock_user_auth():
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('server.routes.integration.jira.verify_jira_signature', new_callable=AsyncMock)
|
||||
@patch('server.routes.integration.jira.jira_manager', new_callable=AsyncMock)
|
||||
@patch('server.routes.integration.jira.redis_client', new_callable=MagicMock)
|
||||
async def test_jira_events_invalid_signature(mock_redis, mock_verify, mock_request):
|
||||
async def test_jira_events_invalid_signature(mock_redis, mock_manager, mock_request):
|
||||
with patch('server.routes.integration.jira.JIRA_WEBHOOKS_ENABLED', True):
|
||||
mock_request.body = AsyncMock(return_value=b'{}')
|
||||
mock_request.json = AsyncMock(return_value={})
|
||||
mock_verify.side_effect = HTTPException(
|
||||
status_code=403, detail="Request signatures didn't match!"
|
||||
)
|
||||
mock_manager.validate_request.return_value = (False, None, None)
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await jira_events(
|
||||
mock_request, MagicMock(), x_hub_signature='sha256=invalid'
|
||||
)
|
||||
await jira_events(mock_request, MagicMock())
|
||||
assert exc_info.value.status_code == 403
|
||||
assert exc_info.value.detail == "Request signatures didn't match!"
|
||||
assert exc_info.value.detail == 'Invalid webhook signature!'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('server.routes.integration.jira.verify_jira_signature', new_callable=AsyncMock)
|
||||
@patch('server.routes.integration.jira.jira_manager', new_callable=AsyncMock)
|
||||
@patch('server.routes.integration.jira.redis_client')
|
||||
async def test_jira_events_duplicate_request(mock_redis, mock_verify, mock_request):
|
||||
async def test_jira_events_duplicate_request(mock_redis, mock_manager, mock_request):
|
||||
with patch('server.routes.integration.jira.JIRA_WEBHOOKS_ENABLED', True):
|
||||
mock_request.body = AsyncMock(return_value=b'{}')
|
||||
mock_request.json = AsyncMock(return_value={})
|
||||
mock_verify.return_value = None
|
||||
mock_manager.validate_request.return_value = (True, 'sig123', 'payload')
|
||||
mock_redis.exists.return_value = True
|
||||
response = await jira_events(
|
||||
mock_request, MagicMock(), x_hub_signature='sha256=sig123'
|
||||
)
|
||||
response = await jira_events(mock_request, MagicMock())
|
||||
assert response.status_code == 200
|
||||
body = json.loads(response.body)
|
||||
assert body['success'] is True
|
||||
@@ -361,21 +348,18 @@ class TestJiraLinkCreateValidation:
|
||||
# Test jira_events error scenarios
|
||||
@pytest.mark.asyncio
|
||||
@patch('server.routes.integration.jira.jira_manager', new_callable=AsyncMock)
|
||||
@patch('server.routes.integration.jira.verify_jira_signature', new_callable=AsyncMock)
|
||||
@patch('server.routes.integration.jira.redis_client', new_callable=MagicMock)
|
||||
async def test_jira_events_processing_success(
|
||||
mock_redis, mock_verify, mock_manager, mock_request
|
||||
):
|
||||
async def test_jira_events_processing_success(mock_redis, mock_manager, mock_request):
|
||||
with patch('server.routes.integration.jira.JIRA_WEBHOOKS_ENABLED', True):
|
||||
mock_request.body = AsyncMock(return_value=b'{"test": "payload"}')
|
||||
mock_request.json = AsyncMock(return_value={'test': 'payload'})
|
||||
mock_verify.return_value = None
|
||||
mock_manager.validate_request.return_value = (
|
||||
True,
|
||||
'sig123',
|
||||
{'test': 'payload'},
|
||||
)
|
||||
mock_redis.exists.return_value = False
|
||||
|
||||
background_tasks = MagicMock()
|
||||
response = await jira_events(
|
||||
mock_request, background_tasks, x_hub_signature='sha256=sig123'
|
||||
)
|
||||
response = await jira_events(mock_request, background_tasks)
|
||||
|
||||
assert response.status_code == 200
|
||||
body = json.loads(response.body)
|
||||
@@ -385,241 +369,19 @@ async def test_jira_events_processing_success(
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('server.routes.integration.jira.verify_jira_signature', new_callable=AsyncMock)
|
||||
@patch('server.routes.integration.jira.jira_manager', new_callable=AsyncMock)
|
||||
@patch('server.routes.integration.jira.redis_client', new_callable=MagicMock)
|
||||
async def test_jira_events_general_exception(mock_redis, mock_verify, mock_request):
|
||||
async def test_jira_events_general_exception(mock_redis, mock_manager, mock_request):
|
||||
with patch('server.routes.integration.jira.JIRA_WEBHOOKS_ENABLED', True):
|
||||
mock_request.body = AsyncMock(side_effect=Exception('Unexpected error'))
|
||||
mock_request.json = AsyncMock(return_value={})
|
||||
mock_manager.validate_request.side_effect = Exception('Unexpected error')
|
||||
|
||||
response = await jira_events(
|
||||
mock_request, MagicMock(), x_hub_signature='sha256=sig123'
|
||||
)
|
||||
response = await jira_events(mock_request, MagicMock())
|
||||
|
||||
assert response.status_code == 500
|
||||
body = json.loads(response.body)
|
||||
assert 'Internal server error processing webhook' in body['error']
|
||||
|
||||
|
||||
# Test verify_jira_signature
|
||||
class TestVerifyJiraSignature:
|
||||
"""Test Jira webhook signature verification."""
|
||||
|
||||
@pytest.fixture
|
||||
def sample_payload(self):
|
||||
"""Sample webhook payload with comment_created event."""
|
||||
return {
|
||||
'webhookEvent': 'comment_created',
|
||||
'comment': {
|
||||
'body': 'Test comment @openhands',
|
||||
'author': {
|
||||
'emailAddress': 'user@test.com',
|
||||
'displayName': 'Test User',
|
||||
'self': 'https://test.atlassian.net/rest/api/2/user?accountId=123',
|
||||
},
|
||||
},
|
||||
'issue': {
|
||||
'id': '12345',
|
||||
'key': 'TEST-123',
|
||||
'self': 'https://test.atlassian.net/rest/api/2/issue/12345',
|
||||
},
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def mock_workspace(self):
|
||||
"""Create a mock workspace."""
|
||||
workspace = MagicMock()
|
||||
workspace.id = 1
|
||||
workspace.name = 'test.atlassian.net'
|
||||
workspace.status = 'active'
|
||||
workspace.webhook_secret = 'encrypted_secret'
|
||||
return workspace
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
'signature,expected_detail',
|
||||
[
|
||||
(None, 'x-hub-signature header is missing!'),
|
||||
('', 'x-hub-signature header is missing!'),
|
||||
],
|
||||
ids=['signature_none', 'signature_empty'],
|
||||
)
|
||||
async def test_missing_signature(self, signature, expected_detail, sample_payload):
|
||||
"""Test that missing or empty signature raises HTTPException."""
|
||||
body = json.dumps(sample_payload).encode()
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await verify_jira_signature(body, signature, sample_payload)
|
||||
|
||||
assert exc_info.value.status_code == 403
|
||||
assert exc_info.value.detail == expected_detail
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
'payload',
|
||||
[
|
||||
{'webhookEvent': 'unknown_event'},
|
||||
{'webhookEvent': 'comment_created', 'comment': {}},
|
||||
{'webhookEvent': 'comment_created', 'comment': {'author': {}}},
|
||||
{'webhookEvent': 'jira:issue_updated', 'user': {}},
|
||||
{},
|
||||
],
|
||||
ids=[
|
||||
'unknown_event',
|
||||
'missing_author',
|
||||
'missing_self_url',
|
||||
'issue_updated_missing_self',
|
||||
'empty_payload',
|
||||
],
|
||||
)
|
||||
@patch('server.routes.integration.jira.jira_manager')
|
||||
async def test_workspace_name_not_found(self, mock_manager, payload):
|
||||
"""Test that missing workspace name in payload raises HTTPException."""
|
||||
mock_manager.get_workspace_name_from_payload.return_value = None
|
||||
body = json.dumps(payload).encode()
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await verify_jira_signature(body, 'valid_signature', payload)
|
||||
|
||||
assert exc_info.value.status_code == 403
|
||||
assert exc_info.value.detail == 'Workspace name not found in payload'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('server.routes.integration.jira.jira_manager')
|
||||
async def test_workspace_not_found_in_database(self, mock_manager, sample_payload):
|
||||
"""Test that workspace not found in database raises HTTPException."""
|
||||
mock_manager.get_workspace_name_from_payload.return_value = 'test.atlassian.net'
|
||||
mock_manager.integration_store.get_workspace_by_name = AsyncMock(
|
||||
return_value=None
|
||||
)
|
||||
body = json.dumps(sample_payload).encode()
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await verify_jira_signature(body, 'valid_signature', sample_payload)
|
||||
|
||||
assert exc_info.value.status_code == 403
|
||||
assert exc_info.value.detail == 'Unidentified workspace'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
'workspace_status',
|
||||
['inactive', 'disabled', 'pending'],
|
||||
ids=['inactive', 'disabled', 'pending'],
|
||||
)
|
||||
@patch('server.routes.integration.jira.jira_manager')
|
||||
async def test_workspace_not_active(
|
||||
self, mock_manager, workspace_status, sample_payload, mock_workspace
|
||||
):
|
||||
"""Test that inactive workspace raises HTTPException."""
|
||||
mock_workspace.status = workspace_status
|
||||
mock_manager.get_workspace_name_from_payload.return_value = 'test.atlassian.net'
|
||||
mock_manager.integration_store.get_workspace_by_name = AsyncMock(
|
||||
return_value=mock_workspace
|
||||
)
|
||||
body = json.dumps(sample_payload).encode()
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await verify_jira_signature(body, 'valid_signature', sample_payload)
|
||||
|
||||
assert exc_info.value.status_code == 403
|
||||
assert exc_info.value.detail == 'Workspace is inactive'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('server.routes.integration.jira.token_manager')
|
||||
@patch('server.routes.integration.jira.jira_manager')
|
||||
async def test_signature_mismatch(
|
||||
self, mock_manager, mock_token_mgr, sample_payload, mock_workspace
|
||||
):
|
||||
"""Test that signature mismatch raises HTTPException."""
|
||||
mock_manager.get_workspace_name_from_payload.return_value = 'test.atlassian.net'
|
||||
mock_manager.integration_store.get_workspace_by_name = AsyncMock(
|
||||
return_value=mock_workspace
|
||||
)
|
||||
mock_token_mgr.decrypt_text.return_value = 'webhook_secret'
|
||||
body = json.dumps(sample_payload).encode()
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await verify_jira_signature(body, 'invalid_signature', sample_payload)
|
||||
|
||||
assert exc_info.value.status_code == 403
|
||||
assert exc_info.value.detail == "Request signatures didn't match!"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('server.routes.integration.jira.token_manager')
|
||||
@patch('server.routes.integration.jira.jira_manager')
|
||||
async def test_valid_signature(
|
||||
self, mock_manager, mock_token_mgr, sample_payload, mock_workspace
|
||||
):
|
||||
"""Test that valid signature passes verification."""
|
||||
webhook_secret = 'webhook_secret'
|
||||
mock_manager.get_workspace_name_from_payload.return_value = 'test.atlassian.net'
|
||||
mock_manager.integration_store.get_workspace_by_name = AsyncMock(
|
||||
return_value=mock_workspace
|
||||
)
|
||||
mock_token_mgr.decrypt_text.return_value = webhook_secret
|
||||
|
||||
body = json.dumps(sample_payload).encode()
|
||||
valid_signature = hmac.new(
|
||||
webhook_secret.encode(), body, hashlib.sha256
|
||||
).hexdigest()
|
||||
|
||||
# Should not raise any exception
|
||||
result = await verify_jira_signature(body, valid_signature, sample_payload)
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
'event_type,payload_key,author_key',
|
||||
[
|
||||
('comment_created', 'comment', 'author'),
|
||||
('jira:issue_updated', 'user', None),
|
||||
],
|
||||
ids=['comment_created', 'issue_updated'],
|
||||
)
|
||||
@patch('server.routes.integration.jira.token_manager')
|
||||
@patch('server.routes.integration.jira.jira_manager')
|
||||
async def test_valid_signature_different_events(
|
||||
self,
|
||||
mock_manager,
|
||||
mock_token_mgr,
|
||||
event_type,
|
||||
payload_key,
|
||||
author_key,
|
||||
mock_workspace,
|
||||
):
|
||||
"""Test valid signature verification for different webhook events."""
|
||||
webhook_secret = 'webhook_secret'
|
||||
mock_manager.get_workspace_name_from_payload.return_value = 'test.atlassian.net'
|
||||
mock_manager.integration_store.get_workspace_by_name = AsyncMock(
|
||||
return_value=mock_workspace
|
||||
)
|
||||
mock_token_mgr.decrypt_text.return_value = webhook_secret
|
||||
|
||||
if event_type == 'comment_created':
|
||||
payload = {
|
||||
'webhookEvent': event_type,
|
||||
'comment': {
|
||||
'body': 'Test',
|
||||
'author': {
|
||||
'self': 'https://test.atlassian.net/rest/api/2/user?id=1'
|
||||
},
|
||||
},
|
||||
}
|
||||
else:
|
||||
payload = {
|
||||
'webhookEvent': event_type,
|
||||
'user': {'self': 'https://test.atlassian.net/rest/api/2/user?id=1'},
|
||||
}
|
||||
|
||||
body = json.dumps(payload).encode()
|
||||
valid_signature = hmac.new(
|
||||
webhook_secret.encode(), body, hashlib.sha256
|
||||
).hexdigest()
|
||||
|
||||
result = await verify_jira_signature(body, valid_signature, payload)
|
||||
assert result is None
|
||||
|
||||
|
||||
# Test create_jira_workspace error scenarios
|
||||
@pytest.mark.asyncio
|
||||
@patch('server.routes.integration.jira.get_user_auth')
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -82,7 +82,7 @@ class TestGetUserId:
|
||||
session_maker_with_minimal_fixtures,
|
||||
):
|
||||
user_id = _get_user_id('mock-conversation-id')
|
||||
assert user_id == '5594c7b6-f959-4b81-92e9-b09c206f5081'
|
||||
assert user_id == 'mock-user-id'
|
||||
|
||||
def test_get_user_id_conversation_not_found(self, session_maker):
|
||||
"""Test getting user ID when conversation doesn't exist."""
|
||||
@@ -105,12 +105,10 @@ class TestGetSessionApiKey:
|
||||
return_value=[mock_agent_loop_info]
|
||||
)
|
||||
|
||||
api_key = await _get_session_api_key(
|
||||
'5594c7b6-f959-4b81-92e9-b09c206f5081', 'conv-456'
|
||||
)
|
||||
api_key = await _get_session_api_key('user-123', 'conv-456')
|
||||
assert api_key == 'test-api-key'
|
||||
mock_manager.get_agent_loop_info.assert_called_once_with(
|
||||
'5594c7b6-f959-4b81-92e9-b09c206f5081', filter_to_sids={'conv-456'}
|
||||
'user-123', filter_to_sids={'conv-456'}
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -120,9 +118,7 @@ class TestGetSessionApiKey:
|
||||
mock_manager.get_agent_loop_info = AsyncMock(return_value=[])
|
||||
|
||||
with pytest.raises(IndexError):
|
||||
await _get_session_api_key(
|
||||
'5594c7b6-f959-4b81-92e9-b09c206f5081', 'conv-456'
|
||||
)
|
||||
await _get_session_api_key('user-123', 'conv-456')
|
||||
|
||||
|
||||
class TestProcessEvent:
|
||||
@@ -146,15 +142,10 @@ class TestProcessEvent:
|
||||
mock_event = MagicMock()
|
||||
mock_event_from_dict.return_value = mock_event
|
||||
|
||||
await process_event(
|
||||
'5594c7b6-f959-4b81-92e9-b09c206f5081',
|
||||
'conv-456',
|
||||
'events/event-1.json',
|
||||
content,
|
||||
)
|
||||
await process_event('user-123', 'conv-456', 'events/event-1.json', content)
|
||||
|
||||
mock_file_store.write.assert_called_once_with(
|
||||
'users/5594c7b6-f959-4b81-92e9-b09c206f5081/conversations/conv-456/events/event-1.json',
|
||||
'users/user-123/conversations/conv-456/events/event-1.json',
|
||||
json.dumps(content),
|
||||
)
|
||||
mock_event_from_dict.assert_called_once_with(content)
|
||||
@@ -186,19 +177,14 @@ class TestProcessEvent:
|
||||
)
|
||||
mock_event_from_dict.return_value = mock_event
|
||||
|
||||
await process_event(
|
||||
'5594c7b6-f959-4b81-92e9-b09c206f5081',
|
||||
'conv-456',
|
||||
'events/event-1.json',
|
||||
content,
|
||||
)
|
||||
await process_event('user-123', 'conv-456', 'events/event-1.json', content)
|
||||
|
||||
mock_file_store.write.assert_called_once()
|
||||
mock_event_from_dict.assert_called_once_with(content)
|
||||
mock_invoke_callbacks.assert_called_once_with('conv-456', mock_event)
|
||||
mock_update_working_seconds.assert_called_once()
|
||||
mock_event_store_class.assert_called_once_with(
|
||||
'conv-456', mock_file_store, '5594c7b6-f959-4b81-92e9-b09c206f5081'
|
||||
'conv-456', mock_file_store, 'user-123'
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -226,12 +212,7 @@ class TestProcessEvent:
|
||||
mock_event.agent_state = 'running' # Set RUNNING state to skip the update
|
||||
mock_event_from_dict.return_value = mock_event
|
||||
|
||||
await process_event(
|
||||
'5594c7b6-f959-4b81-92e9-b09c206f5081',
|
||||
'conv-456',
|
||||
'events/event-1.json',
|
||||
content,
|
||||
)
|
||||
await process_event('user-123', 'conv-456', 'events/event-1.json', content)
|
||||
|
||||
mock_file_store.write.assert_called_once()
|
||||
mock_event_from_dict.assert_called_once_with(content)
|
||||
|
||||
@@ -1,368 +0,0 @@
|
||||
"""Tests for SaasSQLAppConversationInfoService.
|
||||
|
||||
This module tests the SAAS implementation of SQLAppConversationInfoService,
|
||||
focusing on user isolation, SAAS metadata handling, and multi-tenant functionality.
|
||||
"""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import AsyncGenerator
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.pool import StaticPool
|
||||
|
||||
from enterprise.server.utils.saas_app_conversation_info_injector import (
|
||||
SaasSQLAppConversationInfoService,
|
||||
)
|
||||
from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
AppConversationInfo,
|
||||
)
|
||||
from openhands.app_server.user.specifiy_user_context import SpecifyUserContext
|
||||
from openhands.app_server.utils.sql_utils import Base
|
||||
from openhands.integrations.service_types import ProviderType
|
||||
from openhands.storage.data_models.conversation_metadata import ConversationTrigger
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def async_engine():
|
||||
"""Create an async SQLite engine for testing."""
|
||||
engine = create_async_engine(
|
||||
'sqlite+aiosqlite:///:memory:',
|
||||
poolclass=StaticPool,
|
||||
connect_args={'check_same_thread': False},
|
||||
echo=False,
|
||||
)
|
||||
|
||||
# Create all tables
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
yield engine
|
||||
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def async_session(async_engine) -> AsyncGenerator[AsyncSession, None]:
|
||||
"""Create an async session for testing."""
|
||||
async_session_maker = async_sessionmaker(
|
||||
async_engine, class_=AsyncSession, expire_on_commit=False
|
||||
)
|
||||
|
||||
async with async_session_maker() as db_session:
|
||||
yield db_session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def service(async_session) -> SaasSQLAppConversationInfoService:
|
||||
"""Create a SQLAppConversationInfoService instance for testing."""
|
||||
return SaasSQLAppConversationInfoService(
|
||||
db_session=async_session, user_context=SpecifyUserContext(user_id=None)
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def service_with_user(async_session) -> SaasSQLAppConversationInfoService:
|
||||
"""Create a SQLAppConversationInfoService instance with a user_id for testing."""
|
||||
return SaasSQLAppConversationInfoService(
|
||||
db_session=async_session,
|
||||
user_context=SpecifyUserContext(user_id='a1111111-1111-1111-1111-111111111111'),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_conversation_info() -> AppConversationInfo:
|
||||
"""Create a sample AppConversationInfo for testing."""
|
||||
return AppConversationInfo(
|
||||
id=uuid4(),
|
||||
created_by_user_id='a1111111-1111-1111-1111-111111111111',
|
||||
sandbox_id='sandbox_123',
|
||||
selected_repository='https://github.com/test/repo',
|
||||
selected_branch='main',
|
||||
git_provider=ProviderType.GITHUB,
|
||||
title='Test Conversation',
|
||||
trigger=ConversationTrigger.GUI,
|
||||
pr_number=[123, 456],
|
||||
llm_model='gpt-4',
|
||||
metrics=None,
|
||||
created_at=datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc),
|
||||
updated_at=datetime(2024, 1, 1, 12, 30, 0, tzinfo=timezone.utc),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def multiple_conversation_infos() -> list[AppConversationInfo]:
|
||||
"""Create multiple AppConversationInfo instances for testing."""
|
||||
base_time = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc)
|
||||
|
||||
return [
|
||||
AppConversationInfo(
|
||||
id=uuid4(),
|
||||
created_by_user_id=None,
|
||||
sandbox_id=f'sandbox_{i}',
|
||||
selected_repository=f'https://github.com/test/repo{i}',
|
||||
selected_branch='main',
|
||||
git_provider=ProviderType.GITHUB,
|
||||
title=f'Test Conversation {i}',
|
||||
trigger=ConversationTrigger.GUI,
|
||||
pr_number=[i * 100],
|
||||
llm_model='gpt-4',
|
||||
metrics=None,
|
||||
created_at=base_time.replace(hour=12 + i),
|
||||
updated_at=base_time.replace(hour=12 + i, minute=30),
|
||||
)
|
||||
for i in range(1, 6) # Create 5 conversations
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_session():
|
||||
"""Create a mock database session."""
|
||||
return AsyncMock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def user1_context():
|
||||
"""Create user context for user1."""
|
||||
return SpecifyUserContext(user_id='a1111111-1111-1111-1111-111111111111')
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def user2_context():
|
||||
"""Create user context for user2."""
|
||||
return SpecifyUserContext(user_id='b2222222-2222-2222-2222-222222222222')
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def saas_service_user1(mock_db_session, user1_context):
|
||||
"""Create a SaasSQLAppConversationInfoService instance for user1."""
|
||||
return SaasSQLAppConversationInfoService(
|
||||
db_session=mock_db_session, user_context=user1_context
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def saas_service_user2(mock_db_session, user2_context):
|
||||
"""Create a SaasSQLAppConversationInfoService instance for user2."""
|
||||
return SaasSQLAppConversationInfoService(
|
||||
db_session=mock_db_session, user_context=user2_context
|
||||
)
|
||||
|
||||
|
||||
class TestSaasSQLAppConversationInfoService:
|
||||
"""Test suite for SaasSQLAppConversationInfoService."""
|
||||
|
||||
def test_service_initialization(
|
||||
self,
|
||||
saas_service_user1: SaasSQLAppConversationInfoService,
|
||||
user1_context: SpecifyUserContext,
|
||||
):
|
||||
"""Test that the SAAS service is properly initialized."""
|
||||
assert saas_service_user1.user_context == user1_context
|
||||
assert saas_service_user1.db_session is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_context_isolation(
|
||||
self,
|
||||
saas_service_user1: SaasSQLAppConversationInfoService,
|
||||
saas_service_user2: SaasSQLAppConversationInfoService,
|
||||
):
|
||||
"""Test that different service instances have different user contexts."""
|
||||
user1_id = await saas_service_user1.user_context.get_user_id()
|
||||
user2_id = await saas_service_user2.user_context.get_user_id()
|
||||
|
||||
assert user1_id == 'a1111111-1111-1111-1111-111111111111'
|
||||
assert user2_id == 'b2222222-2222-2222-2222-222222222222'
|
||||
assert user1_id != user2_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_secure_select_includes_user_filtering(
|
||||
self,
|
||||
saas_service_user1: SaasSQLAppConversationInfoService,
|
||||
):
|
||||
"""Test that _secure_select method includes user filtering."""
|
||||
# This test verifies that the _secure_select method exists and can be called
|
||||
# The actual SQL generation is tested implicitly through integration
|
||||
query = await saas_service_user1._secure_select()
|
||||
assert query is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_to_info_with_user_id_functionality(
|
||||
self,
|
||||
saas_service_user1: SaasSQLAppConversationInfoService,
|
||||
):
|
||||
"""Test that _to_info_with_user_id properly sets user_id from SAAS metadata."""
|
||||
from storage.stored_conversation_metadata import StoredConversationMetadata
|
||||
from storage.stored_conversation_metadata_saas import (
|
||||
StoredConversationMetadataSaas,
|
||||
)
|
||||
|
||||
# Create mock metadata objects
|
||||
stored_metadata = MagicMock(spec=StoredConversationMetadata)
|
||||
stored_metadata.conversation_id = '12345678-1234-5678-1234-567812345678'
|
||||
stored_metadata.parent_conversation_id = None
|
||||
stored_metadata.title = 'Test Conversation'
|
||||
stored_metadata.sandbox_id = 'test-sandbox'
|
||||
stored_metadata.selected_repository = None
|
||||
stored_metadata.selected_branch = None
|
||||
stored_metadata.git_provider = None
|
||||
stored_metadata.trigger = None
|
||||
stored_metadata.pr_number = []
|
||||
stored_metadata.llm_model = None
|
||||
from datetime import datetime, timezone
|
||||
|
||||
stored_metadata.created_at = datetime.now(timezone.utc)
|
||||
stored_metadata.last_updated_at = datetime.now(timezone.utc)
|
||||
stored_metadata.accumulated_cost = 0.0
|
||||
stored_metadata.prompt_tokens = 0
|
||||
stored_metadata.completion_tokens = 0
|
||||
stored_metadata.total_tokens = 0
|
||||
stored_metadata.max_budget_per_task = None
|
||||
stored_metadata.cache_read_tokens = 0
|
||||
stored_metadata.cache_write_tokens = 0
|
||||
stored_metadata.reasoning_tokens = 0
|
||||
stored_metadata.context_window = 0
|
||||
stored_metadata.per_turn_token = 0
|
||||
|
||||
saas_metadata = MagicMock(spec=StoredConversationMetadataSaas)
|
||||
saas_metadata.user_id = UUID('a1111111-1111-1111-1111-111111111111')
|
||||
saas_metadata.org_id = UUID('a1111111-1111-1111-1111-111111111111')
|
||||
|
||||
# Test the _to_info_with_user_id method
|
||||
result = saas_service_user1._to_info_with_user_id(
|
||||
stored_metadata, saas_metadata
|
||||
)
|
||||
|
||||
# Verify that the user_id from SAAS metadata is used
|
||||
assert result.created_by_user_id == 'a1111111-1111-1111-1111-111111111111'
|
||||
assert result.title == 'Test Conversation'
|
||||
assert result.sandbox_id == 'test-sandbox'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_isolation(
|
||||
self,
|
||||
async_session: AsyncSession,
|
||||
multiple_conversation_infos: list[AppConversationInfo],
|
||||
):
|
||||
"""Test that user isolation works correctly."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from storage.user import User
|
||||
|
||||
# Mock the database session execute method to return mock users
|
||||
# This mock intercepts User queries and returns a mock user object
|
||||
# with user_id and org_id the same as the user_id_uuid from the query
|
||||
original_execute = async_session.execute
|
||||
|
||||
async def mock_execute(query):
|
||||
query_str = str(query)
|
||||
|
||||
# Check if this is a User query
|
||||
if '"user"' in query_str.lower() and '"user".id' in query_str.lower():
|
||||
# Extract the UUID from the query parameters
|
||||
# The query will have bound parameters, we need to get the UUID value
|
||||
if hasattr(query, 'compile'):
|
||||
try:
|
||||
compiled = query.compile(compile_kwargs={'literal_binds': True})
|
||||
query_with_params = str(compiled)
|
||||
|
||||
# Extract UUID from the query string
|
||||
import re
|
||||
|
||||
# Try both formats: with dashes and without dashes
|
||||
uuid_pattern_with_dashes = r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}'
|
||||
uuid_pattern_without_dashes = r'[a-f0-9]{32}'
|
||||
|
||||
uuid_match = re.search(
|
||||
uuid_pattern_with_dashes, query_with_params
|
||||
)
|
||||
if not uuid_match:
|
||||
uuid_match = re.search(
|
||||
uuid_pattern_without_dashes, query_with_params
|
||||
)
|
||||
|
||||
if uuid_match:
|
||||
user_id_str = uuid_match.group(0)
|
||||
# If the UUID doesn't have dashes, add them
|
||||
if len(user_id_str) == 32 and '-' not in user_id_str:
|
||||
# Convert from 'a1111111111111111111111111111111' to 'a1111111-1111-1111-1111-111111111111'
|
||||
user_id_str = f'{user_id_str[:8]}-{user_id_str[8:12]}-{user_id_str[12:16]}-{user_id_str[16:20]}-{user_id_str[20:]}'
|
||||
user_id_uuid = UUID(user_id_str)
|
||||
|
||||
# Create a mock user with user_id and org_id the same as user_id_uuid
|
||||
mock_user = MagicMock(spec=User)
|
||||
mock_user.id = user_id_uuid
|
||||
mock_user.current_org_id = user_id_uuid
|
||||
|
||||
# Create a mock result
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalar_one_or_none.return_value = mock_user
|
||||
return mock_result
|
||||
except Exception:
|
||||
# If there's any error in parsing, fall back to original execute
|
||||
pass
|
||||
|
||||
# For all other queries, use the original execute method
|
||||
return await original_execute(query)
|
||||
|
||||
# Apply the mock
|
||||
async_session.execute = mock_execute
|
||||
|
||||
# Create services for different users
|
||||
user1_service = SaasSQLAppConversationInfoService(
|
||||
db_session=async_session,
|
||||
user_context=SpecifyUserContext(
|
||||
user_id='a1111111-1111-1111-1111-111111111111'
|
||||
),
|
||||
)
|
||||
user2_service = SaasSQLAppConversationInfoService(
|
||||
db_session=async_session,
|
||||
user_context=SpecifyUserContext(
|
||||
user_id='b2222222-2222-2222-2222-222222222222'
|
||||
),
|
||||
)
|
||||
|
||||
# Create conversations for different users
|
||||
user1_info = AppConversationInfo(
|
||||
id=uuid4(),
|
||||
created_by_user_id='a1111111-1111-1111-1111-111111111111',
|
||||
sandbox_id='sandbox_user1',
|
||||
title='User 1 Conversation',
|
||||
)
|
||||
|
||||
user2_info = AppConversationInfo(
|
||||
id=uuid4(),
|
||||
created_by_user_id='b2222222-2222-2222-2222-222222222222',
|
||||
sandbox_id='sandbox_user2',
|
||||
title='User 2 Conversation',
|
||||
)
|
||||
|
||||
# Save conversations
|
||||
await user1_service.save_app_conversation_info(user1_info)
|
||||
await user2_service.save_app_conversation_info(user2_info)
|
||||
|
||||
# User 1 should only see their conversation
|
||||
user1_page = await user1_service.search_app_conversation_info()
|
||||
assert len(user1_page.items) == 1
|
||||
assert (
|
||||
user1_page.items[0].created_by_user_id
|
||||
== 'a1111111-1111-1111-1111-111111111111'
|
||||
)
|
||||
|
||||
# User 2 should only see their conversation
|
||||
user2_page = await user2_service.search_app_conversation_info()
|
||||
assert len(user2_page.items) == 1
|
||||
assert (
|
||||
user2_page.items[0].created_by_user_id
|
||||
== 'b2222222-2222-2222-2222-222222222222'
|
||||
)
|
||||
|
||||
# User 1 should not be able to get user 2's conversation
|
||||
user2_from_user1 = await user1_service.get_app_conversation_info(user2_info.id)
|
||||
assert user2_from_user1 is None
|
||||
|
||||
# User 2 should not be able to get user 1's conversation
|
||||
user1_from_user2 = await user2_service.get_app_conversation_info(user1_info.id)
|
||||
assert user1_from_user2 is None
|
||||
@@ -1,5 +1,5 @@
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from storage.api_key_store import ApiKeyStore
|
||||
@@ -19,14 +19,6 @@ def mock_session_maker(mock_session):
|
||||
return session_maker
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_user():
|
||||
"""Mock user with org_id."""
|
||||
user = MagicMock()
|
||||
user.current_org_id = 'test-org-123'
|
||||
return user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def api_key_store(mock_session_maker):
|
||||
return ApiKeyStore(mock_session_maker)
|
||||
@@ -41,13 +33,11 @@ def test_generate_api_key(api_key_store):
|
||||
assert len(key) == len('sk-oh-') + 32
|
||||
|
||||
|
||||
@patch('storage.api_key_store.UserStore.get_user_by_id')
|
||||
def test_create_api_key(mock_get_user, api_key_store, mock_session, mock_user):
|
||||
def test_create_api_key(api_key_store, mock_session):
|
||||
"""Test creating an API key."""
|
||||
# Setup
|
||||
user_id = 'test-user-123'
|
||||
name = 'Test Key'
|
||||
mock_get_user.return_value = mock_user
|
||||
api_key_store.generate_api_key = MagicMock(return_value='test-api-key')
|
||||
|
||||
# Execute
|
||||
@@ -55,15 +45,10 @@ def test_create_api_key(mock_get_user, api_key_store, mock_session, mock_user):
|
||||
|
||||
# Verify
|
||||
assert result == 'test-api-key'
|
||||
mock_get_user.assert_called_once_with(user_id)
|
||||
mock_session.add.assert_called_once()
|
||||
mock_session.commit.assert_called_once()
|
||||
api_key_store.generate_api_key.assert_called_once()
|
||||
|
||||
# Verify the ApiKey was created with the correct org_id
|
||||
added_api_key = mock_session.add.call_args[0][0]
|
||||
assert added_api_key.org_id == mock_user.current_org_id
|
||||
|
||||
|
||||
def test_validate_api_key_valid(api_key_store, mock_session):
|
||||
"""Test validating a valid API key."""
|
||||
@@ -219,12 +204,10 @@ def test_delete_api_key_by_id(api_key_store, mock_session):
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
|
||||
@patch('storage.api_key_store.UserStore.get_user_by_id')
|
||||
def test_list_api_keys(mock_get_user, api_key_store, mock_session, mock_user):
|
||||
def test_list_api_keys(api_key_store, mock_session):
|
||||
"""Test listing API keys for a user."""
|
||||
# Setup
|
||||
user_id = 'test-user-123'
|
||||
mock_get_user.return_value = mock_user
|
||||
now = datetime.now(UTC)
|
||||
mock_key1 = MagicMock()
|
||||
mock_key1.id = 1
|
||||
@@ -240,17 +223,15 @@ def test_list_api_keys(mock_get_user, api_key_store, mock_session, mock_user):
|
||||
mock_key2.last_used_at = None
|
||||
mock_key2.expires_at = None
|
||||
|
||||
# Mock the chained query calls for filtering by user_id and org_id
|
||||
mock_query = mock_session.query.return_value
|
||||
mock_filter_user = mock_query.filter.return_value
|
||||
mock_filter_org = mock_filter_user.filter.return_value
|
||||
mock_filter_org.all.return_value = [mock_key1, mock_key2]
|
||||
mock_session.query.return_value.filter.return_value.all.return_value = [
|
||||
mock_key1,
|
||||
mock_key2,
|
||||
]
|
||||
|
||||
# Execute
|
||||
result = api_key_store.list_api_keys(user_id)
|
||||
|
||||
# Verify
|
||||
mock_get_user.assert_called_once_with(user_id)
|
||||
assert len(result) == 2
|
||||
assert result[0]['id'] == 1
|
||||
assert result[0]['name'] == 'Key 1'
|
||||
@@ -263,59 +244,3 @@ def test_list_api_keys(mock_get_user, api_key_store, mock_session, mock_user):
|
||||
assert result[1]['created_at'] == now
|
||||
assert result[1]['last_used_at'] is None
|
||||
assert result[1]['expires_at'] is None
|
||||
|
||||
|
||||
@patch('storage.api_key_store.UserStore.get_user_by_id')
|
||||
def test_retrieve_mcp_api_key(mock_get_user, api_key_store, mock_session, mock_user):
|
||||
"""Test retrieving MCP API key for a user."""
|
||||
# Setup
|
||||
user_id = 'test-user-123'
|
||||
mock_get_user.return_value = mock_user
|
||||
|
||||
mock_mcp_key = MagicMock()
|
||||
mock_mcp_key.name = 'MCP_API_KEY'
|
||||
mock_mcp_key.key = 'mcp-test-key'
|
||||
|
||||
mock_other_key = MagicMock()
|
||||
mock_other_key.name = 'Other Key'
|
||||
mock_other_key.key = 'other-test-key'
|
||||
|
||||
# Mock the chained query calls for filtering by user_id and org_id
|
||||
mock_query = mock_session.query.return_value
|
||||
mock_filter_user = mock_query.filter.return_value
|
||||
mock_filter_org = mock_filter_user.filter.return_value
|
||||
mock_filter_org.all.return_value = [mock_other_key, mock_mcp_key]
|
||||
|
||||
# Execute
|
||||
result = api_key_store.retrieve_mcp_api_key(user_id)
|
||||
|
||||
# Verify
|
||||
mock_get_user.assert_called_once_with(user_id)
|
||||
assert result == 'mcp-test-key'
|
||||
|
||||
|
||||
@patch('storage.api_key_store.UserStore.get_user_by_id')
|
||||
def test_retrieve_mcp_api_key_not_found(
|
||||
mock_get_user, api_key_store, mock_session, mock_user
|
||||
):
|
||||
"""Test retrieving MCP API key when none exists."""
|
||||
# Setup
|
||||
user_id = 'test-user-123'
|
||||
mock_get_user.return_value = mock_user
|
||||
|
||||
mock_other_key = MagicMock()
|
||||
mock_other_key.name = 'Other Key'
|
||||
mock_other_key.key = 'other-test-key'
|
||||
|
||||
# Mock the chained query calls for filtering by user_id and org_id
|
||||
mock_query = mock_session.query.return_value
|
||||
mock_filter_user = mock_query.filter.return_value
|
||||
mock_filter_org = mock_filter_user.filter.return_value
|
||||
mock_filter_org.all.return_value = [mock_other_key]
|
||||
|
||||
# Execute
|
||||
result = api_key_store.retrieve_mcp_api_key(user_id)
|
||||
|
||||
# Verify
|
||||
mock_get_user.assert_called_once_with(user_id)
|
||||
assert result is None
|
||||
|
||||
@@ -130,7 +130,6 @@ async def test_keycloak_callback_user_not_allowed(mock_request):
|
||||
with (
|
||||
patch('server.routes.auth.token_manager') as mock_token_manager,
|
||||
patch('server.routes.auth.user_verifier') as mock_verifier,
|
||||
patch('server.routes.auth.UserStore') as mock_user_store,
|
||||
):
|
||||
mock_token_manager.get_keycloak_tokens = AsyncMock(
|
||||
return_value=('test_access_token', 'test_refresh_token')
|
||||
@@ -145,15 +144,6 @@ async def test_keycloak_callback_user_not_allowed(mock_request):
|
||||
)
|
||||
mock_token_manager.store_idp_tokens = AsyncMock()
|
||||
|
||||
# Mock the user creation
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = 'test_user_id'
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user.accepted_tos = None
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.migrate_user = AsyncMock(return_value=mock_user)
|
||||
|
||||
mock_verifier.is_active.return_value = True
|
||||
mock_verifier.is_user_allowed.return_value = False
|
||||
|
||||
@@ -175,19 +165,20 @@ async def test_keycloak_callback_success_with_valid_offline_token(mock_request):
|
||||
patch('server.routes.auth.token_manager') as mock_token_manager,
|
||||
patch('server.routes.auth.user_verifier') as mock_verifier,
|
||||
patch('server.routes.auth.set_response_cookie') as mock_set_cookie,
|
||||
patch('server.routes.auth.UserStore') as mock_user_store,
|
||||
patch('server.routes.auth.session_maker') as mock_session_maker,
|
||||
patch('server.routes.auth.posthog') as mock_posthog,
|
||||
):
|
||||
# Mock user with accepted_tos
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = 'test_user_id'
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user.accepted_tos = '2025-01-01'
|
||||
# Mock the session and query results
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.filter.return_value = mock_query
|
||||
|
||||
# Setup UserStore mocks
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.migrate_user = AsyncMock(return_value=mock_user)
|
||||
# Mock user settings with accepted_tos
|
||||
mock_user_settings = MagicMock()
|
||||
mock_user_settings.accepted_tos = '2025-01-01'
|
||||
mock_query.first.return_value = mock_user_settings
|
||||
|
||||
mock_token_manager.get_keycloak_tokens = AsyncMock(
|
||||
return_value=('test_access_token', 'test_refresh_token')
|
||||
@@ -237,7 +228,6 @@ async def test_keycloak_callback_email_not_verified(mock_request):
|
||||
patch('server.routes.auth.token_manager') as mock_token_manager,
|
||||
patch('server.routes.auth.user_verifier') as mock_verifier,
|
||||
patch('server.routes.email.verify_email', mock_verify_email),
|
||||
patch('server.routes.auth.UserStore') as mock_user_store,
|
||||
):
|
||||
mock_token_manager.get_keycloak_tokens = AsyncMock(
|
||||
return_value=('test_access_token', 'test_refresh_token')
|
||||
@@ -253,13 +243,6 @@ async def test_keycloak_callback_email_not_verified(mock_request):
|
||||
mock_token_manager.store_idp_tokens = AsyncMock()
|
||||
mock_verifier.is_active.return_value = False
|
||||
|
||||
# Mock the user creation
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = 'test_user_id'
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
|
||||
# Act
|
||||
result = await keycloak_callback(
|
||||
code='test_code', state='test_state', request=mock_request
|
||||
@@ -284,7 +267,6 @@ async def test_keycloak_callback_email_not_verified_missing_field(mock_request):
|
||||
patch('server.routes.auth.token_manager') as mock_token_manager,
|
||||
patch('server.routes.auth.user_verifier') as mock_verifier,
|
||||
patch('server.routes.email.verify_email', mock_verify_email),
|
||||
patch('server.routes.auth.UserStore') as mock_user_store,
|
||||
):
|
||||
mock_token_manager.get_keycloak_tokens = AsyncMock(
|
||||
return_value=('test_access_token', 'test_refresh_token')
|
||||
@@ -300,13 +282,6 @@ async def test_keycloak_callback_email_not_verified_missing_field(mock_request):
|
||||
mock_token_manager.store_idp_tokens = AsyncMock()
|
||||
mock_verifier.is_active.return_value = False
|
||||
|
||||
# Mock the user creation
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = 'test_user_id'
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
|
||||
# Act
|
||||
result = await keycloak_callback(
|
||||
code='test_code', state='test_state', request=mock_request
|
||||
@@ -334,20 +309,20 @@ async def test_keycloak_callback_success_without_offline_token(mock_request):
|
||||
),
|
||||
patch('server.routes.auth.KEYCLOAK_REALM_NAME', 'test-realm'),
|
||||
patch('server.routes.auth.KEYCLOAK_CLIENT_ID', 'test-client'),
|
||||
patch('server.routes.auth.UserStore') as mock_user_store,
|
||||
patch('server.routes.auth.session_maker') as mock_session_maker,
|
||||
patch('server.routes.auth.posthog') as mock_posthog,
|
||||
):
|
||||
# Mock user with accepted_tos
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = 'test_user_id'
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user.accepted_tos = '2025-01-01'
|
||||
|
||||
# Setup UserStore mocks
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.migrate_user = AsyncMock(return_value=mock_user)
|
||||
# Mock the session and query results
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.filter.return_value = mock_query
|
||||
|
||||
# Mock user settings with accepted_tos
|
||||
mock_user_settings = MagicMock()
|
||||
mock_user_settings.accepted_tos = '2025-01-01'
|
||||
mock_query.first.return_value = mock_user_settings
|
||||
mock_token_manager.get_keycloak_tokens = AsyncMock(
|
||||
return_value=('test_access_token', 'test_refresh_token')
|
||||
)
|
||||
@@ -560,7 +535,6 @@ async def test_keycloak_callback_blocked_email_domain(mock_request):
|
||||
with (
|
||||
patch('server.routes.auth.token_manager') as mock_token_manager,
|
||||
patch('server.routes.auth.domain_blocker') as mock_domain_blocker,
|
||||
patch('server.routes.auth.UserStore') as mock_user_store,
|
||||
):
|
||||
mock_token_manager.get_keycloak_tokens = AsyncMock(
|
||||
return_value=('test_access_token', 'test_refresh_token')
|
||||
@@ -575,14 +549,6 @@ async def test_keycloak_callback_blocked_email_domain(mock_request):
|
||||
)
|
||||
mock_token_manager.disable_keycloak_user = AsyncMock()
|
||||
|
||||
# Mock the user creation
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = 'test_user_id'
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
|
||||
mock_domain_blocker.is_active.return_value = True
|
||||
mock_domain_blocker.is_domain_blocked.return_value = True
|
||||
|
||||
# Act
|
||||
@@ -610,7 +576,6 @@ async def test_keycloak_callback_allowed_email_domain(mock_request):
|
||||
patch('server.routes.auth.domain_blocker') as mock_domain_blocker,
|
||||
patch('server.routes.auth.user_verifier') as mock_verifier,
|
||||
patch('server.routes.auth.session_maker') as mock_session_maker,
|
||||
patch('server.routes.auth.UserStore') as mock_user_store,
|
||||
):
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
@@ -637,15 +602,6 @@ async def test_keycloak_callback_allowed_email_domain(mock_request):
|
||||
mock_token_manager.store_idp_tokens = AsyncMock()
|
||||
mock_token_manager.validate_offline_token = AsyncMock(return_value=True)
|
||||
|
||||
# Mock the user creation
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = 'test_user_id'
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user.accepted_tos = '2025-01-01'
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
|
||||
mock_domain_blocker.is_active.return_value = True
|
||||
mock_domain_blocker.is_domain_blocked.return_value = False
|
||||
|
||||
mock_verifier.is_active.return_value = True
|
||||
@@ -673,7 +629,6 @@ async def test_keycloak_callback_domain_blocking_inactive(mock_request):
|
||||
patch('server.routes.auth.domain_blocker') as mock_domain_blocker,
|
||||
patch('server.routes.auth.user_verifier') as mock_verifier,
|
||||
patch('server.routes.auth.session_maker') as mock_session_maker,
|
||||
patch('server.routes.auth.UserStore') as mock_user_store,
|
||||
):
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
@@ -700,15 +655,6 @@ async def test_keycloak_callback_domain_blocking_inactive(mock_request):
|
||||
mock_token_manager.store_idp_tokens = AsyncMock()
|
||||
mock_token_manager.validate_offline_token = AsyncMock(return_value=True)
|
||||
|
||||
# Mock the user creation
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = 'test_user_id'
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user.accepted_tos = '2025-01-01'
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
|
||||
mock_domain_blocker.is_active.return_value = False
|
||||
mock_domain_blocker.is_domain_blocked.return_value = False
|
||||
|
||||
mock_verifier.is_active.return_value = True
|
||||
@@ -734,7 +680,6 @@ async def test_keycloak_callback_missing_email(mock_request):
|
||||
patch('server.routes.auth.domain_blocker') as mock_domain_blocker,
|
||||
patch('server.routes.auth.user_verifier') as mock_verifier,
|
||||
patch('server.routes.auth.session_maker') as mock_session_maker,
|
||||
patch('server.routes.auth.UserStore') as mock_user_store,
|
||||
):
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
@@ -761,16 +706,6 @@ async def test_keycloak_callback_missing_email(mock_request):
|
||||
mock_token_manager.store_idp_tokens = AsyncMock()
|
||||
mock_token_manager.validate_offline_token = AsyncMock(return_value=True)
|
||||
|
||||
# Mock the user creation
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = 'test_user_id'
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user.accepted_tos = '2025-01-01'
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
|
||||
mock_domain_blocker.is_active.return_value = True
|
||||
|
||||
mock_verifier.is_active.return_value = True
|
||||
mock_verifier.is_user_allowed.return_value = True
|
||||
|
||||
@@ -790,7 +725,6 @@ async def test_keycloak_callback_duplicate_email_detected(mock_request):
|
||||
"""Test keycloak_callback when duplicate email is detected."""
|
||||
with (
|
||||
patch('server.routes.auth.token_manager') as mock_token_manager,
|
||||
patch('server.routes.auth.UserStore') as mock_user_store,
|
||||
):
|
||||
# Arrange
|
||||
mock_token_manager.get_keycloak_tokens = AsyncMock(
|
||||
@@ -807,13 +741,6 @@ async def test_keycloak_callback_duplicate_email_detected(mock_request):
|
||||
mock_token_manager.check_duplicate_base_email = AsyncMock(return_value=True)
|
||||
mock_token_manager.delete_keycloak_user = AsyncMock(return_value=True)
|
||||
|
||||
# Mock the user creation
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = 'test_user_id'
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
|
||||
# Act
|
||||
result = await keycloak_callback(
|
||||
code='test_code', state='test_state', request=mock_request
|
||||
@@ -834,7 +761,6 @@ async def test_keycloak_callback_duplicate_email_deletion_fails(mock_request):
|
||||
"""Test keycloak_callback when duplicate is detected but deletion fails."""
|
||||
with (
|
||||
patch('server.routes.auth.token_manager') as mock_token_manager,
|
||||
patch('server.routes.auth.UserStore') as mock_user_store,
|
||||
):
|
||||
# Arrange
|
||||
mock_token_manager.get_keycloak_tokens = AsyncMock(
|
||||
@@ -851,13 +777,6 @@ async def test_keycloak_callback_duplicate_email_deletion_fails(mock_request):
|
||||
mock_token_manager.check_duplicate_base_email = AsyncMock(return_value=True)
|
||||
mock_token_manager.delete_keycloak_user = AsyncMock(return_value=False)
|
||||
|
||||
# Mock the user creation
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = 'test_user_id'
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
|
||||
# Act
|
||||
result = await keycloak_callback(
|
||||
code='test_code', state='test_state', request=mock_request
|
||||
@@ -877,7 +796,6 @@ async def test_keycloak_callback_duplicate_check_exception(mock_request):
|
||||
patch('server.routes.auth.token_manager') as mock_token_manager,
|
||||
patch('server.routes.auth.user_verifier') as mock_verifier,
|
||||
patch('server.routes.auth.session_maker') as mock_session_maker,
|
||||
patch('server.routes.auth.UserStore') as mock_user_store,
|
||||
):
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
@@ -907,14 +825,6 @@ async def test_keycloak_callback_duplicate_check_exception(mock_request):
|
||||
mock_token_manager.store_idp_tokens = AsyncMock()
|
||||
mock_token_manager.validate_offline_token = AsyncMock(return_value=True)
|
||||
|
||||
# Mock the user creation
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = 'test_user_id'
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user.accepted_tos = '2025-01-01'
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
|
||||
mock_verifier.is_active.return_value = True
|
||||
mock_verifier.is_user_allowed.return_value = True
|
||||
|
||||
@@ -936,7 +846,6 @@ async def test_keycloak_callback_no_duplicate_email(mock_request):
|
||||
patch('server.routes.auth.token_manager') as mock_token_manager,
|
||||
patch('server.routes.auth.user_verifier') as mock_verifier,
|
||||
patch('server.routes.auth.session_maker') as mock_session_maker,
|
||||
patch('server.routes.auth.UserStore') as mock_user_store,
|
||||
):
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
@@ -964,14 +873,6 @@ async def test_keycloak_callback_no_duplicate_email(mock_request):
|
||||
mock_token_manager.store_idp_tokens = AsyncMock()
|
||||
mock_token_manager.validate_offline_token = AsyncMock(return_value=True)
|
||||
|
||||
# Mock the user creation
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = 'test_user_id'
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user.accepted_tos = '2025-01-01'
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
|
||||
mock_verifier.is_active.return_value = True
|
||||
mock_verifier.is_user_allowed.return_value = True
|
||||
|
||||
@@ -997,7 +898,6 @@ async def test_keycloak_callback_no_email_in_user_info(mock_request):
|
||||
patch('server.routes.auth.token_manager') as mock_token_manager,
|
||||
patch('server.routes.auth.user_verifier') as mock_verifier,
|
||||
patch('server.routes.auth.session_maker') as mock_session_maker,
|
||||
patch('server.routes.auth.UserStore') as mock_user_store,
|
||||
):
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
@@ -1024,14 +924,6 @@ async def test_keycloak_callback_no_email_in_user_info(mock_request):
|
||||
mock_token_manager.store_idp_tokens = AsyncMock()
|
||||
mock_token_manager.validate_offline_token = AsyncMock(return_value=True)
|
||||
|
||||
# Mock the user creation
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = 'test_user_id'
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user.accepted_tos = '2025-01-01'
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
|
||||
mock_verifier.is_active.return_value = True
|
||||
mock_verifier.is_user_allowed.return_value = True
|
||||
|
||||
@@ -1151,7 +1043,6 @@ class TestKeycloakCallbackRecaptcha:
|
||||
patch('server.routes.auth.set_response_cookie'),
|
||||
patch('server.routes.auth.posthog'),
|
||||
patch('server.routes.email.verify_email', new_callable=AsyncMock),
|
||||
patch('server.routes.auth.UserStore') as mock_user_store,
|
||||
):
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
@@ -1180,14 +1071,6 @@ class TestKeycloakCallbackRecaptcha:
|
||||
return_value=False
|
||||
)
|
||||
|
||||
# Setup UserStore mocks
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = 'test_user_id'
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user.accepted_tos = '2025-01-01'
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
|
||||
mock_verifier.is_active.return_value = True
|
||||
mock_verifier.is_user_allowed.return_value = True
|
||||
|
||||
@@ -1229,7 +1112,6 @@ class TestKeycloakCallbackRecaptcha:
|
||||
patch('server.routes.auth.recaptcha_service') as mock_recaptcha_service,
|
||||
patch('server.routes.auth.RECAPTCHA_SITE_KEY', 'test-site-key'),
|
||||
patch('server.routes.auth.domain_blocker') as mock_domain_blocker,
|
||||
patch('server.routes.auth.UserStore') as mock_user_store,
|
||||
):
|
||||
mock_token_manager.get_keycloak_tokens = AsyncMock(
|
||||
return_value=('test_access_token', 'test_refresh_token')
|
||||
@@ -1245,13 +1127,6 @@ class TestKeycloakCallbackRecaptcha:
|
||||
return_value=False
|
||||
)
|
||||
|
||||
# Setup UserStore mocks
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = 'test_user_id'
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
|
||||
mock_domain_blocker.is_domain_blocked.return_value = False
|
||||
|
||||
# Patch the module-level recaptcha_service instance
|
||||
@@ -1297,7 +1172,6 @@ class TestKeycloakCallbackRecaptcha:
|
||||
patch('server.routes.auth.set_response_cookie'),
|
||||
patch('server.routes.auth.posthog'),
|
||||
patch('server.routes.email.verify_email', new_callable=AsyncMock),
|
||||
patch('server.routes.auth.UserStore') as mock_user_store,
|
||||
):
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
@@ -1326,14 +1200,6 @@ class TestKeycloakCallbackRecaptcha:
|
||||
return_value=False
|
||||
)
|
||||
|
||||
# Setup UserStore mocks
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = 'test_user_id'
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user.accepted_tos = '2025-01-01'
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
|
||||
mock_verifier.is_active.return_value = True
|
||||
mock_verifier.is_user_allowed.return_value = True
|
||||
|
||||
@@ -1384,7 +1250,6 @@ class TestKeycloakCallbackRecaptcha:
|
||||
patch('server.routes.auth.set_response_cookie'),
|
||||
patch('server.routes.auth.posthog'),
|
||||
patch('server.routes.email.verify_email', new_callable=AsyncMock),
|
||||
patch('server.routes.auth.UserStore') as mock_user_store,
|
||||
):
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
@@ -1413,14 +1278,6 @@ class TestKeycloakCallbackRecaptcha:
|
||||
return_value=False
|
||||
)
|
||||
|
||||
# Setup UserStore mocks
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = 'test_user_id'
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user.accepted_tos = '2025-01-01'
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
|
||||
mock_verifier.is_active.return_value = True
|
||||
mock_verifier.is_user_allowed.return_value = True
|
||||
|
||||
@@ -1468,7 +1325,6 @@ class TestKeycloakCallbackRecaptcha:
|
||||
patch('server.routes.auth.set_response_cookie'),
|
||||
patch('server.routes.auth.posthog'),
|
||||
patch('server.routes.email.verify_email', new_callable=AsyncMock),
|
||||
patch('server.routes.auth.UserStore') as mock_user_store,
|
||||
):
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
@@ -1497,14 +1353,6 @@ class TestKeycloakCallbackRecaptcha:
|
||||
return_value=False
|
||||
)
|
||||
|
||||
# Setup UserStore mocks
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = 'test_user_id'
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user.accepted_tos = '2025-01-01'
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
|
||||
mock_verifier.is_active.return_value = True
|
||||
mock_verifier.is_user_allowed.return_value = True
|
||||
|
||||
@@ -1551,7 +1399,6 @@ class TestKeycloakCallbackRecaptcha:
|
||||
patch('server.routes.auth.set_response_cookie'),
|
||||
patch('server.routes.auth.posthog'),
|
||||
patch('server.routes.email.verify_email', new_callable=AsyncMock),
|
||||
patch('server.routes.auth.UserStore') as mock_user_store,
|
||||
):
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
@@ -1580,14 +1427,6 @@ class TestKeycloakCallbackRecaptcha:
|
||||
return_value=False
|
||||
)
|
||||
|
||||
# Setup UserStore mocks
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = 'test_user_id'
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user.accepted_tos = '2025-01-01'
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
|
||||
mock_verifier.is_active.return_value = True
|
||||
mock_verifier.is_user_allowed.return_value = True
|
||||
|
||||
@@ -1631,7 +1470,6 @@ class TestKeycloakCallbackRecaptcha:
|
||||
patch('server.routes.auth.set_response_cookie'),
|
||||
patch('server.routes.auth.posthog'),
|
||||
patch('server.routes.email.verify_email', new_callable=AsyncMock),
|
||||
patch('server.routes.auth.UserStore') as mock_user_store,
|
||||
):
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
@@ -1660,14 +1498,6 @@ class TestKeycloakCallbackRecaptcha:
|
||||
return_value=False
|
||||
)
|
||||
|
||||
# Setup UserStore mocks
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = 'test_user_id'
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user.accepted_tos = '2025-01-01'
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
|
||||
mock_verifier.is_active.return_value = True
|
||||
mock_verifier.is_user_allowed.return_value = True
|
||||
|
||||
@@ -1697,7 +1527,6 @@ class TestKeycloakCallbackRecaptcha:
|
||||
patch('server.routes.auth.set_response_cookie'),
|
||||
patch('server.routes.auth.posthog'),
|
||||
patch('server.routes.email.verify_email', new_callable=AsyncMock),
|
||||
patch('server.routes.auth.UserStore') as mock_user_store,
|
||||
):
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
@@ -1726,14 +1555,6 @@ class TestKeycloakCallbackRecaptcha:
|
||||
return_value=False
|
||||
)
|
||||
|
||||
# Setup UserStore mocks
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = 'test_user_id'
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user.accepted_tos = '2025-01-01'
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
|
||||
mock_verifier.is_active.return_value = True
|
||||
mock_verifier.is_user_allowed.return_value = True
|
||||
|
||||
@@ -1769,7 +1590,6 @@ class TestKeycloakCallbackRecaptcha:
|
||||
patch('server.routes.auth.set_response_cookie'),
|
||||
patch('server.routes.auth.posthog'),
|
||||
patch('server.routes.auth.logger') as mock_logger,
|
||||
patch('server.routes.auth.UserStore') as mock_user_store,
|
||||
):
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
@@ -1798,14 +1618,6 @@ class TestKeycloakCallbackRecaptcha:
|
||||
return_value=False
|
||||
)
|
||||
|
||||
# Setup UserStore mocks
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = 'test_user_id'
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user.accepted_tos = '2025-01-01'
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
|
||||
mock_verifier.is_active.return_value = True
|
||||
mock_verifier.is_user_allowed.return_value = True
|
||||
|
||||
@@ -1853,7 +1665,6 @@ class TestKeycloakCallbackRecaptcha:
|
||||
patch('server.routes.auth.domain_blocker') as mock_domain_blocker,
|
||||
patch('server.routes.auth.logger') as mock_logger,
|
||||
patch('server.routes.email.verify_email', new_callable=AsyncMock),
|
||||
patch('server.routes.auth.UserStore') as mock_user_store,
|
||||
):
|
||||
mock_token_manager.get_keycloak_tokens = AsyncMock(
|
||||
return_value=('test_access_token', 'test_refresh_token')
|
||||
@@ -1869,13 +1680,6 @@ class TestKeycloakCallbackRecaptcha:
|
||||
return_value=False
|
||||
)
|
||||
|
||||
# Setup UserStore mocks
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = 'test_user_id'
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
|
||||
mock_domain_blocker.is_domain_blocked.return_value = False
|
||||
|
||||
# Patch the module-level recaptcha_service instance
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import uuid
|
||||
from decimal import Decimal
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
@@ -6,21 +5,22 @@ import pytest
|
||||
import stripe
|
||||
from fastapi import HTTPException, Request, status
|
||||
from httpx import Response
|
||||
from server.routes import billing
|
||||
from integrations.stripe_service import has_payment_method
|
||||
from server.routes.billing import (
|
||||
CreateBillingSessionResponse,
|
||||
CreateCheckoutSessionRequest,
|
||||
GetCreditsResponse,
|
||||
cancel_callback,
|
||||
cancel_subscription,
|
||||
create_checkout_session,
|
||||
create_customer_setup_session,
|
||||
create_subscription_checkout_session,
|
||||
get_credits,
|
||||
has_payment_method,
|
||||
success_callback,
|
||||
)
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from starlette.datastructures import URL
|
||||
from storage.billing_session_type import BillingSessionType
|
||||
from storage.stripe_customer import Base as StripeCustomerBase
|
||||
|
||||
|
||||
@@ -78,32 +78,28 @@ def mock_subscription_request():
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_credits_lite_llm_error():
|
||||
with (
|
||||
patch('integrations.stripe_service.STRIPE_API_KEY', 'mock_key'),
|
||||
patch(
|
||||
'storage.user_store.UserStore.get_user_by_id_async',
|
||||
new_callable=AsyncMock,
|
||||
return_value=MagicMock(current_org_id='mock_org_id'),
|
||||
),
|
||||
patch(
|
||||
'storage.lite_llm_manager.LiteLlmManager.get_user_team_info',
|
||||
side_effect=Exception('LiteLLM API Error'),
|
||||
),
|
||||
):
|
||||
with pytest.raises(Exception, match='LiteLLM API Error'):
|
||||
await get_credits('mock_user')
|
||||
mock_response = Response(
|
||||
status_code=500, json={'error': 'Internal Server Error'}, request=MagicMock()
|
||||
)
|
||||
mock_client = AsyncMock()
|
||||
mock_client.__aenter__.return_value.get.return_value = mock_response
|
||||
|
||||
with patch('integrations.stripe_service.STRIPE_API_KEY', 'mock_key'):
|
||||
with patch('httpx.AsyncClient', return_value=mock_client):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await get_credits('mock_user')
|
||||
assert exc_info.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
assert (
|
||||
exc_info.value.detail
|
||||
== 'Failed to retrieve credit balance from billing service'
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_credits_success():
|
||||
mock_response = Response(
|
||||
status_code=200,
|
||||
json={
|
||||
'user_info': {
|
||||
'spend': 25.50,
|
||||
'litellm_budget_table': {'max_budget': 100.00},
|
||||
}
|
||||
},
|
||||
json={'user_info': {'max_budget': 100.00, 'spend': 25.50}},
|
||||
request=MagicMock(),
|
||||
)
|
||||
mock_client = AsyncMock()
|
||||
@@ -112,23 +108,24 @@ async def test_get_credits_success():
|
||||
with (
|
||||
patch('integrations.stripe_service.STRIPE_API_KEY', 'mock_key'),
|
||||
patch('httpx.AsyncClient', return_value=mock_client),
|
||||
patch(
|
||||
'storage.user_store.UserStore.get_user_by_id_async',
|
||||
new_callable=AsyncMock,
|
||||
return_value=MagicMock(current_org_id='mock_org_id'),
|
||||
),
|
||||
patch(
|
||||
'storage.lite_llm_manager.LiteLlmManager.get_user_team_info',
|
||||
return_value={
|
||||
'spend': 25.50,
|
||||
'litellm_budget_table': {'max_budget': 100.00},
|
||||
},
|
||||
),
|
||||
):
|
||||
result = await get_credits('mock_user')
|
||||
with patch('server.routes.billing.session_maker') as mock_session_maker:
|
||||
mock_db_session = MagicMock()
|
||||
mock_db_session.query.return_value.filter.return_value.first.return_value = MagicMock(
|
||||
billing_margin=4
|
||||
)
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_db_session
|
||||
|
||||
assert isinstance(result, GetCreditsResponse)
|
||||
assert result.credits == Decimal('74.50') # 100.00 - 25.50 = 74.50
|
||||
result = await get_credits('mock_user')
|
||||
|
||||
assert isinstance(result, GetCreditsResponse)
|
||||
assert result.credits == Decimal(
|
||||
'74.50'
|
||||
) # 100.00 - 25.50 = 74.50 (no billing margin applied)
|
||||
mock_client.__aenter__.return_value.get.assert_called_once_with(
|
||||
'https://llm-proxy.app.all-hands.dev/user/info?user_id=mock_user',
|
||||
headers={'x-goog-api-key': None},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -141,9 +138,6 @@ async def test_create_checkout_session_stripe_error(
|
||||
id='mock-customer', metadata={'user_id': 'mock-user'}
|
||||
)
|
||||
mock_customer_create = AsyncMock(return_value=mock_customer)
|
||||
mock_org = MagicMock()
|
||||
mock_org.id = uuid.uuid4()
|
||||
mock_org.contact_email = 'testy@tester.com'
|
||||
with (
|
||||
pytest.raises(Exception, match='Stripe API Error'),
|
||||
patch('stripe.Customer.create_async', mock_customer_create),
|
||||
@@ -155,15 +149,11 @@ async def test_create_checkout_session_stripe_error(
|
||||
AsyncMock(side_effect=Exception('Stripe API Error')),
|
||||
),
|
||||
patch('integrations.stripe_service.session_maker', session_maker),
|
||||
patch(
|
||||
'storage.org_store.OrgStore.get_current_org_from_keycloak_user_id',
|
||||
return_value=mock_org,
|
||||
),
|
||||
patch(
|
||||
'server.auth.token_manager.TokenManager.get_user_info_from_user_id',
|
||||
AsyncMock(return_value={'email': 'testy@tester.com'}),
|
||||
),
|
||||
patch('server.routes.billing.validate_billing_enabled'),
|
||||
patch('server.routes.billing.validate_saas_environment'),
|
||||
):
|
||||
await create_checkout_session(
|
||||
CreateCheckoutSessionRequest(amount=25), mock_checkout_request, 'mock_user'
|
||||
@@ -184,10 +174,6 @@ async def test_create_checkout_session_success(session_maker, mock_checkout_requ
|
||||
id='mock-customer', metadata={'user_id': 'mock-user'}
|
||||
)
|
||||
mock_customer_create = AsyncMock(return_value=mock_customer)
|
||||
mock_org = MagicMock()
|
||||
mock_org_id = uuid.uuid4()
|
||||
mock_org.id = mock_org_id
|
||||
mock_org.contact_email = 'testy@tester.com'
|
||||
with (
|
||||
patch('stripe.Customer.create_async', mock_customer_create),
|
||||
patch(
|
||||
@@ -196,15 +182,11 @@ async def test_create_checkout_session_success(session_maker, mock_checkout_requ
|
||||
patch('stripe.checkout.Session.create_async', mock_create),
|
||||
patch('server.routes.billing.session_maker') as mock_session_maker,
|
||||
patch('integrations.stripe_service.session_maker', session_maker),
|
||||
patch(
|
||||
'storage.org_store.OrgStore.get_current_org_from_keycloak_user_id',
|
||||
return_value=mock_org,
|
||||
),
|
||||
patch(
|
||||
'server.auth.token_manager.TokenManager.get_user_info_from_user_id',
|
||||
AsyncMock(return_value={'email': 'testy@tester.com'}),
|
||||
),
|
||||
patch('server.routes.billing.validate_billing_enabled'),
|
||||
patch('server.routes.billing.validate_saas_environment'),
|
||||
):
|
||||
mock_db_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_db_session
|
||||
@@ -236,8 +218,8 @@ async def test_create_checkout_session_success(session_maker, mock_checkout_requ
|
||||
mode='payment',
|
||||
payment_method_types=['card'],
|
||||
saved_payment_method_options={'payment_method_save': 'enabled'},
|
||||
success_url='https://test.com/api/billing/success?session_id={CHECKOUT_SESSION_ID}',
|
||||
cancel_url='https://test.com/api/billing/cancel?session_id={CHECKOUT_SESSION_ID}',
|
||||
success_url='http://test.com/api/billing/success?session_id={CHECKOUT_SESSION_ID}',
|
||||
cancel_url='http://test.com/api/billing/cancel?session_id={CHECKOUT_SESSION_ID}',
|
||||
)
|
||||
|
||||
# Verify database session creation
|
||||
@@ -271,6 +253,7 @@ async def test_success_callback_stripe_incomplete():
|
||||
mock_billing_session = MagicMock()
|
||||
mock_billing_session.status = 'in_progress'
|
||||
mock_billing_session.user_id = 'mock_user'
|
||||
mock_billing_session.billing_session_type = BillingSessionType.DIRECT_PAYMENT.value
|
||||
|
||||
with (
|
||||
patch('server.routes.billing.session_maker') as mock_session_maker,
|
||||
@@ -298,51 +281,65 @@ async def test_success_callback_success():
|
||||
mock_billing_session = MagicMock()
|
||||
mock_billing_session.status = 'in_progress'
|
||||
mock_billing_session.user_id = 'mock_user'
|
||||
mock_billing_session.billing_session_type = BillingSessionType.DIRECT_PAYMENT.value
|
||||
|
||||
mock_lite_llm_response = Response(
|
||||
status_code=200,
|
||||
json={'user_info': {'max_budget': 100.00, 'spend': 25.50}},
|
||||
request=MagicMock(),
|
||||
)
|
||||
mock_lite_llm_update_response = Response(
|
||||
status_code=200, json={}, request=MagicMock()
|
||||
)
|
||||
|
||||
with (
|
||||
patch('server.routes.billing.session_maker') as mock_session_maker,
|
||||
patch('stripe.checkout.Session.retrieve') as mock_stripe_retrieve,
|
||||
patch(
|
||||
'storage.user_store.UserStore.get_user_by_id_async',
|
||||
new_callable=AsyncMock,
|
||||
return_value=MagicMock(current_org_id='mock_org_id'),
|
||||
),
|
||||
patch(
|
||||
'storage.lite_llm_manager.LiteLlmManager.get_user_team_info',
|
||||
return_value={
|
||||
'spend': 25.50,
|
||||
'litellm_budget_table': {'max_budget': 100.00},
|
||||
},
|
||||
),
|
||||
patch(
|
||||
'storage.lite_llm_manager.LiteLlmManager.update_team_and_users_budget'
|
||||
) as mock_update_budget,
|
||||
patch('httpx.AsyncClient') as mock_client,
|
||||
):
|
||||
mock_db_session = MagicMock()
|
||||
mock_db_session.query.return_value.filter.return_value.filter.return_value.first.return_value = mock_billing_session
|
||||
mock_user_settings = MagicMock(billing_margin=None)
|
||||
mock_db_session.query.return_value.filter.return_value.first.return_value = (
|
||||
mock_user_settings
|
||||
)
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_db_session
|
||||
|
||||
mock_stripe_retrieve.return_value = MagicMock(
|
||||
status='complete', amount_subtotal=2500, customer='mock_customer_id'
|
||||
status='complete',
|
||||
amount_subtotal=2500,
|
||||
) # $25.00 in cents
|
||||
|
||||
mock_client_instance = AsyncMock()
|
||||
mock_client_instance.__aenter__.return_value.get.return_value = (
|
||||
mock_lite_llm_response
|
||||
)
|
||||
mock_client_instance.__aenter__.return_value.post.return_value = (
|
||||
mock_lite_llm_update_response
|
||||
)
|
||||
mock_client.return_value = mock_client_instance
|
||||
|
||||
response = await success_callback('test_session_id', mock_request)
|
||||
|
||||
assert response.status_code == 302
|
||||
assert (
|
||||
response.headers['location']
|
||||
== 'https://test.com/settings/billing?checkout=success'
|
||||
== 'http://test.com/settings/billing?checkout=success'
|
||||
)
|
||||
|
||||
# Verify LiteLLM API calls
|
||||
mock_update_budget.assert_called_once_with(
|
||||
'mock_org_id',
|
||||
125.0, # 100 + (25.00 from Stripe)
|
||||
mock_client_instance.__aenter__.return_value.get.assert_called_once()
|
||||
mock_client_instance.__aenter__.return_value.post.assert_called_once_with(
|
||||
'https://llm-proxy.app.all-hands.dev/user/update',
|
||||
headers={'x-goog-api-key': None},
|
||||
json={
|
||||
'user_id': 'mock_user',
|
||||
'max_budget': 125,
|
||||
}, # 100 + (25.00 from Stripe)
|
||||
)
|
||||
|
||||
# Verify database updates
|
||||
assert mock_billing_session.status == 'completed'
|
||||
assert mock_billing_session.price == 25.0
|
||||
mock_db_session.merge.assert_called_once()
|
||||
mock_db_session.commit.assert_called_once()
|
||||
|
||||
@@ -356,28 +353,27 @@ async def test_success_callback_lite_llm_error():
|
||||
mock_billing_session = MagicMock()
|
||||
mock_billing_session.status = 'in_progress'
|
||||
mock_billing_session.user_id = 'mock_user'
|
||||
mock_billing_session.billing_session_type = BillingSessionType.DIRECT_PAYMENT.value
|
||||
|
||||
with (
|
||||
patch('server.routes.billing.session_maker') as mock_session_maker,
|
||||
patch('stripe.checkout.Session.retrieve') as mock_stripe_retrieve,
|
||||
patch(
|
||||
'storage.user_store.UserStore.get_user_by_id_async',
|
||||
new_callable=AsyncMock,
|
||||
return_value=MagicMock(current_org_id='mock_org_id'),
|
||||
),
|
||||
patch(
|
||||
'storage.lite_llm_manager.LiteLlmManager.get_user_team_info',
|
||||
side_effect=Exception('LiteLLM API Error'),
|
||||
),
|
||||
patch('httpx.AsyncClient') as mock_client,
|
||||
):
|
||||
mock_db_session = MagicMock()
|
||||
mock_db_session.query.return_value.filter.return_value.filter.return_value.first.return_value = mock_billing_session
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_db_session
|
||||
|
||||
mock_stripe_retrieve.return_value = MagicMock(
|
||||
status='complete', amount_subtotal=2500
|
||||
status='complete', amount_total=2500
|
||||
)
|
||||
|
||||
mock_client_instance = AsyncMock()
|
||||
mock_client_instance.__aenter__.return_value.get.side_effect = Exception(
|
||||
'LiteLLM API Error'
|
||||
)
|
||||
mock_client.return_value = mock_client_instance
|
||||
|
||||
with pytest.raises(Exception, match='LiteLLM API Error'):
|
||||
await success_callback('test_session_id', mock_request)
|
||||
|
||||
@@ -401,8 +397,7 @@ async def test_cancel_callback_session_not_found():
|
||||
response = await cancel_callback('test_session_id', mock_request)
|
||||
assert response.status_code == 302
|
||||
assert (
|
||||
response.headers['location']
|
||||
== 'https://test.com/settings/billing?checkout=cancel'
|
||||
response.headers['location'] == 'http://test.com/settings?checkout=cancel'
|
||||
)
|
||||
|
||||
# Verify no database updates occurred
|
||||
@@ -428,8 +423,7 @@ async def test_cancel_callback_success():
|
||||
|
||||
assert response.status_code == 302
|
||||
assert (
|
||||
response.headers['location']
|
||||
== 'https://test.com/settings/billing?checkout=cancel'
|
||||
response.headers['location'] == 'http://test.com/settings?checkout=cancel'
|
||||
)
|
||||
|
||||
# Verify database updates
|
||||
@@ -441,67 +435,314 @@ async def test_cancel_callback_success():
|
||||
@pytest.mark.asyncio
|
||||
async def test_has_payment_method_with_payment_method():
|
||||
"""Test has_payment_method returns True when user has a payment method."""
|
||||
|
||||
mock_has_payment_method = AsyncMock(return_value=True)
|
||||
with patch(
|
||||
'server.routes.billing.stripe_service.has_payment_method_by_user_id',
|
||||
mock_has_payment_method,
|
||||
with (
|
||||
patch('integrations.stripe_service.session_maker') as mock_session_maker,
|
||||
patch(
|
||||
'stripe.Customer.list_payment_methods_async',
|
||||
AsyncMock(return_value=MagicMock(data=[MagicMock()])),
|
||||
) as mock_list_payment_methods,
|
||||
):
|
||||
# Setup mock session
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_session.query.return_value.filter.return_value.first.return_value = (
|
||||
MagicMock(stripe_customer_id='cus_test123')
|
||||
)
|
||||
|
||||
result = await has_payment_method('mock_user')
|
||||
assert result is True
|
||||
mock_has_payment_method.assert_called_once_with('mock_user')
|
||||
mock_list_payment_methods.assert_called_once_with('cus_test123')
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_has_payment_method_without_payment_method():
|
||||
"""Test has_payment_method returns False when user has no payment method."""
|
||||
mock_has_payment_method = AsyncMock(return_value=False)
|
||||
with patch(
|
||||
'server.routes.billing.stripe_service.has_payment_method_by_user_id',
|
||||
mock_has_payment_method,
|
||||
with (
|
||||
patch('integrations.stripe_service.session_maker') as mock_session_maker,
|
||||
patch(
|
||||
'stripe.Customer.list_payment_methods_async',
|
||||
AsyncMock(return_value=MagicMock(data=[])),
|
||||
) as mock_list_payment_methods,
|
||||
):
|
||||
mock_has_payment_method.return_value = False
|
||||
# Setup mock session
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_session.query.return_value.filter.return_value.first.return_value = (
|
||||
MagicMock(stripe_customer_id='cus_test123')
|
||||
)
|
||||
|
||||
result = await has_payment_method('mock_user')
|
||||
assert result is False
|
||||
mock_has_payment_method.assert_called_once_with('mock_user')
|
||||
mock_list_payment_methods.assert_called_once_with('cus_test123')
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_customer_setup_session_success():
|
||||
"""Test successful creation of customer setup session."""
|
||||
mock_request = Request(
|
||||
scope={
|
||||
'type': 'http',
|
||||
'path': '/api/billing/create-customer-setup-session',
|
||||
'server': ('test.com', 80),
|
||||
'headers': [],
|
||||
}
|
||||
)
|
||||
mock_request._base_url = URL('http://test.com/')
|
||||
async def test_cancel_subscription_success():
|
||||
"""Test successful subscription cancellation."""
|
||||
from datetime import UTC, datetime
|
||||
|
||||
mock_customer_info = {'customer_id': 'mock-customer-id', 'org_id': 'mock-org-id'}
|
||||
mock_session = MagicMock()
|
||||
mock_session.url = 'https://checkout.stripe.com/test-session'
|
||||
mock_create = AsyncMock(return_value=mock_session)
|
||||
from storage.subscription_access import SubscriptionAccess
|
||||
|
||||
# Mock active subscription
|
||||
mock_subscription_access = SubscriptionAccess(
|
||||
id=1,
|
||||
status='ACTIVE',
|
||||
user_id='test_user',
|
||||
start_at=datetime.now(UTC),
|
||||
end_at=datetime.now(UTC),
|
||||
amount_paid=2000,
|
||||
stripe_invoice_payment_id='pi_test',
|
||||
stripe_subscription_id='sub_test123',
|
||||
cancelled_at=None,
|
||||
)
|
||||
|
||||
# Mock Stripe subscription response
|
||||
mock_stripe_subscription = MagicMock()
|
||||
mock_stripe_subscription.cancel_at_period_end = True
|
||||
|
||||
with (
|
||||
patch('server.routes.billing.session_maker') as mock_session_maker,
|
||||
patch(
|
||||
'integrations.stripe_service.find_or_create_customer_by_user_id',
|
||||
AsyncMock(return_value=mock_customer_info),
|
||||
),
|
||||
patch('stripe.checkout.Session.create_async', mock_create),
|
||||
patch('server.routes.billing.validate_billing_enabled'),
|
||||
'stripe.Subscription.modify_async',
|
||||
AsyncMock(return_value=mock_stripe_subscription),
|
||||
) as mock_stripe_modify,
|
||||
):
|
||||
result = await create_customer_setup_session(mock_request, 'mock_user')
|
||||
# Setup mock session
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_session.query.return_value.filter.return_value.filter.return_value.filter.return_value.filter.return_value.filter.return_value.first.return_value = mock_subscription_access
|
||||
|
||||
assert isinstance(result, billing.CreateBillingSessionResponse)
|
||||
# Call the function
|
||||
result = await cancel_subscription('test_user')
|
||||
|
||||
# Verify Stripe API was called
|
||||
mock_stripe_modify.assert_called_once_with(
|
||||
'sub_test123', cancel_at_period_end=True
|
||||
)
|
||||
|
||||
# Verify database was updated
|
||||
assert mock_subscription_access.cancelled_at is not None
|
||||
mock_session.merge.assert_called_once_with(mock_subscription_access)
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
# Verify response
|
||||
assert result.status_code == 200
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_subscription_no_active_subscription():
|
||||
"""Test cancellation when no active subscription exists."""
|
||||
with (
|
||||
patch('server.routes.billing.session_maker') as mock_session_maker,
|
||||
):
|
||||
# Setup mock session with no subscription found
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_session.query.return_value.filter.return_value.filter.return_value.filter.return_value.filter.return_value.filter.return_value.first.return_value = None
|
||||
|
||||
# Call the function and expect HTTPException
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await cancel_subscription('test_user')
|
||||
|
||||
assert exc_info.value.status_code == 404
|
||||
assert 'No active subscription found' in str(exc_info.value.detail)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_subscription_missing_stripe_id():
|
||||
"""Test cancellation when subscription has no Stripe ID."""
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from storage.subscription_access import SubscriptionAccess
|
||||
|
||||
# Mock subscription without Stripe ID
|
||||
mock_subscription_access = SubscriptionAccess(
|
||||
id=1,
|
||||
status='ACTIVE',
|
||||
user_id='test_user',
|
||||
start_at=datetime.now(UTC),
|
||||
end_at=datetime.now(UTC),
|
||||
amount_paid=2000,
|
||||
stripe_invoice_payment_id='pi_test',
|
||||
stripe_subscription_id=None, # Missing Stripe ID
|
||||
cancelled_at=None,
|
||||
)
|
||||
|
||||
with (
|
||||
patch('server.routes.billing.session_maker') as mock_session_maker,
|
||||
):
|
||||
# Setup mock session
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_session.query.return_value.filter.return_value.filter.return_value.filter.return_value.filter.return_value.filter.return_value.first.return_value = mock_subscription_access
|
||||
|
||||
# Call the function and expect HTTPException
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await cancel_subscription('test_user')
|
||||
|
||||
assert exc_info.value.status_code == 400
|
||||
assert 'missing Stripe subscription ID' in str(exc_info.value.detail)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_subscription_stripe_error():
|
||||
"""Test cancellation when Stripe API fails."""
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from storage.subscription_access import SubscriptionAccess
|
||||
|
||||
# Mock active subscription
|
||||
mock_subscription_access = SubscriptionAccess(
|
||||
id=1,
|
||||
status='ACTIVE',
|
||||
user_id='test_user',
|
||||
start_at=datetime.now(UTC),
|
||||
end_at=datetime.now(UTC),
|
||||
amount_paid=2000,
|
||||
stripe_invoice_payment_id='pi_test',
|
||||
stripe_subscription_id='sub_test123',
|
||||
cancelled_at=None,
|
||||
)
|
||||
|
||||
with (
|
||||
patch('server.routes.billing.session_maker') as mock_session_maker,
|
||||
patch(
|
||||
'stripe.Subscription.modify_async',
|
||||
AsyncMock(side_effect=stripe.StripeError('API Error')),
|
||||
),
|
||||
):
|
||||
# Setup mock session
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_session.query.return_value.filter.return_value.filter.return_value.filter.return_value.filter.return_value.filter.return_value.first.return_value = mock_subscription_access
|
||||
|
||||
# Call the function and expect HTTPException
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await cancel_subscription('test_user')
|
||||
|
||||
assert exc_info.value.status_code == 500
|
||||
assert 'Failed to cancel subscription' in str(exc_info.value.detail)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_subscription_checkout_session_duplicate_prevention(
|
||||
mock_subscription_request,
|
||||
):
|
||||
"""Test that creating a subscription when user already has active subscription raises error."""
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from storage.subscription_access import SubscriptionAccess
|
||||
|
||||
# Mock active subscription
|
||||
mock_subscription_access = SubscriptionAccess(
|
||||
id=1,
|
||||
status='ACTIVE',
|
||||
user_id='test_user',
|
||||
start_at=datetime.now(UTC),
|
||||
end_at=datetime.now(UTC),
|
||||
amount_paid=2000,
|
||||
stripe_invoice_payment_id='pi_test',
|
||||
stripe_subscription_id='sub_test123',
|
||||
cancelled_at=None,
|
||||
)
|
||||
|
||||
with (
|
||||
patch('server.routes.billing.session_maker') as mock_session_maker,
|
||||
patch('server.routes.billing.validate_saas_environment'),
|
||||
):
|
||||
# Setup mock session to return existing active subscription
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_session.query.return_value.filter.return_value.filter.return_value.filter.return_value.filter.return_value.filter.return_value.first.return_value = mock_subscription_access
|
||||
|
||||
# Call the function and expect HTTPException
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await create_subscription_checkout_session(
|
||||
mock_subscription_request, user_id='test_user'
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 400
|
||||
assert (
|
||||
'user already has an active subscription'
|
||||
in str(exc_info.value.detail).lower()
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_subscription_checkout_session_allows_after_cancellation(
|
||||
mock_subscription_request,
|
||||
):
|
||||
"""Test that creating a subscription is allowed when previous subscription was cancelled."""
|
||||
|
||||
mock_session_obj = MagicMock()
|
||||
mock_session_obj.url = 'https://checkout.stripe.com/test-session'
|
||||
mock_session_obj.id = 'test_session_id'
|
||||
|
||||
with (
|
||||
patch('server.routes.billing.session_maker') as mock_session_maker,
|
||||
patch(
|
||||
'integrations.stripe_service.find_or_create_customer',
|
||||
AsyncMock(return_value='cus_test123'),
|
||||
),
|
||||
patch(
|
||||
'stripe.checkout.Session.create_async',
|
||||
AsyncMock(return_value=mock_session_obj),
|
||||
),
|
||||
patch(
|
||||
'server.routes.billing.SUBSCRIPTION_PRICE_DATA',
|
||||
{'MONTHLY_SUBSCRIPTION': {'unit_amount': 2000}},
|
||||
),
|
||||
patch('server.routes.billing.validate_saas_environment'),
|
||||
):
|
||||
# Setup mock session - the query should return None because cancelled subscriptions are filtered out
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_session.query.return_value.filter.return_value.filter.return_value.filter.return_value.filter.return_value.filter.return_value.first.return_value = None
|
||||
|
||||
# Should succeed
|
||||
result = await create_subscription_checkout_session(
|
||||
mock_subscription_request, user_id='test_user'
|
||||
)
|
||||
|
||||
assert isinstance(result, CreateBillingSessionResponse)
|
||||
assert result.redirect_url == 'https://checkout.stripe.com/test-session'
|
||||
|
||||
# Verify Stripe session creation parameters
|
||||
mock_create.assert_called_once_with(
|
||||
customer='mock-customer-id',
|
||||
mode='setup',
|
||||
payment_method_types=['card'],
|
||||
success_url='https://test.com/?free_credits=success',
|
||||
cancel_url='https://test.com/',
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_subscription_checkout_session_success_no_existing(
|
||||
mock_subscription_request,
|
||||
):
|
||||
"""Test successful subscription creation when no existing subscription."""
|
||||
|
||||
mock_session_obj = MagicMock()
|
||||
mock_session_obj.url = 'https://checkout.stripe.com/test-session'
|
||||
mock_session_obj.id = 'test_session_id'
|
||||
|
||||
with (
|
||||
patch('server.routes.billing.session_maker') as mock_session_maker,
|
||||
patch(
|
||||
'integrations.stripe_service.find_or_create_customer',
|
||||
AsyncMock(return_value='cus_test123'),
|
||||
),
|
||||
patch(
|
||||
'stripe.checkout.Session.create_async',
|
||||
AsyncMock(return_value=mock_session_obj),
|
||||
),
|
||||
patch(
|
||||
'server.routes.billing.SUBSCRIPTION_PRICE_DATA',
|
||||
{'MONTHLY_SUBSCRIPTION': {'unit_amount': 2000}},
|
||||
),
|
||||
patch('server.routes.billing.validate_saas_environment'),
|
||||
):
|
||||
# Setup mock session to return no existing subscription
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_session.query.return_value.filter.return_value.filter.return_value.filter.return_value.filter.return_value.filter.return_value.first.return_value = None
|
||||
|
||||
# Should succeed
|
||||
result = await create_subscription_checkout_session(
|
||||
mock_subscription_request, user_id='test_user'
|
||||
)
|
||||
|
||||
assert isinstance(result, CreateBillingSessionResponse)
|
||||
assert result.redirect_url == 'https://checkout.stripe.com/test-session'
|
||||
|
||||
@@ -3,7 +3,6 @@ Tests for ConversationCallbackProcessor and ConversationCallback models.
|
||||
"""
|
||||
|
||||
import json
|
||||
from uuid import UUID
|
||||
|
||||
import pytest
|
||||
from storage.conversation_callback import (
|
||||
@@ -12,9 +11,6 @@ from storage.conversation_callback import (
|
||||
ConversationCallbackProcessor,
|
||||
)
|
||||
from storage.stored_conversation_metadata import StoredConversationMetadata
|
||||
from storage.stored_conversation_metadata_saas import (
|
||||
StoredConversationMetadataSaas,
|
||||
)
|
||||
|
||||
from openhands.events.observation.agent import AgentStateChangedObservation
|
||||
|
||||
@@ -84,22 +80,15 @@ class TestConversationCallback:
|
||||
"""Create a test conversation metadata record."""
|
||||
with session_maker() as session:
|
||||
metadata = StoredConversationMetadata(
|
||||
conversation_id='test_conversation_123'
|
||||
)
|
||||
metadata_saas = StoredConversationMetadataSaas(
|
||||
conversation_id='test_conversation_123',
|
||||
user_id=UUID('5594c7b6-f959-4b81-92e9-b09c206f5081'),
|
||||
org_id=UUID('5594c7b6-f959-4b81-92e9-b09c206f5081'),
|
||||
conversation_id='test_conversation_123', user_id='test_user_456'
|
||||
)
|
||||
session.add(metadata)
|
||||
session.add(metadata_saas)
|
||||
session.commit()
|
||||
session.refresh(metadata)
|
||||
yield metadata
|
||||
|
||||
# Cleanup
|
||||
session.delete(metadata)
|
||||
session.delete(metadata_saas)
|
||||
session.commit()
|
||||
|
||||
def test_callback_creation(self, conversation_metadata, session_maker):
|
||||
|
||||
@@ -1,272 +0,0 @@
|
||||
"""
|
||||
Unit tests for email validation dependency (get_admin_user_id).
|
||||
|
||||
Tests the FastAPI dependency that validates @openhands.dev email domain.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException, Request
|
||||
from server.email_validation import get_admin_user_id
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_request():
|
||||
"""Create a mock FastAPI request."""
|
||||
return MagicMock(spec=Request)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_user_auth():
|
||||
"""Create a mock user auth object."""
|
||||
mock_auth = AsyncMock()
|
||||
mock_auth.get_user_email = AsyncMock()
|
||||
return mock_auth
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_openhands_user_id_success(mock_request, mock_user_auth):
|
||||
"""
|
||||
GIVEN: Valid user ID and @openhands.dev email
|
||||
WHEN: get_admin_user_id is called
|
||||
THEN: User ID is returned successfully
|
||||
"""
|
||||
# Arrange
|
||||
user_id = 'test-user-123'
|
||||
mock_user_auth.get_user_email.return_value = 'test@openhands.dev'
|
||||
|
||||
with patch('server.email_validation.get_user_auth', return_value=mock_user_auth):
|
||||
# Act
|
||||
result = await get_admin_user_id(mock_request, user_id)
|
||||
|
||||
# Assert
|
||||
assert result == user_id
|
||||
mock_user_auth.get_user_email.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_openhands_user_id_no_user_id(mock_request):
|
||||
"""
|
||||
GIVEN: No user ID provided (None)
|
||||
WHEN: get_admin_user_id is called
|
||||
THEN: 401 Unauthorized is raised
|
||||
"""
|
||||
# Arrange
|
||||
user_id = None
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await get_admin_user_id(mock_request, user_id)
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
assert 'not authenticated' in exc_info.value.detail.lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_openhands_user_id_no_email(mock_request, mock_user_auth):
|
||||
"""
|
||||
GIVEN: User ID provided but email is None
|
||||
WHEN: get_admin_user_id is called
|
||||
THEN: 401 Unauthorized is raised
|
||||
"""
|
||||
# Arrange
|
||||
user_id = 'test-user-123'
|
||||
mock_user_auth.get_user_email.return_value = None
|
||||
|
||||
with patch('server.email_validation.get_user_auth', return_value=mock_user_auth):
|
||||
# Act & Assert
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await get_admin_user_id(mock_request, user_id)
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
assert 'email not available' in exc_info.value.detail.lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_openhands_user_id_invalid_domain(mock_request, mock_user_auth):
|
||||
"""
|
||||
GIVEN: User ID and email with non-@openhands.dev domain
|
||||
WHEN: get_admin_user_id is called
|
||||
THEN: 403 Forbidden is raised
|
||||
"""
|
||||
# Arrange
|
||||
user_id = 'test-user-123'
|
||||
mock_user_auth.get_user_email.return_value = 'test@external.com'
|
||||
|
||||
with patch('server.email_validation.get_user_auth', return_value=mock_user_auth):
|
||||
# Act & Assert
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await get_admin_user_id(mock_request, user_id)
|
||||
|
||||
assert exc_info.value.status_code == 403
|
||||
assert 'openhands.dev' in exc_info.value.detail.lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_openhands_user_id_empty_string_user_id(mock_request):
|
||||
"""
|
||||
GIVEN: Empty string user ID
|
||||
WHEN: get_admin_user_id is called
|
||||
THEN: 401 Unauthorized is raised
|
||||
"""
|
||||
# Arrange
|
||||
user_id = ''
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await get_admin_user_id(mock_request, user_id)
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
assert 'not authenticated' in exc_info.value.detail.lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_openhands_user_id_case_sensitivity(mock_request, mock_user_auth):
|
||||
"""
|
||||
GIVEN: Email with uppercase @OPENHANDS.DEV domain
|
||||
WHEN: get_admin_user_id is called
|
||||
THEN: 403 Forbidden is raised (case-sensitive check)
|
||||
"""
|
||||
# Arrange
|
||||
user_id = 'test-user-123'
|
||||
mock_user_auth.get_user_email.return_value = 'test@OPENHANDS.DEV'
|
||||
|
||||
with patch('server.email_validation.get_user_auth', return_value=mock_user_auth):
|
||||
# Act & Assert
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await get_admin_user_id(mock_request, user_id)
|
||||
|
||||
assert exc_info.value.status_code == 403
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_openhands_user_id_subdomain_not_allowed(
|
||||
mock_request, mock_user_auth
|
||||
):
|
||||
"""
|
||||
GIVEN: Email with subdomain like @test.openhands.dev
|
||||
WHEN: get_admin_user_id is called
|
||||
THEN: 403 Forbidden is raised
|
||||
"""
|
||||
# Arrange
|
||||
user_id = 'test-user-123'
|
||||
mock_user_auth.get_user_email.return_value = 'test@test.openhands.dev'
|
||||
|
||||
with patch('server.email_validation.get_user_auth', return_value=mock_user_auth):
|
||||
# Act & Assert
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await get_admin_user_id(mock_request, user_id)
|
||||
|
||||
assert exc_info.value.status_code == 403
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_openhands_user_id_similar_domain_not_allowed(
|
||||
mock_request, mock_user_auth
|
||||
):
|
||||
"""
|
||||
GIVEN: Email with similar but different domain like @openhands.dev.fake.com
|
||||
WHEN: get_admin_user_id is called
|
||||
THEN: 403 Forbidden is raised
|
||||
"""
|
||||
# Arrange
|
||||
user_id = 'test-user-123'
|
||||
mock_user_auth.get_user_email.return_value = 'test@openhands.dev.fake.com'
|
||||
|
||||
with patch('server.email_validation.get_user_auth', return_value=mock_user_auth):
|
||||
# Act & Assert
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await get_admin_user_id(mock_request, user_id)
|
||||
|
||||
assert exc_info.value.status_code == 403
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_openhands_user_id_logs_warning_on_invalid_domain(
|
||||
mock_request, mock_user_auth
|
||||
):
|
||||
"""
|
||||
GIVEN: User with invalid email domain
|
||||
WHEN: get_admin_user_id is called
|
||||
THEN: Warning is logged with user_id and email_domain
|
||||
"""
|
||||
# Arrange
|
||||
user_id = 'test-user-123'
|
||||
invalid_email = 'test@external.com'
|
||||
mock_user_auth.get_user_email.return_value = invalid_email
|
||||
|
||||
with (
|
||||
patch('server.email_validation.get_user_auth', return_value=mock_user_auth),
|
||||
patch('server.email_validation.logger') as mock_logger,
|
||||
):
|
||||
# Act & Assert
|
||||
with pytest.raises(HTTPException):
|
||||
await get_admin_user_id(mock_request, user_id)
|
||||
|
||||
# Verify warning was logged
|
||||
mock_logger.warning.assert_called_once()
|
||||
call_args = mock_logger.warning.call_args
|
||||
assert 'Access denied' in call_args[0][0]
|
||||
assert call_args[1]['extra']['user_id'] == user_id
|
||||
assert call_args[1]['extra']['email_domain'] == 'external.com'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_openhands_user_id_with_plus_addressing(mock_request, mock_user_auth):
|
||||
"""
|
||||
GIVEN: Email with plus addressing (test+tag@openhands.dev)
|
||||
WHEN: get_admin_user_id is called
|
||||
THEN: User ID is returned successfully
|
||||
"""
|
||||
# Arrange
|
||||
user_id = 'test-user-123'
|
||||
mock_user_auth.get_user_email.return_value = 'test+tag@openhands.dev'
|
||||
|
||||
with patch('server.email_validation.get_user_auth', return_value=mock_user_auth):
|
||||
# Act
|
||||
result = await get_admin_user_id(mock_request, user_id)
|
||||
|
||||
# Assert
|
||||
assert result == user_id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_openhands_user_id_with_dots_in_local_part(
|
||||
mock_request, mock_user_auth
|
||||
):
|
||||
"""
|
||||
GIVEN: Email with dots in local part (first.last@openhands.dev)
|
||||
WHEN: get_admin_user_id is called
|
||||
THEN: User ID is returned successfully
|
||||
"""
|
||||
# Arrange
|
||||
user_id = 'test-user-123'
|
||||
mock_user_auth.get_user_email.return_value = 'first.last@openhands.dev'
|
||||
|
||||
with patch('server.email_validation.get_user_auth', return_value=mock_user_auth):
|
||||
# Act
|
||||
result = await get_admin_user_id(mock_request, user_id)
|
||||
|
||||
# Assert
|
||||
assert result == user_id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_openhands_user_id_empty_email(mock_request, mock_user_auth):
|
||||
"""
|
||||
GIVEN: Empty string email
|
||||
WHEN: get_admin_user_id is called
|
||||
THEN: 401 Unauthorized is raised
|
||||
"""
|
||||
# Arrange
|
||||
user_id = 'test-user-123'
|
||||
mock_user_auth.get_user_email.return_value = ''
|
||||
|
||||
with patch('server.email_validation.get_user_auth', return_value=mock_user_auth):
|
||||
# Act & Assert
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await get_admin_user_id(mock_request, user_id)
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
assert 'email not available' in exc_info.value.detail.lower()
|
||||
@@ -1,100 +1,114 @@
|
||||
"""Unit tests for get_user_v1_enabled_setting and is_v1_enabled_for_github_resolver functions."""
|
||||
"""Unit tests for get_user_v1_enabled_setting function."""
|
||||
|
||||
import os
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from integrations.github.github_view import (
|
||||
get_user_v1_enabled_setting,
|
||||
is_v1_enabled_for_github_resolver,
|
||||
)
|
||||
from integrations.utils import get_user_v1_enabled_setting
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_org():
|
||||
"""Create a mock org object."""
|
||||
org = MagicMock()
|
||||
org.v1_enabled = True # Default to True, can be overridden in tests
|
||||
return org
|
||||
def mock_user_settings():
|
||||
"""Create a mock user settings object."""
|
||||
settings = MagicMock()
|
||||
settings.v1_enabled = True # Default to True, can be overridden in tests
|
||||
return settings
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_dependencies(mock_org):
|
||||
def mock_settings_store():
|
||||
"""Create a mock settings store."""
|
||||
store = MagicMock()
|
||||
return store
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_config():
|
||||
"""Create a mock config object."""
|
||||
return MagicMock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session_maker():
|
||||
"""Create a mock session maker."""
|
||||
return MagicMock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_dependencies(
|
||||
mock_settings_store, mock_config, mock_session_maker, mock_user_settings
|
||||
):
|
||||
"""Fixture that patches all the common dependencies."""
|
||||
# Patch at the source module since SaasSettingsStore is imported inside the function
|
||||
with patch(
|
||||
'storage.saas_settings_store.SaasSettingsStore',
|
||||
return_value=mock_settings_store,
|
||||
) as mock_store_class, patch(
|
||||
'integrations.utils.get_config', return_value=mock_config
|
||||
) as mock_get_config, patch(
|
||||
'integrations.utils.session_maker', mock_session_maker
|
||||
), patch(
|
||||
'integrations.utils.call_sync_from_async',
|
||||
return_value=mock_org,
|
||||
) as mock_call_sync, patch('integrations.utils.OrgStore') as mock_org_store:
|
||||
return_value=mock_user_settings,
|
||||
) as mock_call_sync:
|
||||
yield {
|
||||
'store_class': mock_store_class,
|
||||
'get_config': mock_get_config,
|
||||
'session_maker': mock_session_maker,
|
||||
'call_sync': mock_call_sync,
|
||||
'org_store': mock_org_store,
|
||||
'org': mock_org,
|
||||
'settings_store': mock_settings_store,
|
||||
'user_settings': mock_user_settings,
|
||||
}
|
||||
|
||||
|
||||
class TestIsV1EnabledForGithubResolver:
|
||||
"""Test cases for is_v1_enabled_for_github_resolver function.
|
||||
|
||||
This function returns True only if BOTH the environment variable
|
||||
ENABLE_V1_GITHUB_RESOLVER is true AND the user's org has v1_enabled=True.
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
'env_var_enabled,user_setting_enabled,expected_result',
|
||||
[
|
||||
(False, True, False), # Env var disabled, user enabled -> False
|
||||
(True, False, False), # Env var enabled, user disabled -> False
|
||||
(True, True, True), # Both enabled -> True
|
||||
(False, False, False), # Both disabled -> False
|
||||
],
|
||||
)
|
||||
async def test_v1_enabled_combinations(
|
||||
self, mock_dependencies, env_var_enabled, user_setting_enabled, expected_result
|
||||
):
|
||||
"""Test all combinations of environment variable and user setting values."""
|
||||
mock_dependencies['org'].v1_enabled = user_setting_enabled
|
||||
|
||||
with patch(
|
||||
'integrations.github.github_view.ENABLE_V1_GITHUB_RESOLVER', env_var_enabled
|
||||
):
|
||||
result = await is_v1_enabled_for_github_resolver('test_user_id')
|
||||
assert result is expected_result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
'env_var_value,env_var_bool,expected_result',
|
||||
[
|
||||
('false', False, False), # Environment variable 'false' -> False
|
||||
('true', True, True), # Environment variable 'true' -> True
|
||||
],
|
||||
)
|
||||
async def test_environment_variable_integration(
|
||||
self, mock_dependencies, env_var_value, env_var_bool, expected_result
|
||||
):
|
||||
"""Test that the function properly reads the ENABLE_V1_GITHUB_RESOLVER environment variable."""
|
||||
mock_dependencies['org'].v1_enabled = True
|
||||
|
||||
with patch.dict(
|
||||
os.environ, {'ENABLE_V1_GITHUB_RESOLVER': env_var_value}
|
||||
), patch('integrations.utils.os.getenv', return_value=env_var_value), patch(
|
||||
'integrations.github.github_view.ENABLE_V1_GITHUB_RESOLVER', env_var_bool
|
||||
):
|
||||
result = await is_v1_enabled_for_github_resolver('test_user_id')
|
||||
assert result is expected_result
|
||||
|
||||
|
||||
class TestGetUserV1EnabledSetting:
|
||||
"""Test cases for get_user_v1_enabled_setting function.
|
||||
"""Test cases for get_user_v1_enabled_setting function."""
|
||||
|
||||
This function only returns the user's org v1_enabled setting.
|
||||
It does NOT check the ENABLE_V1_GITHUB_RESOLVER environment variable.
|
||||
"""
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
'user_setting_enabled,expected_result',
|
||||
[
|
||||
(True, True), # User enabled -> True
|
||||
(False, False), # User disabled -> False
|
||||
],
|
||||
)
|
||||
async def test_v1_enabled_user_setting(
|
||||
self, mock_dependencies, user_setting_enabled, expected_result
|
||||
):
|
||||
"""Test that the function returns the user's v1_enabled setting."""
|
||||
mock_dependencies['user_settings'].v1_enabled = user_setting_enabled
|
||||
|
||||
result = await get_user_v1_enabled_setting('test_user_id')
|
||||
assert result is expected_result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_false_when_no_user_id(self):
|
||||
"""Test that the function returns False when no user_id is provided."""
|
||||
result = await get_user_v1_enabled_setting(None)
|
||||
assert result is False
|
||||
|
||||
result = await get_user_v1_enabled_setting('')
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_false_when_settings_is_none(self, mock_dependencies):
|
||||
"""Test that the function returns False when settings is None."""
|
||||
mock_dependencies['call_sync'].return_value = None
|
||||
|
||||
result = await get_user_v1_enabled_setting('test_user_id')
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_false_when_v1_enabled_is_none(self, mock_dependencies):
|
||||
"""Test that the function returns False when v1_enabled is None."""
|
||||
mock_dependencies['user_settings'].v1_enabled = None
|
||||
|
||||
result = await get_user_v1_enabled_setting('test_user_id')
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_function_calls_correct_methods(self, mock_dependencies):
|
||||
"""Test that the function calls the correct methods with correct parameters."""
|
||||
mock_dependencies['org'].v1_enabled = True
|
||||
mock_dependencies['user_settings'].v1_enabled = True
|
||||
|
||||
result = await get_user_v1_enabled_setting('test_user_123')
|
||||
|
||||
@@ -102,38 +116,13 @@ class TestGetUserV1EnabledSetting:
|
||||
assert result is True
|
||||
|
||||
# Verify correct methods were called with correct parameters
|
||||
mock_dependencies['get_config'].assert_called_once()
|
||||
mock_dependencies['store_class'].assert_called_once_with(
|
||||
user_id='test_user_123',
|
||||
session_maker=mock_dependencies['session_maker'],
|
||||
config=mock_dependencies['get_config'].return_value,
|
||||
)
|
||||
mock_dependencies['call_sync'].assert_called_once_with(
|
||||
mock_dependencies['org_store'].get_current_org_from_keycloak_user_id,
|
||||
mock_dependencies['settings_store'].get_user_settings_by_keycloak_id,
|
||||
'test_user_123',
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_user_setting_true(self, mock_dependencies):
|
||||
"""Test that the function returns True when org.v1_enabled is True."""
|
||||
mock_dependencies['org'].v1_enabled = True
|
||||
result = await get_user_v1_enabled_setting('test_user_123')
|
||||
assert result is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_user_setting_false(self, mock_dependencies):
|
||||
"""Test that the function returns False when org.v1_enabled is False."""
|
||||
mock_dependencies['org'].v1_enabled = False
|
||||
result = await get_user_v1_enabled_setting('test_user_123')
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_org_returns_false(self, mock_dependencies):
|
||||
"""Test that the function returns False when no org is found."""
|
||||
# Mock call_sync_from_async to return None (no org found)
|
||||
mock_dependencies['call_sync'].return_value = None
|
||||
|
||||
result = await get_user_v1_enabled_setting('test_user_123')
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_org_v1_enabled_none_returns_false(self, mock_dependencies):
|
||||
"""Test that the function returns False when org.v1_enabled is None."""
|
||||
mock_dependencies['org'].v1_enabled = None
|
||||
|
||||
result = await get_user_v1_enabled_setting('test_user_123')
|
||||
assert result is False
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,151 +0,0 @@
|
||||
"""
|
||||
Test that the models are correctly defined.
|
||||
"""
|
||||
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from storage.base import Base
|
||||
from storage.org import Org
|
||||
from storage.org_member import OrgMember
|
||||
from storage.user import User
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def engine():
|
||||
engine = create_engine('sqlite:///:memory:')
|
||||
Base.metadata.create_all(engine)
|
||||
return engine
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def session_maker(engine):
|
||||
return sessionmaker(bind=engine)
|
||||
|
||||
|
||||
def test_user_model(session_maker):
|
||||
"""Test that the User model works correctly."""
|
||||
with session_maker() as session:
|
||||
# Create a test org
|
||||
org = Org(name='test_org')
|
||||
session.add(org)
|
||||
session.flush()
|
||||
|
||||
# Create a test user
|
||||
test_user_id = uuid4()
|
||||
user = User(id=test_user_id, current_org_id=org.id, language='en')
|
||||
session.add(user)
|
||||
session.flush()
|
||||
|
||||
# Create org_member relationship
|
||||
org_member = OrgMember(
|
||||
org_id=org.id,
|
||||
user_id=user.id,
|
||||
role_id=1,
|
||||
llm_api_key='test-api-key',
|
||||
status='active',
|
||||
)
|
||||
session.add(org_member)
|
||||
session.commit()
|
||||
|
||||
# Query the user
|
||||
queried_user = session.query(User).filter(User.id == test_user_id).first()
|
||||
assert queried_user is not None
|
||||
assert queried_user.language == 'en'
|
||||
|
||||
# Query the org
|
||||
queried_org = session.query(Org).filter(Org.id == org.id).first()
|
||||
assert queried_org is not None
|
||||
assert queried_org.name == 'test_org'
|
||||
|
||||
# Query the org_member relationship
|
||||
queried_org_member = (
|
||||
session.query(OrgMember)
|
||||
.filter(OrgMember.org_id == org.id, OrgMember.user_id == user.id)
|
||||
.first()
|
||||
)
|
||||
assert queried_org_member is not None
|
||||
assert queried_org_member.llm_api_key.get_secret_value() == 'test-api-key'
|
||||
|
||||
|
||||
def test_user_model_git_user_fields(session_maker):
|
||||
"""Test that git_user_name and git_user_email columns exist and work correctly."""
|
||||
with session_maker() as session:
|
||||
# Arrange
|
||||
org = Org(name='test_org_git')
|
||||
session.add(org)
|
||||
session.flush()
|
||||
|
||||
test_user_id = uuid4()
|
||||
|
||||
# Act
|
||||
user = User(
|
||||
id=test_user_id,
|
||||
current_org_id=org.id,
|
||||
git_user_name='Test Git Author',
|
||||
git_user_email='git@example.com',
|
||||
)
|
||||
session.add(user)
|
||||
session.commit()
|
||||
|
||||
# Assert
|
||||
queried_user = session.query(User).filter(User.id == test_user_id).first()
|
||||
assert queried_user.git_user_name == 'Test Git Author'
|
||||
assert queried_user.git_user_email == 'git@example.com'
|
||||
|
||||
|
||||
def test_user_model_git_user_fields_nullable(session_maker):
|
||||
"""Test that git_user_name and git_user_email can be null."""
|
||||
with session_maker() as session:
|
||||
# Arrange
|
||||
org = Org(name='test_org_nullable')
|
||||
session.add(org)
|
||||
session.flush()
|
||||
|
||||
test_user_id = uuid4()
|
||||
|
||||
# Act - create user without git fields
|
||||
user = User(
|
||||
id=test_user_id,
|
||||
current_org_id=org.id,
|
||||
)
|
||||
session.add(user)
|
||||
session.commit()
|
||||
|
||||
# Assert
|
||||
queried_user = session.query(User).filter(User.id == test_user_id).first()
|
||||
assert queried_user.git_user_name is None
|
||||
assert queried_user.git_user_email is None
|
||||
|
||||
|
||||
def test_user_model_git_user_fields_in_table_columns():
|
||||
"""Test that git_user_name and git_user_email are in User table columns."""
|
||||
# Arrange & Act
|
||||
column_names = [c.name for c in User.__table__.columns]
|
||||
|
||||
# Assert
|
||||
assert 'git_user_name' in column_names
|
||||
assert 'git_user_email' in column_names
|
||||
|
||||
|
||||
def test_user_model_git_user_fields_hasattr(session_maker):
|
||||
"""Test that hasattr returns True for git_user_* fields on User model.
|
||||
|
||||
This verifies the fix for SaasSettingsStore.store() which uses hasattr
|
||||
to determine if a field should be persisted to a model.
|
||||
"""
|
||||
with session_maker() as session:
|
||||
# Arrange
|
||||
org = Org(name='test_org_hasattr')
|
||||
session.add(org)
|
||||
session.flush()
|
||||
|
||||
user = User(id=uuid4(), current_org_id=org.id)
|
||||
session.add(user)
|
||||
session.flush()
|
||||
|
||||
# Assert - hasattr must return True for store() to work
|
||||
assert hasattr(user, 'git_user_name')
|
||||
assert hasattr(user, 'git_user_email')
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user